summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-03-30 18:15:02 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2010-03-30 18:15:02 -0400
commit00738b252c280111dafc8a034eade1507c1dddd8 (patch)
tree84250759b0e653e7b72278b649ccc00ce3d074a7 /lib/sqlalchemy
parent62d6bf4cc33171ac21cd9b4d52701d6af39cfb42 (diff)
parent4cbe117eb2feb7cff28c66d849d3a0613448fdce (diff)
downloadsqlalchemy-00738b252c280111dafc8a034eade1507c1dddd8.tar.gz
merge trunk. Re-instating topological._find_cycles for the moment
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/__init__.py2
-rw-r--r--lib/sqlalchemy/connectors/mxodbc.py45
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py43
-rw-r--r--lib/sqlalchemy/dialects/access/base.py5
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py7
-rw-r--r--lib/sqlalchemy/dialects/informix/base.py3
-rw-r--r--lib/sqlalchemy/dialects/maxdb/base.py5
-rw-r--r--lib/sqlalchemy/dialects/mssql/adodbapi.py4
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py127
-rw-r--r--lib/sqlalchemy/dialects/mssql/information_schema.py4
-rw-r--r--lib/sqlalchemy/dialects/mssql/mxodbc.py36
-rw-r--r--lib/sqlalchemy/dialects/mssql/pymssql.py85
-rw-r--r--lib/sqlalchemy/dialects/mssql/pyodbc.py126
-rw-r--r--lib/sqlalchemy/dialects/mysql/__init__.py4
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py111
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqlconnector.py11
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqldb.py28
-rw-r--r--lib/sqlalchemy/dialects/mysql/oursql.py15
-rw-r--r--lib/sqlalchemy/dialects/mysql/pyodbc.py19
-rw-r--r--lib/sqlalchemy/dialects/mysql/zxjdbc.py7
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py81
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py59
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py105
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py32
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py19
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py9
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py17
-rw-r--r--lib/sqlalchemy/dialects/sybase/pyodbc.py32
-rw-r--r--lib/sqlalchemy/engine/__init__.py51
-rw-r--r--lib/sqlalchemy/engine/base.py5
-rw-r--r--lib/sqlalchemy/engine/default.py5
-rw-r--r--lib/sqlalchemy/ext/compiler.py2
-rw-r--r--lib/sqlalchemy/ext/declarative.py44
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py125
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py153
-rw-r--r--lib/sqlalchemy/orm/__init__.py221
-rw-r--r--lib/sqlalchemy/orm/interfaces.py168
-rw-r--r--lib/sqlalchemy/orm/properties.py192
-rw-r--r--lib/sqlalchemy/orm/query.py126
-rw-r--r--lib/sqlalchemy/orm/session.py27
-rw-r--r--lib/sqlalchemy/orm/shard.py112
-rw-r--r--lib/sqlalchemy/orm/strategies.py574
-rw-r--r--lib/sqlalchemy/pool.py39
-rw-r--r--lib/sqlalchemy/sql/compiler.py78
-rw-r--r--lib/sqlalchemy/sql/expression.py30
-rw-r--r--lib/sqlalchemy/sql/util.py10
-rw-r--r--lib/sqlalchemy/sql/visitors.py21
-rw-r--r--lib/sqlalchemy/test/requires.py13
-rw-r--r--lib/sqlalchemy/topological.py33
49 files changed, 2064 insertions, 1006 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index 13e843801..376b13e64 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -114,6 +114,6 @@ from sqlalchemy.engine import create_engine, engine_from_config
__all__ = sorted(name for name, obj in locals().items()
if not (name.startswith('_') or inspect.ismodule(obj)))
-__version__ = '0.6beta2'
+__version__ = '0.6beta3'
del inspect, sys
diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py
index 484c11d49..816474d43 100644
--- a/lib/sqlalchemy/connectors/mxodbc.py
+++ b/lib/sqlalchemy/connectors/mxodbc.py
@@ -9,6 +9,7 @@ and 2008, using the SQL Server Native driver. However, it is
possible for this to be used on other database platforms.
For more info on mxODBC, see http://www.egenix.com/
+
"""
import sys
@@ -31,6 +32,9 @@ class MxODBCConnector(Connector):
@classmethod
def dbapi(cls):
+ # this classmethod will normally be replaced by an instance
+ # attribute of the same name, so this is normally only called once.
+ cls._load_mx_exceptions()
platform = sys.platform
if platform == 'win32':
from mx.ODBC import Windows as module
@@ -43,6 +47,16 @@ class MxODBCConnector(Connector):
raise ImportError, "Unrecognized platform for mxODBC import"
return module
+ @classmethod
+ def _load_mx_exceptions(cls):
+ """ Import mxODBC exception classes into the module namespace,
+ as if they had been imported normally. This is done here
+ to avoid requiring all SQLAlchemy users to install mxODBC.
+ """
+ global InterfaceError, ProgrammingError
+ from mx.ODBC import InterfaceError
+ from mx.ODBC import ProgrammingError
+
def on_connect(self):
def connect(conn):
conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
@@ -52,10 +66,9 @@ class MxODBCConnector(Connector):
return connect
def _error_handler(self):
- """Return a handler that adjusts mxODBC's raised Warnings to
+ """ Return a handler that adjusts mxODBC's raised Warnings to
emit Python standard warnings.
"""
-
from mx.ODBC.Error import Warning as MxOdbcWarning
def error_handler(connection, cursor, errorclass, errorvalue):
@@ -85,10 +98,10 @@ class MxODBCConnector(Connector):
"""
opts = url.translate_connect_args(username='user')
opts.update(url.query)
- args = opts['host'],
- kwargs = {'user':opts['user'],
- 'password': opts['password']}
- return args, kwargs
+ args = opts.pop('host')
+ opts.pop('port', None)
+ opts.pop('database', None)
+ return (args,), opts
def is_disconnect(self, e):
# eGenix recommends checking connection.closed here,
@@ -101,6 +114,7 @@ class MxODBCConnector(Connector):
return False
def _get_server_version_info(self, connection):
+ # eGenix suggests using conn.dbms_version instead of what we're doing here
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
@@ -112,4 +126,21 @@ class MxODBCConnector(Connector):
version.append(n)
return tuple(version)
-
+ def do_execute(self, cursor, statement, parameters, context=None):
+ if context:
+ native_odbc_execute = context.execution_options.\
+ get('native_odbc_execute', 'auto')
+ if native_odbc_execute is True:
+ # user specified native_odbc_execute=True
+ cursor.execute(statement, parameters)
+ elif native_odbc_execute is False:
+ # user specified native_odbc_execute=False
+ cursor.executedirect(statement, parameters)
+ elif context.is_crud:
+ # statement is UPDATE, DELETE, INSERT
+ cursor.execute(statement, parameters)
+ else:
+ # all other statements
+ cursor.executedirect(statement, parameters)
+ else:
+ cursor.executedirect(statement, parameters)
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index 5cf00bc92..b291f3e16 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -5,49 +5,6 @@ import sys
import re
import urllib
import decimal
-from sqlalchemy import processors, types as sqltypes
-
-class PyODBCNumeric(sqltypes.Numeric):
- """Turns Decimals with adjusted() < -6 into floats, > 7 into strings"""
-
- convert_large_decimals_to_string = False
-
- def bind_processor(self, dialect):
- super_process = super(PyODBCNumeric, self).bind_processor(dialect)
-
- def process(value):
- if self.asdecimal and \
- isinstance(value, decimal.Decimal):
-
- if value.adjusted() < -6:
- return processors.to_float(value)
- elif self.convert_large_decimals_to_string and \
- value.adjusted() > 7:
- return self._large_dec_to_string(value)
-
- if super_process:
- return super_process(value)
- else:
- return value
- return process
-
- def _large_dec_to_string(self, value):
- if 'E' in str(value):
- result = "%s%s%s" % (
- (value < 0 and '-' or ''),
- "".join([str(s) for s in value._int]),
- "0" * (value.adjusted() - (len(value._int)-1)))
- else:
- if (len(value._int) - 1) > value.adjusted():
- result = "%s%s.%s" % (
- (value < 0 and '-' or ''),
- "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
- "".join([str(s) for s in value._int][value.adjusted() + 1:]))
- else:
- result = "%s%s" % (
- (value < 0 and '-' or ''),
- "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
- return result
class PyODBCConnector(Connector):
driver='pyodbc'
diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py
index 7dfb3153e..2b76b93d0 100644
--- a/lib/sqlalchemy/dialects/access/base.py
+++ b/lib/sqlalchemy/dialects/access/base.py
@@ -16,7 +16,7 @@ This dialect is *not* tested on SQLAlchemy 0.6.
"""
from sqlalchemy import sql, schema, types, exc, pool
from sqlalchemy.sql import compiler, expression
-from sqlalchemy.engine import default, base
+from sqlalchemy.engine import default, base, reflection
from sqlalchemy import processors
class AcNumeric(types.Numeric):
@@ -299,7 +299,8 @@ class AccessDialect(default.DefaultDialect):
finally:
dtbs.Close()
- def table_names(self, connection, schema):
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
# A fresh DAO connection is opened for each reflection
# This is necessary, so we get the latest updates
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
index a2da132da..70318157c 100644
--- a/lib/sqlalchemy/dialects/firebird/base.py
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -378,7 +378,8 @@ class FBDialect(default.DefaultDialect):
c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
return c.first() is not None
- def table_names(self, connection, schema):
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
s = """
SELECT DISTINCT rdb$relation_name
FROM rdb$relation_fields
@@ -387,10 +388,6 @@ class FBDialect(default.DefaultDialect):
return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
- return self.table_names(connection, schema)
-
- @reflection.cache
def get_view_names(self, connection, schema=None, **kw):
s = """
SELECT distinct rdb$view_name
diff --git a/lib/sqlalchemy/dialects/informix/base.py b/lib/sqlalchemy/dialects/informix/base.py
index 54aae6eb3..266a74a7b 100644
--- a/lib/sqlalchemy/dialects/informix/base.py
+++ b/lib/sqlalchemy/dialects/informix/base.py
@@ -193,7 +193,8 @@ class InformixDialect(default.DefaultDialect):
cu.execute('SET LOCK MODE TO WAIT')
#cu.execute('SET ISOLATION TO REPEATABLE READ')
- def table_names(self, connection, schema):
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
s = "select tabname from systables"
return [row[0] for row in connection.execute(s)]
diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py
index 758cfaf05..2e1d6a58f 100644
--- a/lib/sqlalchemy/dialects/maxdb/base.py
+++ b/lib/sqlalchemy/dialects/maxdb/base.py
@@ -63,7 +63,7 @@ import datetime, itertools, re
from sqlalchemy import exc, schema, sql, util, processors
from sqlalchemy.sql import operators as sql_operators, expression as sql_expr
from sqlalchemy.sql import compiler, visitors
-from sqlalchemy.engine import base as engine_base, default
+from sqlalchemy.engine import base as engine_base, default, reflection
from sqlalchemy import types as sqltypes
@@ -880,7 +880,8 @@ class MaxDBDialect(default.DefaultDialect):
rp = connection.execute(sql, bind)
return bool(rp.first())
- def table_names(self, connection, schema):
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
if schema is None:
sql = (" SELECT TABLENAME FROM TABLES WHERE "
" SCHEMANAME=CURRENT_SCHEMA ")
diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py
index 9e12a944d..502a02acc 100644
--- a/lib/sqlalchemy/dialects/mssql/adodbapi.py
+++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py
@@ -1,3 +1,7 @@
+"""
+The adodbapi dialect is not implemented for 0.6 at this time.
+
+"""
from sqlalchemy import types as sqltypes, util
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
import sys
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 7660fe9f7..066ab8d04 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2,119 +2,10 @@
"""Support for the Microsoft SQL Server database.
-Driver
-------
-
-The MSSQL dialect will work with three different available drivers:
-
-* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded
- driver.
-
-* *pymssql* - http://pymssql.sourceforge.net/
-
-* *adodbapi* - http://adodbapi.sourceforge.net/
-
-Drivers are loaded in the order listed above based on availability.
-
-If you need to load a specific driver pass ``module_name`` when
-creating the engine::
-
- engine = create_engine('mssql+module_name://dsn')
-
-``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and
-``adodbapi``.
-
-Currently the pyodbc driver offers the greatest level of
-compatibility.
-
Connecting
----------
-Connecting with create_engine() uses the standard URL approach of
-``mssql://user:pass@host/dbname[?key=value&key=value...]``.
-
-If the database name is present, the tokens are converted to a
-connection string with the specified values. If the database is not
-present, then the host token is taken directly as the DSN name.
-
-Examples of pyodbc connection string URLs:
-
-* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``.
- The connection string that is created will appear like::
-
- dsn=mydsn;Trusted_Connection=Yes
-
-* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named
- ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
- connection string that is created will appear like::
-
- dsn=mydsn;UID=user;PWD=pass
-
-* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects
- using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
- information, plus the additional connection configuration option
- ``LANGUAGE``. The connection string that is created will appear
- like::
-
- dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
-
-* *mssql+pyodbc://user:pass@host/db* - connects using a connection string
- dynamically created that would appear like::
-
- DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
-
-* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection
- string that is dynamically created, which also includes the port
- information using the comma syntax. If your connection string
- requires the port information to be passed as a ``port`` keyword
- see the next example. This will create the following connection
- string::
-
- DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
-
-* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection
- string that is dynamically created that includes the port
- information as a separate ``port`` keyword. This will create the
- following connection string::
-
- DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
-
-If you require a connection string that is outside the options
-presented above, use the ``odbc_connect`` keyword to pass in a
-urlencoded connection string. What gets passed in will be urldecoded
-and passed directly.
-
-For example::
-
- mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
-
-would create the following connection string::
-
- dsn=mydsn;Database=db
-
-Encoding your connection string can be easily accomplished through
-the python shell. For example::
-
- >>> import urllib
- >>> urllib.quote_plus('dsn=mydsn;Database=db')
- 'dsn%3Dmydsn%3BDatabase%3Ddb'
-
-Additional arguments which may be specified either as query string
-arguments on the URL, or as keyword argument to
-:func:`~sqlalchemy.create_engine()` are:
-
-* *query_timeout* - allows you to override the default query timeout.
- Defaults to ``None``. This is only supported on pymssql.
-
-* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
- should be used in place of the non-scoped version @@IDENTITY.
- Defaults to True.
-
-* *max_identifier_length* - allows you to se the maximum length of
- identfiers supported by the database. Defaults to 128. For pymssql
- the default is 30.
-
-* *schema_name* - use to set the schema name. Defaults to ``dbo``.
+See the individual driver sections below for details on connecting.
Auto Increment Behavior
-----------------------
@@ -220,9 +111,6 @@ Known Issues
* No support for more than one ``IDENTITY`` column per table
-* pymssql has problems with binary and unicode data that this module
- does **not** work around
-
"""
import datetime, decimal, inspect, operator, sys, re
import itertools
@@ -1149,11 +1037,6 @@ class MSDialect(default.DefaultDialect):
pass
return self.schema_name
- def table_names(self, connection, schema):
- s = select([ischema.tables.c.table_name],
- ischema.tables.c.table_schema==schema)
- return [row[0] for row in connection.execute(s)]
-
def has_table(self, connection, tablename, schema=None):
current_schema = schema or self.default_schema_name
@@ -1182,7 +1065,7 @@ class MSDialect(default.DefaultDialect):
s = sql.select([tables.c.table_name],
sql.and_(
tables.c.table_schema == current_schema,
- tables.c.table_type == 'BASE TABLE'
+ tables.c.table_type == u'BASE TABLE'
),
order_by=[tables.c.table_name]
)
@@ -1196,7 +1079,7 @@ class MSDialect(default.DefaultDialect):
s = sql.select([tables.c.table_name],
sql.and_(
tables.c.table_schema == current_schema,
- tables.c.table_type == 'VIEW'
+ tables.c.table_type == u'VIEW'
),
order_by=[tables.c.table_name]
)
@@ -1320,11 +1203,11 @@ class MSDialect(default.DefaultDialect):
table_fullname = "%s.%s" % (current_schema, tablename)
cursor = connection.execute(
"select ident_seed('%s'), ident_incr('%s')"
- % (tablename, tablename)
+ % (table_fullname, table_fullname)
)
row = cursor.first()
- if not row is None:
+ if row is not None and row[0] is not None:
colmap[ic]['sequence'].update({
'start' : int(row[0]),
'increment' : int(row[1])
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
index bb6ff315a..312e83cb1 100644
--- a/lib/sqlalchemy/dialects/mssql/information_schema.py
+++ b/lib/sqlalchemy/dialects/mssql/information_schema.py
@@ -21,7 +21,7 @@ tables = Table("TABLES", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
- Column("TABLE_TYPE", String, key="table_type"),
+ Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
schema="INFORMATION_SCHEMA")
columns = Table("COLUMNS", ischema,
@@ -42,7 +42,7 @@ constraints = Table("TABLE_CONSTRAINTS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
- Column("CONSTRAINT_TYPE", String, key="constraint_type"),
+ Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"),
schema="INFORMATION_SCHEMA")
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py
index 7148a3628..efe763659 100644
--- a/lib/sqlalchemy/dialects/mssql/mxodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py
@@ -1,9 +1,41 @@
"""
-MSSQL dialect tweaked to work with mxODBC, mainly by making use
-of the MSSQLStrictCompiler.
+Support for MS-SQL via mxODBC.
+
+mxODBC is available at:
+
+ http://www.egenix.com/
This was tested with mxODBC 3.1.2 and the SQL Server Native
Client connected to MSSQL 2005 and 2008 Express Editions.
+
+Connecting
+~~~~~~~~~~
+
+Connection is via DSN::
+
+ mssql+mxodbc://<username>:<password>@<dsnname>
+
+Execution Modes
+~~~~~~~~~~~~~~~
+
+mxODBC features two styles of statement execution, using the ``cursor.execute()``
+and ``cursor.executedirect()`` methods (the second being an extension to the
+DBAPI specification). The former makes use of the native
+parameter binding services of the ODBC driver, while the latter uses string escaping.
+The primary advantage to native parameter binding is that the same statement, when
+executed many times, is only prepared once. Whereas the primary advantage to the
+latter is that the rules for bind parameter placement are relaxed. MS-SQL has very
+strict rules for native binds, including that they cannot be placed within the argument
+lists of function calls, anywhere outside the FROM, or even within subqueries within the
+FROM clause - making the usage of bind parameters within SELECT statements impossible for
+all but the most simplistic statements. For this reason, the mxODBC dialect uses the
+"native" mode by default only for INSERT, UPDATE, and DELETE statements, and uses the
+escaped string mode for all other statements. This behavior can be controlled completely
+via :meth:`~sqlalchemy.sql.expression.Executable.execution_options`
+using the ``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a value of
+``True`` will unconditionally use native bind parameters and a value of ``False`` will
+uncondtionally use string-escaped parameters.
+
"""
import re
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
index b3a57d318..ca1c4a142 100644
--- a/lib/sqlalchemy/dialects/mssql/pymssql.py
+++ b/lib/sqlalchemy/dialects/mssql/pymssql.py
@@ -1,40 +1,101 @@
"""
Support for the pymssql dialect.
-Going forward we will be supporting the 1.0 release of pymssql.
+This dialect supports pymssql 1.0 and greater.
+
+pymssql is available at:
+
+ http://pymssql.sourceforge.net/
+
+Connecting
+^^^^^^^^^^
+
+Sample connect string::
+
+ mssql+pymssql://<username>:<password>@<freetds_name>
+
+Adding "?charset=utf8" or similar will cause pymssql to return
+strings as Python unicode objects. This can potentially improve
+performance in some scenarios as decoding of strings is
+handled natively.
+
+Limitations
+^^^^^^^^^^^
+
+pymssql inherits a lot of limitations from FreeTDS, including:
+
+* no support for multibyte schema identifiers
+* poor support for large decimals
+* poor support for binary fields
+* poor support for VARCHAR/CHAR fields over 255 characters
+
+Please consult the pymssql documentation for further information.
"""
from sqlalchemy.dialects.mssql.base import MSDialect
-from sqlalchemy import types as sqltypes
+from sqlalchemy import types as sqltypes, util, processors
+import re
+import decimal
+class _MSNumeric_pymssql(sqltypes.Numeric):
+ def result_processor(self, dialect, type_):
+ if not self.asdecimal:
+ return processors.to_float
+ else:
+ return sqltypes.Numeric.result_processor(self, dialect, type_)
class MSDialect_pymssql(MSDialect):
supports_sane_rowcount = False
max_identifier_length = 30
driver = 'pymssql'
-
+
+ colspecs = util.update_copy(
+ MSDialect.colspecs,
+ {
+ sqltypes.Numeric:_MSNumeric_pymssql,
+ sqltypes.Float:sqltypes.Float,
+ }
+ )
@classmethod
def dbapi(cls):
- import pymssql as module
+ module = __import__('pymssql')
# pymmsql doesn't have a Binary method. we use string
# TODO: monkeypatching here is less than ideal
- module.Binary = lambda st: str(st)
+ module.Binary = str
+
+ client_ver = tuple(int(x) for x in module.__version__.split("."))
+ if client_ver < (1, ):
+ util.warn("The pymssql dialect expects at least "
+ "the 1.0 series of the pymssql DBAPI.")
return module
def __init__(self, **params):
super(MSDialect_pymssql, self).__init__(**params)
self.use_scope_identity = True
+ def _get_server_version_info(self, connection):
+ vers = connection.scalar("select @@version")
+ m = re.match(r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3, 4))
+ else:
+ return None
def create_connect_args(self, url):
- keys = url.query
- if keys.get('port'):
- # pymssql expects port as host:port, not a separate arg
- keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
- del keys['port']
- return [[], keys]
+ opts = url.translate_connect_args(username='user')
+ opts.update(url.query)
+ opts.pop('port', None)
+ return [[], opts]
def is_disconnect(self, e):
- return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
+ for msg in (
+ "Error 10054",
+ "Not connected to any MS SQL server",
+ "Connection is closed"
+ ):
+ if msg in str(e):
+ return True
+ else:
+ return False
dialect = MSDialect_pymssql \ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
index 8e7e90629..c74be0e53 100644
--- a/lib/sqlalchemy/dialects/mssql/pyodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py
@@ -1,22 +1,134 @@
"""
Support for MS-SQL via pyodbc.
-http://pypi.python.org/pypi/pyodbc/
+pyodbc is available at:
-Connect strings are of the form::
+ http://pypi.python.org/pypi/pyodbc/
- mssql+pyodbc://<username>:<password>@<dsn>/
- mssql+pyodbc://<username>:<password>@<host>/<database>
+Connecting
+^^^^^^^^^^
+
+Examples of pyodbc connection string URLs:
+
+* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``.
+ The connection string that is created will appear like::
+
+ dsn=mydsn;Trusted_Connection=Yes
+
+* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named
+ ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
+ connection string that is created will appear like::
+
+ dsn=mydsn;UID=user;PWD=pass
+
+* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects
+ using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
+ information, plus the additional connection configuration option
+ ``LANGUAGE``. The connection string that is created will appear
+ like::
+
+ dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
+
+* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection string
+ dynamically created that would appear like::
+
+ DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
+
+* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection
+ string that is dynamically created, which also includes the port
+ information using the comma syntax. If your connection string
+ requires the port information to be passed as a ``port`` keyword
+ see the next example. This will create the following connection
+ string::
+
+ DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
+
+* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection
+ string that is dynamically created that includes the port
+ information as a separate ``port`` keyword. This will create the
+ following connection string::
+
+ DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
+
+If you require a connection string that is outside the options
+presented above, use the ``odbc_connect`` keyword to pass in a
+urlencoded connection string. What gets passed in will be urldecoded
+and passed directly.
+
+For example::
+
+ mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
+
+would create the following connection string::
+
+ dsn=mydsn;Database=db
+
+Encoding your connection string can be easily accomplished through
+the python shell. For example::
+
+ >>> import urllib
+ >>> urllib.quote_plus('dsn=mydsn;Database=db')
+ 'dsn%3Dmydsn%3BDatabase%3Ddb'
"""
from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
-from sqlalchemy.connectors.pyodbc import PyODBCConnector, PyODBCNumeric
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy import types as sqltypes, util
+import decimal
+
+class _MSNumeric_pyodbc(sqltypes.Numeric):
+ """Turns Decimals with adjusted() < 0 or > 7 into strings.
+
+ This is the only method that is proven to work with Pyodbc+MSSQL
+ without crashing (floats can be used but seem to cause sporadic
+ crashes).
+
+ """
+
+ def bind_processor(self, dialect):
+ super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect)
+
+ def process(value):
+ if self.asdecimal and \
+ isinstance(value, decimal.Decimal):
+
+ adjusted = value.adjusted()
+ if adjusted < 0:
+ return self._small_dec_to_string(value)
+ elif adjusted > 7:
+ return self._large_dec_to_string(value)
-class _MSNumeric_pyodbc(PyODBCNumeric):
- convert_large_decimals_to_string = True
+ if super_process:
+ return super_process(value)
+ else:
+ return value
+ return process
+
+ def _small_dec_to_string(self, value):
+ return "%s0.%s%s" % (
+ (value < 0 and '-' or ''),
+ '0' * (abs(value.adjusted()) - 1),
+ "".join([str(nint) for nint in value._int]))
+
+ def _large_dec_to_string(self, value):
+ if 'E' in str(value):
+ result = "%s%s%s" % (
+ (value < 0 and '-' or ''),
+ "".join([str(s) for s in value._int]),
+ "0" * (value.adjusted() - (len(value._int)-1)))
+ else:
+ if (len(value._int) - 1) > value.adjusted():
+ result = "%s%s.%s" % (
+ (value < 0 and '-' or ''),
+ "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
+ "".join([str(s) for s in value._int][value.adjusted() + 1:]))
+ else:
+ result = "%s%s" % (
+ (value < 0 and '-' or ''),
+ "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
+ return result
class MSExecutionContext_pyodbc(MSExecutionContext):
diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py
index e4ecccdfc..f37a0c766 100644
--- a/lib/sqlalchemy/dialects/mysql/__init__.py
+++ b/lib/sqlalchemy/dialects/mysql/__init__.py
@@ -6,12 +6,12 @@ base.dialect = mysqldb.dialect
from sqlalchemy.dialects.mysql.base import \
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, DOUBLE, ENUM, DECIMAL,\
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, MEDIUMTEXT, NCHAR, \
- NVARCHAR, NUMERIC, SET, SMALLINT, TEXT, TIME, TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT,\
+ NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT,\
VARBINARY, VARCHAR, YEAR, dialect
__all__ = (
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE',
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT',
-'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'TEXT', 'TIME', 'TIMESTAMP',
+'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP',
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect'
)
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 2311b06df..6a0761476 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1,37 +1,13 @@
# -*- fill-column: 78 -*-
-# mysql.py
+# mysql/base.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
+# and Jason Kirtland.
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Support for the MySQL database.
-Overview
---------
-
-For normal SQLAlchemy usage, importing this module is unnecessary. It will be
-loaded on-demand when a MySQL connection is needed. The generic column types
-like :class:`~sqlalchemy.String` and :class:`~sqlalchemy.Integer` will
-automatically be adapted to the optimal matching MySQL column type.
-
-But if you would like to use one of the MySQL-specific or enhanced column
-types when creating tables with your :class:`~sqlalchemy.Table` definitions,
-then you will need to import them from this module::
-
- from sqlalchemy.dialect.mysql import base as mysql
-
- Table('mytable', metadata,
- Column('id', Integer, primary_key=True),
- Column('ittybittyblob', mysql.TINYBLOB),
- Column('biggy', mysql.BIGINT(unsigned=True)))
-
-All standard MySQL column types are supported. The OpenGIS types are
-available for use via table reflection but have no special support or mapping
-to Python classes. If you're using these types and have opinions about how
-OpenGIS can be smartly integrated into SQLAlchemy please join the mailing
-list!
-
Supported Versions and Features
-------------------------------
@@ -44,10 +20,7 @@ in the suite 100%. No heroic measures are taken to work around major missing
SQL features- if your server version does not support sub-selects, for
example, they won't work in SQLAlchemy either.
-Currently, the only DB-API driver supported is `MySQL-Python` (also referred to
-as `MySQLdb`). Either 1.2.1 or 1.2.2 are recommended. The alpha, beta and
-gamma releases of 1.2.1 and 1.2.2 should be avoided. Support for Jython and
-IronPython is planned.
+Most available DBAPI drivers are supported; see below.
===================================== ===============
Feature Minimum Version
@@ -64,6 +37,37 @@ Nested Transactions 5.0.3
See the official MySQL documentation for detailed information about features
supported in any given server release.
+Connecting
+----------
+
+See the API documentation on individual drivers for details on connecting.
+
+Data Types
+----------
+
+All of MySQL's standard types are supported. These can also be specified within
+table metadata, for the purpose of issuing CREATE TABLE statements
+which include MySQL-specific extensions. The types are available
+from the module, as in::
+
+ from sqlalchemy.dialects import mysql
+
+ Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('ittybittyblob', mysql.TINYBLOB),
+ Column('biggy', mysql.BIGINT(unsigned=True)))
+
+See the API documentation on specific column types for further details.
+
+Connection Timeouts
+-------------------
+
+MySQL features an automatic connection close behavior, for connections that have
+been idle for eight hours or more. To circumvent having this issue, use the
+``pool_recycle`` option which controls the maximum age of any connection::
+
+ engine = create_engine('mysql+mysqldb://...', pool_recycle=3600)
+
Storage Engines
---------------
@@ -159,20 +163,13 @@ And of course any valid MySQL statement can be executed as a string as well.
Some limited direct support for MySQL extensions to SQL is currently
available.
- * SELECT pragma::
-
- select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT'])
+* SELECT pragma::
- * UPDATE with LIMIT::
+ select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT'])
- update(..., mysql_limit=10)
+* UPDATE with LIMIT::
-Boolean Types
--------------
-
-MySQL's BOOL type is a synonym for SMALLINT, so is actually a numeric value,
-and additionally MySQL doesn't support CHECK constraints. Therefore SQLA's
-Boolean type cannot fully constrain values to just "True" and "False" the way it does for most other backends.
+ update(..., mysql_limit=10)
Troubleshooting
---------------
@@ -1154,7 +1151,10 @@ class MySQLCompiler(compiler.SQLCompiler):
def visit_match_op(self, binary, **kw):
return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right))
-
+
+ def get_from_hint_text(self, table, text):
+ return text
+
def visit_typeclause(self, typeclause):
type_ = typeclause.type.dialect_impl(self.dialect)
if isinstance(type_, sqltypes.Integer):
@@ -1204,11 +1204,11 @@ class MySQLCompiler(compiler.SQLCompiler):
# support can be added, preferably after dialects are
# refactored to be version-sensitive.
return ''.join(
- (self.process(join.left, asfrom=True),
+ (self.process(join.left, asfrom=True, **kwargs),
(join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "),
- self.process(join.right, asfrom=True),
+ self.process(join.right, asfrom=True, **kwargs),
" ON ",
- self.process(join.onclause)))
+ self.process(join.onclause, **kwargs)))
def for_update_clause(self, select):
if select.for_update == 'read':
@@ -1766,24 +1766,20 @@ class MySQLDialect(default.DefaultDialect):
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
+ """Return a Unicode SHOW TABLES from a given schema."""
if schema is not None:
current_schema = schema
else:
current_schema = self.default_schema_name
- table_names = self.table_names(connection, current_schema)
- return table_names
-
- def table_names(self, connection, schema):
- """Return a Unicode SHOW TABLES from a given schema."""
charset = self._connection_charset
if self.server_version_info < (5, 0, 2):
rp = connection.execute("SHOW TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(schema))
+ self.identifier_preparer.quote_identifier(current_schema))
return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
else:
rp = connection.execute("SHOW FULL TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(schema))
+ self.identifier_preparer.quote_identifier(current_schema))
return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
if row[1] == 'BASE TABLE']
@@ -1796,7 +1792,7 @@ class MySQLDialect(default.DefaultDialect):
if schema is None:
schema = self.default_schema_name
if self.server_version_info < (5, 0, 2):
- return self.table_names(connection, schema)
+ return self.get_table_names(connection, schema)
charset = self._connection_charset
rp = connection.execute("SHOW FULL TABLES FROM %s" %
self.identifier_preparer.quote_identifier(schema))
@@ -1946,7 +1942,7 @@ class MySQLDialect(default.DefaultDialect):
# For winxx database hosts. TODO: is this really needed?
if casing == 1 and table.name != table.name.lower():
table.name = table.name.lower()
- lc_alias = schema._get_table_key(table.name, table.schema)
+ lc_alias = sa_schema._get_table_key(table.name, table.schema)
table.metadata.tables[lc_alias] = table
def _detect_charset(self, connection):
@@ -2208,13 +2204,6 @@ class MySQLTableDefinitionParser(object):
name, type_, args, notnull = \
spec['name'], spec['coltype'], spec['arg'], spec['notnull']
- # Convention says that TINYINT(1) columns == BOOLEAN
- if type_ == 'tinyint' and args == '1':
- type_ = 'boolean'
- args = None
- spec['unsigned'] = None
- spec['zerofill'] = None
-
try:
col_type = self.dialect.ischema_names[type_]
except KeyError:
diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
index 981e1e204..2da18e50f 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
@@ -1,6 +1,15 @@
"""Support for the MySQL database via the MySQL Connector/Python adapter.
-# TODO: add docs/notes here regarding MySQL Connector/Python
+MySQL Connector/Python is available at:
+
+ https://launchpad.net/myconnpy
+
+Connecting
+-----------
+
+Connect string format::
+
+ mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
"""
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py
index 9d34939a1..6e6bb0ecc 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqldb.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py
@@ -1,5 +1,18 @@
"""Support for the MySQL database via the MySQL-python adapter.
+MySQL-Python is available at:
+
+ http://sourceforge.net/projects/mysql-python
+
+At least version 1.2.1 or 1.2.2 should be used.
+
+Connecting
+-----------
+
+Connect string format::
+
+ mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
+
Character Sets
--------------
@@ -14,10 +27,21 @@ enabling ``use_unicode`` in the driver by default. For regular encoded
strings, also pass ``use_unicode=0`` in the connection arguments::
# set client encoding to utf8; all strings come back as unicode
- create_engine('mysql:///mydb?charset=utf8')
+ create_engine('mysql+mysqldb:///mydb?charset=utf8')
# set client encoding to utf8; all strings come back as utf8 str
- create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
+ create_engine('mysql+mysqldb:///mydb?charset=utf8&use_unicode=0')
+
+Known Issues
+-------------
+
+MySQL-python at least as of version 1.2.2 has a serious memory leak related
+to unicode conversion, a feature which is disabled via ``use_unicode=0``.
+The recommended connection form with SQLAlchemy is::
+
+ engine = create_engine('mysql://scott:tiger@localhost/test?charset=utf8&use_unicode=0', pool_recycle=3600)
+
+
"""
import re
diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py
index f26bc4da2..ebc726482 100644
--- a/lib/sqlalchemy/dialects/mysql/oursql.py
+++ b/lib/sqlalchemy/dialects/mysql/oursql.py
@@ -1,5 +1,16 @@
"""Support for the MySQL database via the oursql adapter.
+OurSQL is available at:
+
+ http://packages.python.org/oursql/
+
+Connecting
+-----------
+
+Connect string format::
+
+ mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
+
Character Sets
--------------
@@ -151,8 +162,8 @@ class MySQLDialect_oursql(MySQLDialect):
**kw
)
- def table_names(self, connection, schema):
- return MySQLDialect.table_names(self,
+ def get_table_names(self, connection, schema=None, **kw):
+ return MySQLDialect.get_table_names(self,
connection.connect().\
execution_options(_oursql_plain_query=True),
schema
diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py
index 5add45b21..1f73c6ef1 100644
--- a/lib/sqlalchemy/dialects/mysql/pyodbc.py
+++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py
@@ -1,5 +1,24 @@
"""Support for the MySQL database via the pyodbc adapter.
+pyodbc is available at:
+
+ http://pypi.python.org/pypi/pyodbc/
+
+Connecting
+----------
+
+Connect string::
+
+ mysql+pyodbc://<username>:<password>@<dsnname>
+
+Limitations
+-----------
+
+The mysql-pyodbc dialect is subject to unresolved character encoding issues
+which exist within the current ODBC drivers available.
+(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
+of OurSQL, MySQLdb, or MySQL-connector/Python.
+
"""
from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py
index f4cf0013c..06d3e6616 100644
--- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py
+++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py
@@ -6,6 +6,13 @@ JDBC Driver
The official MySQL JDBC driver is at
http://dev.mysql.com/downloads/connector/j/.
+Connecting
+----------
+
+Connect string format:
+
+ mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
+
Character Sets
--------------
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index f76edabf2..475730988 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -225,6 +225,8 @@ ischema_names = {
'CLOB' : CLOB,
'NCLOB' : NCLOB,
'TIMESTAMP' : TIMESTAMP,
+ 'TIMESTAMP WITH TIME ZONE' : TIMESTAMP,
+ 'INTERVAL DAY TO SECOND' : INTERVAL,
'RAW' : RAW,
'FLOAT' : FLOAT,
'DOUBLE PRECISION' : DOUBLE_PRECISION,
@@ -256,7 +258,13 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
"(%d)" % type_.second_precision or
"",
)
-
+
+ def visit_TIMESTAMP(self, type_):
+ if type_.timezone:
+ return "TIMESTAMP WITH TIME ZONE"
+ else:
+ return "TIMESTAMP"
+
def visit_DOUBLE_PRECISION(self, type_):
return self._generate_numeric(type_, "DOUBLE PRECISION")
@@ -278,7 +286,10 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
return "%(name)s(%(precision)s, %(scale)s)" % {'name':name,'precision': precision, 'scale' : scale}
def visit_VARCHAR(self, type_):
- return "VARCHAR(%(length)s)" % {'length' : type_.length}
+ if self.dialect.supports_char_length:
+ return "VARCHAR(%(length)s CHAR)" % {'length' : type_.length}
+ else:
+ return "VARCHAR(%(length)s)" % {'length' : type_.length}
def visit_NVARCHAR(self, type_):
return "NVARCHAR2(%(length)s)" % {'length' : type_.length}
@@ -331,6 +342,11 @@ class OracleCompiler(compiler.SQLCompiler):
def visit_match_op(self, binary, **kw):
return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+ def get_select_hint_text(self, byfroms):
+ return " ".join(
+ "/*+ %s */" % text for table, text in byfroms.items()
+ )
+
def function_argspec(self, fn, **kw):
if len(fn.clauses) > 0:
return compiler.SQLCompiler.function_argspec(self, fn, **kw)
@@ -349,7 +365,9 @@ class OracleCompiler(compiler.SQLCompiler):
if self.dialect.use_ansi:
return compiler.SQLCompiler.visit_join(self, join, **kwargs)
else:
- return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+ kwargs['asfrom'] = True
+ return self.process(join.left, **kwargs) + \
+ ", " + self.process(join.right, **kwargs)
def _get_nonansi_join_whereclause(self, froms):
clauses = []
@@ -381,14 +399,18 @@ class OracleCompiler(compiler.SQLCompiler):
def visit_sequence(self, seq):
return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
- def visit_alias(self, alias, asfrom=False, **kwargs):
+ def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
-
- if asfrom:
+
+ if asfrom or ashint:
alias_name = isinstance(alias.name, expression._generated_label) and \
self._truncated_identifier("alias", alias.name) or alias.name
-
- return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, alias_name)
+
+ if ashint:
+ return alias_name
+ elif asfrom:
+ return self.process(alias.original, asfrom=asfrom, **kwargs) + \
+ " " + self.preparer.format_alias(alias, alias_name)
else:
return self.process(alias.original, **kwargs)
@@ -561,7 +583,8 @@ class OracleDialect(default.DefaultDialect):
execution_ctx_cls = OracleExecutionContext
reflection_options = ('oracle_resolve_synonyms', )
-
+
+ supports_char_length = True
def __init__(self,
use_ansi=True,
@@ -576,6 +599,8 @@ class OracleDialect(default.DefaultDialect):
self.implicit_returning = self.server_version_info > (10, ) and \
self.__dict__.get('implicit_returning', True)
+ self.supports_char_length = self.server_version_info >= (9, )
+
if self.server_version_info < (9,):
self.colspecs = self.colspecs.copy()
self.colspecs.pop(sqltypes.Interval)
@@ -631,18 +656,6 @@ class OracleDialect(default.DefaultDialect):
def _get_default_schema_name(self, connection):
return self.normalize_name(connection.execute(u'SELECT USER FROM DUAL').scalar())
- def table_names(self, connection, schema):
- # note that table_names() isnt loading DBLINKed or synonym'ed tables
- if schema is None:
- schema = self.default_schema_name
- s = sql.text(
- "SELECT table_name FROM all_tables "
- "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') "
- "AND OWNER = :owner "
- "AND IOT_NAME IS NULL")
- cursor = connection.execute(s, owner=self.denormalize_name(schema))
- return [self.normalize_name(row[0]) for row in cursor]
-
def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
"""search for a local synonym matching the given desired owner/name.
@@ -712,7 +725,18 @@ class OracleDialect(default.DefaultDialect):
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
schema = self.denormalize_name(schema or self.default_schema_name)
- return self.table_names(connection, schema)
+
+ # note that table_names() isnt loading DBLINKed or synonym'ed tables
+ if schema is None:
+ schema = self.default_schema_name
+ s = sql.text(
+ "SELECT table_name FROM all_tables "
+ "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') "
+ "AND OWNER = :owner "
+ "AND IOT_NAME IS NULL")
+ cursor = connection.execute(s, owner=schema)
+ return [self.normalize_name(row[0]) for row in cursor]
+
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
@@ -742,11 +766,16 @@ class OracleDialect(default.DefaultDialect):
resolve_synonyms, dblink,
info_cache=info_cache)
columns = []
+ if self.supports_char_length:
+ char_length_col = 'char_length'
+ else:
+ char_length_col = 'data_length'
+
c = connection.execute(sql.text(
- "SELECT column_name, data_type, data_length, data_precision, data_scale, "
+ "SELECT column_name, data_type, %(char_length_col)s, data_precision, data_scale, "
"nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s "
"WHERE table_name = :table_name AND owner = :owner "
- "ORDER BY column_id" % {'dblink': dblink}),
+ "ORDER BY column_id" % {'dblink': dblink, 'char_length_col':char_length_col}),
table_name=table_name, owner=schema)
for row in c:
@@ -755,8 +784,10 @@ class OracleDialect(default.DefaultDialect):
if coltype == 'NUMBER' :
coltype = NUMBER(precision, scale)
- elif coltype=='CHAR' or coltype=='VARCHAR2':
+ elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'):
coltype = self.ischema_names.get(coltype)(length)
+ elif 'WITH TIME ZONE' in coltype:
+ coltype = TIMESTAMP(timezone=True)
else:
coltype = re.sub(r'\(\d+\)', '', coltype)
try:
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index c6e9cea5d..91af6620b 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -207,11 +207,19 @@ class OracleCompiler_cx_oracle(OracleCompiler):
class OracleExecutionContext_cx_oracle(OracleExecutionContext):
def pre_exec(self):
- quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {})
+ quoted_bind_names = \
+ getattr(self.compiled, '_quoted_bind_names', None)
if quoted_bind_names:
+ if not self.dialect.supports_unicode_binds:
+ quoted_bind_names = \
+ dict(
+ (fromname, toname.encode(self.dialect.encoding))
+ for fromname, toname in
+ quoted_bind_names.items()
+ )
for param in self.parameters:
- for fromname, toname in self.compiled._quoted_bind_names.iteritems():
- param[toname.encode(self.dialect.encoding)] = param[fromname]
+ for fromname, toname in quoted_bind_names.items():
+ param[toname] = param[fromname]
del param[fromname]
if self.dialect.auto_setinputsizes:
@@ -219,14 +227,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
# on String, including that outparams/RETURNING
# breaks for varchars
self.set_input_sizes(quoted_bind_names,
- exclude_types=self.dialect._cx_oracle_string_types
+ exclude_types=self.dialect._cx_oracle_string_types
)
-
+
+ # if a single execute, check for outparams
if len(self.compiled_parameters) == 1:
- for key in self.compiled.binds:
- bindparam = self.compiled.binds[key]
- name = self.compiled.bind_names[bindparam]
- value = self.compiled_parameters[0][name]
+ for bindparam in self.compiled.binds.values():
if bindparam.isoutparam:
dbtype = bindparam.type.dialect_impl(self.dialect).\
get_dbapi_type(self.dialect.dbapi)
@@ -238,6 +244,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
" cx_oracle" %
(name, bindparam.type)
)
+ name = self.compiled.bind_names[bindparam]
self.out_parameters[name] = self.cursor.var(dbtype)
self.parameters[0][quoted_bind_names.get(name, name)] = \
self.out_parameters[name]
@@ -250,7 +257,10 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
def get_result_proxy(self):
if hasattr(self, 'out_parameters') and self.compiled.returning:
- returning_params = dict((k, v.getvalue()) for k, v in self.out_parameters.items())
+ returning_params = dict(
+ (k, v.getvalue())
+ for k, v in self.out_parameters.items()
+ )
return ReturningResultProxy(self, returning_params)
result = None
@@ -264,10 +274,11 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
result = base.ResultProxy(self)
if hasattr(self, 'out_parameters'):
- if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
+ if self.compiled_parameters is not None and \
+ len(self.compiled_parameters) == 1:
result.out_parameters = out_parameters = {}
- for bind, name in self.compiled.bind_names.iteritems():
+ for bind, name in self.compiled.bind_names.items():
if name in self.out_parameters:
type = bind.type
impl_type = type.dialect_impl(self.dialect)
@@ -291,12 +302,14 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_oracle):
"""Support WITH_UNICODE in Python 2.xx.
- WITH_UNICODE allows cx_Oracle's Python 3 unicode handling behavior under Python 2.x.
- This mode in some cases disallows and in other cases silently
- passes corrupted data when non-Python-unicode strings (a.k.a. plain old Python strings)
- are passed as arguments to connect(), the statement sent to execute(), or any of the bind
- parameter keys or values sent to execute(). This optional context
- therefore ensures that all statements are passed as Python unicode objects.
+ WITH_UNICODE allows cx_Oracle's Python 3 unicode handling
+ behavior under Python 2.x. This mode in some cases disallows
+ and in other cases silently passes corrupted data when
+ non-Python-unicode strings (a.k.a. plain old Python strings)
+ are passed as arguments to connect(), the statement sent to execute(),
+ or any of the bind parameter keys or values sent to execute().
+ This optional context therefore ensures that all statements are
+ passed as Python unicode objects.
"""
def __init__(self, *arg, **kw):
@@ -373,17 +386,19 @@ class OracleDialect_cx_oracle(OracleDialect):
if hasattr(self.dbapi, 'version'):
cx_oracle_ver = tuple([int(x) for x in self.dbapi.version.split('.')])
- self.supports_unicode_binds = cx_oracle_ver >= (5, 0)
- self._cx_oracle_native_nvarchar = cx_oracle_ver >= (5, 0)
else:
- cx_oracle_ver = None
+ cx_oracle_ver = (0, 0, 0)
def types(*names):
- return set([getattr(self.dbapi, name, None) for name in names]).difference([None])
+ return set([
+ getattr(self.dbapi, name, None) for name in names
+ ]).difference([None])
self._cx_oracle_string_types = types("STRING", "UNICODE", "NCLOB", "CLOB")
self._cx_oracle_unicode_types = types("UNICODE", "NCLOB")
self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB")
+ self.supports_unicode_binds = cx_oracle_ver >= (5, 0)
+ self._cx_oracle_native_nvarchar = cx_oracle_ver >= (5, 0)
if cx_oracle_ver is None:
# this occurs in tests with mock DBAPIs
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index cbd92ccfe..bef2f1c61 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -78,7 +78,7 @@ from sqlalchemy import types as sqltypes
from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \
CHAR, TEXT, FLOAT, NUMERIC, \
- TIMESTAMP, TIME, DATE, BOOLEAN
+ DATE, BOOLEAN
class REAL(sqltypes.Float):
__visit_name__ = "REAL"
@@ -101,6 +101,16 @@ class MACADDR(sqltypes.TypeEngine):
__visit_name__ = "MACADDR"
PGMacAddr = MACADDR
+class TIMESTAMP(sqltypes.TIMESTAMP):
+ def __init__(self, timezone=False, precision=None):
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.precision = precision
+
+class TIME(sqltypes.TIME):
+ def __init__(self, timezone=False, precision=None):
+ super(TIME, self).__init__(timezone=timezone)
+ self.precision = precision
+
class INTERVAL(sqltypes.TypeEngine):
__visit_name__ = 'INTERVAL'
def __init__(self, precision=None):
@@ -466,10 +476,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
return self.dialect.identifier_preparer.format_type(type_)
def visit_TIMESTAMP(self, type_):
- return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+ return "TIMESTAMP%s %s" % (
+ getattr(type_, 'precision', None) and "(%d)" % type_.precision or "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+ )
def visit_TIME(self, type_):
- return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+ return "TIME%s %s" % (
+ getattr(type_, 'precision', None) and "(%d)" % type_.precision or "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+ )
def visit_INTERVAL(self, type_):
if type_.precision is not None:
@@ -725,17 +741,6 @@ class PGDialect(default.DefaultDialect):
cursor = connection.execute(sql.text(query, bindparams=bindparams))
return bool(cursor.scalar())
- def table_names(self, connection, schema):
- result = connection.execute(
- sql.text(u"SELECT relname FROM pg_class c "
- "WHERE relkind = 'r' "
- "AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) " %
- schema,
- typemap = {'relname':sqltypes.Unicode}
- )
- )
- return [row[0] for row in result]
-
def _get_server_version_info(self, connection):
v = connection.execute("select version()").scalar()
m = re.match('PostgreSQL (\d+)\.(\d+)(?:\.(\d+))?(?:devel)?', v)
@@ -805,8 +810,17 @@ class PGDialect(default.DefaultDialect):
current_schema = schema
else:
current_schema = self.default_schema_name
- table_names = self.table_names(connection, current_schema)
- return table_names
+
+ result = connection.execute(
+ sql.text(u"SELECT relname FROM pg_class c "
+ "WHERE relkind = 'r' "
+ "AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) " %
+ current_schema,
+ typemap = {'relname':sqltypes.Unicode}
+ )
+ )
+ return [row[0] for row in result]
+
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
@@ -877,39 +891,48 @@ class PGDialect(default.DefaultDialect):
# format columns
columns = []
for name, format_type, default, notnull, attnum, table_oid in rows:
- ## strip (30) from character varying(30)
- attype = re.search('([^\([]+)', format_type).group(1)
+ ## strip (5) from character varying(5), timestamp(5) with time zone, etc
+ attype = re.sub(r'\([\d,]+\)', '', format_type)
+
+ # strip '[]' from integer[], etc.
+ attype = re.sub(r'\[\]', '', attype)
+
nullable = not notnull
is_array = format_type.endswith('[]')
- try:
- charlen = re.search('\(([\d,]+)\)', format_type).group(1)
- except:
- charlen = False
- numericprec = False
- numericscale = False
+ charlen = re.search('\(([\d,]+)\)', format_type)
+ if charlen:
+ charlen = charlen.group(1)
+ kwargs = {}
+
if attype == 'numeric':
- if charlen is False:
- numericprec, numericscale = (None, None)
+ if charlen:
+ prec, scale = charlen.split(',')
+ args = (int(prec), int(scale))
else:
- numericprec, numericscale = charlen.split(',')
- charlen = False
+ args = ()
elif attype == 'double precision':
- numericprec, numericscale = (53, False)
- charlen = False
+ args = (53, )
elif attype == 'integer':
- numericprec, numericscale = (32, 0)
- charlen = False
- args = []
- for a in (charlen, numericprec, numericscale):
- if a is None:
- args.append(None)
- elif a is not False:
- args.append(int(a))
- kwargs = {}
- if attype == 'timestamp with time zone':
+ args = (32, 0)
+ elif attype in ('timestamp with time zone', 'time with time zone'):
kwargs['timezone'] = True
- elif attype == 'timestamp without time zone':
+ if charlen:
+ kwargs['precision'] = int(charlen)
+ args = ()
+ elif attype in ('timestamp without time zone', 'time without time zone', 'time'):
kwargs['timezone'] = False
+ if charlen:
+ kwargs['precision'] = int(charlen)
+ args = ()
+ elif attype in ('interval','interval year to month','interval day to second'):
+ if charlen:
+ kwargs['precision'] = int(charlen)
+ args = ()
+ elif charlen:
+ args = (int(charlen),)
+ else:
+ args = ()
+
if attype in self.ischema_names:
coltype = self.ischema_names[attype]
elif attype in enums:
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index c239a3ee0..f21c9a558 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -12,7 +12,7 @@ Note that psycopg1 is **not** supported.
Connecting
----------
-URLs are of the form `postgresql+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`.
+URLs are of the form `postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]`.
psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
@@ -34,6 +34,15 @@ Transactions
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+NOTICE logging
+---------------
+
+The psycopg2 dialect will log Postgresql NOTICE messages via the
+``sqlalchemy.dialects.postgresql`` logger::
+
+ import logging
+ logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
+
Per-Statement Execution Options
-------------------------------
@@ -46,8 +55,10 @@ The following per-statement execution options are respected:
"""
-import random, re
+import random
+import re
import decimal
+import logging
from sqlalchemy import util
from sqlalchemy import processors
@@ -59,6 +70,10 @@ from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \
PGIdentifierPreparer, PGExecutionContext, \
ENUM, ARRAY
+
+logger = logging.getLogger('sqlalchemy.dialects.postgresql')
+
+
class _PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
return None
@@ -130,11 +145,22 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
return self._connection.connection.cursor()
def get_result_proxy(self):
+ if logger.isEnabledFor(logging.INFO):
+ self._log_notices(self.cursor)
+
if self.__is_server_side:
return base.BufferedRowResultProxy(self)
else:
return base.ResultProxy(self)
+ def _log_notices(self, cursor):
+ for notice in cursor.connection.notices:
+ # NOTICE messages have a
+ # newline character at the end
+ logger.info(notice.rstrip())
+
+ cursor.connection.notices[:] = []
+
class PGCompiler_psycopg2(PGCompiler):
def visit_mod(self, binary, **kw):
@@ -190,7 +216,7 @@ class PGDialect_psycopg2(PGDialect):
return connect
else:
return base_on_connect
-
+
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if 'port' in opts:
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index d7637e71b..ca0a39136 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -331,6 +331,9 @@ class SQLiteDialect(default.DefaultDialect):
colspecs = colspecs
isolation_level = None
+ supports_cast = True
+ supports_default_values = True
+
def __init__(self, isolation_level=None, native_datetime=False, **kwargs):
default.DefaultDialect.__init__(self, **kwargs)
if isolation_level and isolation_level not in ('SERIALIZABLE',
@@ -345,6 +348,13 @@ class SQLiteDialect(default.DefaultDialect):
# conversions (and perhaps datetime/time as well on some
# hypothetical driver ?)
self.native_datetime = native_datetime
+
+ if self.dbapi is not None:
+ self.supports_default_values = \
+ self.dbapi.sqlite_version_info >= (3, 3, 8)
+ self.supports_cast = \
+ self.dbapi.sqlite_version_info >= (3, 2, 3)
+
def on_connect(self):
if self.isolation_level is not None:
@@ -360,8 +370,9 @@ class SQLiteDialect(default.DefaultDialect):
return connect
else:
return None
-
- def table_names(self, connection, schema):
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
master = '%s.sqlite_master' % qschema
@@ -401,10 +412,6 @@ class SQLiteDialect(default.DefaultDialect):
return (row is not None)
@reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
- return self.table_names(connection, schema)
-
- @reflection.cache
def get_view_names(self, connection, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
index b48abbb7d..575cb37f2 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -187,20 +187,15 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
def __init__(self, **kwargs):
SQLiteDialect.__init__(self, **kwargs)
- def vers(num):
- return tuple([int(x) for x in num.split('.')])
+
if self.dbapi is not None:
sqlite_ver = self.dbapi.version_info
- if sqlite_ver < (2, 1, '3'):
+ if sqlite_ver < (2, 1, 3):
util.warn(
("The installed version of pysqlite2 (%s) is out-dated "
"and will cause errors in some cases. Version 2.1.3 "
"or greater is recommended.") %
'.'.join([str(subver) for subver in sqlite_ver]))
- if self.dbapi.sqlite_version_info < (3, 3, 8):
- self.supports_default_values = False
- self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
-
@classmethod
def dbapi(cls):
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
index bdaab2eb7..6719b422b 100644
--- a/lib/sqlalchemy/dialects/sybase/base.py
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -1,6 +1,9 @@
-# sybase.py
-# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
-# Coding: Alexander Houben alexander.houben@thor-solutions.ch
+# sybase/base.py
+# Copyright (C) 2010 Michael Bayer mike_mp@zzzcomputing.com
+# get_select_precolumns(), limit_clause() implementation
+# copyright (C) 2007 Fisch Asset Management
+# AG http://www.fam.ch, with coding by Alexander Houben
+# alexander.houben@thor-solutions.ch
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -277,6 +280,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
s += "START AT %s " % (select._offset+1,)
return s
+ def get_from_hint_text(self, table, text):
+ return text
+
def limit_clause(self, select):
# Limit in sybase is after the select keyword
return ""
@@ -310,8 +316,6 @@ class SybaseDDLCompiler(compiler.DDLCompiler):
"columns in order to generate DDL")
seq_col = column.table._autoincrement_column
-
-
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if seq_col is column:
sequence = isinstance(column.default, sa_schema.Sequence) and column.default
@@ -382,9 +386,6 @@ class SybaseDialect(default.DefaultDialect):
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self.default_schema_name
- return self.table_names(connection, schema)
-
- def table_names(self, connection, schema):
result = connection.execute(
text("select sysobjects.name from sysobjects, sysusers "
diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py
index 19ad70fe8..e34f2605c 100644
--- a/lib/sqlalchemy/dialects/sybase/pyodbc.py
+++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py
@@ -29,12 +29,34 @@ Currently *not* supported are::
"""
from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
-from sqlalchemy.connectors.pyodbc import PyODBCConnector, PyODBCNumeric
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+import decimal
+from sqlalchemy import types as sqltypes, util, processors
-from sqlalchemy import types as sqltypes, util
+class _SybNumeric_pyodbc(sqltypes.Numeric):
+ """Turns Decimals with adjusted() < -6 into floats.
+
+ It's not yet known how to get decimals with many
+ significant digits or very large adjusted() into Sybase
+ via pyodbc.
+
+ """
+
+ def bind_processor(self, dialect):
+ super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
+
+ def process(value):
+ if self.asdecimal and \
+ isinstance(value, decimal.Decimal):
-class _SybNumeric_pyodbc(PyODBCNumeric):
- convert_large_decimals_to_string = False
+ if value.adjusted() < -6:
+ return processors.to_float(value)
+
+ if super_process:
+ return super_process(value)
+ else:
+ return value
+ return process
class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
def set_ddl_autocommit(self, connection, value):
@@ -43,8 +65,6 @@ class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
else:
connection.autocommit = False
-
-
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_pyodbc
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 9a53545df..9b3dbedd8 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -107,10 +107,11 @@ def create_engine(*args, **kwargs):
arguments sent as options to the dialect and resulting Engine.
The URL is a string in the form
- ``dialect://user:password@host/dbname[?key=value..]``, where
- ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgresql``,
- etc. Alternatively, the URL can be an instance of
- :class:`~sqlalchemy.engine.url.URL`.
+ ``dialect+driver://user:password@host/dbname[?key=value..]``, where
+ ``dialect`` is a database name such as ``mysql``, ``oracle``,
+ ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
+ ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
+ the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
`**kwargs` takes a wide variety of options which are routed
towards their appropriate components. Arguments may be
@@ -120,11 +121,11 @@ def create_engine(*args, **kwargs):
that are common to most ``create_engine()`` usage.
:param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode
- object is passed when SQLAlchemy would coerce into an encoding
- (note: but **not** when the DBAPI handles unicode objects natively).
- To suppress or raise this warning to an
- error, use the Python warnings filter documented at:
- http://docs.python.org/library/warnings.html
+ object is passed when SQLAlchemy would coerce into an encoding
+ (note: but **not** when the DBAPI handles unicode objects natively).
+ To suppress or raise this warning to an
+ error, use the Python warnings filter documented at:
+ http://docs.python.org/library/warnings.html
:param connect_args: a dictionary of options which will be
passed directly to the DBAPI's ``connect()`` method as
@@ -144,11 +145,6 @@ def create_engine(*args, **kwargs):
connections. Usage of this function causes connection
parameters specified in the URL argument to be bypassed.
- :param logging_name: String identifier which will be used within
- the "name" field of logging records generated within the
- "sqlalchemy.engine" logger. Defaults to a hexstring of the
- object's id.
-
:param echo=False: if True, the Engine will log all statements
as well as a repr() of their parameter lists to the engines
logger, which defaults to sys.stdout. The ``echo`` attribute of
@@ -158,11 +154,6 @@ def create_engine(*args, **kwargs):
controls a Python logger; see :ref:`dbengine_logging` for
information on how to configure logging directly.
- :param pool_logging_name: String identifier which will be used within
- the "name" field of logging records generated within the
- "sqlalchemy.pool" logger. Defaults to a hexstring of the object's
- id.
-
:param echo_pool=False: if True, the connection pool will log
all checkouts/checkins to the logging stream, which defaults to
sys.stdout. This flag ultimately controls a Python logger; see
@@ -178,6 +169,20 @@ def create_engine(*args, **kwargs):
characters. If less than 6, labels are generated as
"_(counter)". If ``None``, the value of
``dialect.max_identifier_length`` is used instead.
+
+ :param listeners: A list of one or more
+ :class:`~sqlalchemy.interfaces.PoolListener` objects which will
+ receive connection pool events.
+
+ :param logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.engine" logger. Defaults to a hexstring of the
+ object's id.
+
+ :param max_overflow=10: the number of connections to allow in
+ connection pool "overflow", that is connections that can be
+ opened above and beyond the pool_size setting, which defaults
+ to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
:param module=None: used by database implementations which
support multiple DBAPI modules, this is a reference to a DBAPI2
@@ -199,10 +204,10 @@ def create_engine(*args, **kwargs):
instantiate the pool in this case, you just indicate what type
of pool to be used.
- :param max_overflow=10: the number of connections to allow in
- connection pool "overflow", that is connections that can be
- opened above and beyond the pool_size setting, which defaults
- to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
+ :param pool_logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.pool" logger. Defaults to a hexstring of the object's
+ id.
:param pool_size=5: the number of connections to keep open
inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 095f7a960..dc42ed957 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1420,6 +1420,9 @@ class Engine(Connectable, log.Identified):
"""
Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect`
together to provide a source of database connectivity and behavior.
+
+ An :class:`Engine` object is instantiated publically using the :func:`~sqlalchemy.create_engine`
+ function.
"""
@@ -1569,7 +1572,7 @@ class Engine(Connectable, log.Identified):
if not schema:
schema = self.dialect.default_schema_name
try:
- return self.dialect.table_names(conn, schema)
+ return self.dialect.get_table_names(conn, schema)
finally:
if connection is None:
conn.close()
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 720edf66c..6fb0a14a5 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -381,7 +381,10 @@ class DefaultExecutionContext(base.ExecutionContext):
self.execution_options = self.execution_options.union(connection._execution_options)
self.cursor = self.create_cursor()
-
+ @util.memoized_property
+ def is_crud(self):
+ return self.isinsert or self.isupdate or self.isdelete
+
@util.memoized_property
def should_autocommit(self):
autocommit = self.execution_options.get('autocommit',
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
index 3226b0efd..dde49e232 100644
--- a/lib/sqlalchemy/ext/compiler.py
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -165,7 +165,7 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression
def compiles(class_, *specs):
def decorate(fn):
- existing = getattr(class_, '_compiler_dispatcher', None)
+ existing = class_.__dict__.get('_compiler_dispatcher', None)
if not existing:
existing = _dispatcher()
diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py
index 775efbff1..1f4658b60 100644
--- a/lib/sqlalchemy/ext/declarative.py
+++ b/lib/sqlalchemy/ext/declarative.py
@@ -507,7 +507,7 @@ Mapped instances then make usage of
from sqlalchemy.schema import Table, Column, MetaData
from sqlalchemy.orm import synonym as _orm_synonym, mapper, comparable_property, class_mapper
from sqlalchemy.orm.interfaces import MapperProperty
-from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty
+from sqlalchemy.orm.properties import RelationshipProperty, ColumnProperty
from sqlalchemy.orm.util import _is_mapped_class
from sqlalchemy import util, exceptions
from sqlalchemy.sql import util as sql_util
@@ -531,31 +531,41 @@ def instrument_declarative(cls, registry, metadata):
def _as_declarative(cls, classname, dict_):
- # doing it this way enables these attributes to be descriptors,
- # see below...
- get_mapper_args = '__mapper_args__' in dict_
- get_table_args = '__table_args__' in dict_
-
# dict_ will be a dictproxy, which we can't write to, and we need to!
dict_ = dict(dict_)
column_copies = dict()
-
+ unmapped_mixins = False
for base in cls.__bases__:
names = dir(base)
if not _is_mapped_class(base):
+ unmapped_mixins = True
for name in names:
- obj = getattr(base,name)
+ obj = getattr(base,name, None)
if isinstance(obj, Column):
+ if obj.foreign_keys:
+ raise exceptions.InvalidRequestError(
+ "Columns with foreign keys to other columns "
+ "are not allowed on declarative mixins at this time."
+ )
dict_[name]=column_copies[obj]=obj.copy()
- get_mapper_args = get_mapper_args or getattr(base,'__mapper_args__',None)
- get_table_args = get_table_args or getattr(base,'__table_args__',None)
- tablename = getattr(base,'__tablename__',None)
- if tablename:
- # subtle: if tablename is a descriptor here, we actually
- # put the wrong value in, but it serves as a marker to get
- # the right value value...
- dict_['__tablename__']=tablename
+ elif isinstance(obj, RelationshipProperty):
+ raise exceptions.InvalidRequestError(
+ "relationships are not allowed on "
+ "declarative mixins at this time.")
+
+ # doing it this way enables these attributes to be descriptors
+ get_mapper_args = '__mapper_args__' in dict_
+ get_table_args = '__table_args__' in dict_
+ if unmapped_mixins:
+ get_mapper_args = get_mapper_args or getattr(cls,'__mapper_args__',None)
+ get_table_args = get_table_args or getattr(cls,'__table_args__',None)
+ tablename = getattr(cls,'__tablename__',None)
+ if tablename:
+ # subtle: if tablename is a descriptor here, we actually
+ # put the wrong value in, but it serves as a marker to get
+ # the right value value...
+ dict_['__tablename__']=tablename
# now that we know whether or not to get these, get them from the class
# if we should, enabling them to be decorators
@@ -777,7 +787,7 @@ def _deferred_relationship(cls, prop):
prop.parent, arg, n.args[0], cls))
return return_cls
- if isinstance(prop, PropertyLoader):
+ if isinstance(prop, RelationshipProperty):
for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
'secondary', '_foreign_keys', 'remote_side'):
v = getattr(prop, attr)
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
new file mode 100644
index 000000000..78e3f5953
--- /dev/null
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -0,0 +1,125 @@
+# horizontal_shard.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Horizontal sharding support.
+
+Defines a rudimental 'horizontal sharding' system which allows a Session to
+distribute queries and persistence operations across multiple databases.
+
+For a usage example, see the :ref:`examples_sharding` example included in
+the source distrbution.
+
+"""
+
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import util
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm.query import Query
+
+__all__ = ['ShardedSession', 'ShardedQuery']
+
+
+class ShardedSession(Session):
+ def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
+ """Construct a ShardedSession.
+
+ :param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a
+ SQL clause, returns a shard ID. This id may be based off of the
+ attributes present within the object, or on some round-robin
+ scheme. If the scheme is based on a selection, it should set
+ whatever state on the instance to mark it in the future as
+ participating in that shard.
+
+ :param id_chooser: A callable, passed a query and a tuple of identity values, which
+ should return a list of shard ids where the ID might reside. The
+ databases will be queried in the order of this listing.
+
+ :param query_chooser: For a given Query, returns the list of shard_ids where the query
+ should be issued. Results from all shards returned will be combined
+ together into a single listing.
+
+ :param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.base.Engine`
+ objects.
+
+ """
+ super(ShardedSession, self).__init__(**kwargs)
+ self.shard_chooser = shard_chooser
+ self.id_chooser = id_chooser
+ self.query_chooser = query_chooser
+ self.__binds = {}
+ self._mapper_flush_opts = {'connection_callable':self.connection}
+ self._query_cls = ShardedQuery
+ if shards is not None:
+ for k in shards:
+ self.bind_shard(k, shards[k])
+
+ def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance)
+
+ if self.transaction is not None:
+ return self.transaction.connection(mapper, shard_id=shard_id)
+ else:
+ return self.get_bind(mapper,
+ shard_id=shard_id,
+ instance=instance).contextual_connect(**kwargs)
+
+ def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance, clause=clause)
+ return self.__binds[shard_id]
+
+ def bind_shard(self, shard_id, bind):
+ self.__binds[shard_id] = bind
+
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self._shard_id = None
+
+ def set_shard(self, shard_id):
+ """return a new query, limited to a single shard ID.
+
+ all subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+ """
+
+ q = self._clone()
+ q._shard_id = shard_id
+ return q
+
+ def _execute_and_instances(self, context):
+ if self._shard_id is not None:
+ result = self.session.connection(
+ mapper=self._mapper_zero(),
+ shard_id=self._shard_id).execute(context.statement, self._params)
+ return self.instances(result, context)
+ else:
+ partial = []
+ for shard_id in self.query_chooser(self):
+ result = self.session.connection(
+ mapper=self._mapper_zero(),
+ shard_id=shard_id).execute(context.statement, self._params)
+ partial = partial + list(self.instances(result, context))
+
+ # if some kind of in memory 'sorting'
+ # were done, this is where it would happen
+ return iter(partial)
+
+ def get(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).get(ident)
+ else:
+ ident = util.to_list(ident)
+ for shard_id in self.id_chooser(self, ident):
+ o = self.set_shard(shard_id).get(ident, **kwargs)
+ if o is not None:
+ return o
+ else:
+ return None
+
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
index db0bd2a4e..0d2c3ae5d 100644
--- a/lib/sqlalchemy/ext/orderinglist.py
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -1,67 +1,92 @@
"""A custom list that manages index/position information for its children.
-``orderinglist`` is a custom list collection implementation for mapped
-relationships that keeps an arbitrary "position" attribute on contained objects in
-sync with each object's position in the Python list.
-
-The collection acts just like a normal Python ``list``, with the added
-behavior that as you manipulate the list (via ``insert``, ``pop``, assignment,
-deletion, what have you), each of the objects it contains is updated as needed
-to reflect its position. This is very useful for managing ordered relationships
-which have a user-defined, serialized order::
-
- >>> from sqlalchemy import MetaData, Table, Column, Integer, String, ForeignKey
- >>> from sqlalchemy.orm import mapper, relationship
- >>> from sqlalchemy.ext.orderinglist import ordering_list
-
-A simple model of users their "top 10" things::
-
- >>> metadata = MetaData()
- >>> users = Table('users', metadata,
- ... Column('id', Integer, primary_key=True))
- >>> blurbs = Table('user_top_ten_list', metadata,
- ... Column('id', Integer, primary_key=True),
- ... Column('user_id', Integer, ForeignKey('users.id')),
- ... Column('position', Integer),
- ... Column('blurb', String(80)))
- >>> class User(object):
- ... pass
- ...
- >>> class Blurb(object):
- ... def __init__(self, blurb):
- ... self.blurb = blurb
- ...
- >>> mapper(User, users, properties={
- ... 'topten': relationship(Blurb, collection_class=ordering_list('position'),
- ... order_by=[blurbs.c.position])})
- <Mapper ...>
- >>> mapper(Blurb, blurbs)
- <Mapper ...>
-
-Acts just like a regular list::
-
- >>> u = User()
- >>> u.topten.append(Blurb('Number one!'))
- >>> u.topten.append(Blurb('Number two!'))
-
-But the ``.position`` attibute is set automatically behind the scenes::
-
- >>> assert [blurb.position for blurb in u.topten] == [0, 1]
-
-The objects will be renumbered automaticaly after any list-changing operation,
-for example an ``insert()``::
-
- >>> u.topten.insert(1, Blurb('I am the new Number Two.'))
- >>> assert [blurb.position for blurb in u.topten] == [0, 1, 2]
- >>> assert u.topten[1].blurb == 'I am the new Number Two.'
- >>> assert u.topten[1].position == 1
-
-Numbering and serialization are both highly configurable. See the docstrings
-in this module and the main SQLAlchemy documentation for more information and
-examples.
-
-The :class:`~sqlalchemy.ext.orderinglist.ordering_list` factory function is the
-ORM-compatible constructor for `OrderingList` instances.
+:author: Jason Kirtland
+
+``orderinglist`` is a helper for mutable ordered relationships. It will intercept
+list operations performed on a relationship collection and automatically
+synchronize changes in list position with an attribute on the related objects.
+(See :ref:`advdatamapping_entitycollections` for more information on the general pattern.)
+
+Example: Two tables that store slides in a presentation. Each slide
+has a number of bullet points, displayed in order by the 'position'
+column on the bullets table. These bullets can be inserted and re-ordered
+by your end users, and you need to update the 'position' column of all
+affected rows when changes are made.
+
+.. sourcecode:: python+sql
+
+ slides_table = Table('Slides', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String))
+
+ bullets_table = Table('Bullets', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('slide_id', Integer, ForeignKey('Slides.id')),
+ Column('position', Integer),
+ Column('text', String))
+
+ class Slide(object):
+ pass
+ class Bullet(object):
+ pass
+
+ mapper(Slide, slides_table, properties={
+ 'bullets': relationship(Bullet, order_by=[bullets_table.c.position])
+ })
+ mapper(Bullet, bullets_table)
+
+The standard relationship mapping will produce a list-like attribute on each Slide
+containing all related Bullets, but coping with changes in ordering is totally
+your responsibility. If you insert a Bullet into that list, there is no
+magic- it won't have a position attribute unless you assign it it one, and
+you'll need to manually renumber all the subsequent Bullets in the list to
+accommodate the insert.
+
+An ``orderinglist`` can automate this and manage the 'position' attribute on all
+related bullets for you.
+
+.. sourcecode:: python+sql
+
+ mapper(Slide, slides_table, properties={
+ 'bullets': relationship(Bullet,
+ collection_class=ordering_list('position'),
+ order_by=[bullets_table.c.position])
+ })
+ mapper(Bullet, bullets_table)
+
+ s = Slide()
+ s.bullets.append(Bullet())
+ s.bullets.append(Bullet())
+ s.bullets[1].position
+ >>> 1
+ s.bullets.insert(1, Bullet())
+ s.bullets[2].position
+ >>> 2
+
+Use the ``ordering_list`` function to set up the ``collection_class`` on relationships
+(as in the mapper example above). This implementation depends on the list
+starting in the proper order, so be SURE to put an order_by on your relationship.
+
+.. warning:: ``ordering_list`` only provides limited functionality when a primary
+ key column or unique column is the target of the sort. Since changing the order of
+ entries often means that two rows must trade values, this is not possible when
+ the value is constrained by a primary key or unique constraint, since one of the rows
+ would temporarily have to point to a third available value so that the other row
+ could take its old value. ``ordering_list`` doesn't do any of this for you,
+ nor does SQLAlchemy itself.
+
+``ordering_list`` takes the name of the related object's ordering attribute as
+an argument. By default, the zero-based integer index of the object's
+position in the ``ordering_list`` is synchronized with the ordering attribute:
+index 0 will get position 0, index 1 position 1, etc. To start numbering at 1
+or some other integer, provide ``count_from=1``.
+
+Ordering values are not limited to incrementing integers. Almost any scheme
+can implemented by supplying a custom ``ordering_func`` that maps a Python list
+index to any value you require.
+
+
+
"""
from sqlalchemy.orm.collections import collection
@@ -288,7 +313,3 @@ class OrderingList(list):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
-if __name__ == '__main__':
- import doctest
- doctest.testmod(optionflags=doctest.ELLIPSIS)
-
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 3337287d8..206c8d0c2 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -83,6 +83,8 @@ __all__ = (
'eagerload_all',
'extension',
'join',
+ 'joinedload',
+ 'joinedload_all',
'lazyload',
'mapper',
'make_transient',
@@ -96,6 +98,8 @@ __all__ = (
'relation',
'scoped_session',
'sessionmaker',
+ 'subqueryload',
+ 'subqueryload_all',
'synonym',
'undefer',
'undefer_group',
@@ -226,24 +230,32 @@ def relationship(argument, secondary=None, **kwargs):
Available cascades are:
- ``save-update`` - cascade the "add()" operation (formerly
- known as save() and update())
+ * ``save-update`` - cascade the :meth:`~sqlalchemy.orm.session.Session.add`
+ operation. This cascade applies both to future and
+ past calls to :meth:`~sqlalchemy.orm.session.Session.add`,
+ meaning new items added to a collection or scalar relationship
+ get placed into the same session as that of the parent, and
+ also applies to items which have been removed from this
+ relationship but are still part of unflushed history.
- ``merge`` - cascade the "merge()" operation
+ * ``merge`` - cascade the :meth:`~sqlalchemy.orm.session.Session.merge`
+ operation
- ``expunge`` - cascade the "expunge()" operation
+ * ``expunge`` - cascade the :meth:`~sqlalchemy.orm.session.Session.expunge`
+ operation
- ``delete`` - cascade the "delete()" operation
+ * ``delete`` - cascade the :meth:`~sqlalchemy.orm.session.Session.delete`
+ operation
- ``delete-orphan`` - if an item of the child's type with no
+ * ``delete-orphan`` - if an item of the child's type with no
parent is detected, mark it for deletion. Note that this
option prevents a pending item of the child's class from being
persisted without a parent present.
- ``refresh-expire`` - cascade the expire() and refresh()
- operations
+ * ``refresh-expire`` - cascade the :meth:`~sqlalchemy.orm.session.Session.expire`
+ and :meth:`~sqlalchemy.orm.session.Session.refresh` operations
- ``all`` - shorthand for "save-update,merge, refresh-expire,
+ * ``all`` - shorthand for "save-update,merge, refresh-expire,
expunge, delete"
:param collection_class:
@@ -263,7 +275,6 @@ def relationship(argument, secondary=None, **kwargs):
change the value used in the operation.
:param foreign_keys:
-
a list of columns which are to be used as "foreign key" columns.
this parameter should be used in conjunction with explicit
``primaryjoin`` and ``secondaryjoin`` (if needed) arguments, and
@@ -276,7 +287,7 @@ def relationship(argument, secondary=None, **kwargs):
the table-defined foreign keys.
:param innerjoin=False:
- when ``True``, eager loads will use an inner join to join
+ when ``True``, joined eager loads will use an inner join to join
against related tables instead of an outer join. The purpose
of this option is strictly one of performance, as inner joins
generally perform better than outer joins. This flag can
@@ -287,33 +298,47 @@ def relationship(argument, secondary=None, **kwargs):
:param join_depth:
when non-``None``, an integer value indicating how many levels
- deep eagerload joins should be constructed on a self-referring
- or cyclical relationship. The number counts how many times the
- same Mapper shall be present in the loading condition along a
- particular join branch. When left at its default of ``None``,
- eager loads will automatically stop chaining joins when they
- encounter a mapper which is already higher up in the chain.
-
- :param lazy=(True|False|None|'dynamic'):
- specifies how the related items should be loaded. Values include:
-
- True - items should be loaded lazily when the property is first
- accessed.
-
- False - items should be loaded "eagerly" in the same query as
- that of the parent, using a JOIN or LEFT OUTER JOIN.
-
- None - no loading should occur at any time. This is to support
- "write-only" attributes, or attributes which are
- populated in some manner specific to the application.
-
- 'dynamic' - a ``DynaLoader`` will be attached, which returns a
- ``Query`` object for all read operations. The
- dynamic- collection supports only ``append()`` and
- ``remove()`` for write operations; changes to the
- dynamic property will not be visible until the data
- is flushed to the database.
-
+ deep "eager" loaders should join on a self-referring or cyclical
+ relationship. The number counts how many times the same Mapper
+ shall be present in the loading condition along a particular join
+ branch. When left at its default of ``None``, eager loaders
+ will stop chaining when they encounter a the same target mapper
+ which is already higher up in the chain. This option applies
+ both to joined- and subquery- eager loaders.
+
+ :param lazy=('select'|'joined'|'subquery'|'noload'|'dynamic'): specifies
+ how the related items should be loaded. Values include:
+
+ * 'select' - items should be loaded lazily when the property is first
+ accessed.
+
+ * 'joined' - items should be loaded "eagerly" in the same query as
+ that of the parent, using a JOIN or LEFT OUTER JOIN.
+
+ * 'subquery' - items should be loaded "eagerly" within the same
+ query as that of the parent, using a second SQL statement
+ which issues a JOIN to a subquery of the original
+ statement.
+
+ * 'noload' - no loading should occur at any time. This is to
+ support "write-only" attributes, or attributes which are
+ populated in some manner specific to the application.
+
+ * 'dynamic' - the attribute will return a pre-configured
+ :class:`~sqlalchemy.orm.query.Query` object for all read
+ operations, onto which further filtering operations can be
+ applied before iterating the results. The dynamic
+ collection supports a limited set of mutation operations,
+ allowing ``append()`` and ``remove()``. Changes to the
+ collection will not be visible until flushed
+ to the database, where it is then refetched upon iteration.
+
+ * True - a synonym for 'select'
+
+ * False - a synonyn for 'joined'
+
+ * None - a synonym for 'noload'
+
:param order_by:
indicates the ordering that should be applied when loading these
items.
@@ -904,76 +929,148 @@ def extension(ext):
return ExtensionOption(ext)
@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated')
-def eagerload(*keys, **kw):
+def joinedload(*keys, **kw):
"""Return a ``MapperOption`` that will convert the property of the given
- name into an eager load.
+ name into an joined eager load.
+
+ .. note:: This function is known as :func:`eagerload` in all versions
+ of SQLAlchemy prior to version 0.6beta3, including the 0.5 and 0.4 series.
+ :func:`eagerload` will remain available for
+ the foreseeable future in order to enable cross-compatibility.
Used with :meth:`~sqlalchemy.orm.query.Query.options`.
examples::
- # eagerload the "orders" colleciton on "User"
- query(User).options(eagerload(User.orders))
+ # joined-load the "orders" colleciton on "User"
+ query(User).options(joinedload(User.orders))
- # eagerload the "keywords" collection on each "Item",
+ # joined-load the "keywords" collection on each "Item",
# but not the "items" collection on "Order" - those
# remain lazily loaded.
- query(Order).options(eagerload(Order.items, Item.keywords))
+ query(Order).options(joinedload(Order.items, Item.keywords))
- # to eagerload across both, use eagerload_all()
- query(Order).options(eagerload_all(Order.items, Item.keywords))
+ # to joined-load across both, use joinedload_all()
+ query(Order).options(joinedload_all(Order.items, Item.keywords))
- :func:`eagerload` also accepts a keyword argument `innerjoin=True` which
+ :func:`joinedload` also accepts a keyword argument `innerjoin=True` which
indicates using an inner join instead of an outer::
- query(Order).options(eagerload(Order.user, innerjoin=True))
+ query(Order).options(joinedload(Order.user, innerjoin=True))
- Note that the join created by :func:`eagerload` is aliased such that
- no other aspects of the query will affect what it loads. To use eager
+ Note that the join created by :func:`joinedload` is aliased such that
+ no other aspects of the query will affect what it loads. To use joined eager
loading with a join that is constructed manually using :meth:`~sqlalchemy.orm.query.Query.join`
or :func:`~sqlalchemy.orm.join`, see :func:`contains_eager`.
+ See also: :func:`subqueryload`, :func:`lazyload`
+
"""
innerjoin = kw.pop('innerjoin', None)
if innerjoin is not None:
return (
- strategies.EagerLazyOption(keys, lazy=False),
+ strategies.EagerLazyOption(keys, lazy='joined'),
strategies.EagerJoinOption(keys, innerjoin)
)
else:
- return strategies.EagerLazyOption(keys, lazy=False)
+ return strategies.EagerLazyOption(keys, lazy='joined')
@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated')
-def eagerload_all(*keys, **kw):
+def joinedload_all(*keys, **kw):
"""Return a ``MapperOption`` that will convert all properties along the
- given dot-separated path into an eager load.
+ given dot-separated path into an joined eager load.
+
+ .. note:: This function is known as :func:`eagerload_all` in all versions
+ of SQLAlchemy prior to version 0.6beta3, including the 0.5 and 0.4 series.
+ :func:`eagerload_all` will remain available for
+ the foreseeable future in order to enable cross-compatibility.
Used with :meth:`~sqlalchemy.orm.query.Query.options`.
For example::
- query.options(eagerload_all('orders.items.keywords'))...
+ query.options(joinedload_all('orders.items.keywords'))...
will set all of 'orders', 'orders.items', and 'orders.items.keywords' to
- load in one eager load.
+ load in one joined eager load.
Individual descriptors are accepted as arguments as well::
- query.options(eagerload_all(User.orders, Order.items, Item.keywords))
+ query.options(joinedload_all(User.orders, Order.items, Item.keywords))
The keyword arguments accept a flag `innerjoin=True|False` which will
override the value of the `innerjoin` flag specified on the relationship().
+ See also: :func:`subqueryload_all`, :func:`lazyload`
+
"""
innerjoin = kw.pop('innerjoin', None)
if innerjoin is not None:
return (
- strategies.EagerLazyOption(keys, lazy=False, chained=True),
+ strategies.EagerLazyOption(keys, lazy='joined', chained=True),
strategies.EagerJoinOption(keys, innerjoin, chained=True)
)
else:
- return strategies.EagerLazyOption(keys, lazy=False, chained=True)
+ return strategies.EagerLazyOption(keys, lazy='joined', chained=True)
+
+def eagerload(*args, **kwargs):
+ """A synonym for :func:`joinedload()`."""
+ return joinedload(*args, **kwargs)
+
+def eagerload_all(*args, **kwargs):
+ """A synonym for :func:`joinedload_all()`"""
+ return joinedload_all(*args, **kwargs)
+
+def subqueryload(*keys):
+ """Return a ``MapperOption`` that will convert the property
+ of the given name into an subquery eager load.
+
+ .. note:: This function is new as of SQLAlchemy version 0.6beta3.
+
+ Used with :meth:`~sqlalchemy.orm.query.Query.options`.
+ examples::
+
+ # subquery-load the "orders" colleciton on "User"
+ query(User).options(subqueryload(User.orders))
+
+ # subquery-load the "keywords" collection on each "Item",
+ # but not the "items" collection on "Order" - those
+ # remain lazily loaded.
+ query(Order).options(subqueryload(Order.items, Item.keywords))
+
+ # to subquery-load across both, use subqueryload_all()
+ query(Order).options(subqueryload_all(Order.items, Item.keywords))
+
+ See also: :func:`joinedload`, :func:`lazyload`
+
+ """
+ return strategies.EagerLazyOption(keys, lazy="subquery")
+
+def subqueryload_all(*keys):
+ """Return a ``MapperOption`` that will convert all properties along the
+ given dot-separated path into a subquery eager load.
+
+ .. note:: This function is new as of SQLAlchemy version 0.6beta3.
+
+ Used with :meth:`~sqlalchemy.orm.query.Query.options`.
+
+ For example::
+
+ query.options(subqueryload_all('orders.items.keywords'))...
+
+ will set all of 'orders', 'orders.items', and 'orders.items.keywords' to
+ load in one subquery eager load.
+
+ Individual descriptors are accepted as arguments as well::
+
+ query.options(subqueryload_all(User.orders, Order.items, Item.keywords))
+
+ See also: :func:`joinedload_all`, :func:`lazyload`
+
+ """
+ return strategies.EagerLazyOption(keys, lazy="subquery", chained=True)
+
@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated')
def lazyload(*keys):
"""Return a ``MapperOption`` that will convert the property of the given
@@ -981,6 +1078,8 @@ def lazyload(*keys):
Used with :meth:`~sqlalchemy.orm.query.Query.options`.
+ See also: :func:`eagerload`, :func:`subqueryload`
+
"""
return strategies.EagerLazyOption(keys, lazy=True)
@@ -990,6 +1089,8 @@ def noload(*keys):
Used with :meth:`~sqlalchemy.orm.query.Query.options`.
+ See also: :func:`lazyload`, :func:`eagerload`, :func:`subqueryload`
+
"""
return strategies.EagerLazyOption(keys, lazy=None)
@@ -1041,7 +1142,7 @@ def contains_eager(*keys, **kwargs):
raise exceptions.ArgumentError("Invalid kwargs for contains_eager: %r" % kwargs.keys())
return (
- strategies.EagerLazyOption(keys, lazy=False, propagate_to_loaders=False),
+ strategies.EagerLazyOption(keys, lazy='joined', propagate_to_loaders=False),
strategies.LoadEagerFromAliasOption(keys, alias=alias)
)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 412fabc23..ca9676469 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -482,6 +482,12 @@ class MapperProperty(object):
self.do_init()
self._compile_finished = True
+ @property
+ def class_attribute(self):
+ """Return the class-bound descriptor corresponding to this MapperProperty."""
+
+ return getattr(self.parent.class_, self.key)
+
def do_init(self):
"""Perform subclass-specific initialization post-mapper-creation steps.
@@ -623,7 +629,7 @@ class StrategizedProperty(MapperProperty):
"""
- def __get_context_strategy(self, context, path):
+ def _get_context_strategy(self, context, path):
cls = context.attributes.get(("loaderstrategy", _reduce_path(path)), None)
if cls:
try:
@@ -645,11 +651,11 @@ class StrategizedProperty(MapperProperty):
return strategy
def setup(self, context, entity, path, adapter, **kwargs):
- self.__get_context_strategy(context, path + (self.key,)).\
+ self._get_context_strategy(context, path + (self.key,)).\
setup_query(context, entity, path, adapter, **kwargs)
def create_row_processor(self, context, path, mapper, row, adapter):
- return self.__get_context_strategy(context, path + (self.key,)).\
+ return self._get_context_strategy(context, path + (self.key,)).\
create_row_processor(context, path, mapper, row, adapter)
def do_init(self):
@@ -734,33 +740,13 @@ class PropertyOption(MapperOption):
self._process(query, False)
def _process(self, query, raiseerr):
- paths, mappers = self.__get_paths(query, raiseerr)
+ paths, mappers = self._get_paths(query, raiseerr)
if paths:
self.process_query_property(query, paths, mappers)
def process_query_property(self, query, paths, mappers):
pass
- def __find_entity(self, query, mapper, raiseerr):
- from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class
-
- if _is_aliased_class(mapper):
- searchfor = mapper
- isa = False
- else:
- searchfor = _class_to_mapper(mapper)
- isa = True
-
- for ent in query._mapper_entities:
- if searchfor is ent.path_entity or (isa and searchfor.common_parent(ent.path_entity)):
- return ent
- else:
- if raiseerr:
- raise sa_exc.ArgumentError("Can't find entity %s in Query. Current list: %r"
- % (searchfor, [str(m.path_entity) for m in query._entities]))
- else:
- return None
-
def __getstate__(self):
d = self.__dict__.copy()
d['key'] = ret = []
@@ -782,7 +768,32 @@ class PropertyOption(MapperOption):
state['key'] = tuple(ret)
self.__dict__ = state
- def __get_paths(self, query, raiseerr):
+ def _find_entity(self, query, mapper, raiseerr):
+ from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class
+
+ if _is_aliased_class(mapper):
+ searchfor = mapper
+ isa = False
+ else:
+ searchfor = _class_to_mapper(mapper)
+ isa = True
+
+ for ent in query._mapper_entities:
+ if searchfor is ent.path_entity or (
+ isa and
+ searchfor.common_parent(ent.path_entity)):
+ return ent
+ else:
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Can't find entity %s in Query. Current list: %r"
+ % (searchfor, [
+ str(m.path_entity) for m in query._entities
+ ]))
+ else:
+ return None
+
+ def _get_paths(self, query, raiseerr):
path = None
entity = None
l = []
@@ -792,61 +803,71 @@ class PropertyOption(MapperOption):
# with an existing path
current_path = list(query._current_path)
- if self.mapper:
- entity = self.__find_entity(query, self.mapper, raiseerr)
- mapper = entity.mapper
- path_element = entity.path_entity
-
+ tokens = []
for key in util.to_list(self.key):
if isinstance(key, basestring):
- tokens = key.split('.')
+ tokens += key.split('.')
else:
- tokens = [key]
- for token in tokens:
- if isinstance(token, basestring):
- if not entity:
- entity = query._entity_zero()
- path_element = entity.path_entity
- mapper = entity.mapper
- mappers.append(mapper)
- prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
- key = token
- elif isinstance(token, PropComparator):
- prop = token.property
- if not entity:
- entity = self.__find_entity(query, token.parententity, raiseerr)
- if not entity:
- return [], []
- path_element = entity.path_entity
- mappers.append(prop.parent)
- key = prop.key
- else:
- raise sa_exc.ArgumentError("mapper option expects string key "
- "or list of attributes")
-
- if current_path and key == current_path[1]:
- current_path = current_path[2:]
- continue
+ tokens += [key]
+
+ for token in tokens:
+ if isinstance(token, basestring):
+ if not entity:
+ if current_path:
+ if current_path[1] == token:
+ current_path = current_path[2:]
+ continue
- if prop is None:
- return [], []
-
- path = build_path(path_element, prop.key, path)
- l.append(path)
- if getattr(token, '_of_type', None):
- path_element = mapper = token._of_type
- else:
- path_element = mapper = getattr(prop, 'mapper', None)
-
- if path_element:
- path_element = path_element
+ entity = query._entity_zero()
+ path_element = entity.path_entity
+ mapper = entity.mapper
+ mappers.append(mapper)
+ prop = mapper.get_property(
+ token,
+ resolve_synonyms=True,
+ raiseerr=raiseerr)
+ key = token
+ elif isinstance(token, PropComparator):
+ prop = token.property
+ if not entity:
+ if current_path:
+ if current_path[0:2] == [token.parententity, prop.key]:
+ current_path = current_path[2:]
+ continue
+
+ entity = self._find_entity(
+ query,
+ token.parententity,
+ raiseerr)
+ if not entity:
+ return [], []
+ path_element = entity.path_entity
+ mapper = entity.mapper
+ mappers.append(prop.parent)
+ key = prop.key
+ else:
+ raise sa_exc.ArgumentError("mapper option expects string key "
+ "or list of attributes")
+
+ if prop is None:
+ return [], []
+
+ path = build_path(path_element, prop.key, path)
+ l.append(path)
+ if getattr(token, '_of_type', None):
+ path_element = mapper = token._of_type
+ else:
+ path_element = mapper = getattr(prop, 'mapper', None)
+
+ if path_element:
+ path_element = path_element
# if current_path tokens remain, then
# we didn't have an exact path match.
if current_path:
return [], []
-
+
return l, mappers
class AttributeExtension(object):
@@ -894,16 +915,15 @@ class StrategizedOption(PropertyOption):
for an operation by a StrategizedProperty.
"""
- def is_chained(self):
- return False
+ is_chained = False
def process_query_property(self, query, paths, mappers):
- # __get_context_strategy may receive the path in terms of
+ # _get_context_strategy may receive the path in terms of
# a base mapper - e.g. options(eagerload_all(Company.employees, Engineer.machines))
# in the polymorphic tests leads to "(Person, 'machines')" in
# the path due to the mechanics of how the eager strategy builds
# up the path
- if self.is_chained():
+ if self.is_chained:
for path in paths:
query._attributes[("loaderstrategy", _reduce_path(path))] = \
self.get_strategy_class()
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 7de02d3f0..ec21b27d6 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -391,19 +391,15 @@ class RelationshipProperty(StrategizedProperty):
self.comparator_factory = comparator_factory or RelationshipProperty.Comparator
self.comparator = self.comparator_factory(self, None)
util.set_creation_order(self)
-
+
if strategy_class:
self.strategy_class = strategy_class
- elif self.lazy == 'dynamic':
+ elif self.lazy== 'dynamic':
from sqlalchemy.orm import dynamic
self.strategy_class = dynamic.DynaLoader
- elif self.lazy is False:
- self.strategy_class = strategies.EagerLoader
- elif self.lazy is None:
- self.strategy_class = strategies.NoLoader
else:
- self.strategy_class = strategies.LazyLoader
-
+ self.strategy_class = strategies.factory(self.lazy)
+
self._reverse_property = set()
if cascade is not False:
@@ -411,8 +407,12 @@ class RelationshipProperty(StrategizedProperty):
else:
self.cascade = CascadeOptions("save-update, merge")
- if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade):
- raise sa_exc.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade")
+ if self.passive_deletes == 'all' and \
+ ("delete" in self.cascade or
+ "delete-orphan" in self.cascade):
+ raise sa_exc.ArgumentError(
+ "Can't set passive_deletes='all' in conjunction "
+ "with 'delete' or 'delete-orphan' cascade")
self.order_by = order_by
@@ -420,7 +420,9 @@ class RelationshipProperty(StrategizedProperty):
if self.back_populates:
if backref:
- raise sa_exc.ArgumentError("backref and back_populates keyword arguments are mutually exclusive")
+ raise sa_exc.ArgumentError(
+ "backref and back_populates keyword arguments "
+ "are mutually exclusive")
self.backref = None
else:
self.backref = backref
@@ -467,7 +469,10 @@ class RelationshipProperty(StrategizedProperty):
return op(self, *other, **kwargs)
def of_type(self, cls):
- return RelationshipProperty.Comparator(self.property, self.mapper, cls, adapter=self.adapter)
+ return RelationshipProperty.Comparator(
+ self.property,
+ self.mapper,
+ cls, adapter=self.adapter)
def in_(self, other):
raise NotImplementedError("in_() not yet supported for relationships. For a "
@@ -480,11 +485,21 @@ class RelationshipProperty(StrategizedProperty):
if self.property.direction in [ONETOMANY, MANYTOMANY]:
return ~self._criterion_exists()
else:
- return _orm_annotate(self.property._optimized_compare(None, adapt_source=self.adapter))
+ return _orm_annotate(
+ self.property._optimized_compare(
+ None,
+ adapt_source=self.adapter)
+ )
elif self.property.uselist:
- raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.")
+ raise sa_exc.InvalidRequestError(
+ "Can't compare a collection to an object or "
+ "collection; use contains() to test for membership.")
else:
- return _orm_annotate(self.property._optimized_compare(other, adapt_source=self.adapter))
+ return _orm_annotate(
+ self.property._optimized_compare(
+ other,
+ adapt_source=self.adapter)
+ )
def _criterion_exists(self, criterion=None, **kwargs):
if getattr(self, '_of_type', None):
@@ -508,7 +523,10 @@ class RelationshipProperty(StrategizedProperty):
source_selectable = None
pj, sj, source, dest, secondary, target_adapter = \
- self.property._create_joins(dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable)
+ self.property._create_joins(
+ dest_polymorphic=True,
+ dest_selectable=to_selectable,
+ source_selectable=source_selectable)
for k in kwargs:
crit = self.property.mapper.class_manager[k] == kwargs[k]
@@ -517,9 +535,9 @@ class RelationshipProperty(StrategizedProperty):
else:
criterion = criterion & crit
- # annotate the *local* side of the join condition, in the case of pj + sj this
- # is the full primaryjoin, in the case of just pj its the local side of
- # the primaryjoin.
+ # annotate the *local* side of the join condition, in the case
+ # of pj + sj this is the full primaryjoin, in the case of just
+ # pj its the local side of the primaryjoin.
if sj is not None:
j = _orm_annotate(pj) & sj
else:
@@ -529,8 +547,10 @@ class RelationshipProperty(StrategizedProperty):
# limit this adapter to annotated only?
criterion = target_adapter.traverse(criterion)
- # only have the "joined left side" of what we return be subject to Query adaption. The right
- # side of it is used for an exists() subquery and should not correlate or otherwise reach out
+ # only have the "joined left side" of what we
+ # return be subject to Query adaption. The right
+ # side of it is used for an exists() subquery and
+ # should not correlate or otherwise reach out
# to anything in the enclosing query.
if criterion is not None:
criterion = criterion._annotate({'_halt_adapt': True})
@@ -541,18 +561,25 @@ class RelationshipProperty(StrategizedProperty):
def any(self, criterion=None, **kwargs):
if not self.property.uselist:
- raise sa_exc.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
+ raise sa_exc.InvalidRequestError(
+ "'any()' not implemented for scalar "
+ "attributes. Use has()."
+ )
return self._criterion_exists(criterion, **kwargs)
def has(self, criterion=None, **kwargs):
if self.property.uselist:
- raise sa_exc.InvalidRequestError("'has()' not implemented for collections. Use any().")
+ raise sa_exc.InvalidRequestError(
+ "'has()' not implemented for collections. "
+ "Use any().")
return self._criterion_exists(criterion, **kwargs)
def contains(self, other, **kwargs):
if not self.property.uselist:
- raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==")
+ raise sa_exc.InvalidRequestError(
+ "'contains' not implemented for scalar "
+ "attributes. Use ==")
clause = self.property._optimized_compare(other, adapt_source=self.adapter)
if self.property.secondaryjoin is not None:
@@ -563,7 +590,6 @@ class RelationshipProperty(StrategizedProperty):
def __negated_contains_or_equals(self, other):
if self.property.direction == MANYTOONE:
state = attributes.instance_state(other)
- strategy = self.property._get_strategy(strategies.LazyLoader)
def state_bindparam(state, col):
o = state.obj() # strong ref
@@ -575,14 +601,20 @@ class RelationshipProperty(StrategizedProperty):
else:
return col
- if strategy.use_get:
+ if self.property._use_get:
return sql.and_(*[
sql.or_(
adapt(x) != state_bindparam(state, y),
adapt(x) == None)
for (x, y) in self.property.local_remote_pairs])
- criterion = sql.and_(*[x==y for (x, y) in zip(self.property.mapper.primary_key, self.property.mapper.primary_key_from_instance(other))])
+ criterion = sql.and_(*[x==y for (x, y) in
+ zip(
+ self.property.mapper.primary_key,
+ self.property.\
+ mapper.\
+ primary_key_from_instance(other))
+ ])
return ~self._criterion_exists(criterion)
def __ne__(self, other):
@@ -592,7 +624,9 @@ class RelationshipProperty(StrategizedProperty):
else:
return self._criterion_exists()
elif self.property.uselist:
- raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.")
+ raise sa_exc.InvalidRequestError(
+ "Can't compare a collection to an object or "
+ "collection; use contains() to test for membership.")
else:
return self.__negated_contains_or_equals(other)
@@ -629,7 +663,13 @@ class RelationshipProperty(StrategizedProperty):
def __str__(self):
return str(self.parent.class_.__name__) + "." + self.key
- def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive):
+ def merge(self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load, _recursive):
if load:
# TODO: no test coverage for recursive check
for r in self._reverse_property:
@@ -702,6 +742,8 @@ class RelationshipProperty(StrategizedProperty):
else:
instances = state.value_as_iterable(self.key, passive=passive)
+ skip_pending = type_ == 'refresh-expire' and 'delete-orphan' not in self.cascade
+
if instances:
for c in instances:
if c is not None and \
@@ -717,12 +759,17 @@ class RelationshipProperty(StrategizedProperty):
str(self.parent.class_),
str(c.__class__)
))
+ instance_state = attributes.instance_state(c)
+
+ if skip_pending and not instance_state.key:
+ continue
+
visited_instances.add(c)
# cascade using the mapper local to this
# object, so that its individual properties are located
- instance_mapper = object_mapper(c)
- yield (c, instance_mapper, attributes.instance_state(c))
+ instance_mapper = instance_state.manager.mapper
+ yield (c, instance_mapper, instance_state)
def _add_reverse_property(self, key):
other = self.mapper._get_property(key)
@@ -870,7 +917,10 @@ class RelationshipProperty(StrategizedProperty):
]
if not eq_pairs:
- if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True):
+ if not self.viewonly and criterion_as_pairs(
+ self.primaryjoin,
+ consider_as_foreign_keys=self._foreign_keys,
+ any_operator=True):
raise sa_exc.ArgumentError("Could not locate any equated, locally "
"mapped column pairs for primaryjoin condition '%s' on relationship %s. "
"For more relaxed rules on join conditions, the relationship may be "
@@ -891,11 +941,24 @@ class RelationshipProperty(StrategizedProperty):
self.synchronize_pairs = eq_pairs
if self.secondaryjoin is not None:
- sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=self.viewonly)
- sq_pairs = [(l, r) for l, r in sq_pairs if (self._col_is_part_of_mappings(l) and self._col_is_part_of_mappings(r)) or r in self._foreign_keys]
+ sq_pairs = criterion_as_pairs(
+ self.secondaryjoin,
+ consider_as_foreign_keys=self._foreign_keys,
+ any_operator=self.viewonly)
+
+ sq_pairs = [
+ (l, r)
+ for l, r in sq_pairs
+ if (self._col_is_part_of_mappings(l) and
+ self._col_is_part_of_mappings(r)) or
+ r in self._foreign_keys
+ ]
if not sq_pairs:
- if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True):
+ if not self.viewonly and criterion_as_pairs(
+ self.secondaryjoin,
+ consider_as_foreign_keys=self._foreign_keys,
+ any_operator=True):
raise sa_exc.ArgumentError("Could not locate any equated, locally mapped "
"column pairs for secondaryjoin condition '%s' on relationship %s. "
"For more relaxed rules on join conditions, the "
@@ -1004,17 +1067,29 @@ class RelationshipProperty(StrategizedProperty):
if self.secondaryjoin is not None:
eq_pairs += self.secondary_synchronize_pairs
else:
- eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True)
+ eq_pairs = criterion_as_pairs(
+ self.primaryjoin,
+ consider_as_foreign_keys=self._foreign_keys,
+ any_operator=True)
if self.secondaryjoin is not None:
- eq_pairs += criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True)
- eq_pairs = [(l, r) for l, r in eq_pairs if self._col_is_part_of_mappings(l) and self._col_is_part_of_mappings(r)]
+ eq_pairs += criterion_as_pairs(
+ self.secondaryjoin,
+ consider_as_foreign_keys=self._foreign_keys,
+ any_operator=True)
+
+ eq_pairs = [
+ (l, r) for l, r in eq_pairs
+ if self._col_is_part_of_mappings(l) and
+ self._col_is_part_of_mappings(r)
+ ]
if self.direction is MANYTOONE:
self.local_remote_pairs = [(r, l) for l, r in eq_pairs]
else:
self.local_remote_pairs = eq_pairs
elif self.remote_side:
- raise sa_exc.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.")
+ raise sa_exc.ArgumentError("remote_side argument is redundant "
+ "against more detailed _local_remote_side argument.")
for l, r in self.local_remote_pairs:
@@ -1028,16 +1103,20 @@ class RelationshipProperty(StrategizedProperty):
"Specify remote_side argument to indicate which column lazy "
"join condition should bind." % (r, self.mapper))
- self.local_side, self.remote_side = [util.ordered_column_set(x) for x in zip(*list(self.local_remote_pairs))]
+ self.local_side, self.remote_side = [
+ util.ordered_column_set(x) for x in
+ zip(*list(self.local_remote_pairs))]
def _assert_is_primary(self):
if not self.is_primary() and \
- not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False):
+ not mapper.class_mapper(self.parent.class_, compile=False).\
+ _get_property(self.key, raiseerr=False):
raise sa_exc.ArgumentError("Attempting to assign a new relationship '%s' to "
"a non-primary mapper on class '%s'. New relationships can only be "
"added to the primary mapper, i.e. the very first "
- "mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
+ "mapper created for class '%s' " %
+ (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
def _generate_backref(self):
if not self.is_primary():
@@ -1093,17 +1172,27 @@ class RelationshipProperty(StrategizedProperty):
def _post_init(self):
self.logger.info("%s setup primary join %s", self, self.primaryjoin)
self.logger.info("%s setup secondary join %s", self, self.secondaryjoin)
- self.logger.info("%s synchronize pairs [%s]", self, ",".join("(%s => %s)" % (l, r) for l, r in self.synchronize_pairs))
- self.logger.info("%s secondary synchronize pairs [%s]", self, ",".join(("(%s => %s)" % (l, r) for l, r in self.secondary_synchronize_pairs or [])))
- self.logger.info("%s local/remote pairs [%s]", self, ",".join("(%s / %s)" % (l, r) for l, r in self.local_remote_pairs))
+ self.logger.info("%s synchronize pairs [%s]", self,
+ ",".join("(%s => %s)" % (l, r) for l, r in self.synchronize_pairs))
+ self.logger.info("%s secondary synchronize pairs [%s]", self,
+ ",".join(("(%s => %s)" % (l, r) for l, r in self.secondary_synchronize_pairs or [])))
+ self.logger.info("%s local/remote pairs [%s]", self,
+ ",".join("(%s / %s)" % (l, r) for l, r in self.local_remote_pairs))
self.logger.info("%s relationship direction %s", self, self.direction)
if self.uselist is None:
self.uselist = self.direction is not MANYTOONE
-
+
if not self.viewonly:
self._dependency_processor = dependency.create_dependency_processor(self)
-
+
+ @util.memoized_property
+ def _use_get(self):
+ """memoize the 'use_get' attribute of this RelationshipLoader's lazyloader."""
+
+ strategy = self._get_strategy(strategies.LazyLoader)
+ return strategy.use_get
+
def _refers_to_parent_table(self):
for c, f in self.synchronize_pairs:
if c.table is f.table:
@@ -1114,7 +1203,9 @@ class RelationshipProperty(StrategizedProperty):
def _is_self_referential(self):
return self.mapper.common_parent(self.parent)
- def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None, of_type=None):
+ def _create_joins(self, source_polymorphic=False,
+ source_selectable=None, dest_polymorphic=False,
+ dest_selectable=None, of_type=None):
if source_selectable is None:
if source_polymorphic and self.parent.with_polymorphic:
source_selectable = self.parent._with_polymorphic_selectable
@@ -1157,7 +1248,10 @@ class RelationshipProperty(StrategizedProperty):
secondary = secondary.alias()
primary_aliasizer = ClauseAdapter(secondary)
if dest_selectable is not None:
- secondary_aliasizer = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns).chain(primary_aliasizer)
+ secondary_aliasizer = \
+ ClauseAdapter(dest_selectable,
+ equivalents=self.mapper._equivalent_columns).\
+ chain(primary_aliasizer)
else:
secondary_aliasizer = primary_aliasizer
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 682aa2bbf..e98ad8937 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -84,6 +84,7 @@ class Query(object):
_params = util.frozendict()
_attributes = util.frozendict()
_with_options = ()
+ _with_hints = ()
def __init__(self, entities, session=None):
self.session = session
@@ -114,7 +115,8 @@ class Query(object):
mapper, selectable, is_aliased_class = _entity_info(entity)
if not is_aliased_class and mapper.with_polymorphic:
with_polymorphic = mapper._with_polymorphic_mappers
- self.__mapper_loads_polymorphically_with(mapper,
+ if mapper.mapped_table not in self._polymorphic_adapters:
+ self.__mapper_loads_polymorphically_with(mapper,
sql_util.ColumnAdapter(selectable, mapper._equivalent_columns))
adapter = None
elif is_aliased_class:
@@ -133,7 +135,7 @@ class Query(object):
self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
def _set_select_from(self, *obj):
-
+
fa = []
for from_obj in obj:
if isinstance(from_obj, expression._SelectBaseMixin):
@@ -142,9 +144,8 @@ class Query(object):
self._from_obj = tuple(fa)
- # TODO: only use this adapter for from_self() ? right
- # now its usage is somewhat arbitrary.
- if len(self._from_obj) == 1 and isinstance(self._from_obj[0], expression.Alias):
+ if len(self._from_obj) == 1 and \
+ isinstance(self._from_obj[0], expression.Alias):
equivs = self.__all_equivs()
self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj[0], equivs)
@@ -198,7 +199,13 @@ class Query(object):
@_generative()
def _adapt_all_clauses(self):
self._disable_orm_filtering = True
-
+
+ def _adapt_col_list(self, cols):
+ return [
+ self._adapt_clause(expression._literal_as_text(o), True, True)
+ for o in cols
+ ]
+
def _adapt_clause(self, clause, as_filter, orm_only):
adapters = []
if as_filter and self._filter_aliases:
@@ -375,7 +382,8 @@ class Query(object):
statement._annotate({'_halt_adapt': True})
def subquery(self):
- """return the full SELECT statement represented by this Query, embedded within an Alias.
+ """return the full SELECT statement represented by this Query,
+ embedded within an Alias.
Eager JOIN generation within the query is disabled.
@@ -391,11 +399,14 @@ class Query(object):
@_generative()
def enable_eagerloads(self, value):
- """Control whether or not eager joins are rendered.
+ """Control whether or not eager joins and subqueries are
+ rendered.
When set to False, the returned Query will not render
- eager joins regardless of eagerload() options
- or mapper-level lazy=False configurations.
+ eager joins regardless of :func:`~sqlalchemy.orm.joinedload`,
+ :func:`~sqlalchemy.orm.subqueryload` options
+ or mapper-level ``lazy='joined'``/``lazy='subquery'``
+ configurations.
This is used primarily when nesting the Query's
statement into a subquery or other
@@ -502,13 +513,16 @@ class Query(object):
overwritten.
In particular, it's usually impossible to use this setting with
- eagerly loaded collections (i.e. any lazy=False) since those
- collections will be cleared for a new load when encountered in a
- subsequent result batch.
+ eagerly loaded collections (i.e. any lazy='joined' or 'subquery')
+ since those collections will be cleared for a new load when
+ encountered in a subsequent result batch. In the case of 'subquery'
+ loading, the full result for all rows is fetched which generally
+ defeats the purpose of :meth:`~sqlalchemy.orm.query.Query.yield_per`.
Also note that many DBAPIs do not "stream" results, pre-buffering
all rows before making them available, including mysql-python and
- psycopg2. yield_per() will also set the ``stream_results`` execution
+ psycopg2. :meth:`~sqlalchemy.orm.query.Query.yield_per` will also
+ set the ``stream_results`` execution
option to ``True``, which currently is only understood by psycopg2
and causes server side cursors to be used.
@@ -618,17 +632,20 @@ class Query(object):
those being selected.
"""
- fromclause = self.with_labels().enable_eagerloads(False).statement.correlate(None)
+ fromclause = self.with_labels().enable_eagerloads(False).\
+ statement.correlate(None)
q = self._from_selectable(fromclause)
if entities:
q._set_entities(entities)
return q
-
+
@_generative()
def _from_selectable(self, fromclause):
- self._statement = self._criterion = None
- self._order_by = self._group_by = self._distinct = False
- self._limit = self._offset = None
+ for attr in ('_statement', '_criterion', '_order_by', '_group_by',
+ '_limit', '_offset', '_joinpath', '_joinpoint',
+ '_distinct'
+ ):
+ self.__dict__.pop(attr, None)
self._set_select_from(fromclause)
old_entities = self._entities
self._entities = []
@@ -659,16 +676,25 @@ class Query(object):
return None
@_generative()
- def add_column(self, column):
- """Add a SQL ColumnElement to the list of result columns to be returned."""
+ def add_columns(self, *column):
+ """Add one or more column expressions to the list
+ of result columns to be returned."""
self._entities = list(self._entities)
l = len(self._entities)
- _ColumnEntity(self, column)
+ for c in column:
+ _ColumnEntity(self, c)
# _ColumnEntity may add many entities if the
# given arg is a FROM clause
self._setup_aliasizers(self._entities[l:])
+ @util.pending_deprecation("add_column() superceded by add_columns()")
+ def add_column(self, column):
+ """Add a column expression to the list of result columns
+ to be returned."""
+
+ return self.add_columns(column)
+
def options(self, *args):
"""Return a new Query object, applying the given list of
MapperOptions.
@@ -694,6 +720,21 @@ class Query(object):
opt.process_query(self)
@_generative()
+ def with_hint(self, selectable, text, dialect_name=None):
+ """Add an indexing hint for the given entity or selectable to
+ this :class:`Query`.
+
+ Functionality is passed straight through to
+ :meth:`~sqlalchemy.sql.expression.Select.with_hint`,
+ with the addition that ``selectable`` can be a
+ :class:`Table`, :class:`Alias`, or ORM entity / mapped class
+ /etc.
+ """
+ mapper, selectable, is_aliased_class = _entity_info(selectable)
+
+ self._with_hints += ((selectable, text, dialect_name),)
+
+ @_generative()
def execution_options(self, **kwargs):
""" Set non-SQL options which take effect during execution.
@@ -761,7 +802,6 @@ class Query(object):
return self.filter(sql.and_(*clauses))
-
@_generative(_no_statement_condition, _no_limit_offset)
@util.accepts_a_list_as_starargs(list_deprecation='deprecated')
def order_by(self, *criterion):
@@ -770,7 +810,7 @@ class Query(object):
if len(criterion) == 1 and criterion[0] is None:
self._order_by = None
else:
- criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion]
+ criterion = self._adapt_col_list(criterion)
if self._order_by is False or self._order_by is None:
self._order_by = criterion
@@ -784,7 +824,7 @@ class Query(object):
criterion = list(chain(*[_orm_columns(c) for c in criterion]))
- criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion]
+ criterion = self._adapt_col_list(criterion)
if self._group_by is False:
self._group_by = criterion
@@ -1013,6 +1053,18 @@ class Query(object):
descriptor, prop = _entity_descriptor(left_entity, onclause)
onclause = descriptor
+
+ # check for q.join(Class.propname, from_joinpoint=True)
+ # and Class is that of the current joinpoint
+ elif from_joinpoint and isinstance(onclause, interfaces.PropComparator):
+ left_entity = onclause.parententity
+
+ left_mapper, left_selectable, left_is_aliased = \
+ _entity_info(self._joinpoint_zero())
+ if left_mapper is left_entity:
+ left_entity = self._joinpoint_zero()
+ descriptor, prop = _entity_descriptor(left_entity, onclause.key)
+ onclause = descriptor
if isinstance(onclause, interfaces.PropComparator):
if right_entity is None:
@@ -1022,7 +1074,7 @@ class Query(object):
right_entity = of_type
else:
right_entity = onclause.property.mapper
-
+
left_entity = onclause.parententity
prop = onclause.property
@@ -1051,6 +1103,12 @@ class Query(object):
if left is None:
left = self._joinpoint_zero()
+ if left is right and \
+ not create_aliases:
+ raise sa_exc.InvalidRequestError(
+ "Can't construct a join from %s to %s, they are the same entity" %
+ (left, right))
+
left_mapper, left_selectable, left_is_aliased = _entity_info(left)
right_mapper, right_selectable, is_aliased_class = _entity_info(right)
@@ -1312,7 +1370,7 @@ class Query(object):
first() applies a limit of one within the generated SQL, so that
only one primary entity row is generated on the server side
- (note this may consist of multiple result rows if eagerly loaded
+ (note this may consist of multiple result rows if join-loaded
collections are present).
Calling ``first()`` results in an execution of the underlying query.
@@ -2011,7 +2069,10 @@ class Query(object):
order_by=context.order_by,
**self._select_args
)
-
+
+ for hint in self._with_hints:
+ inner = inner.with_hint(*hint)
+
if self._correlate:
inner = inner.correlate(*self._correlate)
@@ -2066,6 +2127,10 @@ class Query(object):
order_by=context.order_by,
**self._select_args
)
+
+ for hint in self._with_hints:
+ statement = statement.with_hint(*hint)
+
if self._execution_options:
statement = statement.execution_options(**self._execution_options)
@@ -2166,14 +2231,14 @@ class _MapperEntity(_QueryEntity):
query._entities.append(self)
def _get_entity_clauses(self, query, context):
-
+
adapter = None
if not self.is_aliased_class and query._polymorphic_adapters:
adapter = query._polymorphic_adapters.get(self.mapper, None)
if not adapter and self.adapter:
adapter = self.adapter
-
+
if adapter:
if query._from_obj_alias:
ret = adapter.wrap(query._from_obj_alias)
@@ -2247,7 +2312,6 @@ class _MapperEntity(_QueryEntity):
def __str__(self):
return str(self.mapper)
-
class _ColumnEntity(_QueryEntity):
"""Column/expression based entity."""
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 0a3fbe79e..0810175bf 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -883,7 +883,7 @@ class Session(object):
state.commit_all(dict_, self.identity_map)
def refresh(self, instance, attribute_names=None, lockmode=None):
- """Refresh the attributes on the given instance.
+ """Expire and refresh the attributes on the given instance.
A query will be issued to the database and all attributes will be
refreshed with their current database value.
@@ -907,7 +907,9 @@ class Session(object):
state = attributes.instance_state(instance)
except exc.NO_STATE:
raise exc.UnmappedInstanceError(instance)
- self._validate_persistent(state)
+
+ self._expire_state(state, attribute_names)
+
if self.query(_object_mapper(instance))._get(
state.key, refresh_state=state,
lockmode=lockmode,
@@ -939,18 +941,31 @@ class Session(object):
state = attributes.instance_state(instance)
except exc.NO_STATE:
raise exc.UnmappedInstanceError(instance)
+ self._expire_state(state, attribute_names)
+
+ def _expire_state(self, state, attribute_names):
self._validate_persistent(state)
if attribute_names:
_expire_state(state, state.dict,
- attribute_names=attribute_names, instance_dict=self.identity_map)
+ attribute_names=attribute_names,
+ instance_dict=self.identity_map)
else:
# pre-fetch the full cascade since the expire is going to
# remove associations
cascaded = list(_cascade_state_iterator('refresh-expire', state))
- _expire_state(state, state.dict, None, instance_dict=self.identity_map)
+ self._conditional_expire(state)
for (state, m, o) in cascaded:
- _expire_state(state, state.dict, None, instance_dict=self.identity_map)
-
+ self._conditional_expire(state)
+
+ def _conditional_expire(self, state):
+ """Expire a state if persistent, else expunge if pending"""
+
+ if state.key:
+ _expire_state(state, state.dict, None, instance_dict=self.identity_map)
+ elif state in self._new:
+ self._new.pop(state)
+ state.detach()
+
def prune(self):
"""Remove unreferenced instances cached in the identity map.
diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py
index b6026bbc3..9cb26db79 100644
--- a/lib/sqlalchemy/orm/shard.py
+++ b/lib/sqlalchemy/orm/shard.py
@@ -4,114 +4,12 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Horizontal sharding support.
-
-Defines a rudimental 'horizontal sharding' system which allows a Session to
-distribute queries and persistence operations across multiple databases.
-
-For a usage example, see the file ``examples/sharding/attribute_shard.py``
-included in the source distrbution.
-
-"""
-
-import sqlalchemy.exceptions as sa_exc
from sqlalchemy import util
-from sqlalchemy.orm.session import Session
-from sqlalchemy.orm.query import Query
-
-__all__ = ['ShardedSession', 'ShardedQuery']
-
-
-class ShardedSession(Session):
- def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
- """Construct a ShardedSession.
-
- shard_chooser
- A callable which, passed a Mapper, a mapped instance, and possibly a
- SQL clause, returns a shard ID. This id may be based off of the
- attributes present within the object, or on some round-robin
- scheme. If the scheme is based on a selection, it should set
- whatever state on the instance to mark it in the future as
- participating in that shard.
-
- id_chooser
- A callable, passed a query and a tuple of identity values, which
- should return a list of shard ids where the ID might reside. The
- databases will be queried in the order of this listing.
-
- query_chooser
- For a given Query, returns the list of shard_ids where the query
- should be issued. Results from all shards returned will be combined
- together into a single listing.
-
- """
- super(ShardedSession, self).__init__(**kwargs)
- self.shard_chooser = shard_chooser
- self.id_chooser = id_chooser
- self.query_chooser = query_chooser
- self.__binds = {}
- self._mapper_flush_opts = {'connection_callable':self.connection}
- self._query_cls = ShardedQuery
- if shards is not None:
- for k in shards:
- self.bind_shard(k, shards[k])
-
- def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
- if shard_id is None:
- shard_id = self.shard_chooser(mapper, instance)
-
- if self.transaction is not None:
- return self.transaction.connection(mapper, shard_id=shard_id)
- else:
- return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs)
-
- def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw):
- if shard_id is None:
- shard_id = self.shard_chooser(mapper, instance, clause=clause)
- return self.__binds[shard_id]
- def bind_shard(self, shard_id, bind):
- self.__binds[shard_id] = bind
+util.warn_deprecated(
+ "Horizontal sharding is now importable via "
+ "'import sqlalchemy.ext.horizontal_shard"
+)
-class ShardedQuery(Query):
- def __init__(self, *args, **kwargs):
- super(ShardedQuery, self).__init__(*args, **kwargs)
- self.id_chooser = self.session.id_chooser
- self.query_chooser = self.session.query_chooser
- self._shard_id = None
-
- def set_shard(self, shard_id):
- """return a new query, limited to a single shard ID.
-
- all subsequent operations with the returned query will
- be against the single shard regardless of other state.
- """
-
- q = self._clone()
- q._shard_id = shard_id
- return q
-
- def _execute_and_instances(self, context):
- if self._shard_id is not None:
- result = self.session.connection(mapper=self._mapper_zero(), shard_id=self._shard_id).execute(context.statement, self._params)
- return self.instances(result, context)
- else:
- partial = []
- for shard_id in self.query_chooser(self):
- result = self.session.connection(mapper=self._mapper_zero(), shard_id=shard_id).execute(context.statement, self._params)
- partial = partial + list(self.instances(result, context))
- # if some kind of in memory 'sorting' were done, this is where it would happen
- return iter(partial)
+from sqlalchemy.ext.horizontal_shard import *
- def get(self, ident, **kwargs):
- if self._shard_id is not None:
- return super(ShardedQuery, self).get(ident)
- else:
- ident = util.to_list(ident)
- for shard_id in self.id_chooser(self, ident):
- o = self.set_shard(shard_id).get(ident, **kwargs)
- if o is not None:
- return o
- else:
- return None
-
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index ce19667c6..93b1170f4 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -4,7 +4,8 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
+"""sqlalchemy.orm.interfaces.LoaderStrategy
+ implementations, and related MapperOptions."""
from sqlalchemy import exc as sa_exc
from sqlalchemy import sql, util, log
@@ -17,6 +18,7 @@ from sqlalchemy.orm.interfaces import (
)
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
+import itertools
def _register_attribute(strategy, mapper, useobject,
compare_function=None,
@@ -38,7 +40,9 @@ def _register_attribute(strategy, mapper, useobject,
attribute_ext.insert(0, _SingleParentValidator(prop))
if prop.key in prop.parent._validators:
- attribute_ext.insert(0, mapperutil.Validator(prop.key, prop.parent._validators[prop.key]))
+ attribute_ext.insert(0,
+ mapperutil.Validator(prop.key, prop.parent._validators[prop.key])
+ )
if useobject:
attribute_ext.append(sessionlib.UOWEventHandler(prop.key))
@@ -66,7 +70,7 @@ def _register_attribute(strategy, mapper, useobject,
)
class UninstrumentedColumnLoader(LoaderStrategy):
- """Represent the strategy for a MapperProperty that doesn't instrument the class.
+ """Represent the a non-instrumented MapperProperty.
The polymorphic_on argument of mapper() often results in this,
if the argument is against the with_polymorphic selectable.
@@ -75,14 +79,15 @@ class UninstrumentedColumnLoader(LoaderStrategy):
def init(self):
self.columns = self.parent_property.columns
- def setup_query(self, context, entity, path, adapter, column_collection=None, **kwargs):
+ def setup_query(self, context, entity, path, adapter,
+ column_collection=None, **kwargs):
for c in self.columns:
if adapter:
c = adapter.columns[c]
column_collection.append(c)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
- return (None, None)
+ return None, None
class ColumnLoader(LoaderStrategy):
"""Strategize the loading of a plain column-based MapperProperty."""
@@ -91,7 +96,8 @@ class ColumnLoader(LoaderStrategy):
self.columns = self.parent_property.columns
self.is_composite = hasattr(self.parent_property, 'composite_class')
- def setup_query(self, context, entity, path, adapter, column_collection=None, **kwargs):
+ def setup_query(self, context, entity, path, adapter,
+ column_collection=None, **kwargs):
for c in self.columns:
if adapter:
c = adapter.columns[c]
@@ -135,7 +141,8 @@ class CompositeColumnLoader(ColumnLoader):
def copy(obj):
if obj is None:
return None
- return self.parent_property.composite_class(*obj.__composite_values__())
+ return self.parent_property.\
+ composite_class(*obj.__composite_values__())
def compare(a, b):
if a is None or b is None:
@@ -156,7 +163,8 @@ class CompositeColumnLoader(ColumnLoader):
#active_history ?
)
- def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ def create_row_processor(self, selectcontext, path, mapper,
+ row, adapter):
key = self.key
columns = self.columns
composite_class = self.parent_property.composite_class
@@ -203,7 +211,8 @@ class DeferredColumnLoader(LoaderStrategy):
def init(self):
if hasattr(self.parent_property, 'composite_class'):
- raise NotImplementedError("Deferred loading for composite types not implemented yet")
+ raise NotImplementedError("Deferred loading for composite "
+ "types not implemented yet")
self.columns = self.parent_property.columns
self.group = self.parent_property.group
@@ -218,13 +227,15 @@ class DeferredColumnLoader(LoaderStrategy):
expire_missing=False
)
- def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
- if \
- (self.group is not None and context.attributes.get(('undefer', self.group), False)) or \
- (only_load_props and self.key in only_load_props):
-
+ def setup_query(self, context, entity, path, adapter,
+ only_load_props=None, **kwargs):
+ if (
+ self.group is not None and
+ context.attributes.get(('undefer', self.group), False)
+ ) or (only_load_props and self.key in only_load_props):
self.parent_property._get_strategy(ColumnLoader).\
- setup_query(context, entity, path, adapter, **kwargs)
+ setup_query(context, entity,
+ path, adapter, **kwargs)
def _class_level_loader(self, state):
if not mapperutil._state_has_identity(state):
@@ -276,14 +287,15 @@ class LoadDeferredColumns(object):
session = sessionlib._state_session(state)
if session is None:
raise orm_exc.DetachedInstanceError(
- "Parent instance %s is not bound to a Session; "
- "deferred load operation of attribute '%s' cannot proceed" %
- (mapperutil.state_str(state), self.key)
- )
+ "Parent instance %s is not bound to a Session; "
+ "deferred load operation of attribute '%s' cannot proceed" %
+ (mapperutil.state_str(state), self.key)
+ )
query = session.query(localparent)
ident = state.key[1]
- query._get(None, ident=ident, only_load_props=group, refresh_state=state)
+ query._get(None, ident=ident,
+ only_load_props=group, refresh_state=state)
return attributes.ATTR_WAS_SET
class DeferredOption(StrategizedOption):
@@ -309,7 +321,7 @@ class UndeferGroupOption(MapperOption):
query._attributes[('undefer', self.group)] = True
class AbstractRelationshipLoader(LoaderStrategy):
- """LoaderStratgies which deal with related objects as opposed to scalars."""
+ """LoaderStratgies which deal with related objects."""
def init(self):
self.mapper = self.parent_property.mapper
@@ -363,31 +375,47 @@ class LazyLoader(AbstractRelationshipLoader):
for c in self.mapper._equivalent_columns[col]:
self._equated_columns[c] = self._equated_columns[col]
- self.logger.info("%s will use query.get() to optimize instance loads" % self)
+ self.logger.info("%s will use query.get() to "
+ "optimize instance loads" % self)
def init_class_attribute(self, mapper):
self.is_class_level = True
- # MANYTOONE currently only needs the "old" value for delete-orphan
- # cascades. the required _SingleParentValidator will enable active_history
- # in that case. otherwise we don't need the "old" value during backref operations.
+ # MANYTOONE currently only needs the
+ # "old" value for delete-orphan
+ # cascades. the required _SingleParentValidator
+ # will enable active_history
+ # in that case. otherwise we don't need the
+ # "old" value during backref operations.
_register_attribute(self,
mapper,
useobject=True,
callable_=self._class_level_loader,
uselist = self.parent_property.uselist,
typecallable = self.parent_property.collection_class,
- active_history = self.parent_property.direction is not interfaces.MANYTOONE or not self.use_get,
+ active_history = \
+ self.parent_property.direction is not \
+ interfaces.MANYTOONE or \
+ not self.use_get,
)
- def lazy_clause(self, state, reverse_direction=False, alias_secondary=False, adapt_source=None):
+ def lazy_clause(self, state, reverse_direction=False,
+ alias_secondary=False, adapt_source=None):
if state is None:
- return self._lazy_none_clause(reverse_direction, adapt_source=adapt_source)
+ return self._lazy_none_clause(
+ reverse_direction,
+ adapt_source=adapt_source)
if not reverse_direction:
- (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
+ criterion, bind_to_col, rev = \
+ self.__lazywhere, \
+ self.__bind_to_col, \
+ self._equated_columns
else:
- (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
+ criterion, bind_to_col, rev = \
+ LazyLoader._create_lazy_clause(
+ self.parent_property,
+ reverse_direction=reverse_direction)
if reverse_direction:
mapper = self.parent_property.mapper
@@ -396,25 +424,38 @@ class LazyLoader(AbstractRelationshipLoader):
def visit_bindparam(bindparam):
if bindparam.key in bind_to_col:
- # use the "committed" (database) version to get query column values
- # also its a deferred value; so that when used by Query, the committed value is used
+ # use the "committed" (database) version to get
+ # query column values
+ # also its a deferred value; so that when used
+ # by Query, the committed value is used
# after an autoflush occurs
o = state.obj() # strong ref
- bindparam.value = lambda: mapper._get_committed_attr_by_column(o, bind_to_col[bindparam.key])
+ bindparam.value = \
+ lambda: mapper._get_committed_attr_by_column(
+ o, bind_to_col[bindparam.key])
if self.parent_property.secondary is not None and alias_secondary:
- criterion = sql_util.ClauseAdapter(self.parent_property.secondary.alias()).traverse(criterion)
+ criterion = sql_util.ClauseAdapter(
+ self.parent_property.secondary.alias()).\
+ traverse(criterion)
- criterion = visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam})
+ criterion = visitors.cloned_traverse(
+ criterion, {}, {'bindparam':visit_bindparam})
if adapt_source:
criterion = adapt_source(criterion)
return criterion
def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
if not reverse_direction:
- (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
+ criterion, bind_to_col, rev = \
+ self.__lazywhere, \
+ self.__bind_to_col,\
+ self._equated_columns
else:
- (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
+ criterion, bind_to_col, rev = \
+ LazyLoader._create_lazy_clause(
+ self.parent_property,
+ reverse_direction=reverse_direction)
criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col)
@@ -432,22 +473,30 @@ class LazyLoader(AbstractRelationshipLoader):
key = self.key
if not self.is_class_level:
def new_execute(state, dict_, row):
- # we are not the primary manager for this attribute on this class - set up a
- # per-instance lazyloader, which will override the class-level behavior.
- # this currently only happens when using a "lazyload" option on a "no load"
- # attribute - "eager" attributes always have a class-level lazyloader
- # installed.
+ # we are not the primary manager for this attribute
+ # on this class - set up a
+ # per-instance lazyloader, which will override the
+ # class-level behavior.
+ # this currently only happens when using a
+ # "lazyload" option on a "no load"
+ # attribute - "eager" attributes always have a
+ # class-level lazyloader installed.
state.set_callable(dict_, key, LoadLazyAttribute(state, key))
else:
def new_execute(state, dict_, row):
- # we are the primary manager for this attribute on this class - reset its
- # per-instance attribute state, so that the class-level lazy loader is
- # executed when next referenced on this instance. this is needed in
- # populate_existing() types of scenarios to reset any existing state.
+ # we are the primary manager for this attribute on
+ # this class - reset its
+ # per-instance attribute state, so that the class-level
+ # lazy loader is
+ # executed when next referenced on this instance.
+ # this is needed in
+ # populate_existing() types of scenarios to reset
+ # any existing state.
state.reset(dict_, key)
return new_execute, None
-
+
+ @classmethod
def _create_lazy_clause(cls, prop, reverse_direction=False):
binds = util.column_dict()
lookup = util.column_dict()
@@ -477,18 +526,19 @@ class LazyLoader(AbstractRelationshipLoader):
lazywhere = prop.primaryjoin
if prop.secondaryjoin is None or not reverse_direction:
- lazywhere = visitors.replacement_traverse(lazywhere, {}, col_to_bind)
+ lazywhere = visitors.replacement_traverse(
+ lazywhere, {}, col_to_bind)
if prop.secondaryjoin is not None:
secondaryjoin = prop.secondaryjoin
if reverse_direction:
- secondaryjoin = visitors.replacement_traverse(secondaryjoin, {}, col_to_bind)
+ secondaryjoin = visitors.replacement_traverse(
+ secondaryjoin, {}, col_to_bind)
lazywhere = sql.and_(lazywhere, secondaryjoin)
bind_to_col = dict((binds[col].key, col) for col in binds)
- return (lazywhere, bind_to_col, equated_columns)
- _create_lazy_clause = classmethod(_create_lazy_clause)
+ return lazywhere, bind_to_col, equated_columns
log.class_logger(LazyLoader)
@@ -510,12 +560,14 @@ class LoadLazyAttribute(object):
prop = instance_mapper.get_property(self.key)
strategy = prop._get_strategy(LazyLoader)
- if kw.get('passive') is attributes.PASSIVE_NO_FETCH and not strategy.use_get:
+ if kw.get('passive') is attributes.PASSIVE_NO_FETCH and \
+ not strategy.use_get:
return attributes.PASSIVE_NO_RESULT
if strategy._should_log_debug():
strategy.logger.debug("loading %s",
- mapperutil.state_attribute_str(state, self.key))
+ mapperutil.state_attribute_str(
+ state, self.key))
session = sessionlib._state_session(state)
if session is None:
@@ -536,8 +588,11 @@ class LoadLazyAttribute(object):
ident = []
allnulls = True
for primary_key in prop.mapper.primary_key:
- val = instance_mapper._get_committed_state_attr_by_column(
- state, strategy._equated_columns[primary_key], **kw)
+ val = instance_mapper.\
+ _get_committed_state_attr_by_column(
+ state,
+ strategy._equated_columns[primary_key],
+ **kw)
if val is attributes.PASSIVE_NO_RESULT:
return val
allnulls = allnulls and val is None
@@ -556,8 +611,17 @@ class LoadLazyAttribute(object):
if prop.order_by:
q = q.order_by(*util.to_list(prop.order_by))
+ for rev in prop._reverse_property:
+ # reverse props that are MANYTOONE are loading *this*
+ # object from get(), so don't need to eager out to those.
+ if rev.direction is interfaces.MANYTOONE and \
+ rev._use_get and \
+ not isinstance(rev.strategy, LazyLoader):
+ q = q.options(EagerLazyOption(rev.key, lazy='select'))
+
if state.load_options:
q = q._conditional_options(*state.load_options)
+
q = q.filter(strategy.lazy_clause(state))
result = q.all()
@@ -569,24 +633,244 @@ class LoadLazyAttribute(object):
if l > 1:
util.warn(
"Multiple rows returned with "
- "uselist=False for lazily-loaded attribute '%s' " % prop)
+ "uselist=False for lazily-loaded attribute '%s' "
+ % prop)
return result[0]
else:
return None
+class SubqueryLoader(AbstractRelationshipLoader):
+ def init(self):
+ super(SubqueryLoader, self).init()
+ self.join_depth = self.parent_property.join_depth
+
+ def init_class_attribute(self, mapper):
+ self.parent_property.\
+ _get_strategy(LazyLoader).\
+ init_class_attribute(mapper)
+
+ def setup_query(self, context, entity,
+ path, adapter, column_collection=None,
+ parentmapper=None, **kwargs):
+
+ if not context.query._enable_eagerloads:
+ return
+
+ path = path + (self.key, )
+
+ # build up a path indicating the path from the leftmost
+ # entity to the thing we're subquery loading.
+ subq_path = context.attributes.get(('subquery_path', None), ())
+
+ subq_path = subq_path + path
+
+ reduced_path = interfaces._reduce_path(path)
+
+ # join-depth / recursion check
+ if ("loaderstrategy", reduced_path) not in context.attributes:
+ if self.join_depth:
+ if len(path) / 2 > self.join_depth:
+ return
+ else:
+ if self.mapper.base_mapper in interfaces._reduce_path(subq_path):
+ return
+
+ orig_query = context.attributes.get(
+ ("orig_query", SubqueryLoader),
+ context.query)
+
+ # determine attributes of the leftmost mapper
+ if self.parent.isa(subq_path[0]) and self.key==subq_path[1]:
+ leftmost_mapper, leftmost_prop = \
+ self.parent, self.parent_property
+ else:
+ leftmost_mapper, leftmost_prop = \
+ subq_path[0], \
+ subq_path[0].get_property(subq_path[1])
+ leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop)
+
+ leftmost_attr = [
+ leftmost_mapper._get_col_to_prop(c).class_attribute
+ for c in leftmost_cols
+ ]
+
+ # reformat the original query
+ # to look only for significant columns
+ q = orig_query._clone()
+ # TODO: why does polymporphic etc. require hardcoding
+ # into _adapt_col_list ? Does query.add_columns(...) work
+ # with polymorphic loading ?
+ q._set_entities(q._adapt_col_list(leftmost_attr))
+
+ # don't need ORDER BY if no limit/offset
+ if q._limit is None and q._offset is None:
+ q._order_by = None
+
+ # the original query now becomes a subquery
+ # which we'll join onto.
+ embed_q = q.with_labels().subquery()
+ left_alias = mapperutil.AliasedClass(leftmost_mapper, embed_q)
+
+ # q becomes a new query. basically doing a longhand
+ # "from_self()". (from_self() itself not quite industrial
+ # strength enough for all contingencies...but very close)
+
+ q = q.session.query(self.mapper)
+ q._attributes = {
+ ("orig_query", SubqueryLoader): orig_query,
+ ('subquery_path', None) : subq_path
+ }
+
+ # figure out what's being joined. a.k.a. the fun part
+ to_join = [
+ (subq_path[i], subq_path[i+1])
+ for i in xrange(0, len(subq_path), 2)
+ ]
+
+ if len(to_join) < 2:
+ parent_alias = left_alias
+ else:
+ parent_alias = mapperutil.AliasedClass(self.parent)
+
+ local_cols, remote_cols = \
+ self._local_remote_columns(self.parent_property)
+
+ local_attr = [
+ getattr(parent_alias, self.parent._get_col_to_prop(c).key)
+ for c in local_cols
+ ]
+ q = q.order_by(*local_attr)
+ q = q.add_columns(*local_attr)
+
+ for i, (mapper, key) in enumerate(to_join):
+
+ # we need to use query.join() as opposed to
+ # orm.join() here because of the
+ # rich behavior it brings when dealing with
+ # "with_polymorphic" mappers. "aliased"
+ # and "from_joinpoint" take care of most of
+ # the chaining and aliasing for us.
+
+ first = i == 0
+ middle = i < len(to_join) - 1
+ second_to_last = i == len(to_join) - 2
+
+ if first:
+ attr = getattr(left_alias, key)
+ else:
+ attr = key
+
+ if second_to_last:
+ q = q.join((parent_alias, attr), from_joinpoint=True)
+ else:
+ q = q.join(attr, aliased=middle, from_joinpoint=True)
+
+ # propagate loader options etc. to the new query.
+ # these will fire relative to subq_path.
+ q = q._with_current_path(subq_path)
+ q = q._conditional_options(*orig_query._with_options)
+
+ if self.parent_property.order_by:
+ # if there's an ORDER BY, alias it the same
+ # way joinedloader does, but we have to pull out
+ # the "eagerjoin" from the query.
+ # this really only picks up the "secondary" table
+ # right now.
+ eagerjoin = q._from_obj[0]
+ eager_order_by = \
+ eagerjoin._target_adapter.\
+ copy_and_process(
+ util.to_list(
+ self.parent_property.order_by
+ )
+ )
+ q = q.order_by(*eager_order_by)
+
+ # add new query to attributes to be picked up
+ # by create_row_processor
+ context.attributes[('subquery', reduced_path)] = q
+
+ def _local_remote_columns(self, prop):
+ if prop.secondary is None:
+ return zip(*prop.local_remote_pairs)
+ else:
+ return \
+ [p[0] for p in prop.synchronize_pairs],\
+ [
+ p[0] for p in prop.
+ secondary_synchronize_pairs
+ ]
+
+ def create_row_processor(self, context, path, mapper, row, adapter):
+ path = path + (self.key,)
+
+ path = interfaces._reduce_path(path)
+
+ if ('subquery', path) not in context.attributes:
+ return None, None
+
+ local_cols, remote_cols = self._local_remote_columns(self.parent_property)
+
+ remote_attr = [
+ self.mapper._get_col_to_prop(c).key
+ for c in remote_cols]
+
+ q = context.attributes[('subquery', path)]
+
+ collections = dict(
+ (k, [v[0] for v in v])
+ for k, v in itertools.groupby(
+ q,
+ lambda x:x[1:]
+ ))
+
+ if adapter:
+ local_cols = [adapter.columns[c] for c in local_cols]
+
+ if self.uselist:
+ def execute(state, dict_, row):
+ collection = collections.get(
+ tuple([row[col] for col in local_cols]),
+ ()
+ )
+ state.get_impl(self.key).\
+ set_committed_value(state, dict_, collection)
+ else:
+ def execute(state, dict_, row):
+ collection = collections.get(
+ tuple([row[col] for col in local_cols]),
+ (None,)
+ )
+ if len(collection) > 1:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for eagerly-loaded attribute '%s' "
+ % self)
+
+ scalar = collection[0]
+ state.get_impl(self.key).\
+ set_committed_value(state, dict_, scalar)
+
+ return execute, None
+
+log.class_logger(SubqueryLoader)
+
class EagerLoader(AbstractRelationshipLoader):
- """Strategize a relationship() that loads within the process of the parent object being selected."""
+ """Strategize a relationship() that loads within the process
+ of the parent object being selected."""
def init(self):
super(EagerLoader, self).init()
self.join_depth = self.parent_property.join_depth
def init_class_attribute(self, mapper):
- self.parent_property._get_strategy(LazyLoader).init_class_attribute(mapper)
+ self.parent_property.\
+ _get_strategy(LazyLoader).init_class_attribute(mapper)
def setup_query(self, context, entity, path, adapter, \
- column_collection=None, parentmapper=None, **kwargs):
+ column_collection=None, parentmapper=None,
+ **kwargs):
"""Add a left outer join to the statement thats being constructed."""
if not context.query._enable_eagerloads:
@@ -597,16 +881,21 @@ class EagerLoader(AbstractRelationshipLoader):
reduced_path = interfaces._reduce_path(path)
# check for user-defined eager alias
- if ("user_defined_eager_row_processor", reduced_path) in context.attributes:
- clauses = context.attributes[("user_defined_eager_row_processor", reduced_path)]
+ if ("user_defined_eager_row_processor", reduced_path) in\
+ context.attributes:
+ clauses = context.attributes[
+ ("user_defined_eager_row_processor",
+ reduced_path)]
adapter = entity._get_entity_clauses(context.query, context)
if adapter and clauses:
- context.attributes[("user_defined_eager_row_processor", reduced_path)] = \
- clauses = clauses.wrap(adapter)
+ context.attributes[
+ ("user_defined_eager_row_processor",
+ reduced_path)] = clauses = clauses.wrap(adapter)
elif adapter:
- context.attributes[("user_defined_eager_row_processor", reduced_path)] = \
- clauses = adapter
+ context.attributes[
+ ("user_defined_eager_row_processor",
+ reduced_path)] = clauses = adapter
add_to_collection = context.primary_columns
@@ -622,18 +911,24 @@ class EagerLoader(AbstractRelationshipLoader):
if self.mapper.base_mapper in reduced_path:
return
- clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper),
- equivalents=self.mapper._equivalent_columns, adapt_required=True)
+ clauses = mapperutil.ORMAdapter(
+ mapperutil.AliasedClass(self.mapper),
+ equivalents=self.mapper._equivalent_columns,
+ adapt_required=True)
if self.parent_property.direction != interfaces.MANYTOONE:
context.multi_row_eager_loaders = True
context.create_eager_joins.append(
- (self._create_eager_join, context, entity, path, adapter, parentmapper, clauses)
+ (self._create_eager_join, context,
+ entity, path, adapter,
+ parentmapper, clauses)
)
add_to_collection = context.secondary_columns
- context.attributes[("eager_row_processor", reduced_path)] = clauses
+ context.attributes[
+ ("eager_row_processor", reduced_path)
+ ] = clauses
for value in self.mapper._iterate_polymorphic_properties():
value.setup(
@@ -644,7 +939,8 @@ class EagerLoader(AbstractRelationshipLoader):
parentmapper=self.mapper,
column_collection=add_to_collection)
- def _create_eager_join(self, context, entity, path, adapter, parentmapper, clauses):
+ def _create_eager_join(self, context, entity,
+ path, adapter, parentmapper, clauses):
if parentmapper is None:
localparent = entity.mapper
@@ -662,12 +958,13 @@ class EagerLoader(AbstractRelationshipLoader):
not should_nest_selectable and \
context.from_clause:
index, clause = \
- sql_util.find_join_source(context.from_clause, entity.selectable)
+ sql_util.find_join_source(
+ context.from_clause, entity.selectable)
if clause is not None:
# join to an existing FROM clause on the query.
# key it to its list index in the eager_joins dict.
- # Query._compile_context will adapt as needed and append to the
- # FROM clause of the select().
+ # Query._compile_context will adapt as needed and
+ # append to the FROM clause of the select().
entity_key, default_towrap = index, clause
if entity_key is None:
@@ -678,28 +975,38 @@ class EagerLoader(AbstractRelationshipLoader):
join_to_left = False
if adapter:
if getattr(adapter, 'aliased_class', None):
- onclause = getattr(adapter.aliased_class, self.key, self.parent_property)
+ onclause = getattr(
+ adapter.aliased_class, self.key,
+ self.parent_property)
else:
- onclause = getattr(mapperutil.AliasedClass(self.parent, adapter.selectable),
- self.key, self.parent_property)
+ onclause = getattr(
+ mapperutil.AliasedClass(
+ self.parent,
+ adapter.selectable
+ ),
+ self.key, self.parent_property
+ )
if onclause is self.parent_property:
- # TODO: this is a temporary hack to account for polymorphic eager loads where
+ # TODO: this is a temporary hack to
+ # account for polymorphic eager loads where
# the eagerload is referencing via of_type().
join_to_left = True
else:
onclause = self.parent_property
- innerjoin = context.attributes.get(("eager_join_type", path),
- self.parent_property.innerjoin)
+ innerjoin = context.attributes.get(
+ ("eager_join_type", path),
+ self.parent_property.innerjoin)
- context.eager_joins[entity_key] = eagerjoin = mapperutil.join(
- towrap,
- clauses.aliased_class,
- onclause,
- join_to_left=join_to_left,
- isouter=not innerjoin
- )
+ context.eager_joins[entity_key] = eagerjoin = \
+ mapperutil.join(
+ towrap,
+ clauses.aliased_class,
+ onclause,
+ join_to_left=join_to_left,
+ isouter=not innerjoin
+ )
# send a hint to the Query as to where it may "splice" this join
eagerjoin.stop_on = entity.selectable
@@ -707,11 +1014,14 @@ class EagerLoader(AbstractRelationshipLoader):
if self.parent_property.secondary is None and \
not parentmapper:
# for parentclause that is the non-eager end of the join,
- # ensure all the parent cols in the primaryjoin are actually in the
+ # ensure all the parent cols in the primaryjoin are actually
+ # in the
# columns clause (i.e. are not deferred), so that aliasing applied
- # by the Query propagates those columns outward. This has the effect
+ # by the Query propagates those columns outward.
+ # This has the effect
# of "undefering" those columns.
- for col in sql_util.find_columns(self.parent_property.primaryjoin):
+ for col in sql_util.find_columns(
+ self.parent_property.primaryjoin):
if localparent.mapped_table.c.contains_column(col):
if adapter:
col = adapter.columns[col]
@@ -721,22 +1031,29 @@ class EagerLoader(AbstractRelationshipLoader):
context.eager_order_by += \
eagerjoin._target_adapter.\
copy_and_process(
- util.to_list(self.parent_property.order_by)
+ util.to_list(
+ self.parent_property.order_by
+ )
)
def _create_eager_adapter(self, context, row, adapter, path):
reduced_path = interfaces._reduce_path(path)
- if ("user_defined_eager_row_processor", reduced_path) in context.attributes:
- decorator = context.attributes[("user_defined_eager_row_processor", reduced_path)]
- # user defined eagerloads are part of the "primary" portion of the load.
+ if ("user_defined_eager_row_processor", reduced_path) in \
+ context.attributes:
+ decorator = context.attributes[
+ ("user_defined_eager_row_processor",
+ reduced_path)]
+ # user defined eagerloads are part of the "primary"
+ # portion of the load.
# the adapters applied to the Query should be honored.
if context.adapter and decorator:
decorator = decorator.wrap(context.adapter)
elif context.adapter:
decorator = context.adapter
elif ("eager_row_processor", reduced_path) in context.attributes:
- decorator = context.attributes[("eager_row_processor", reduced_path)]
+ decorator = context.attributes[
+ ("eager_row_processor", reduced_path)]
else:
return False
@@ -751,7 +1068,10 @@ class EagerLoader(AbstractRelationshipLoader):
def create_row_processor(self, context, path, mapper, row, adapter):
path = path + (self.key,)
- eager_adapter = self._create_eager_adapter(context, row, adapter, path)
+ eager_adapter = self._create_eager_adapter(
+ context,
+ row,
+ adapter, path)
if eager_adapter is not False:
key = self.key
@@ -780,8 +1100,8 @@ class EagerLoader(AbstractRelationshipLoader):
return new_execute, existing_execute
else:
def new_execute(state, dict_, row):
- collection = attributes.init_state_collection(state, dict_,
- key)
+ collection = attributes.init_state_collection(
+ state, dict_, key)
result_list = util.UniqueAppender(collection,
'append_without_event')
context.attributes[(state, key)] = result_list
@@ -797,36 +1117,56 @@ class EagerLoader(AbstractRelationshipLoader):
# distinct sets of result columns
collection = attributes.init_state_collection(state,
dict_, key)
- result_list = util.UniqueAppender(collection,
- 'append_without_event')
+ result_list = util.UniqueAppender(
+ collection,
+ 'append_without_event')
context.attributes[(state, key)] = result_list
_instance(row, result_list)
return new_execute, existing_execute
else:
- return self.parent_property._get_strategy(LazyLoader).\
- create_row_processor(context, path, mapper, row, adapter)
+ return self.parent_property.\
+ _get_strategy(LazyLoader).\
+ create_row_processor(
+ context, path,
+ mapper, row, adapter)
log.class_logger(EagerLoader)
class EagerLazyOption(StrategizedOption):
-
- def __init__(self, key, lazy=True, chained=False, mapper=None, propagate_to_loaders=True):
- super(EagerLazyOption, self).__init__(key, mapper)
+ def __init__(self, key, lazy=True, chained=False,
+ propagate_to_loaders=True
+ ):
+ super(EagerLazyOption, self).__init__(key)
self.lazy = lazy
self.chained = chained
self.propagate_to_loaders = propagate_to_loaders
-
+ self.strategy_cls = factory(lazy)
+
+ @property
+ def is_eager(self):
+ return self.lazy in (False, 'joined', 'subquery')
+
+ @property
def is_chained(self):
- return not self.lazy and self.chained
-
- def get_strategy_class(self):
- if self.lazy:
- return LazyLoader
- elif self.lazy is False:
- return EagerLoader
- elif self.lazy is None:
- return NoLoader
+ return self.is_eager and self.chained
+ def get_strategy_class(self):
+ return self.strategy_cls
+
+def factory(identifier):
+ if identifier is False or identifier == 'joined':
+ return EagerLoader
+ elif identifier is None or identifier == 'noload':
+ return NoLoader
+ elif identifier is False or identifier == 'select':
+ return LazyLoader
+ elif identifier == 'subquery':
+ return SubqueryLoader
+ else:
+ return LazyLoader
+
+
+
class EagerJoinOption(PropertyOption):
def __init__(self, key, innerjoin, chained=False):
@@ -881,8 +1221,10 @@ class _SingleParentValidator(interfaces.AttributeExtension):
if value is not None:
hasparent = initiator.hasparent(attributes.instance_state(value))
if hasparent and oldvalue is not value:
- raise sa_exc.InvalidRequestError("Instance %s is already associated with an instance "
- "of %s via its %s attribute, and is only allowed a single parent." %
+ raise sa_exc.InvalidRequestError(
+ "Instance %s is already associated with an instance "
+ "of %s via its %s attribute, and is only allowed a "
+ "single parent." %
(mapperutil.instance_str(value), state.class_, self.prop)
)
return value
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index 3be63ced3..31ab7facc 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -747,35 +747,10 @@ class StaticPool(Pool):
"""
- def __init__(self, creator, **params):
- """
- Construct a StaticPool.
-
- :param creator: a callable function that returns a DB-API
- connection object. The function will be called with
- parameters.
-
- :param echo: If True, connections being pulled and retrieved
- from the pool will be logged to the standard output, as well
- as pool sizing information. Echoing can also be achieved by
- enabling logging for the "sqlalchemy.pool"
- namespace. Defaults to False.
-
- :param reset_on_return: If true, reset the database state of
- connections returned to the pool. This is typically a
- ROLLBACK to release locks and transaction resources.
- Disable at your own peril. Defaults to True.
-
- :param listeners: A list of
- :class:`~sqlalchemy.interfaces.PoolListener`-like objects or
- dictionaries of callables that receive events when DB-API
- connections are created, checked out and checked in to the
- pool.
+ @memoized_property
+ def _conn(self):
+ return self._creator()
- """
- Pool.__init__(self, creator, **params)
- self._conn = creator()
-
@memoized_property
def connection(self):
return _ConnectionRecord(self)
@@ -784,8 +759,9 @@ class StaticPool(Pool):
return "StaticPool"
def dispose(self):
- self._conn.close()
- self._conn = None
+ if '_conn' in self.__dict__:
+ self._conn.close()
+ self._conn = None
def recreate(self):
self.logger.info("Pool recreating")
@@ -837,7 +813,8 @@ class AssertionPool(Pool):
def dispose(self):
self._checked_out = False
- self._conn.close()
+ if self._conn:
+ self._conn.close()
def recreate(self):
self.logger.info("Pool recreating")
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 4e9175ae8..78c65771b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -305,11 +305,13 @@ class SQLCompiler(engine.Compiled):
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + self.process(grouping.element, **kwargs) + ")"
- def visit_label(self, label, result_map=None, within_columns_clause=False, **kw):
+ def visit_label(self, label, result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False, **kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
- if within_columns_clause:
+ if within_columns_clause and not within_label_clause:
labelname = isinstance(label.name, sql._generated_label) and \
self._truncated_identifier("colident", label.name) or label.name
@@ -318,13 +320,14 @@ class SQLCompiler(engine.Compiled):
(label.name, (label, label.element, labelname), label.element.type)
return self.process(label.element,
- within_columns_clause=within_columns_clause,
+ within_columns_clause=True,
+ within_label_clause=True,
**kw) + \
OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
return self.process(label.element,
- within_columns_clause=within_columns_clause,
+ within_columns_clause=False,
**kw)
def visit_column(self, column, result_map=None, **kwargs):
@@ -625,13 +628,22 @@ class SQLCompiler(engine.Compiled):
else:
return self.bindtemplate % {'name':name}
- def visit_alias(self, alias, asfrom=False, **kwargs):
- if asfrom:
+ def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs):
+ if asfrom or ashint:
alias_name = isinstance(alias.name, sql._generated_label) and \
self._truncated_identifier("alias", alias.name) or alias.name
-
- return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \
+ if ashint:
+ return self.preparer.format_alias(alias, alias_name)
+ elif asfrom:
+ ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \
self.preparer.format_alias(alias, alias_name)
+
+ if fromhints and alias in fromhints:
+ hinttext = self.get_from_hint_text(alias, fromhints[alias])
+ if hinttext:
+ ret += " " + hinttext
+
+ return ret
else:
return self.process(alias.original, **kwargs)
@@ -658,8 +670,15 @@ class SQLCompiler(engine.Compiled):
else:
return column
+ def get_select_hint_text(self, byfroms):
+ return None
+
+ def get_from_hint_text(self, table, text):
+ return None
+
def visit_select(self, select, asfrom=False, parens=True,
- iswrapper=False, compound_index=1, **kwargs):
+ iswrapper=False, fromhints=None,
+ compound_index=1, **kwargs):
entry = self.stack and self.stack[-1] or {}
@@ -694,6 +713,18 @@ class SQLCompiler(engine.Compiled):
]
text = "SELECT " # we're off to a good start !
+
+ if select._hints:
+ byfrom = dict([
+ (from_, hinttext % {'name':self.process(from_, ashint=True)})
+ for (from_, dialect), hinttext in
+ select._hints.iteritems()
+ if dialect in ('*', self.dialect.name)
+ ])
+ hint_text = self.get_select_hint_text(byfrom)
+ if hint_text:
+ text += hint_text + " "
+
if select._prefixes:
text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
@@ -701,7 +732,16 @@ class SQLCompiler(engine.Compiled):
if froms:
text += " \nFROM "
- text += ', '.join(self.process(f, asfrom=True, **kwargs) for f in froms)
+
+ if select._hints:
+ text += ', '.join([self.process(f,
+ asfrom=True, fromhints=byfrom,
+ **kwargs)
+ for f in froms])
+ else:
+ text += ', '.join([self.process(f,
+ asfrom=True, **kwargs)
+ for f in froms])
else:
text += self.default_from()
@@ -764,20 +804,26 @@ class SQLCompiler(engine.Compiled):
text += " OFFSET " + str(select._offset)
return text
- def visit_table(self, table, asfrom=False, **kwargs):
- if asfrom:
+ def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs):
+ if asfrom or ashint:
if getattr(table, "schema", None):
- return self.preparer.quote_schema(table.schema, table.quote_schema) + \
+ ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \
"." + self.preparer.quote(table.name, table.quote)
else:
- return self.preparer.quote(table.name, table.quote)
+ ret = self.preparer.quote(table.name, table.quote)
+ if fromhints and table in fromhints:
+ hinttext = self.get_from_hint_text(table, fromhints[table])
+ if hinttext:
+ ret += " " + hinttext
+ return ret
else:
return ""
def visit_join(self, join, asfrom=False, **kwargs):
- return (self.process(join.left, asfrom=True) + \
+ return (self.process(join.left, asfrom=True, **kwargs) + \
(join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
- self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
+ self.process(join.right, asfrom=True, **kwargs) + " ON " + \
+ self.process(join.onclause, **kwargs))
def visit_sequence(self, seq):
return None
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 1e02ba96a..3aaa06fd6 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -3557,6 +3557,7 @@ class Select(_SelectBaseMixin, FromClause):
__visit_name__ = 'select'
_prefixes = ()
+ _hints = util.frozendict()
def __init__(self,
columns,
@@ -3659,7 +3660,34 @@ class Select(_SelectBaseMixin, FromClause):
"""Return the displayed list of FromClause elements."""
return self._get_display_froms()
-
+
+ @_generative
+ def with_hint(self, selectable, text, dialect_name=None):
+ """Add an indexing hint for the given selectable to this :class:`Select`.
+
+ The text of the hint is written specific to a specific backend, and
+ typically uses Python string substitution syntax to render the name
+ of the table or alias, such as for Oracle::
+
+ select([mytable]).with_hint(mytable, "+ index(%(name)s ix_mytable)")
+
+ Would render SQL as::
+
+ select /*+ index(mytable ix_mytable) */ ... from mytable
+
+ The ``dialect_name`` option will limit the rendering of a particular hint
+ to a particular backend. Such as, to add hints for both Oracle and
+ Sybase simultaneously::
+
+ select([mytable]).\
+ with_hint(mytable, "+ index(%(name)s ix_mytable)", 'oracle').\
+ with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
+
+ """
+ if not dialect_name:
+ dialect_name = '*'
+ self._hints = self._hints.union({(selectable, dialect_name):text})
+
@property
def type(self):
raise exc.InvalidRequestError("Select objects don't have a type. "
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 74651a9d1..d5575e0e7 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -579,7 +579,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return None
elif self.exclude and col in self.exclude:
return None
-
+
return self._corresponding_column(col, True)
class ColumnAdapter(ClauseAdapter):
@@ -587,11 +587,13 @@ class ColumnAdapter(ClauseAdapter):
Provides the ability to "wrap" this ClauseAdapter
around another, a columns dictionary which returns
- cached, adapted elements given an original, and an
+ adapted elements given an original, and an
adapted_row() factory.
"""
- def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None, adapt_required=False):
+ def __init__(self, selectable, equivalents=None,
+ chain_to=None, include=None,
+ exclude=None, adapt_required=False):
ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
if chain_to:
self.chain(chain_to)
@@ -617,7 +619,7 @@ class ColumnAdapter(ClauseAdapter):
return locate
def _locate_col(self, col):
- c = self._corresponding_column(col, False)
+ c = self._corresponding_column(col, True)
if c is None:
c = self.adapt_clause(col)
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 4a54375f8..799486c02 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -40,16 +40,17 @@ class VisitableType(type):
# set up an optimized visit dispatch function
# for use by the compiler
- visit_name = cls.__visit_name__
- if isinstance(visit_name, str):
- getter = operator.attrgetter("visit_%s" % visit_name)
- def _compiler_dispatch(self, visitor, **kw):
- return getter(visitor)(self, **kw)
- else:
- def _compiler_dispatch(self, visitor, **kw):
- return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
-
- cls._compiler_dispatch = _compiler_dispatch
+ if '__visit_name__' in cls.__dict__:
+ visit_name = cls.__visit_name__
+ if isinstance(visit_name, str):
+ getter = operator.attrgetter("visit_%s" % visit_name)
+ def _compiler_dispatch(self, visitor, **kw):
+ return getter(visitor)(self, **kw)
+ else:
+ def _compiler_dispatch(self, visitor, **kw):
+ return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
+
+ cls._compiler_dispatch = _compiler_dispatch
super(VisitableType, cls).__init__(clsname, bases, clsdict)
diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py
index c4c745c54..bf911c2c2 100644
--- a/lib/sqlalchemy/test/requires.py
+++ b/lib/sqlalchemy/test/requires.py
@@ -149,6 +149,18 @@ def sequences(fn):
no_support('sybase', 'no SEQUENCE support'),
)
+def update_nowait(fn):
+ """Target database must support SELECT...FOR UPDATE NOWAIT"""
+ return _chain_decorators_on(
+ fn,
+ no_support('access', 'no FOR UPDATE NOWAIT support'),
+ no_support('firebird', 'no FOR UPDATE NOWAIT support'),
+ no_support('mssql', 'no FOR UPDATE NOWAIT support'),
+ no_support('mysql', 'no FOR UPDATE NOWAIT support'),
+ no_support('sqlite', 'no FOR UPDATE NOWAIT support'),
+ no_support('sybase', 'no FOR UPDATE NOWAIT support'),
+ )
+
def subqueries(fn):
"""Target database must support subqueries."""
return _chain_decorators_on(
@@ -224,6 +236,7 @@ def unicode_ddl(fn):
no_support('maxdb', 'database support flakey'),
no_support('oracle', 'FIXME: no support in database?'),
no_support('sybase', 'FIXME: guessing, needs confirmation'),
+ no_support('mssql+pymssql', 'no FreeTDS support'),
exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
)
diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py
index 3f2ff6399..324995889 100644
--- a/lib/sqlalchemy/topological.py
+++ b/lib/sqlalchemy/topological.py
@@ -141,3 +141,36 @@ def sort(tuples, allitems):
queue.append(childnode)
return output
+
+def _find_cycles(edges):
+ cycles = {}
+
+ def traverse(node, cycle, goal):
+ for (n, key) in edges.edges_by_parent(node):
+ if key in cycle:
+ continue
+ cycle.add(key)
+ if key is goal:
+ cycset = set(cycle)
+ for x in cycle:
+ if x in cycles:
+ existing_set = cycles[x]
+ existing_set.update(cycset)
+ for y in existing_set:
+ cycles[y] = existing_set
+ cycset = existing_set
+ else:
+ cycles[x] = cycset
+ else:
+ traverse(key, cycle, goal)
+ cycle.pop()
+
+ for parent in edges.get_parents():
+ traverse(parent, set(), parent)
+
+ unique_cycles = set(tuple(s) for s in cycles.values())
+
+ for cycle in unique_cycles:
+ edgecollection = [edge for edge in edges
+ if edge[0] in cycle and edge[1] in cycle]
+ yield edgecollection