diff options
Diffstat (limited to 'lib')
82 files changed, 4088 insertions, 3210 deletions
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 36229a105..74e8abfc2 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -180,16 +180,16 @@ ischema_names = { # _FBDate, etc. as bind/result functionality is required) class FBTypeCompiler(compiler.GenericTypeCompiler): - def visit_boolean(self, type_): - return self.visit_SMALLINT(type_) + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_datetime(self, type_): - return self.visit_TIMESTAMP(type_) + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return "BLOB SUB_TYPE 1" - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB SUB_TYPE 0" def _extend_string(self, type_, basic): @@ -199,16 +199,16 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): else: return '%s CHARACTER SET %s' % (basic, charset) - def visit_CHAR(self, type_): - basic = super(FBTypeCompiler, self).visit_CHAR(type_) + def visit_CHAR(self, type_, **kw): + basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) return self._extend_string(type_, basic) - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): if not type_.length: raise exc.CompileError( "VARCHAR requires a length on dialect %s" % self.dialect.name) - basic = super(FBTypeCompiler, self).visit_VARCHAR(type_) + basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) return self._extend_string(type_, basic) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index dad02ee0f..92d7e4ab3 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -226,6 +226,53 @@ The DATE and TIME types are not available for MSSQL 2005 and previous - if a server version below 2008 is detected, DDL for these types will be issued as DATETIME. +.. _mssql_large_type_deprecation: + +Large Text/Binary Type Deprecation +---------------------------------- + +Per `SQL Server 2012/2014 Documentation <http://technet.microsoft.com/en-us/library/ms187993.aspx>`_, +the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL Server +in a future release. SQLAlchemy normally relates these types to the +:class:`.UnicodeText`, :class:`.Text` and :class:`.LargeBinary` datatypes. + +In order to accommodate this change, a new flag ``deprecate_large_types`` +is added to the dialect, which will be automatically set based on detection +of the server version in use, if not otherwise set by the user. The +behavior of this flag is as follows: + +* When this flag is ``True``, the :class:`.UnicodeText`, :class:`.Text` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``, + respectively. This is a new behavior as of the addition of this flag. + +* When this flag is ``False``, the :class:`.UnicodeText`, :class:`.Text` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NTEXT``, ``TEXT``, and ``IMAGE``, + respectively. This is the long-standing behavior of these types. + +* The flag begins with the value ``None``, before a database connection is + established. If the dialect is used to render DDL without the flag being + set, it is interpreted the same as ``False``. + +* On first connection, the dialect detects if SQL Server version 2012 or greater + is in use; if the flag is still at ``None``, it sets it to ``True`` or + ``False`` based on whether 2012 or greater is detected. + +* The flag can be set to either ``True`` or ``False`` when the dialect + is created, typically via :func:`.create_engine`:: + + eng = create_engine("mssql+pymssql://user:pass@host/db", + deprecate_large_types=True) + +* Complete control over whether the "old" or "new" types are rendered is + available in all SQLAlchemy versions by using the UPPERCASE type objects + instead: :class:`.NVARCHAR`, :class:`.VARCHAR`, :class:`.types.VARBINARY`, + :class:`.TEXT`, :class:`.mssql.NTEXT`, :class:`.mssql.IMAGE` will always remain + fixed and always output exactly that type. + +.. versionadded:: 1.0.0 + .. _mssql_indexes: Clustered Index Support @@ -367,19 +414,20 @@ import operator import re from ... import sql, schema as sa_schema, exc, util -from ...sql import compiler, expression, \ - util as sql_util, cast +from ...sql import compiler, expression, util as sql_util from ... import engine from ...engine import reflection, default from ... import types as sqltypes from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ - VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR + TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR from ...util import update_wrapper from . import information_schema as ischema +# http://sqlserverbuilds.blogspot.com/ +MS_2012_VERSION = (11,) MS_2008_VERSION = (10,) MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) @@ -545,6 +593,26 @@ class NTEXT(sqltypes.UnicodeText): __visit_name__ = 'NTEXT' +class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): + """The MSSQL VARBINARY type. + + This type extends both :class:`.types.VARBINARY` and + :class:`.types.LargeBinary`. In "deprecate_large_types" mode, + the :class:`.types.LargeBinary` type will produce ``VARBINARY(max)`` + on SQL Server. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :ref:`mssql_large_type_deprecation` + + + + """ + __visit_name__ = 'VARBINARY' + + class IMAGE(sqltypes.LargeBinary): __visit_name__ = 'IMAGE' @@ -626,7 +694,6 @@ ischema_names = { class MSTypeCompiler(compiler.GenericTypeCompiler): - def _extend(self, spec, type_, length=None): """Extend a string-type declaration with standard SQL COLLATE annotations. @@ -647,103 +714,115 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return ' '.join([c for c in (spec, collation) if c is not None]) - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision is None: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': precision} - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): return "TINYINT" - def visit_DATETIMEOFFSET(self, type_): + def visit_DATETIMEOFFSET(self, type_, **kw): if type_.precision: return "DATETIMEOFFSET(%s)" % type_.precision else: return "DATETIMEOFFSET" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision: return "TIME(%s)" % precision else: return "TIME" - def visit_DATETIME2(self, type_): + def visit_DATETIME2(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision: return "DATETIME2(%s)" % precision else: return "DATETIME2" - def visit_SMALLDATETIME(self, type_): + def visit_SMALLDATETIME(self, type_, **kw): return "SMALLDATETIME" - def visit_unicode(self, type_): - return self.visit_NVARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_NVARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARCHAR(type_, **kw) + else: + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): - return self.visit_NTEXT(type_) + def visit_unicode_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_NVARCHAR(type_, **kw) + else: + return self.visit_NTEXT(type_, **kw) - def visit_NTEXT(self, type_): + def visit_NTEXT(self, type_, **kw): return self._extend("NTEXT", type_) - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._extend("TEXT", type_) - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._extend("VARCHAR", type_, length=type_.length or 'max') - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._extend("CHAR", type_) - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._extend("NCHAR", type_) - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._extend("NVARCHAR", type_, length=type_.length or 'max') - def visit_date(self, type_): + def visit_date(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: - return self.visit_DATETIME(type_) + return self.visit_DATETIME(type_, **kw) else: - return self.visit_DATE(type_) + return self.visit_DATE(type_, **kw) - def visit_time(self, type_): + def visit_time(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: - return self.visit_DATETIME(type_) + return self.visit_DATETIME(type_, **kw) else: - return self.visit_TIME(type_) + return self.visit_TIME(type_, **kw) - def visit_large_binary(self, type_): - return self.visit_IMAGE(type_) + def visit_large_binary(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARBINARY(type_, **kw) + else: + return self.visit_IMAGE(type_, **kw) - def visit_IMAGE(self, type_): + def visit_IMAGE(self, type_, **kw): return "IMAGE" - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return self._extend( "VARBINARY", type_, length=type_.length or 'max') - def visit_boolean(self, type_): + def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): return "BIT" - def visit_MONEY(self, type_): + def visit_MONEY(self, type_, **kw): return "MONEY" - def visit_SMALLMONEY(self, type_): + def visit_SMALLMONEY(self, type_, **kw): return 'SMALLMONEY' - def visit_UNIQUEIDENTIFIER(self, type_): + def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" - def visit_SQL_VARIANT(self, type_): + def visit_SQL_VARIANT(self, type_, **kw): return 'SQL_VARIANT' @@ -1160,8 +1239,11 @@ class MSSQLStrictCompiler(MSSQLCompiler): class MSDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = (self.preparer.format_column(column) + " " - + self.dialect.type_compiler.process(column.type)) + colspec = ( + self.preparer.format_column(column) + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column) + ) if column.nullable is not None: if not column.nullable or column.primary_key or \ @@ -1370,13 +1452,15 @@ class MSDialect(default.DefaultDialect): query_timeout=None, use_scope_identity=True, max_identifier_length=None, - schema_name="dbo", **opts): + schema_name="dbo", + deprecate_large_types=None, **opts): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = use_scope_identity self.max_identifier_length = int(max_identifier_length or 0) or \ self.max_identifier_length + self.deprecate_large_types = deprecate_large_types super(MSDialect, self).__init__(**opts) def do_savepoint(self, connection, name): @@ -1390,6 +1474,9 @@ class MSDialect(default.DefaultDialect): def initialize(self, connection): super(MSDialect, self).initialize(connection) + self._setup_version_attributes() + + def _setup_version_attributes(self): if self.server_version_info[0] not in list(range(8, 17)): # FreeTDS with version 4.2 seems to report here # a number like "95.10.255". Don't know what @@ -1405,6 +1492,9 @@ class MSDialect(default.DefaultDialect): self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True + if self.deprecate_large_types is None: + self.deprecate_large_types = \ + self.server_version_info >= MS_2012_VERSION def _get_default_schema_name(self, connection): if self.server_version_info < MS_2005_VERSION: @@ -1592,12 +1682,11 @@ class MSDialect(default.DefaultDialect): if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary): + if charlen == -1: + charlen = 'max' kwargs['length'] = charlen if collation: kwargs['collation'] = collation - if coltype == MSText or \ - (coltype in (MSString, MSNVarchar) and charlen == -1): - kwargs.pop('length') if coltype is None: util.warn( diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 58eb3afa0..c8e33bfb2 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -106,7 +106,7 @@ to be used. Transaction Isolation Level --------------------------- -:func:`.create_engine` accepts an ``isolation_level`` +:func:`.create_engine` accepts an :paramref:`.create_engine.isolation_level` parameter which results in the command ``SET SESSION TRANSACTION ISOLATION LEVEL <level>`` being invoked for every new connection. Valid values for this parameter are @@ -602,6 +602,14 @@ class _StringType(sqltypes.String): to_inspect=[_StringType, sqltypes.String]) +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): + # TODO: float arguments? + sqltypes.Float.__init__(self) + sqltypes.MatchType.__init__(self) + + + class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" @@ -1420,32 +1428,28 @@ class SET(_EnumeratedValues): Column('myset', SET("foo", "bar", "baz")) - :param values: The range of valid values for this SET. Values will be - quoted when generating the schema according to the quoting flag (see - below). - .. versionchanged:: 0.9.0 quoting is applied automatically to - :class:`.mysql.SET` in the same way as for :class:`.mysql.ENUM`. + The list of potential values is required in the case that this + set will be used to generate DDL for a table, or if the + :paramref:`.SET.retrieve_as_bitwise` flag is set to True. - :param charset: Optional, a column-level character set for this string - value. Takes precedence to 'ascii' or 'unicode' short-hand. + :param values: The range of valid values for this SET. - :param collation: Optional, a column-level collation for this string - value. Takes precedence to 'binary' short-hand. + :param convert_unicode: Same flag as that of + :paramref:`.String.convert_unicode`. - :param ascii: Defaults to False: short-hand for the ``latin1`` - character set, generates ASCII in schema. + :param collation: same as that of :paramref:`.String.collation` - :param unicode: Defaults to False: short-hand for the ``ucs2`` - character set, generates UNICODE in schema. + :param charset: same as that of :paramref:`.VARCHAR.charset`. - :param binary: Defaults to False: short-hand, pick the binary - collation type that matches the column's character set. Generates - BINARY in schema. This does not affect the type of data stored, - only the collation of character data. + :param ascii: same as that of :paramref:`.VARCHAR.ascii`. - :param quoting: Defaults to 'auto': automatically determine enum value - quoting. If all enum values are surrounded by the same quoting + :param unicode: same as that of :paramref:`.VARCHAR.unicode`. + + :param binary: same as that of :paramref:`.VARCHAR.binary`. + + :param quoting: Defaults to 'auto': automatically determine set value + quoting. If all values are surrounded by the same quoting character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. 'quoted': values in enums are already quoted, they will be used @@ -1460,50 +1464,117 @@ class SET(_EnumeratedValues): .. versionadded:: 0.9.0 + :param retrieve_as_bitwise: if True, the data for the set type will be + persisted and selected using an integer value, where a set is coerced + into a bitwise mask for persistence. MySQL allows this mode which + has the advantage of being able to store values unambiguously, + such as the blank string ``''``. The datatype will appear + as the expression ``col + 0`` in a SELECT statement, so that the + value is coerced into an integer value in result sets. + This flag is required if one wishes + to persist a set that can store the blank string ``''`` as a value. + + .. warning:: + + When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is + essential that the list of set values is expressed in the + **exact same order** as exists on the MySQL database. + + .. versionadded:: 1.0.0 + + """ + self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False) values, length = self._init_values(values, kw) self.values = tuple(values) - + if not self.retrieve_as_bitwise and '' in values: + raise exc.ArgumentError( + "Can't use the blank value '' in a SET without " + "setting retrieve_as_bitwise=True") + if self.retrieve_as_bitwise: + self._bitmap = dict( + (value, 2 ** idx) + for idx, value in enumerate(self.values) + ) + self._bitmap.update( + (2 ** idx, value) + for idx, value in enumerate(self.values) + ) kw.setdefault('length', length) super(SET, self).__init__(**kw) + def column_expression(self, colexpr): + if self.retrieve_as_bitwise: + return colexpr + 0 + else: + return colexpr + def result_processor(self, dialect, coltype): - def process(value): - # The good news: - # No ',' quoting issues- commas aren't allowed in SET values - # The bad news: - # Plenty of driver inconsistencies here. - if isinstance(value, set): - # ..some versions convert '' to an empty set - if not value: - value.add('') - return value - # ...and some versions return strings - if value is not None: - return set(value.split(',')) - else: - return value + if self.retrieve_as_bitwise: + def process(value): + if value is not None: + value = int(value) + + return set( + util.map_bits(self._bitmap.__getitem__, value) + ) + else: + return None + else: + super_convert = super(SET, self).result_processor(dialect, coltype) + + def process(value): + if isinstance(value, util.string_types): + # MySQLdb returns a string, let's parse + if super_convert: + value = super_convert(value) + return set(re.findall(r'[^,]+', value)) + else: + # mysql-connector-python does a naive + # split(",") which throws in an empty string + if value is not None: + value.discard('') + return value return process def bind_processor(self, dialect): super_convert = super(SET, self).bind_processor(dialect) + if self.retrieve_as_bitwise: + def process(value): + if value is None: + return None + elif isinstance(value, util.int_types + util.string_types): + if super_convert: + return super_convert(value) + else: + return value + else: + int_value = 0 + for v in value: + int_value |= self._bitmap[v] + return int_value + else: - def process(value): - if value is None or isinstance( - value, util.int_types + util.string_types): - pass - else: - if None in value: - value = set(value) - value.remove(None) - value.add('') - value = ','.join(value) - if super_convert: - return super_convert(value) - else: - return value + def process(value): + # accept strings and int (actually bitflag) values directly + if value is not None and not isinstance( + value, util.int_types + util.string_types): + value = ",".join(value) + + if super_convert: + return super_convert(value) + else: + return value return process + def adapt(self, impltype, **kw): + kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise + return util.constructor_copy( + self, impltype, + *self.values, + **kw + ) + # old names MSTime = TIME MSSet = SET @@ -1544,6 +1615,7 @@ colspecs = { sqltypes.Float: FLOAT, sqltypes.Time: TIME, sqltypes.Enum: ENUM, + sqltypes.MatchType: _MatchType } # Everything 3.23 through 5.1 excepting OpenGIS types. @@ -1758,10 +1830,10 @@ class MySQLCompiler(compiler.SQLCompiler): # creation of foreign key constraints fails." class MySQLDDLCompiler(compiler.DDLCompiler): - def create_table_constraints(self, table): + def create_table_constraints(self, table, **kw): """Get table constraints.""" constraint_string = super( - MySQLDDLCompiler, self).create_table_constraints(table) + MySQLDDLCompiler, self).create_table_constraints(table, **kw) # why self.dialect.name and not 'mysql'? because of drizzle is_innodb = 'engine' in table.dialect_options[self.dialect.name] and \ @@ -1787,9 +1859,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kw): """Builds column DDL.""" - colspec = [self.preparer.format_column(column), - self.dialect.type_compiler.process(column.type) - ] + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler.process( + column.type, type_expression=column) + ] default = self.get_column_default_string(column) if default is not None: @@ -1987,7 +2061,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: @@ -2000,7 +2074,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: @@ -2013,7 +2087,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DOUBLE(self, type_): + def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % @@ -2022,7 +2096,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'DOUBLE') - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % @@ -2031,7 +2105,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'REAL') - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): if self._mysql_type(type_) and \ type_.scale is not None and \ type_.precision is not None: @@ -2043,7 +2117,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "FLOAT") - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "INTEGER(%(display_width)s)" % @@ -2051,7 +2125,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "INTEGER") - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "BIGINT(%(display_width)s)" % @@ -2059,7 +2133,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "BIGINT") - def visit_MEDIUMINT(self, type_): + def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "MEDIUMINT(%(display_width)s)" % @@ -2067,14 +2141,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "MEDIUMINT") - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) else: return self._extend_numeric(type_, "TINYINT") - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % @@ -2083,55 +2157,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "SMALLINT") - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): if type_.length is not None: return "BIT(%s)" % type_.length else: return "BIT" - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIME(%d)" % type_.fsp else: return "TIME" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" - def visit_YEAR(self, type_): + def visit_YEAR(self, type_, **kw): if type_.display_width is None: return "YEAR" else: return "YEAR(%s)" % type_.display_width - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) else: return self._extend_string(type_, {}, "TEXT") - def visit_TINYTEXT(self, type_): + def visit_TINYTEXT(self, type_, **kw): return self._extend_string(type_, {}, "TINYTEXT") - def visit_MEDIUMTEXT(self, type_): + def visit_MEDIUMTEXT(self, type_, **kw): return self._extend_string(type_, {}, "MEDIUMTEXT") - def visit_LONGTEXT(self, type_): + def visit_LONGTEXT(self, type_, **kw): return self._extend_string(type_, {}, "LONGTEXT") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): if type_.length: return self._extend_string( type_, {}, "VARCHAR(%d)" % type_.length) @@ -2140,14 +2214,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "VARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "CHAR(%(length)s)" % {'length': type_.length}) else: return self._extend_string(type_, {}, "CHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". if type_.length: @@ -2159,7 +2233,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "NVARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: @@ -2169,31 +2243,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_string(type_, {'national': True}, "CHAR") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) - def visit_enum(self, type_): + def visit_enum(self, type_, **kw): if not type_.native_enum: return super(MySQLTypeCompiler, self).visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): if type_.length: return "BLOB(%d)" % type_.length else: return "BLOB" - def visit_TINYBLOB(self, type_): + def visit_TINYBLOB(self, type_, **kw): return "TINYBLOB" - def visit_MEDIUMBLOB(self, type_): + def visit_MEDIUMBLOB(self, type_, **kw): return "MEDIUMBLOB" - def visit_LONGBLOB(self, type_): + def visit_LONGBLOB(self, type_, **kw): return "LONGBLOB" def _visit_enumerated_values(self, name, type_, enumerated_values): @@ -2204,15 +2278,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): name, ",".join(quoted_enums)) ) - def visit_ENUM(self, type_): + def visit_ENUM(self, type_, **kw): return self._visit_enumerated_values("ENUM", type_, type_._enumerated_values) - def visit_SET(self, type_): + def visit_SET(self, type_, **kw): return self._visit_enumerated_values("SET", type_, type_._enumerated_values) - def visit_BOOLEAN(self, type): + def visit_BOOLEAN(self, type, **kw): return "BOOL" @@ -2963,6 +3037,9 @@ class MySQLTableDefinitionParser(object): if issubclass(col_type, _EnumeratedValues): type_args = _EnumeratedValues._strip_values(type_args) + if issubclass(col_type, SET) and '' in type_args: + type_kw['retrieve_as_bitwise'] = True + type_instance = col_type(*type_args, **type_kw) col_args, col_kw = [], {} diff --git a/lib/sqlalchemy/dialects/mysql/gaerdbms.py b/lib/sqlalchemy/dialects/mysql/gaerdbms.py index 0059f5a65..284b51bde 100644 --- a/lib/sqlalchemy/dialects/mysql/gaerdbms.py +++ b/lib/sqlalchemy/dialects/mysql/gaerdbms.py @@ -17,6 +17,13 @@ developers-guide .. versionadded:: 0.7.8 + .. deprecated:: 1.0 This dialect is **no longer necessary** for + Google Cloud SQL; the MySQLdb dialect can be used directly. + Cloud SQL now recommends creating connections via the + mysql dialect using the URL format + + `mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>` + Pooling ------- @@ -33,6 +40,7 @@ import os from .mysqldb import MySQLDialect_mysqldb from ...pool import NullPool import re +from sqlalchemy.util import warn_deprecated def _is_dev_environment(): @@ -43,6 +51,14 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): @classmethod def dbapi(cls): + + warn_deprecated( + "Google Cloud SQL now recommends creating connections via the " + "MySQLdb dialect directly, using the URL format " + "mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/" + "<projectid>:<instancename>" + ) + # from django: # http://code.google.com/p/googleappengine/source/ # browse/trunk/python/google/storage/speckle/ diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 73210d67a..929317467 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -39,6 +39,14 @@ MySQL-python version 1.2.2 has a serious memory leak related to unicode conversion, a feature which is disabled via ``use_unicode=0``. It is strongly advised to use the latest version of MySQL-Python. +Using MySQLdb with Google Cloud SQL +----------------------------------- + +Google Cloud SQL now recommends use of the MySQLdb dialect. Connect +using a URL like the following:: + + mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename> + """ from .base import (MySQLDialect, MySQLExecutionContext, @@ -77,7 +85,7 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): class MySQLDialect_mysqldb(MySQLDialect): driver = 'mysqldb' - supports_unicode_statements = False + supports_unicode_statements = True supports_sane_rowcount = True supports_sane_multi_rowcount = True @@ -102,12 +110,13 @@ class MySQLDialect_mysqldb(MySQLDialect): # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 # specific issue w/ the utf8_bin collation and unicode returns - has_utf8_bin = connection.scalar( - "show collation where %s = 'utf8' and %s = 'utf8_bin'" - % ( - self.identifier_preparer.quote("Charset"), - self.identifier_preparer.quote("Collation") - )) + has_utf8_bin = self.server_version_info > (5, ) and \ + connection.scalar( + "show collation where %s = 'utf8' and %s = 'utf8_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation") + )) if has_utf8_bin: additional_tests = [ sql.collate(sql.cast( diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 31226cea0..8df2ba03f 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -31,8 +31,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): driver = 'pymysql' description_encoding = None - if py3k: - supports_unicode_statements = True + supports_unicode_statements = True @classmethod def dbapi(cls): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 6df38e57e..b482c9069 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -213,6 +213,8 @@ is reflected and the type is reported as ``DATE``, the time-supporting examining the type of column for use in special Python translations or for migrating schemas to other database backends. +.. _oracle_table_options: + Oracle Table Options ------------------------- @@ -228,15 +230,63 @@ in conjunction with the :class:`.Table` construct: .. versionadded:: 1.0.0 +* ``COMPRESS``:: + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=True) + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=6) + + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. + +.. versionadded:: 1.0.0 + +.. _oracle_index_options: + +Oracle Specific Index Options +----------------------------- + +Bitmap Indexes +~~~~~~~~~~~~~~ + +You can specify the ``oracle_bitmap`` parameter to create a bitmap index +instead of a B-tree index:: + + Index('my_index', my_table.c.data, oracle_bitmap=True) + +Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not +check for such limitations, only the database will. + +.. versionadded:: 1.0.0 + +Index compression +~~~~~~~~~~~~~~~~~ + +Oracle has a more efficient storage mode for indexes containing lots of +repeated values. Use the ``oracle_compress`` parameter to turn on key c +ompression:: + + Index('my_index', my_table.c.data, oracle_compress=True) + + Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, + oracle_compress=1) + +The ``oracle_compress`` parameter accepts either an integer specifying the +number of prefix columns to compress, or ``True`` to use the default (all +columns for non-unique indexes, all but the last column for unique indexes). + +.. versionadded:: 1.0.0 + """ import re from sqlalchemy import util, sql -from sqlalchemy.engine import default, base, reflection +from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler, visitors, expression -from sqlalchemy.sql import (operators as sql_operators, - functions as sql_functions) +from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes, schema as sa_schema from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ BLOB, CLOB, TIMESTAMP, FLOAT @@ -407,19 +457,19 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): # Oracle does not allow milliseconds in DATE # Oracle does not support TIME columns - def visit_datetime(self, type_): - return self.visit_DATE(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_unicode(self, type_): + def visit_unicode(self, type_, **kw): if self.dialect._supports_nchar: - return self.visit_NVARCHAR2(type_) + return self.visit_NVARCHAR2(type_, **kw) else: - return self.visit_VARCHAR2(type_) + return self.visit_VARCHAR2(type_, **kw) - def visit_INTERVAL(self, type_): + def visit_INTERVAL(self, type_, **kw): return "INTERVAL DAY%s TO SECOND%s" % ( type_.day_precision is not None and "(%d)" % type_.day_precision or @@ -429,22 +479,22 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): "", ) - def visit_LONG(self, type_): + def visit_LONG(self, type_, **kw): return "LONG" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): if type_.timezone: return "TIMESTAMP WITH TIME ZONE" else: return "TIMESTAMP" - def visit_DOUBLE_PRECISION(self, type_): - return self._generate_numeric(type_, "DOUBLE PRECISION") + def visit_DOUBLE_PRECISION(self, type_, **kw): + return self._generate_numeric(type_, "DOUBLE PRECISION", **kw) def visit_NUMBER(self, type_, **kw): return self._generate_numeric(type_, "NUMBER", **kw) - def _generate_numeric(self, type_, name, precision=None, scale=None): + def _generate_numeric(self, type_, name, precision=None, scale=None, **kw): if precision is None: precision = type_.precision @@ -460,17 +510,17 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): n = "%(name)s(%(precision)s, %(scale)s)" return n % {'name': name, 'precision': precision, 'scale': scale} - def visit_string(self, type_): - return self.visit_VARCHAR2(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR2(type_, **kw) - def visit_VARCHAR2(self, type_): + def visit_VARCHAR2(self, type_, **kw): return self._visit_varchar(type_, '', '2') - def visit_NVARCHAR2(self, type_): + def visit_NVARCHAR2(self, type_, **kw): return self._visit_varchar(type_, 'N', '2') visit_NVARCHAR = visit_NVARCHAR2 - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._visit_varchar(type_, '', '') def _visit_varchar(self, type_, n, num): @@ -483,31 +533,31 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): varchar = "%(n)sVARCHAR%(two)s(%(length)s)" return varchar % {'length': type_.length, 'two': num, 'n': n} - def visit_text(self, type_): - return self.visit_CLOB(type_) + def visit_text(self, type_, **kw): + return self.visit_CLOB(type_, **kw) - def visit_unicode_text(self, type_): + def visit_unicode_text(self, type_, **kw): if self.dialect._supports_nchar: - return self.visit_NCLOB(type_) + return self.visit_NCLOB(type_, **kw) else: - return self.visit_CLOB(type_) + return self.visit_CLOB(type_, **kw) - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_NUMBER(type_, precision=19) + def visit_big_integer(self, type_, **kw): + return self.visit_NUMBER(type_, precision=19, **kw) - def visit_boolean(self, type_): - return self.visit_SMALLINT(type_) + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_RAW(self, type_): + def visit_RAW(self, type_, **kw): if type_.length: return "RAW(%(length)s)" % {'length': type_.length} else: return "RAW" - def visit_ROWID(self, type_): + def visit_ROWID(self, type_, **kw): return "ROWID" @@ -549,6 +599,9 @@ class OracleCompiler(compiler.SQLCompiler): def visit_false(self, expr, **kw): return '0' + def get_cte_preamble(self, recursive): + return "WITH" + def get_select_hint_text(self, byfroms): return " ".join( "/*+ %s */" % text for table, text in byfroms.items() @@ -619,22 +672,10 @@ class OracleCompiler(compiler.SQLCompiler): return (self.dialect.identifier_preparer.format_sequence(seq) + ".nextval") - def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs): - """Oracle doesn't like ``FROM table AS alias``. Is the AS standard - SQL?? - """ - - if asfrom or ashint: - alias_name = isinstance(alias.name, expression._truncated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name + def get_render_as_alias_suffix(self, alias_name_text): + """Oracle doesn't like ``FROM table AS alias``""" - if ashint: - return alias_name - elif asfrom: - return self.process(alias.original, asfrom=asfrom, **kwargs) + \ - " " + self.preparer.format_alias(alias, alias_name) - else: - return self.process(alias.original, **kwargs) + return " " + alias_name_text def returning_clause(self, stmt, returning_cols): columns = [] @@ -795,9 +836,32 @@ class OracleDDLCompiler(compiler.DDLCompiler): return text - def visit_create_index(self, create, **kw): - return super(OracleDDLCompiler, self).\ - visit_create_index(create, include_schema=True) + def visit_create_index(self, create): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.dialect_options['oracle']['bitmap']: + text += "BITMAP " + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=True), + ', '.join( + self.sql_compiler.process( + expr, + include_table=False, literal_binds=True) + for expr in index.expressions) + ) + if index.dialect_options['oracle']['compress'] is not False: + if index.dialect_options['oracle']['compress'] is True: + text += " COMPRESS" + else: + text += " COMPRESS %d" % ( + index.dialect_options['oracle']['compress'] + ) + return text def post_create_table(self, table): table_opts = [] @@ -807,6 +871,14 @@ class OracleDDLCompiler(compiler.DDLCompiler): on_commit_options = opts['on_commit'].replace("_", " ").upper() table_opts.append('\n ON COMMIT %s' % on_commit_options) + if opts['compress']: + if opts['compress'] is True: + table_opts.append("\n COMPRESS") + else: + table_opts.append("\n COMPRESS FOR %s" % ( + opts['compress'] + )) + return ''.join(table_opts) @@ -870,7 +942,12 @@ class OracleDialect(default.DefaultDialect): construct_arguments = [ (sa_schema.Table, { "resolve_synonyms": False, - "on_commit": None + "on_commit": None, + "compress": False + }), + (sa_schema.Index, { + "bitmap": False, + "compress": False }) ] @@ -902,6 +979,16 @@ class OracleDialect(default.DefaultDialect): self.server_version_info < (9, ) @property + def _supports_table_compression(self): + return self.server_version_info and \ + self.server_version_info >= (9, 2, ) + + @property + def _supports_table_compress_for(self): + return self.server_version_info and \ + self.server_version_info >= (11, ) + + @property def _supports_char_length(self): return not self._is_oracle_8 @@ -1084,6 +1171,50 @@ class OracleDialect(default.DefaultDialect): return [self.normalize_name(row[0]) for row in cursor] @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + options = {} + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + params = {"table_name": table_name} + + columns = ["table_name"] + if self._supports_table_compression: + columns.append("compression") + if self._supports_table_compress_for: + columns.append("compress_for") + + text = "SELECT %(columns)s "\ + "FROM ALL_TABLES%(dblink)s "\ + "WHERE table_name = :table_name" + + if schema is not None: + params['owner'] = schema + text += " AND owner = :owner " + text = text % {'dblink': dblink, 'columns': ", ".join(columns)} + + result = connection.execute(sql.text(text), **params) + + enabled = dict(DISABLED=False, ENABLED=True) + + row = result.first() + if row: + if "compression" in row and enabled.get(row.compression, False): + if "compress_for" in row: + options['oracle_compress'] = row.compress_for + else: + options['oracle_compress'] = True + + return options + + @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): """ @@ -1168,7 +1299,8 @@ class OracleDialect(default.DefaultDialect): params = {'table_name': table_name} text = \ - "SELECT a.index_name, a.column_name, b.uniqueness "\ + "SELECT a.index_name, a.column_name, "\ + "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "\ "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\ "\nALL_INDEXES%(dblink)s b "\ "\nWHERE "\ @@ -1194,6 +1326,7 @@ class OracleDialect(default.DefaultDialect): dblink=dblink, info_cache=kw.get('info_cache')) pkeys = pk_constraint['constrained_columns'] uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + enabled = dict(DISABLED=False, ENABLED=True) oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) @@ -1213,10 +1346,15 @@ class OracleDialect(default.DefaultDialect): if rset.index_name != last_index_name: remove_if_primary_key(index) index = dict(name=self.normalize_name(rset.index_name), - column_names=[]) + column_names=[], dialect_options={}) indexes.append(index) index['unique'] = uniqueness.get(rset.uniqueness, False) + if rset.index_type in ('BITMAP', 'FUNCTION-BASED BITMAP'): + index['dialect_options']['oracle_bitmap'] = True + if enabled.get(rset.compression, False): + index['dialect_options']['oracle_compress'] = rset.prefix_length + # filter out Oracle SYS_NC names. could also do an outer join # to the all_tab_columns table and check for real col names there. if not oracle_sys_col.match(rset.column_name): diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 1cff8e3a0..01a846314 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -5,7 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from . import base, psycopg2, pg8000, pypostgresql, zxjdbc +from . import base, psycopg2, pg8000, pypostgresql, zxjdbc, psycopg2cffi base.dialect = psycopg2.dialect diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index baa640eaa..1935d0cad 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -48,7 +48,7 @@ Transaction Isolation Level --------------------------- All Postgresql dialects support setting of transaction isolation level -both via a dialect-specific parameter ``isolation_level`` +both via a dialect-specific parameter :paramref:`.create_engine.isolation_level` accepted by :func:`.create_engine`, as well as the ``isolation_level`` argument as passed to :meth:`.Connection.execution_options`. When using a non-psycopg2 dialect, @@ -266,7 +266,7 @@ will emit to the database:: The Postgresql text search functions such as ``to_tsquery()`` and ``to_tsvector()`` are available -explicitly using the standard :attr:`.func` construct. For example:: +explicitly using the standard :data:`.func` construct. For example:: select([ func.to_tsvector('fat cats ate rats').match('cat & rat') @@ -299,7 +299,7 @@ not re-compute the column on demand. In order to provide for this explicit query planning, or to use different search strategies, the ``match`` method accepts a ``postgresql_regconfig`` -keyword argument. +keyword argument:: select([mytable.c.id]).where( mytable.c.title.match('somestring', postgresql_regconfig='english') @@ -311,7 +311,7 @@ Emits the equivalent of:: WHERE mytable.title @@ to_tsquery('english', 'somestring') One can also specifically pass in a `'regconfig'` value to the -``to_tsvector()`` command as the initial argument. +``to_tsvector()`` command as the initial argument:: select([mytable.c.id]).where( func.to_tsvector('english', mytable.c.title )\ @@ -483,7 +483,7 @@ import re from ... import sql, schema, exc, util from ...engine import default, reflection -from ...sql import compiler, expression, operators +from ...sql import compiler, expression, operators, default_comparator from ... import types as sqltypes try: @@ -680,10 +680,10 @@ class _Slice(expression.ColumnElement): type = sqltypes.NULLTYPE def __init__(self, slice_, source_comparator): - self.start = source_comparator._check_literal( + self.start = default_comparator._check_literal( source_comparator.expr, operators.getitem, slice_.start) - self.stop = source_comparator._check_literal( + self.stop = default_comparator._check_literal( source_comparator.expr, operators.getitem, slice_.stop) @@ -876,8 +876,9 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): index += shift_indexes return_type = self.type.item_type - return self._binary_operate(self.expr, operators.getitem, index, - result_type=return_type) + return default_comparator._binary_operate( + self.expr, operators.getitem, index, + result_type=return_type) def any(self, other, operator=operators.eq): """Return ``other operator ANY (array)`` clause. @@ -1424,7 +1425,8 @@ class PGDDLCompiler(compiler.DDLCompiler): else: colspec += " SERIAL" else: - colspec += " " + self.dialect.type_compiler.process(column.type) + colspec += " " + self.dialect.type_compiler.process(column.type, + type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -1476,8 +1478,13 @@ class PGDDLCompiler(compiler.DDLCompiler): if not isinstance(expr, expression.ColumnClause) else expr, include_table=False, literal_binds=True) + - (c.key in ops and (' ' + ops[c.key]) or '') - for expr, c in zip(index.expressions, index.columns)]) + ( + (' ' + ops[expr.key]) + if hasattr(expr, 'key') + and expr.key in ops else '' + ) + for expr in index.expressions + ]) ) whereclause = index.dialect_options["postgresql"]["where"] @@ -1539,94 +1546,93 @@ class PGDDLCompiler(compiler.DDLCompiler): class PGTypeCompiler(compiler.GenericTypeCompiler): - - def visit_TSVECTOR(self, type): + def visit_TSVECTOR(self, type, **kw): return "TSVECTOR" - def visit_INET(self, type_): + def visit_INET(self, type_, **kw): return "INET" - def visit_CIDR(self, type_): + def visit_CIDR(self, type_, **kw): return "CIDR" - def visit_MACADDR(self, type_): + def visit_MACADDR(self, type_, **kw): return "MACADDR" - def visit_OID(self, type_): + def visit_OID(self, type_, **kw): return "OID" - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): if not type_.precision: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': type_.precision} - def visit_DOUBLE_PRECISION(self, type_): + def visit_DOUBLE_PRECISION(self, type_, **kw): return "DOUBLE PRECISION" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_HSTORE(self, type_): + def visit_HSTORE(self, type_, **kw): return "HSTORE" - def visit_JSON(self, type_): + def visit_JSON(self, type_, **kw): return "JSON" - def visit_JSONB(self, type_): + def visit_JSONB(self, type_, **kw): return "JSONB" - def visit_INT4RANGE(self, type_): + def visit_INT4RANGE(self, type_, **kw): return "INT4RANGE" - def visit_INT8RANGE(self, type_): + def visit_INT8RANGE(self, type_, **kw): return "INT8RANGE" - def visit_NUMRANGE(self, type_): + def visit_NUMRANGE(self, type_, **kw): return "NUMRANGE" - def visit_DATERANGE(self, type_): + def visit_DATERANGE(self, type_, **kw): return "DATERANGE" - def visit_TSRANGE(self, type_): + def visit_TSRANGE(self, type_, **kw): return "TSRANGE" - def visit_TSTZRANGE(self, type_): + def visit_TSTZRANGE(self, type_, **kw): return "TSTZRANGE" - def visit_datetime(self, type_): - return self.visit_TIMESTAMP(type_) + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) - def visit_enum(self, type_): + def visit_enum(self, type_, **kw): if not type_.native_enum or not self.dialect.supports_native_enum: - return super(PGTypeCompiler, self).visit_enum(type_) + return super(PGTypeCompiler, self).visit_enum(type_, **kw) else: - return self.visit_ENUM(type_) + return self.visit_ENUM(type_, **kw) - def visit_ENUM(self, type_): + def visit_ENUM(self, type_, **kw): return self.dialect.identifier_preparer.format_type(type_) - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_INTERVAL(self, type_): + def visit_INTERVAL(self, type_, **kw): if type_.precision is not None: return "INTERVAL(%d)" % type_.precision else: return "INTERVAL" - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): if type_.varying: compiled = "BIT VARYING" if type_.length is not None: @@ -1635,16 +1641,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): compiled = "BIT(%d)" % type_.length return compiled - def visit_UUID(self, type_): + def visit_UUID(self, type_, **kw): return "UUID" - def visit_large_binary(self, type_): - return self.visit_BYTEA(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BYTEA(type_, **kw) - def visit_BYTEA(self, type_): + def visit_BYTEA(self, type_, **kw): return "BYTEA" - def visit_ARRAY(self, type_): + def visit_ARRAY(self, type_, **kw): return self.process(type_.item_type) + ('[]' * (type_.dimensions if type_.dimensions is not None else 1)) @@ -1942,7 +1948,8 @@ class PGDialect(default.DefaultDialect): cursor = connection.execute( sql.text( "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=current_schema() " + "n.oid=c.relnamespace where " + "pg_catalog.pg_table_is_visible(c.oid) " "and relname=:name", bindparams=[ sql.bindparam('name', util.text_type(table_name), diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 250bf5e9d..f38c4a56a 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -12,7 +12,7 @@ from .base import ischema_names from ... import types as sqltypes from ...sql.operators import custom_op from ... import sql -from ...sql import elements +from ...sql import elements, default_comparator from ... import util __all__ = ('JSON', 'JSONElement', 'JSONB') @@ -46,7 +46,8 @@ class JSONElement(elements.BinaryExpression): self._json_opstring = opstring operator = custom_op(opstring, precedence=5) - right = left._check_literal(left, operator, right) + right = default_comparator._check_literal( + left, operator, right) super(JSONElement, self).__init__( left, right, operator, type_=result_type) @@ -77,7 +78,7 @@ class JSONElement(elements.BinaryExpression): def cast(self, type_): """Convert this :class:`.JSONElement` to apply both the 'astext' operator - as well as an explicit type cast when evaulated. + as well as an explicit type cast when evaluated. E.g.:: diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index f67b2e3b0..5246abf1c 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -66,12 +66,13 @@ in ``/tmp``, or whatever socket directory was specified when PostgreSQL was built. This value can be overridden by passing a pathname to psycopg2, using ``host`` as an additional keyword argument:: - create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + create_engine("postgresql+psycopg2://user:password@/dbname?\ +host=/var/lib/postgresql") See also: -`PQconnectdbParams <http://www.postgresql.org/docs/9.1/static\ -/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_ +`PQconnectdbParams <http://www.postgresql.org/docs/9.1/static/\ +libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_ Per-Statement/Connection Execution Options ------------------------------------------- @@ -237,7 +238,7 @@ The psycopg2 dialect supports these constants for isolation level: * ``AUTOCOMMIT`` .. versionadded:: 0.8.2 support for AUTOCOMMIT isolation level when using - psycopg2. + psycopg2. .. seealso:: @@ -511,9 +512,19 @@ class PGDialect_psycopg2(PGDialect): import psycopg2 return psycopg2 + @classmethod + def _psycopg2_extensions(cls): + from psycopg2 import extensions + return extensions + + @classmethod + def _psycopg2_extras(cls): + from psycopg2 import extras + return extras + @util.memoized_property def _isolation_lookup(self): - from psycopg2 import extensions + extensions = self._psycopg2_extensions() return { 'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT, 'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED, @@ -535,7 +546,8 @@ class PGDialect_psycopg2(PGDialect): connection.set_isolation_level(level) def on_connect(self): - from psycopg2 import extras, extensions + extras = self._psycopg2_extras() + extensions = self._psycopg2_extensions() fns = [] if self.client_encoding is not None: @@ -585,7 +597,7 @@ class PGDialect_psycopg2(PGDialect): @util.memoized_instancemethod def _hstore_oids(self, conn): if self.psycopg2_version >= (2, 4): - from psycopg2 import extras + extras = self._psycopg2_extras() oids = extras.HstoreAdapter.get_oids(conn) if oids is not None and oids[0]: return oids[0:2] diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py new file mode 100644 index 000000000..f5c475d90 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -0,0 +1,49 @@ +# testing/engines.py +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +.. dialect:: postgresql+psycopg2cffi + :name: psycopg2cffi + :dbapi: psycopg2cffi + :connectstring: \ +postgresql+psycopg2cffi://user:password@host:port/dbname\ +[?key=value&key=value...] + :url: http://pypi.python.org/pypi/psycopg2cffi/ + +``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C +layer. This makes it suitable for use in e.g. PyPy. Documentation +is as per ``psycopg2``. + +.. versionadded:: 1.0.0 + +.. seealso:: + + :mod:`sqlalchemy.dialects.postgresql.psycopg2` + +""" +from .psycopg2 import PGDialect_psycopg2 + + +class PGDialect_psycopg2cffi(PGDialect_psycopg2): + driver = 'psycopg2cffi' + supports_unicode_statements = True + + @classmethod + def dbapi(cls): + return __import__('psycopg2cffi') + + @classmethod + def _psycopg2_extensions(cls): + root = __import__('psycopg2cffi', fromlist=['extensions']) + return root.extensions + + @classmethod + def _psycopg2_extras(cls): + root = __import__('psycopg2cffi', fromlist=['extras']) + return root.extras + + +dialect = PGDialect_psycopg2cffi diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 33003297c..1ed89bacb 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -9,6 +9,7 @@ .. dialect:: sqlite :name: SQLite +.. _sqlite_datetime: Date and Time Types ------------------- @@ -23,6 +24,20 @@ These types represent dates and times as ISO formatted strings, which also nicely support ordering. There's no reliance on typical "libc" internals for these functions so historical dates are fully supported. +Ensuring Text affinity +^^^^^^^^^^^^^^^^^^^^^^ + +The DDL rendered for these types is the standard ``DATE``, ``TIME`` +and ``DATETIME`` indicators. However, custom storage formats can also be +applied to these types. When the +storage format is detected as containing no alpha characters, the DDL for +these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``, +so that the column continues to have textual affinity. + +.. seealso:: + + `Type Affinity <http://www.sqlite.org/datatype3.html#affinity>`_ - in the SQLite documentation + .. _sqlite_autoincrement: SQLite Auto Incrementing Behavior @@ -92,8 +107,10 @@ The following subsections introduce areas that are impacted by SQLite's file-based architecture and additionally will usually require workarounds to work when using the pysqlite driver. +.. _sqlite_isolation_level: + Transaction Isolation Level -=========================== +---------------------------- SQLite supports "transaction isolation" in a non-standard way, along two axes. One is that of the `PRAGMA read_uncommitted <http://www.sqlite.org/pragma.html#pragma_read_uncommitted>`_ @@ -126,7 +143,7 @@ by *not even emitting BEGIN* until the first write operation. for techniques to work around this behavior. SAVEPOINT Support -================= +---------------------------- SQLite supports SAVEPOINTs, which only function once a transaction is begun. SQLAlchemy's SAVEPOINT support is available using the @@ -142,7 +159,7 @@ won't work at all with pysqlite unless workarounds are taken. for techniques to work around this behavior. Transactional DDL -================= +---------------------------- The SQLite database supports transactional :term:`DDL` as well. In this case, the pysqlite driver is not only failing to start transactions, @@ -186,6 +203,15 @@ new connections through the usage of events:: cursor.execute("PRAGMA foreign_keys=ON") cursor.close() +.. warning:: + + When SQLite foreign keys are enabled, it is **not possible** + to emit CREATE or DROP statements for tables that contain + mutually-dependent foreign key constraints; + to emit the DDL for these tables requires that ALTER TABLE be used to + create or drop these constraints separately, for which SQLite has + no support. + .. seealso:: `SQLite Foreign Key Support <http://www.sqlite.org/foreignkeys.html>`_ @@ -193,6 +219,9 @@ new connections through the usage of events:: :ref:`event_toplevel` - SQLAlchemy event API. + :ref:`use_alter` - more information on SQLAlchemy's facilities for handling + mutually-dependent foreign key constraints. + .. _sqlite_type_reflection: Type Reflection @@ -255,7 +284,7 @@ from ... import util from ...engine import default, reflection from ...sql import compiler -from ...types import (BLOB, BOOLEAN, CHAR, DATE, DECIMAL, FLOAT, +from ...types import (BLOB, BOOLEAN, CHAR, DECIMAL, FLOAT, INTEGER, REAL, NUMERIC, SMALLINT, TEXT, TIMESTAMP, VARCHAR) @@ -271,6 +300,25 @@ class _DateTimeMixin(object): if storage_format is not None: self._storage_format = storage_format + @property + def format_is_text_affinity(self): + """return True if the storage format will automatically imply + a TEXT affinity. + + If the storage format contains no non-numeric characters, + it will imply a NUMERIC storage format on SQLite; in this case, + the type will generate its DDL as DATE_CHAR, DATETIME_CHAR, + TIME_CHAR. + + .. versionadded:: 1.0.0 + + """ + spec = self._storage_format % { + "year": 0, "month": 0, "day": 0, "hour": 0, + "minute": 0, "second": 0, "microsecond": 0 + } + return bool(re.search(r'[^0-9]', spec)) + def adapt(self, cls, **kw): if issubclass(cls, _DateTimeMixin): if self._storage_format: @@ -526,7 +574,9 @@ ischema_names = { 'BOOLEAN': sqltypes.BOOLEAN, 'CHAR': sqltypes.CHAR, 'DATE': sqltypes.DATE, + 'DATE_CHAR': sqltypes.DATE, 'DATETIME': sqltypes.DATETIME, + 'DATETIME_CHAR': sqltypes.DATETIME, 'DOUBLE': sqltypes.FLOAT, 'DECIMAL': sqltypes.DECIMAL, 'FLOAT': sqltypes.FLOAT, @@ -537,6 +587,7 @@ ischema_names = { 'SMALLINT': sqltypes.SMALLINT, 'TEXT': sqltypes.TEXT, 'TIME': sqltypes.TIME, + 'TIME_CHAR': sqltypes.TIME, 'TIMESTAMP': sqltypes.TIMESTAMP, 'VARCHAR': sqltypes.VARCHAR, 'NVARCHAR': sqltypes.NVARCHAR, @@ -611,7 +662,8 @@ class SQLiteCompiler(compiler.SQLCompiler): class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - coltype = self.dialect.type_compiler.process(column.type) + coltype = self.dialect.type_compiler.process( + column.type, type_expression=column) colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: @@ -667,9 +719,30 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): class SQLiteTypeCompiler(compiler.GenericTypeCompiler): - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) + def visit_DATETIME(self, type_, **kw): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) + else: + return "DATETIME_CHAR" + + def visit_DATE(self, type_, **kw): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_DATE(type_) + else: + return "DATE_CHAR" + + def visit_TIME(self, type_, **kw): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_TIME(type_) + else: + return "TIME_CHAR" + class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = set([ @@ -855,22 +928,9 @@ class SQLiteDialect(default.DefaultDialect): return [row[0] for row in rs] def has_table(self, connection, table_name, schema=None): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%stable_info(%s)" % (pragma, qtable) - cursor = _pragma_cursor(connection.execute(statement)) - row = cursor.fetchone() - - # consume remaining rows, to work around - # http://www.sqlite.org/cvstrac/tktview?tn=1884 - while not cursor.closed and cursor.fetchone() is not None: - pass - - return row is not None + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema) + return bool(info) @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -912,18 +972,11 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%stable_info(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema) - rows = c.fetchall() columns = [] - for row in rows: + for row in info: (name, type_, nullable, default, primary_key) = ( row[1], row[2].upper(), not row[3], row[4], row[5]) @@ -1010,92 +1063,192 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%sforeign_key_list(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) - fkeys = [] + # sqlite makes this *extremely difficult*. + # First, use the pragma to get the actual FKs. + pragma_fks = self._get_table_pragma( + connection, "foreign_key_list", + table_name, schema=schema + ) + fks = {} - while True: - row = c.fetchone() - if row is None: - break + + for row in pragma_fks: (numerical_id, rtbl, lcol, rcol) = ( row[0], row[2], row[3], row[4]) - self._parse_fk(fks, fkeys, numerical_id, rtbl, lcol, rcol) - return fkeys + if rcol is None: + rcol = lcol - def _parse_fk(self, fks, fkeys, numerical_id, rtbl, lcol, rcol): - # sqlite won't return rcol if the table was created with REFERENCES - # <tablename>, no col - if rcol is None: - rcol = lcol + if self._broken_fk_pragma_quotes: + rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) - if self._broken_fk_pragma_quotes: - rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) + if numerical_id in fks: + fk = fks[numerical_id] + else: + fk = fks[numerical_id] = { + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': rtbl, + 'referred_columns': [], + } + fks[numerical_id] = fk - try: - fk = fks[numerical_id] - except KeyError: - fk = { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': rtbl, - 'referred_columns': [], - } - fkeys.append(fk) - fks[numerical_id] = fk - - if lcol not in fk['constrained_columns']: fk['constrained_columns'].append(lcol) - if rcol not in fk['referred_columns']: fk['referred_columns'].append(rcol) - return fk + + def fk_sig(constrained_columns, referred_table, referred_columns): + return tuple(constrained_columns) + (referred_table,) + \ + tuple(referred_columns) + + # then, parse the actual SQL and attempt to find DDL that matches + # the names as well. SQLite saves the DDL in whatever format + # it was typed in as, so need to be liberal here. + + keys_by_signature = dict( + ( + fk_sig( + fk['constrained_columns'], + fk['referred_table'], fk['referred_columns']), + fk + ) for fk in fks.values() + ) + + table_data = self._get_table_sql(connection, table_name, schema=schema) + if table_data is None: + # system tables, etc. + return [] + + def parse_fks(): + FK_PATTERN = ( + '(?:CONSTRAINT (\w+) +)?' + 'FOREIGN KEY *\( *(.+?) *\) +' + 'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\)' + ) + + for match in re.finditer(FK_PATTERN, table_data, re.I): + ( + constraint_name, constrained_columns, + referred_quoted_name, referred_name, + referred_columns) = match.group(1, 2, 3, 4, 5) + constrained_columns = list( + self._find_cols_in_sig(constrained_columns)) + if not referred_columns: + referred_columns = constrained_columns + else: + referred_columns = list( + self._find_cols_in_sig(referred_columns)) + referred_name = referred_quoted_name or referred_name + yield ( + constraint_name, constrained_columns, + referred_name, referred_columns) + fkeys = [] + + for ( + constraint_name, constrained_columns, + referred_name, referred_columns) in parse_fks(): + sig = fk_sig( + constrained_columns, referred_name, referred_columns) + if sig not in keys_by_signature: + util.warn( + "WARNING: SQL-parsed foreign key constraint " + "'%s' could not be located in PRAGMA " + "foreign_keys for table %s" % ( + sig, + table_name + )) + continue + key = keys_by_signature.pop(sig) + key['name'] = constraint_name + fkeys.append(key) + # assume the remainders are the unnamed, inline constraints, just + # use them as is as it's extremely difficult to parse inline + # constraints + fkeys.extend(keys_by_signature.values()) + return fkeys + + def _find_cols_in_sig(self, sig): + for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): + yield match.group(1) or match.group(2) + + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + + auto_index_by_sig = {} + for idx in self.get_indexes( + connection, table_name, schema=schema, + include_auto_indexes=True, **kw): + if not idx['name'].startswith("sqlite_autoindex"): + continue + sig = tuple(idx['column_names']) + auto_index_by_sig[sig] = idx + + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw) + if not table_data: + return [] + + unique_constraints = [] + + def parse_uqs(): + UNIQUE_PATTERN = '(?:CONSTRAINT (\w+) +)?UNIQUE *\((.+?)\)' + INLINE_UNIQUE_PATTERN = ( + '(?:(".+?")|([a-z0-9]+)) ' + '+[a-z0-9_ ]+? +UNIQUE') + + for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): + name, cols = match.group(1, 2) + yield name, list(self._find_cols_in_sig(cols)) + + # we need to match inlines as well, as we seek to differentiate + # a UNIQUE constraint from a UNIQUE INDEX, even though these + # are kind of the same thing :) + for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): + cols = list( + self._find_cols_in_sig(match.group(1) or match.group(2))) + yield None, cols + + for name, cols in parse_uqs(): + sig = tuple(cols) + if sig in auto_index_by_sig: + auto_index_by_sig.pop(sig) + parsed_constraint = { + 'name': name, + 'column_names': cols + } + unique_constraints.append(parsed_constraint) + # NOTE: auto_index_by_sig might not be empty here, + # the PRIMARY KEY may have an entry. + return unique_constraints @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - include_auto_indexes = kw.pop('include_auto_indexes', False) - qtable = quote(table_name) - statement = "%sindex_list(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) + pragma_indexes = self._get_table_pragma( + connection, "index_list", table_name, schema=schema) indexes = [] - while True: - row = c.fetchone() - if row is None: - break + + include_auto_indexes = kw.pop('include_auto_indexes', False) + for row in pragma_indexes: # ignore implicit primary key index. # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html - elif (not include_auto_indexes and - row[1].startswith('sqlite_autoindex')): + if (not include_auto_indexes and + row[1].startswith('sqlite_autoindex')): continue indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + # loop thru unique indexes to get the column names. for idx in indexes: - statement = "%sindex_info(%s)" % (pragma, quote(idx['name'])) - c = connection.execute(statement) - cols = idx['column_names'] - while True: - row = c.fetchone() - if row is None: - break - cols.append(row[2]) + pragma_index = self._get_table_pragma( + connection, "index_info", idx['name']) + + for row in pragma_index: + idx['column_names'].append(row[2]) return indexes @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def _get_table_sql(self, connection, table_name, schema=None, **kw): try: s = ("SELECT sql FROM " " (SELECT * FROM sqlite_master UNION ALL " @@ -1107,27 +1260,22 @@ class SQLiteDialect(default.DefaultDialect): s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " "AND type = 'table'") % table_name rs = connection.execute(s) - row = rs.fetchone() - if row is None: - # sqlite won't return the schema for the sqlite_master or - # sqlite_temp_master tables from this query. These tables - # don't have any unique constraints anyway. - return [] - table_data = row[0] - - UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)' - return [ - {'name': name, - 'column_names': [col.strip(' "') for col in cols.split(',')]} - for name, cols in re.findall(UNIQUE_PATTERN, table_data) - ] + return rs.scalar() - -def _pragma_cursor(cursor): - """work around SQLite issue whereby cursor.description - is blank when PRAGMA returns no rows.""" - - if cursor.closed: - cursor.fetchone = lambda: None - cursor.fetchall = lambda: [] - return cursor + def _get_table_pragma(self, connection, pragma, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + statement = "PRAGMA %s." % quote(schema) + else: + statement = "PRAGMA " + qtable = quote(table_name) + statement = "%s%s(%s)" % (statement, pragma, qtable) + cursor = connection.execute(statement) + if not cursor.closed: + # work around SQLite issue whereby cursor.description + # is blank when PRAGMA returns no rows: + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + result = cursor.fetchall() + else: + result = [] + return result diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index f65a76a27..369420358 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -146,40 +146,40 @@ class IMAGE(sqltypes.LargeBinary): class SybaseTypeCompiler(compiler.GenericTypeCompiler): - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_IMAGE(type_) - def visit_boolean(self, type_): + def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) - def visit_unicode(self, type_): + def visit_unicode(self, type_, **kw): return self.visit_NVARCHAR(type_) - def visit_UNICHAR(self, type_): + def visit_UNICHAR(self, type_, **kw): return "UNICHAR(%d)" % type_.length - def visit_UNIVARCHAR(self, type_): + def visit_UNIVARCHAR(self, type_, **kw): return "UNIVARCHAR(%d)" % type_.length - def visit_UNITEXT(self, type_): + def visit_UNITEXT(self, type_, **kw): return "UNITEXT" - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): return "TINYINT" - def visit_IMAGE(self, type_): + def visit_IMAGE(self, type_, **kw): return "IMAGE" - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): return "BIT" - def visit_MONEY(self, type_): + def visit_MONEY(self, type_, **kw): return "MONEY" - def visit_SMALLMONEY(self, type_): + def visit_SMALLMONEY(self, type_, **kw): return "SMALLMONEY" - def visit_UNIQUEIDENTIFIER(self, type_): + def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" ischema_names = { @@ -377,7 +377,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler): class SybaseDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) if column.table is None: raise exc.CompileError( diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index cf75871bf..f512e260a 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -72,6 +72,7 @@ from .base import ( ) from .result import ( + BaseRowProxy, BufferedColumnResultProxy, BufferedColumnRow, BufferedRowResultProxy, @@ -256,9 +257,19 @@ def create_engine(*args, **kwargs): Behavior here varies per backend, and individual dialects should be consulted directly. + Note that the isolation level can also be set on a per-:class:`.Connection` + basis as well, using the + :paramref:`.Connection.execution_options.isolation_level` + feature. + .. seealso:: - :ref:`SQLite Concurrency <sqlite_concurrency>` + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + :ref:`SQLite Transaction Isolation <sqlite_isolation_level>` :ref:`Postgresql Transaction Isolation <postgresql_isolation_level>` diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dd82be1d1..8d816b7fd 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -201,14 +201,19 @@ class Connection(Connectable): used by the ORM internally supersedes a cache dictionary specified here. - :param isolation_level: Available on: Connection. + :param isolation_level: Available on: :class:`.Connection`. Set the transaction isolation level for - the lifespan of this connection. Valid values include - those string values accepted by the ``isolation_level`` - parameter passed to :func:`.create_engine`, and are - database specific, including those for :ref:`sqlite_toplevel`, - :ref:`postgresql_toplevel` - see those dialect's documentation - for further info. + the lifespan of this :class:`.Connection` object (*not* the + underyling DBAPI connection, for which the level is reset + to its original setting upon termination of this + :class:`.Connection` object). + + Valid values include + those string values accepted by the + :paramref:`.create_engine.isolation_level` + parameter passed to :func:`.create_engine`. These levels are + semi-database specific; see individual dialect documentation for + valid levels. Note that this option necessarily affects the underlying DBAPI connection for the lifespan of the originating @@ -217,6 +222,20 @@ class Connection(Connectable): is returned to the connection pool, i.e. the :meth:`.Connection.close` method is called. + .. seealso:: + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :meth:`.Connection.get_isolation_level` - view current level + + :ref:`SQLite Transaction Isolation <sqlite_isolation_level>` + + :ref:`Postgresql Transaction Isolation <postgresql_isolation_level>` + + :ref:`MySQL Transaction Isolation <mysql_isolation_level>` + + :param no_parameters: When ``True``, if the final parameter list or dictionary is totally empty, will invoke the statement on the cursor as ``cursor.execute(statement)``, @@ -260,23 +279,97 @@ class Connection(Connectable): @property def connection(self): - "The underlying DB-API connection managed by this Connection." + """The underlying DB-API connection managed by this Connection. + + .. seealso:: + + + :ref:`dbapi_connections` + + """ try: return self.__connection except AttributeError: - return self._revalidate_connection() + try: + return self._revalidate_connection() + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def get_isolation_level(self): + """Return the current isolation level assigned to this + :class:`.Connection`. + + This will typically be the default isolation level as determined + by the dialect, unless if the + :paramref:`.Connection.execution_options.isolation_level` + feature has been used to alter the isolation level on a + per-:class:`.Connection` basis. + + This attribute will typically perform a live SQL operation in order + to procure the current isolation level, so the value returned is the + actual level on the underlying DBAPI connection regardless of how + this state was set. Compare to the + :attr:`.Connection.default_isolation_level` accessor + which returns the dialect-level setting without performing a SQL + query. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + """ + try: + return self.dialect.get_isolation_level(self.connection) + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + + @property + def default_isolation_level(self): + """The default isolation level assigned to this :class:`.Connection`. + + This is the isolation level setting that the :class:`.Connection` + has when first procured via the :meth:`.Engine.connect` method. + This level stays in place until the + :paramref:`.Connection.execution_options.isolation_level` is used + to change the setting on a per-:class:`.Connection` basis. + + Unlike :meth:`.Connection.get_isolation_level`, this attribute is set + ahead of time from the first connection procured by the dialect, + so SQL query is not invoked when this accessor is called. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + """ + return self.dialect.default_isolation_level def _revalidate_connection(self): if self.__branch_from: return self.__branch_from._revalidate_connection() - if self.__can_reconnect and self.__invalid: if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " "transaction is rolled back") - self.__connection = self.engine.raw_connection() + self.__connection = self.engine.raw_connection(_connection=self) self.__invalid = False return self.__connection raise exc.ResourceClosedError("This Connection is closed") @@ -741,7 +834,7 @@ class Connection(Connectable): a subclass of :class:`.Executable`, such as a :func:`~.expression.select` construct * a :class:`.FunctionElement`, such as that generated - by :attr:`.func`, will be automatically wrapped in + by :data:`.func`, will be automatically wrapped in a SELECT statement, which is then executed. * a :class:`.DDLElement` object * a :class:`.DefaultGenerator` object @@ -959,9 +1052,10 @@ class Connection(Connectable): context = constructor(dialect, self, conn, *args) except Exception as e: - self._handle_dbapi_exception(e, - util.text_type(statement), parameters, - None, None) + self._handle_dbapi_exception( + e, + util.text_type(statement), parameters, + None, None) if context.compiled: context.pre_exec() @@ -985,36 +1079,39 @@ class Connection(Connectable): "%r", sql_util._repr_params(parameters, batches=10) ) + + evt_handled = False try: if context.executemany: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_executemany: - if fn(cursor, statement, parameters, context): - break - else: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_executemany: + if fn(cursor, statement, parameters, context): + evt_handled = True + break + if not evt_handled: self.dialect.do_executemany( cursor, statement, parameters, context) - elif not parameters and context.no_parameters: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute_no_params: - if fn(cursor, statement, context): - break - else: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute_no_params: + if fn(cursor, statement, context): + evt_handled = True + break + if not evt_handled: self.dialect.do_execute_no_params( cursor, statement, context) - else: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute: - if fn(cursor, statement, parameters, context): - break - else: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute: + if fn(cursor, statement, parameters, context): + evt_handled = True + break + if not evt_handled: self.dialect.do_execute( cursor, statement, @@ -1038,31 +1135,12 @@ class Connection(Connectable): if context.compiled: context.post_exec() - if context.isinsert and not context.executemany: - context.post_insert() - - # create a resultproxy, get rowcount/implicit RETURNING - # rows, close cursor if no further results pending - result = context.get_result_proxy() - if context.isinsert: - if context._is_implicit_returning: - context._fetch_implicit_returning(result) - result.close(_autoclose_connection=False) - result._metadata = None - elif not context._is_explicit_returning: + if context.is_crud: + result = context._setup_crud_result_proxy() + else: + result = context.get_result_proxy() + if result._metadata is None: result.close(_autoclose_connection=False) - result._metadata = None - elif context.isupdate and context._is_implicit_returning: - context._fetch_implicit_update_returning(result) - result.close(_autoclose_connection=False) - result._metadata = None - - elif result._metadata is None: - # no results, get rowcount - # (which requires open cursor on some drivers - # such as kintersbasdb, mxodbc), - result.rowcount - result.close(_autoclose_connection=False) if context.should_autocommit and self._root.__transaction is None: self._root._commit_impl(autocommit=True) @@ -1149,7 +1227,10 @@ class Connection(Connectable): self._is_disconnect = \ isinstance(e, self.dialect.dbapi.Error) and \ not self.closed and \ - self.dialect.is_disconnect(e, self.__connection, cursor) + self.dialect.is_disconnect( + e, + self.__connection if not self.invalidated else None, + cursor) if context: context.is_disconnect = self._is_disconnect @@ -1194,7 +1275,8 @@ class Connection(Connectable): # new handle_error event ctx = ExceptionContextImpl( - e, sqlalchemy_exception, self, cursor, statement, + e, sqlalchemy_exception, self.engine, + self, cursor, statement, parameters, context, self._is_disconnect) for fn in self.dispatch.handle_error: @@ -1236,12 +1318,65 @@ class Connection(Connectable): del self._reentrant_error if self._is_disconnect: del self._is_disconnect - dbapi_conn_wrapper = self.connection - self.engine.pool._invalidate(dbapi_conn_wrapper, e) - self.invalidate(e) + if not self.invalidated: + dbapi_conn_wrapper = self.__connection + self.engine.pool._invalidate(dbapi_conn_wrapper, e) + self.invalidate(e) if self.should_close_with_result: self.close() + @classmethod + def _handle_dbapi_exception_noconnection(cls, e, dialect, engine): + + exc_info = sys.exc_info() + + is_disconnect = dialect.is_disconnect(e, None, None) + + should_wrap = isinstance(e, dialect.dbapi.Error) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + None, + None, + e, + dialect.dbapi.Error, + connection_invalidated=is_disconnect) + else: + sqlalchemy_exception = None + + newraise = None + + if engine._has_events: + ctx = ExceptionContextImpl( + e, sqlalchemy_exception, engine, None, None, None, + None, None, is_disconnect) + for fn in engine.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if sqlalchemy_exception and \ + is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = \ + is_disconnect = ctx.is_disconnect + + if newraise: + util.raise_from_cause(newraise, exc_info) + elif should_wrap: + util.raise_from_cause( + sqlalchemy_exception, + exc_info + ) + else: + util.reraise(*exc_info) + def default_schema_name(self): return self.engine.dialect.get_default_schema_name(self) @@ -1320,8 +1455,9 @@ class ExceptionContextImpl(ExceptionContext): """Implement the :class:`.ExceptionContext` interface.""" def __init__(self, exception, sqlalchemy_exception, - connection, cursor, statement, parameters, + engine, connection, cursor, statement, parameters, context, is_disconnect): + self.engine = engine self.connection = connection self.sqlalchemy_exception = sqlalchemy_exception self.original_exception = exception @@ -1865,10 +2001,11 @@ class Engine(Connectable, log.Identified): """ - return self._connection_cls(self, - self.pool.connect(), - close_with_result=close_with_result, - **kwargs) + return self._connection_cls( + self, + self._wrap_pool_connect(self.pool.connect, None), + close_with_result=close_with_result, + **kwargs) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. @@ -1898,7 +2035,18 @@ class Engine(Connectable, log.Identified): """ return self.run_callable(self.dialect.has_table, table_name, schema) - def raw_connection(self): + def _wrap_pool_connect(self, fn, connection): + dialect = self.dialect + try: + return fn() + except dialect.dbapi.Error as e: + if connection is None: + Connection._handle_dbapi_exception_noconnection( + e, dialect, self) + else: + util.reraise(*sys.exc_info()) + + def raw_connection(self, _connection=None): """Return a "raw" DBAPI connection from the connection pool. The returned object is a proxied version of the DBAPI @@ -1909,13 +2057,18 @@ class Engine(Connectable, log.Identified): for real. This method provides direct DBAPI connection access for - special situations. In most situations, the :class:`.Connection` - object should be used, which is procured using the - :meth:`.Engine.connect` method. + special situations when the API provided by :class:`.Connection` + is not needed. When a :class:`.Connection` object is already + present, the DBAPI connection is available using + the :attr:`.Connection.connection` accessor. - """ + .. seealso:: - return self.pool.unique_connection() + :ref:`dbapi_connections` + + """ + return self._wrap_pool_connect( + self.pool.unique_connection, _connection) class OptionEngine(Engine): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a5af6ff19..f6c2263b3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -452,14 +452,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): isinsert = False isupdate = False isdelete = False + is_crud = False isddl = False executemany = False result_map = None compiled = None statement = None - postfetch_cols = None - prefetch_cols = None - returning_cols = None _is_implicit_returning = False _is_explicit_returning = False @@ -515,10 +513,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not compiled.can_execute: raise exc.ArgumentError("Not an executable clause") - self.execution_options = compiled.statement._execution_options - if connection._execution_options: - self.execution_options = dict(self.execution_options) - self.execution_options.update(connection._execution_options) + self.execution_options = compiled.statement._execution_options.union( + connection._execution_options) # compiled clauseelement. process bind params, process table defaults, # track collections used by ResultProxy to target and process results @@ -548,6 +544,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() if self.isinsert or self.isupdate or self.isdelete: + self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( compiled.returning and not compiled.statement._returning) @@ -681,10 +678,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.execution_options.get("no_parameters", False) @util.memoized_property - def is_crud(self): - return self.isinsert or self.isupdate or self.isdelete - - @util.memoized_property def should_autocommit(self): autocommit = self.execution_options.get('autocommit', not self.compiled and @@ -799,52 +792,84 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - def post_insert(self): - + def _setup_crud_result_proxy(self): + if self.isinsert and \ + not self.executemany: + if not self._is_implicit_returning and \ + not self.compiled.inline and \ + self.dialect.postfetch_lastrowid: + + self._setup_ins_pk_from_lastrowid() + + elif not self._is_implicit_returning: + self._setup_ins_pk_from_empty() + + result = self.get_result_proxy() + + if self.isinsert: + if self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + self._setup_ins_pk_from_implicit_returning(row) + result.close(_autoclose_connection=False) + result._metadata = None + elif not self._is_explicit_returning: + result.close(_autoclose_connection=False) + result._metadata = None + elif self.isupdate and self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + result.close(_autoclose_connection=False) + result._metadata = None + + elif result._metadata is None: + # no results, get rowcount + # (which requires open cursor on some drivers + # such as kintersbasdb, mxodbc) + result.rowcount + result.close(_autoclose_connection=False) + return result + + def _setup_ins_pk_from_lastrowid(self): key_getter = self.compiled._key_getters_for_crud_column[2] table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] + + lastrowid = self.get_lastrowid() + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + proc = autoinc_col.type._cached_result_processor( + self.dialect, None) + if proc is not None: + lastrowid = proc(lastrowid) + self.inserted_primary_key = [ + lastrowid if c is autoinc_col else + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] - if not self._is_implicit_returning and \ - not self._is_explicit_returning and \ - not self.compiled.inline and \ - self.dialect.postfetch_lastrowid: - - lastrowid = self.get_lastrowid() - autoinc_col = table._autoincrement_column - if autoinc_col is not None: - # apply type post processors to the lastrowid - proc = autoinc_col.type._cached_result_processor( - self.dialect, None) - if proc is not None: - lastrowid = proc(lastrowid) - self.inserted_primary_key = [ - lastrowid if c is autoinc_col else - self.compiled_parameters[0].get(key_getter(c), None) - for c in table.primary_key - ] - else: - self.inserted_primary_key = [ - self.compiled_parameters[0].get(key_getter(c), None) - for c in table.primary_key - ] - - def _fetch_implicit_returning(self, resultproxy): + def _setup_ins_pk_from_empty(self): + key_getter = self.compiled._key_getters_for_crud_column[2] table = self.compiled.statement.table - row = resultproxy.fetchone() - - ipk = [] - for c, v in zip(table.primary_key, self.inserted_primary_key): - if v is not None: - ipk.append(v) - else: - ipk.append(row[c]) + compiled_params = self.compiled_parameters[0] + self.inserted_primary_key = [ + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] - self.inserted_primary_key = ipk - self.returned_defaults = row + def _setup_ins_pk_from_implicit_returning(self, row): + key_getter = self.compiled._key_getters_for_crud_column[2] + table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] - def _fetch_implicit_update_returning(self, resultproxy): - row = resultproxy.fetchone() - self.returned_defaults = row + self.inserted_primary_key = [ + row[col] if value is None else value + for col, value in [ + (col, compiled_params.get(key_getter(col), None)) + for col in table.primary_key + ] + ] def lastrow_has_defaults(self): return (self.isinsert or self.isupdate) and \ @@ -956,14 +981,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - prefetch = self.compiled.prefetch self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] for c in prefetch: if self.isinsert: - val = self.get_insert_default(c) + if c.default and \ + not c.default.is_sequence and c.default.is_scalar: + val = c.default.arg + else: + val = self.get_insert_default(c) else: val = self.get_update_default(c) @@ -972,6 +1000,4 @@ class DefaultExecutionContext(interfaces.ExecutionContext): del self.current_parameters - - DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 0ad2efae0..5f0d74328 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -654,17 +654,82 @@ class Dialect(object): return None def reset_isolation_level(self, dbapi_conn): - """Given a DBAPI connection, revert its isolation to the default.""" + """Given a DBAPI connection, revert its isolation to the default. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`.Connection` and + :class:`.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`.Connection` isolation level + + :paramref:`.create_engine.isolation_level` - + set per :class:`.Engine` isolation level + + """ raise NotImplementedError() def set_isolation_level(self, dbapi_conn, level): - """Given a DBAPI connection, set its isolation level.""" + """Given a DBAPI connection, set its isolation level. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`.Connection` and + :class:`.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`.Connection` isolation level + + :paramref:`.create_engine.isolation_level` - + set per :class:`.Engine` isolation level + + """ raise NotImplementedError() def get_isolation_level(self, dbapi_conn): - """Given a DBAPI connection, return its isolation level.""" + """Given a DBAPI connection, return its isolation level. + + When working with a :class:`.Connection` object, the corresponding + DBAPI connection may be procured using the + :attr:`.Connection.connection` accessor. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`.Connection` and + :class:`.Engine` isolation level facilities; + these APIs should be preferred for most typical use cases. + + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`.Connection` isolation level + + :paramref:`.create_engine.isolation_level` - + set per :class:`.Engine` isolation level + + + """ raise NotImplementedError() @@ -917,7 +982,23 @@ class ExceptionContext(object): connection = None """The :class:`.Connection` in use during the exception. - This member is always present. + This member is present, except in the case of a failure when + first connecting. + + .. seealso:: + + :attr:`.ExceptionContext.engine` + + + """ + + engine = None + """The :class:`.Engine` in use during the exception. + + This member should always be present, even in the case of a failure + when first connecting. + + .. versionadded:: 1.0.0 """ diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 2a1def86a..6e102aad6 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -173,7 +173,14 @@ class Inspector(object): passed as ``None``. For special quoting, use :class:`.quoted_name`. :param order_by: Optional, may be the string "foreign_key" to sort - the result on foreign key dependencies. + the result on foreign key dependencies. Does not automatically + resolve cycles, and will raise :class:`.CircularDependencyError` + if cycles exist. + + .. deprecated:: 1.0.0 - see + :meth:`.Inspector.get_sorted_table_and_fkc_names` for a version + of this which resolves foreign key cycles between tables + automatically. .. versionchanged:: 0.8 the "foreign_key" sorting sorts tables in order of dependee to dependent; that is, in creation @@ -183,6 +190,8 @@ class Inspector(object): .. seealso:: + :meth:`.Inspector.get_sorted_table_and_fkc_names` + :attr:`.MetaData.sorted_tables` """ @@ -201,6 +210,64 @@ class Inspector(object): tnames = list(topological.sort(tuples, tnames)) return tnames + def get_sorted_table_and_fkc_names(self, schema=None): + """Return dependency-sorted table and foreign key constraint names in + referred to within a particular schema. + + This will yield 2-tuples of + ``(tablename, [(tname, fkname), (tname, fkname), ...])`` + consisting of table names in CREATE order grouped with the foreign key + constraint names that are not detected as belonging to a cycle. + The final element + will be ``(None, [(tname, fkname), (tname, fkname), ..])`` + which will consist of remaining + foreign key constraint names that would require a separate CREATE + step after-the-fact, based on dependencies between tables. + + .. versionadded:: 1.0.- + + .. seealso:: + + :meth:`.Inspector.get_table_names` + + :func:`.sort_tables_and_constraints` - similar method which works + with an already-given :class:`.MetaData`. + + """ + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names( + self.bind, schema, info_cache=self.info_cache) + else: + tnames = self.engine.table_names(schema) + + tuples = set() + remaining_fkcs = set() + + fknames_for_table = {} + for tname in tnames: + fkeys = self.get_foreign_keys(tname, schema) + fknames_for_table[tname] = set( + [fk['name'] for fk in fkeys] + ) + for fkey in fkeys: + if tname != fkey['referred_table']: + tuples.add((fkey['referred_table'], tname)) + try: + candidate_sort = list(topological.sort(tuples, tnames)) + except exc.CircularDependencyError as err: + for edge in err.edges: + tuples.remove(edge) + remaining_fkcs.update( + (edge[1], fkc) + for fkc in fknames_for_table[edge[1]] + ) + + candidate_sort = list(topological.sort(tuples, tnames)) + return [ + (tname, fknames_for_table[tname].difference(remaining_fkcs)) + for tname in candidate_sort + ] + [(None, list(remaining_fkcs))] + def get_temp_table_names(self): """return a list of temporary table names for the current bind. @@ -394,6 +461,12 @@ class Inspector(object): unique boolean + dialect_options + dict of dialect-specific index options. May not be present + for all dialects. + + .. versionadded:: 1.0.0 + :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -642,6 +715,8 @@ class Inspector(object): columns = index_d['column_names'] unique = index_d['unique'] flavor = index_d.get('type', 'index') + dialect_options = index_d.get('dialect_options', {}) + duplicates = index_d.get('duplicates_constraint') if include_columns and \ not set(columns).issubset(include_columns): @@ -667,7 +742,10 @@ class Inspector(object): else: idx_cols.append(idx_col) - sa_schema.Index(name, *idx_cols, **dict(unique=unique)) + sa_schema.Index( + name, *idx_cols, + **dict(list(dialect_options.items()) + [('unique', unique)]) + ) def _reflect_unique_constraints( self, table_name, schema, table, cols_by_orig_name, diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 398ef8df6..fd665ad03 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -86,16 +86,7 @@ class DefaultEngineStrategy(EngineStrategy): pool = pop_kwarg('pool', None) if pool is None: def connect(): - try: - return dialect.connect(*cargs, **cparams) - except dialect.dbapi.Error as e: - invalidated = dialect.is_disconnect(e, None, None) - util.raise_from_cause( - exc.DBAPIError.instance( - None, None, e, dialect.dbapi.Error, - connection_invalidated=invalidated - ) - ) + return dialect.connect(*cargs, **cparams) creator = pop_kwarg('creator', connect) diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 637523a0e..e64ab09f4 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -59,7 +59,10 @@ class TLEngine(base.Engine): # guards against pool-level reapers, if desired. # or not connection.connection.is_valid: connection = self._tl_connection_cls( - self, self.pool.connect(), **kw) + self, + self._wrap_pool_connect( + self.pool.connect, connection), + **kw) self._connections.conn = weakref.ref(connection) return connection._increment_connect() diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index be2a82208..ed1dca644 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -40,17 +40,21 @@ import weakref import collections -class RefCollection(object): - @util.memoized_property - def ref(self): +class RefCollection(util.MemoizedSlots): + __slots__ = 'ref', + + def _memoized_attr_ref(self): return weakref.ref(self, registry._collection_gced) -class _DispatchDescriptor(RefCollection): - """Class-level attributes on :class:`._Dispatch` classes.""" +class _ClsLevelDispatch(RefCollection): + """Class-level events on :class:`._Dispatch` classes.""" + + __slots__ = ('name', 'arg_names', 'has_kw', + 'legacy_signatures', '_clslevel') def __init__(self, parent_dispatch_cls, fn): - self.__name__ = fn.__name__ + self.name = fn.__name__ argspec = util.inspect_getargspec(fn) self.arg_names = argspec.args[1:] self.has_kw = bool(argspec.keywords) @@ -60,11 +64,9 @@ class _DispatchDescriptor(RefCollection): key=lambda s: s[0] ) )) - self.__doc__ = fn.__doc__ = legacy._augment_fn_docs( - self, parent_dispatch_cls, fn) + fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn) self._clslevel = weakref.WeakKeyDictionary() - self._empty_listeners = weakref.WeakKeyDictionary() def _adjust_fn_spec(self, fn, named): if named: @@ -152,34 +154,23 @@ class _DispatchDescriptor(RefCollection): def for_modify(self, obj): """Return an event collection which can be modified. - For _DispatchDescriptor at the class level of + For _ClsLevelDispatch at the class level of a dispatcher, this returns self. """ return self - def __get__(self, obj, cls): - if obj is None: - return self - elif obj._parent_cls in self._empty_listeners: - ret = self._empty_listeners[obj._parent_cls] - else: - self._empty_listeners[obj._parent_cls] = ret = \ - _EmptyListener(self, obj._parent_cls) - # assigning it to __dict__ means - # memoized for fast re-access. but more memory. - obj.__dict__[self.__name__] = ret - return ret +class _InstanceLevelDispatch(RefCollection): + __slots__ = () -class _HasParentDispatchDescriptor(object): def _adjust_fn_spec(self, fn, named): return self.parent._adjust_fn_spec(fn, named) -class _EmptyListener(_HasParentDispatchDescriptor): - """Serves as a class-level interface to the events - served by a _DispatchDescriptor, when there are no +class _EmptyListener(_InstanceLevelDispatch): + """Serves as a proxy interface to the events + served by a _ClsLevelDispatch, when there are no instance-level events present. Is replaced by _ListenerCollection when instance-level @@ -187,14 +178,17 @@ class _EmptyListener(_HasParentDispatchDescriptor): """ + propagate = frozenset() + listeners = () + + __slots__ = 'parent', 'parent_listeners', 'name' + def __init__(self, parent, target_cls): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) - self.parent = parent # _DispatchDescriptor + self.parent = parent # _ClsLevelDispatch self.parent_listeners = parent._clslevel[target_cls] - self.name = parent.__name__ - self.propagate = frozenset() - self.listeners = () + self.name = parent.name def for_modify(self, obj): """Return an event collection which can be modified. @@ -205,9 +199,11 @@ class _EmptyListener(_HasParentDispatchDescriptor): and returns it. """ - result = _ListenerCollection(self.parent, obj._parent_cls) - if obj.__dict__[self.name] is self: - obj.__dict__[self.name] = result + result = _ListenerCollection(self.parent, obj._instance_cls) + if getattr(obj, self.name) is self: + setattr(obj, self.name, result) + else: + assert isinstance(getattr(obj, self.name), _JoinedListener) return result def _needs_modify(self, *args, **kw): @@ -233,11 +229,12 @@ class _EmptyListener(_HasParentDispatchDescriptor): __nonzero__ = __bool__ -class _CompoundListener(_HasParentDispatchDescriptor): +class _CompoundListener(_InstanceLevelDispatch): _exec_once = False - @util.memoized_property - def _exec_once_mutex(self): + __slots__ = '_exec_once_mutex', + + def _memoized_attr__exec_once_mutex(self): return threading.Lock() def exec_once(self, *args, **kw): @@ -272,7 +269,7 @@ class _CompoundListener(_HasParentDispatchDescriptor): __nonzero__ = __bool__ -class _ListenerCollection(RefCollection, _CompoundListener): +class _ListenerCollection(_CompoundListener): """Instance-level attributes on instances of :class:`._Dispatch`. Represents a collection of listeners. @@ -282,12 +279,14 @@ class _ListenerCollection(RefCollection, _CompoundListener): """ + __slots__ = 'parent_listeners', 'parent', 'name', 'listeners', 'propagate' + def __init__(self, parent, target_cls): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) self.parent_listeners = parent._clslevel[target_cls] self.parent = parent - self.name = parent.__name__ + self.name = parent.name self.listeners = collections.deque() self.propagate = set() @@ -339,24 +338,11 @@ class _ListenerCollection(RefCollection, _CompoundListener): self.listeners.clear() -class _JoinedDispatchDescriptor(object): - def __init__(self, name): - self.name = name - - def __get__(self, obj, cls): - if obj is None: - return self - else: - obj.__dict__[self.name] = ret = _JoinedListener( - obj.parent, self.name, - getattr(obj.local, self.name) - ) - return ret - - class _JoinedListener(_CompoundListener): _exec_once = False + __slots__ = 'parent', 'name', 'local', 'parent_listeners' + def __init__(self, parent, name, local): self.parent = parent self.name = name diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 4925f6ffa..2d5468886 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -17,9 +17,11 @@ instances of ``_Dispatch``. """ from __future__ import absolute_import +import weakref + from .. import util -from .attr import _JoinedDispatchDescriptor, \ - _EmptyListener, _DispatchDescriptor +from .attr import _JoinedListener, \ + _EmptyListener, _ClsLevelDispatch _registrars = util.defaultdict(list) @@ -34,10 +36,11 @@ class _UnpickleDispatch(object): """ - def __call__(self, _parent_cls): - for cls in _parent_cls.__mro__: + def __call__(self, _instance_cls): + for cls in _instance_cls.__mro__: if 'dispatch' in cls.__dict__: - return cls.__dict__['dispatch'].dispatch_cls(_parent_cls) + return cls.__dict__['dispatch'].\ + dispatch_cls._for_class(_instance_cls) else: raise AttributeError("No class with a 'dispatch' member present.") @@ -62,16 +65,53 @@ class _Dispatch(object): """ - _events = None - """reference the :class:`.Events` class which this - :class:`._Dispatch` is created for.""" + # in one ORM edge case, an attribute is added to _Dispatch, + # so __dict__ is used in just that case and potentially others. + __slots__ = '_parent', '_instance_cls', '__dict__', '_empty_listeners' + + _empty_listener_reg = weakref.WeakKeyDictionary() + + def __init__(self, parent, instance_cls=None): + self._parent = parent + self._instance_cls = instance_cls + if instance_cls: + try: + self._empty_listeners = self._empty_listener_reg[instance_cls] + except KeyError: + self._empty_listeners = \ + self._empty_listener_reg[instance_cls] = dict( + (ls.name, _EmptyListener(ls, instance_cls)) + for ls in parent._event_descriptors + ) + else: + self._empty_listeners = {} + + def __getattr__(self, name): + # assign EmptyListeners as attributes on demand + # to reduce startup time for new dispatch objects + try: + ls = self._empty_listeners[name] + except KeyError: + raise AttributeError(name) + else: + setattr(self, ls.name, ls) + return ls + + @property + def _event_descriptors(self): + for k in self._event_names: + yield getattr(self, k) + + def _for_class(self, instance_cls): + return self.__class__(self, instance_cls) - def __init__(self, _parent_cls): - self._parent_cls = _parent_cls + def _for_instance(self, instance): + instance_cls = instance.__class__ + return self._for_class(instance_cls) - @util.classproperty - def _listen(cls): - return cls._events._listen + @property + def _listen(self): + return self._events._listen def _join(self, other): """Create a 'join' of this :class:`._Dispatch` and another. @@ -83,36 +123,27 @@ class _Dispatch(object): if '_joined_dispatch_cls' not in self.__class__.__dict__: cls = type( "Joined%s" % self.__class__.__name__, - (_JoinedDispatcher, self.__class__), {} + (_JoinedDispatcher, ), {'__slots__': self._event_names} ) - for ls in _event_descriptors(self): - setattr(cls, ls.name, _JoinedDispatchDescriptor(ls.name)) self.__class__._joined_dispatch_cls = cls return self._joined_dispatch_cls(self, other) def __reduce__(self): - return _UnpickleDispatch(), (self._parent_cls, ) + return _UnpickleDispatch(), (self._instance_cls, ) def _update(self, other, only_propagate=True): """Populate from the listeners in another :class:`_Dispatch` object.""" - - for ls in _event_descriptors(other): + for ls in other._event_descriptors: if isinstance(ls, _EmptyListener): continue getattr(self, ls.name).\ for_modify(self)._update(ls, only_propagate=only_propagate) - @util.hybridmethod def _clear(self): - for attr in dir(self): - if _is_event_name(attr): - getattr(self, attr).for_modify(self).clear() - - -def _event_descriptors(target): - return [getattr(target, k) for k in dir(target) if _is_event_name(k)] + for ls in self._event_descriptors: + ls.for_modify(self).clear() class _EventMeta(type): @@ -131,26 +162,37 @@ def _create_dispatcher_class(cls, classname, bases, dict_): # there's all kinds of ways to do this, # i.e. make a Dispatch class that shares the '_listen' method # of the Event class, this is the straight monkeypatch. - dispatch_base = getattr(cls, 'dispatch', _Dispatch) + if hasattr(cls, 'dispatch'): + dispatch_base = cls.dispatch.__class__ + else: + dispatch_base = _Dispatch + + event_names = [k for k in dict_ if _is_event_name(k)] dispatch_cls = type("%sDispatch" % classname, - (dispatch_base, ), {}) - cls._set_dispatch(cls, dispatch_cls) + (dispatch_base, ), {'__slots__': event_names}) + + dispatch_cls._event_names = event_names - for k in dict_: - if _is_event_name(k): - setattr(dispatch_cls, k, _DispatchDescriptor(cls, dict_[k])) - _registrars[k].append(cls) + dispatch_inst = cls._set_dispatch(cls, dispatch_cls) + for k in dispatch_cls._event_names: + setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) + _registrars[k].append(cls) + + for super_ in dispatch_cls.__bases__: + if issubclass(super_, _Dispatch) and super_ is not _Dispatch: + for ls in super_._events.dispatch._event_descriptors: + setattr(dispatch_inst, ls.name, ls) + dispatch_cls._event_names.append(ls.name) if getattr(cls, '_dispatch_target', None): cls._dispatch_target.dispatch = dispatcher(cls) def _remove_dispatcher(cls): - for k in dir(cls): - if _is_event_name(k): - _registrars[k].remove(cls) - if not _registrars[k]: - del _registrars[k] + for k in cls.dispatch._event_names: + _registrars[k].remove(cls) + if not _registrars[k]: + del _registrars[k] class Events(util.with_metaclass(_EventMeta, object)): @@ -163,17 +205,30 @@ class Events(util.with_metaclass(_EventMeta, object)): # "self.dispatch._events.<utilitymethod>" # @staticemethod to allow easy "super" calls while in a metaclass # constructor. - cls.dispatch = dispatch_cls + cls.dispatch = dispatch_cls(None) dispatch_cls._events = cls + return cls.dispatch @classmethod def _accept_with(cls, target): # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. if hasattr(target, 'dispatch') and ( - isinstance(target.dispatch, cls.dispatch) or - isinstance(target.dispatch, type) and - issubclass(target.dispatch, cls.dispatch) + + isinstance(target.dispatch, cls.dispatch.__class__) or + + + ( + isinstance(target.dispatch, type) and + isinstance(target.dispatch, cls.dispatch.__class__) + ) or + + ( + isinstance(target.dispatch, _JoinedDispatcher) and + isinstance(target.dispatch.parent, cls.dispatch.__class__) + ) + + ): return target else: @@ -195,10 +250,24 @@ class Events(util.with_metaclass(_EventMeta, object)): class _JoinedDispatcher(object): """Represent a connection between two _Dispatch objects.""" + __slots__ = 'local', 'parent', '_instance_cls' + def __init__(self, local, parent): self.local = local self.parent = parent - self._parent_cls = local._parent_cls + self._instance_cls = self.local._instance_cls + + def __getattr__(self, name): + # assign _JoinedListeners as attributes on demand + # to reduce startup time for new dispatch objects + ls = getattr(self.local, name) + jl = _JoinedListener(self.parent, ls.name, ls) + setattr(self, ls.name, jl) + return jl + + @property + def _listen(self): + return self.parent._listen class dispatcher(object): @@ -216,5 +285,5 @@ class dispatcher(object): def __get__(self, obj, cls): if obj is None: return self.dispatch_cls - obj.__dict__['dispatch'] = disp = self.dispatch_cls(cls) + obj.__dict__['dispatch'] = disp = self.dispatch_cls._for_instance(obj) return disp diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 3b1519cb6..7513c7d4d 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -22,8 +22,8 @@ def _legacy_signature(since, argnames, converter=None): return leg -def _wrap_fn_for_legacy(dispatch_descriptor, fn, argspec): - for since, argnames, conv in dispatch_descriptor.legacy_signatures: +def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): + for since, argnames, conv in dispatch_collection.legacy_signatures: if argnames[-1] == "**kw": has_kw = True argnames = argnames[0:-1] @@ -40,7 +40,7 @@ def _wrap_fn_for_legacy(dispatch_descriptor, fn, argspec): return fn(*conv(*args)) else: def wrap_leg(*args, **kw): - argdict = dict(zip(dispatch_descriptor.arg_names, args)) + argdict = dict(zip(dispatch_collection.arg_names, args)) args = [argdict[name] for name in argnames] if has_kw: return fn(*args, **kw) @@ -58,16 +58,16 @@ def _indent(text, indent): ) -def _standard_listen_example(dispatch_descriptor, sample_target, fn): +def _standard_listen_example(dispatch_collection, sample_target, fn): example_kw_arg = _indent( "\n".join( "%(arg)s = kw['%(arg)s']" % {"arg": arg} - for arg in dispatch_descriptor.arg_names[0:2] + for arg in dispatch_collection.arg_names[0:2] ), " ") - if dispatch_descriptor.legacy_signatures: + if dispatch_collection.legacy_signatures: current_since = max(since for since, args, conv - in dispatch_descriptor.legacy_signatures) + in dispatch_collection.legacy_signatures) else: current_since = None text = ( @@ -80,7 +80,7 @@ def _standard_listen_example(dispatch_descriptor, sample_target, fn): "\n # ... (event handling logic) ...\n" ) - if len(dispatch_descriptor.arg_names) > 3: + if len(dispatch_collection.arg_names) > 3: text += ( "\n# named argument style (new in 0.9)\n" @@ -96,17 +96,17 @@ def _standard_listen_example(dispatch_descriptor, sample_target, fn): "current_since": " (arguments as of %s)" % current_since if current_since else "", "event_name": fn.__name__, - "has_kw_arguments": ", **kw" if dispatch_descriptor.has_kw else "", - "named_event_arguments": ", ".join(dispatch_descriptor.arg_names), + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", + "named_event_arguments": ", ".join(dispatch_collection.arg_names), "example_kw_arg": example_kw_arg, "sample_target": sample_target } return text -def _legacy_listen_examples(dispatch_descriptor, sample_target, fn): +def _legacy_listen_examples(dispatch_collection, sample_target, fn): text = "" - for since, args, conv in dispatch_descriptor.legacy_signatures: + for since, args, conv in dispatch_collection.legacy_signatures: text += ( "\n# legacy calling style (pre-%(since)s)\n" "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" @@ -117,7 +117,7 @@ def _legacy_listen_examples(dispatch_descriptor, sample_target, fn): "since": since, "event_name": fn.__name__, "has_kw_arguments": " **kw" - if dispatch_descriptor.has_kw else "", + if dispatch_collection.has_kw else "", "named_event_arguments": ", ".join(args), "sample_target": sample_target } @@ -125,8 +125,8 @@ def _legacy_listen_examples(dispatch_descriptor, sample_target, fn): return text -def _version_signature_changes(dispatch_descriptor): - since, args, conv = dispatch_descriptor.legacy_signatures[0] +def _version_signature_changes(dispatch_collection): + since, args, conv = dispatch_collection.legacy_signatures[0] return ( "\n.. versionchanged:: %(since)s\n" " The ``%(event_name)s`` event now accepts the \n" @@ -135,14 +135,14 @@ def _version_signature_changes(dispatch_descriptor): " signature(s) listed above will be automatically \n" " adapted to the new signature." % { "since": since, - "event_name": dispatch_descriptor.__name__, - "named_event_arguments": ", ".join(dispatch_descriptor.arg_names), - "has_kw_arguments": ", **kw" if dispatch_descriptor.has_kw else "" + "event_name": dispatch_collection.name, + "named_event_arguments": ", ".join(dispatch_collection.arg_names), + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "" } ) -def _augment_fn_docs(dispatch_descriptor, parent_dispatch_cls, fn): +def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn): header = ".. container:: event_signatures\n\n"\ " Example argument forms::\n"\ "\n" @@ -152,16 +152,16 @@ def _augment_fn_docs(dispatch_descriptor, parent_dispatch_cls, fn): header + _indent( _standard_listen_example( - dispatch_descriptor, sample_target, fn), + dispatch_collection, sample_target, fn), " " * 8) ) - if dispatch_descriptor.legacy_signatures: + if dispatch_collection.legacy_signatures: text += _indent( _legacy_listen_examples( - dispatch_descriptor, sample_target, fn), + dispatch_collection, sample_target, fn), " " * 8) - text += _version_signature_changes(dispatch_descriptor) + text += _version_signature_changes(dispatch_collection) return util.inject_docstring_text(fn.__doc__, text, diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index 5b422c401..ebc0e6d18 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -37,7 +37,7 @@ listener collections and the listener fn contained _collection_to_key = collections.defaultdict(dict) """ -Given a _ListenerCollection or _DispatchDescriptor, can locate +Given a _ListenerCollection or _ClsLevelListener, can locate all the original listen() arguments and the listener fn contained ref(listenercollection) -> { @@ -140,6 +140,10 @@ class _EventKey(object): """Represent :func:`.listen` arguments. """ + __slots__ = ( + 'target', 'identifier', 'fn', 'fn_key', 'fn_wrap', 'dispatch_target' + ) + def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None): self.target = target @@ -187,9 +191,9 @@ class _EventKey(object): target, identifier, fn = \ self.dispatch_target, self.identifier, self._listen_fn - dispatch_descriptor = getattr(target.dispatch, identifier) + dispatch_collection = getattr(target.dispatch, identifier) - adjusted_fn = dispatch_descriptor._adjust_fn_spec(fn, named) + adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named) self = self.with_wrapper(adjusted_fn) @@ -226,13 +230,13 @@ class _EventKey(object): target, identifier, fn = \ self.dispatch_target, self.identifier, self._listen_fn - dispatch_descriptor = getattr(target.dispatch, identifier) + dispatch_collection = getattr(target.dispatch, identifier) if insert: - dispatch_descriptor.\ + dispatch_collection.\ for_modify(target.dispatch).insert(self, propagate) else: - dispatch_descriptor.\ + dispatch_collection.\ for_modify(target.dispatch).append(self, propagate) @property diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index c144902cd..8600c20f5 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -739,6 +739,12 @@ class ConnectionEvents(event.Events): .. versionadded:: 0.9.7 Added the :meth:`.ConnectionEvents.handle_error` hook. + .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is now + invoked when an :class:`.Engine` fails during the initial + call to :meth:`.Engine.connect`, as well as when a + :class:`.Connection` object encounters an error during a + reconnect operation. + .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is not fired off when a dialect makes use of the ``skip_user_error_events`` execution option. This is used diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 3271d09d4..d6355a212 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -63,7 +63,7 @@ class CircularDependencyError(SQLAlchemyError): """ def __init__(self, message, cycles, edges, msg=None): if msg is None: - message += " Cycles: %r all edges: %r" % (cycles, edges) + message += " (%s)" % ", ".join(repr(s) for s in cycles) else: message = msg SQLAlchemyError.__init__(self, message) diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 1aa68ac32..bb08ce9ba 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -86,7 +86,7 @@ ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') """ -class AssociationProxy(interfaces.InspectionAttr): +class AssociationProxy(interfaces.InspectionAttrInfo): """A descriptor that presents a read/write view of an object attribute.""" is_attribute = False diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index 2b611252a..cbde6f9d2 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -5,1377 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -""" -Synopsis -======== - -SQLAlchemy object-relational configuration involves the -combination of :class:`.Table`, :func:`.mapper`, and class -objects to define a mapped class. -:mod:`~sqlalchemy.ext.declarative` allows all three to be -expressed at once within the class declaration. As much as -possible, regular SQLAlchemy schema and ORM constructs are -used directly, so that configuration between "classical" ORM -usage and declarative remain highly similar. - -As a simple example:: - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base() - - class SomeClass(Base): - __tablename__ = 'some_table' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - -Above, the :func:`declarative_base` callable returns a new base class from -which all mapped classes should inherit. When the class definition is -completed, a new :class:`.Table` and :func:`.mapper` will have been generated. - -The resulting table and mapper are accessible via -``__table__`` and ``__mapper__`` attributes on the -``SomeClass`` class:: - - # access the mapped Table - SomeClass.__table__ - - # access the Mapper - SomeClass.__mapper__ - -Defining Attributes -=================== - -In the previous example, the :class:`.Column` objects are -automatically named with the name of the attribute to which they are -assigned. - -To name columns explicitly with a name distinct from their mapped attribute, -just give the column a name. Below, column "some_table_id" is mapped to the -"id" attribute of `SomeClass`, but in SQL will be represented as -"some_table_id":: - - class SomeClass(Base): - __tablename__ = 'some_table' - id = Column("some_table_id", Integer, primary_key=True) - -Attributes may be added to the class after its construction, and they will be -added to the underlying :class:`.Table` and -:func:`.mapper` definitions as appropriate:: - - SomeClass.data = Column('data', Unicode) - SomeClass.related = relationship(RelatedInfo) - -Classes which are constructed using declarative can interact freely -with classes that are mapped explicitly with :func:`.mapper`. - -It is recommended, though not required, that all tables -share the same underlying :class:`~sqlalchemy.schema.MetaData` object, -so that string-configured :class:`~sqlalchemy.schema.ForeignKey` -references can be resolved without issue. - -Accessing the MetaData -======================= - -The :func:`declarative_base` base class contains a -:class:`.MetaData` object where newly defined -:class:`.Table` objects are collected. This object is -intended to be accessed directly for -:class:`.MetaData`-specific operations. Such as, to issue -CREATE statements for all tables:: - - engine = create_engine('sqlite://') - Base.metadata.create_all(engine) - -:func:`declarative_base` can also receive a pre-existing -:class:`.MetaData` object, which allows a -declarative setup to be associated with an already -existing traditional collection of :class:`~sqlalchemy.schema.Table` -objects:: - - mymetadata = MetaData() - Base = declarative_base(metadata=mymetadata) - - -.. _declarative_configuring_relationships: - -Configuring Relationships -========================= - -Relationships to other classes are done in the usual way, with the added -feature that the class specified to :func:`~sqlalchemy.orm.relationship` -may be a string name. The "class registry" associated with ``Base`` -is used at mapper compilation time to resolve the name into the actual -class object, which is expected to have been defined once the mapper -configuration is used:: - - class User(Base): - __tablename__ = 'users' - - id = Column(Integer, primary_key=True) - name = Column(String(50)) - addresses = relationship("Address", backref="user") - - class Address(Base): - __tablename__ = 'addresses' - - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - -Column constructs, since they are just that, are immediately usable, -as below where we define a primary join condition on the ``Address`` -class using them:: - - class Address(Base): - __tablename__ = 'addresses' - - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - user = relationship(User, primaryjoin=user_id == User.id) - -In addition to the main argument for :func:`~sqlalchemy.orm.relationship`, -other arguments which depend upon the columns present on an as-yet -undefined class may also be specified as strings. These strings are -evaluated as Python expressions. The full namespace available within -this evaluation includes all classes mapped for this declarative base, -as well as the contents of the ``sqlalchemy`` package, including -expression functions like :func:`~sqlalchemy.sql.expression.desc` and -:attr:`~sqlalchemy.sql.expression.func`:: - - class User(Base): - # .... - addresses = relationship("Address", - order_by="desc(Address.email)", - primaryjoin="Address.user_id==User.id") - -For the case where more than one module contains a class of the same name, -string class names can also be specified as module-qualified paths -within any of these string expressions:: - - class User(Base): - # .... - addresses = relationship("myapp.model.address.Address", - order_by="desc(myapp.model.address.Address.email)", - primaryjoin="myapp.model.address.Address.user_id==" - "myapp.model.user.User.id") - -The qualified path can be any partial path that removes ambiguity between -the names. For example, to disambiguate between -``myapp.model.address.Address`` and ``myapp.model.lookup.Address``, -we can specify ``address.Address`` or ``lookup.Address``:: - - class User(Base): - # .... - addresses = relationship("address.Address", - order_by="desc(address.Address.email)", - primaryjoin="address.Address.user_id==" - "User.id") - -.. versionadded:: 0.8 - module-qualified paths can be used when specifying string arguments - with Declarative, in order to specify specific modules. - -Two alternatives also exist to using string-based attributes. A lambda -can also be used, which will be evaluated after all mappers have been -configured:: - - class User(Base): - # ... - addresses = relationship(lambda: Address, - order_by=lambda: desc(Address.email), - primaryjoin=lambda: Address.user_id==User.id) - -Or, the relationship can be added to the class explicitly after the classes -are available:: - - User.addresses = relationship(Address, - primaryjoin=Address.user_id==User.id) - - - -.. _declarative_many_to_many: - -Configuring Many-to-Many Relationships -====================================== - -Many-to-many relationships are also declared in the same way -with declarative as with traditional mappings. The -``secondary`` argument to -:func:`.relationship` is as usual passed a -:class:`.Table` object, which is typically declared in the -traditional way. The :class:`.Table` usually shares -the :class:`.MetaData` object used by the declarative base:: - - keywords = Table( - 'keywords', Base.metadata, - Column('author_id', Integer, ForeignKey('authors.id')), - Column('keyword_id', Integer, ForeignKey('keywords.id')) - ) - - class Author(Base): - __tablename__ = 'authors' - id = Column(Integer, primary_key=True) - keywords = relationship("Keyword", secondary=keywords) - -Like other :func:`~sqlalchemy.orm.relationship` arguments, a string is accepted -as well, passing the string name of the table as defined in the -``Base.metadata.tables`` collection:: - - class Author(Base): - __tablename__ = 'authors' - id = Column(Integer, primary_key=True) - keywords = relationship("Keyword", secondary="keywords") - -As with traditional mapping, its generally not a good idea to use -a :class:`.Table` as the "secondary" argument which is also mapped to -a class, unless the :func:`.relationship` is declared with ``viewonly=True``. -Otherwise, the unit-of-work system may attempt duplicate INSERT and -DELETE statements against the underlying table. - -.. _declarative_sql_expressions: - -Defining SQL Expressions -======================== - -See :ref:`mapper_sql_expressions` for examples on declaratively -mapping attributes to SQL expressions. - -.. _declarative_table_args: - -Table Configuration -=================== - -Table arguments other than the name, metadata, and mapped Column -arguments are specified using the ``__table_args__`` class attribute. -This attribute accommodates both positional as well as keyword -arguments that are normally sent to the -:class:`~sqlalchemy.schema.Table` constructor. -The attribute can be specified in one of two forms. One is as a -dictionary:: - - class MyClass(Base): - __tablename__ = 'sometable' - __table_args__ = {'mysql_engine':'InnoDB'} - -The other, a tuple, where each argument is positional -(usually constraints):: - - class MyClass(Base): - __tablename__ = 'sometable' - __table_args__ = ( - ForeignKeyConstraint(['id'], ['remote_table.id']), - UniqueConstraint('foo'), - ) - -Keyword arguments can be specified with the above form by -specifying the last argument as a dictionary:: - - class MyClass(Base): - __tablename__ = 'sometable' - __table_args__ = ( - ForeignKeyConstraint(['id'], ['remote_table.id']), - UniqueConstraint('foo'), - {'autoload':True} - ) - -Using a Hybrid Approach with __table__ -======================================= - -As an alternative to ``__tablename__``, a direct -:class:`~sqlalchemy.schema.Table` construct may be used. The -:class:`~sqlalchemy.schema.Column` objects, which in this case require -their names, will be added to the mapping just like a regular mapping -to a table:: - - class MyClass(Base): - __table__ = Table('my_table', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) - -``__table__`` provides a more focused point of control for establishing -table metadata, while still getting most of the benefits of using declarative. -An application that uses reflection might want to load table metadata elsewhere -and pass it to declarative classes:: - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base() - Base.metadata.reflect(some_engine) - - class User(Base): - __table__ = metadata.tables['user'] - - class Address(Base): - __table__ = metadata.tables['address'] - -Some configuration schemes may find it more appropriate to use ``__table__``, -such as those which already take advantage of the data-driven nature of -:class:`.Table` to customize and/or automate schema definition. - -Note that when the ``__table__`` approach is used, the object is immediately -usable as a plain :class:`.Table` within the class declaration body itself, -as a Python class is only another syntactical block. Below this is illustrated -by using the ``id`` column in the ``primaryjoin`` condition of a -:func:`.relationship`:: - - class MyClass(Base): - __table__ = Table('my_table', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) - - widgets = relationship(Widget, - primaryjoin=Widget.myclass_id==__table__.c.id) - -Similarly, mapped attributes which refer to ``__table__`` can be placed inline, -as below where we assign the ``name`` column to the attribute ``_name``, -generating a synonym for ``name``:: - - from sqlalchemy.ext.declarative import synonym_for - - class MyClass(Base): - __table__ = Table('my_table', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) - - _name = __table__.c.name - - @synonym_for("_name") - def name(self): - return "Name: %s" % _name - -Using Reflection with Declarative -================================= - -It's easy to set up a :class:`.Table` that uses ``autoload=True`` -in conjunction with a mapped class:: - - class MyClass(Base): - __table__ = Table('mytable', Base.metadata, - autoload=True, autoload_with=some_engine) - -However, one improvement that can be made here is to not -require the :class:`.Engine` to be available when classes are -being first declared. To achieve this, use the -:class:`.DeferredReflection` mixin, which sets up mappings -only after a special ``prepare(engine)`` step is called:: - - from sqlalchemy.ext.declarative import declarative_base, DeferredReflection - - Base = declarative_base(cls=DeferredReflection) - - class Foo(Base): - __tablename__ = 'foo' - bars = relationship("Bar") - - class Bar(Base): - __tablename__ = 'bar' - - # illustrate overriding of "bar.foo_id" to have - # a foreign key constraint otherwise not - # reflected, such as when using MySQL - foo_id = Column(Integer, ForeignKey('foo.id')) - - Base.prepare(e) - -.. versionadded:: 0.8 - Added :class:`.DeferredReflection`. - -Mapper Configuration -==================== - -Declarative makes use of the :func:`~.orm.mapper` function internally -when it creates the mapping to the declared table. The options -for :func:`~.orm.mapper` are passed directly through via the -``__mapper_args__`` class attribute. As always, arguments which reference -locally mapped columns can reference them directly from within the -class declaration:: - - from datetime import datetime - - class Widget(Base): - __tablename__ = 'widgets' - - id = Column(Integer, primary_key=True) - timestamp = Column(DateTime, nullable=False) - - __mapper_args__ = { - 'version_id_col': timestamp, - 'version_id_generator': lambda v:datetime.now() - } - -.. _declarative_inheritance: - -Inheritance Configuration -========================= - -Declarative supports all three forms of inheritance as intuitively -as possible. The ``inherits`` mapper keyword argument is not needed -as declarative will determine this from the class itself. The various -"polymorphic" keyword arguments are specified using ``__mapper_args__``. - -Joined Table Inheritance -~~~~~~~~~~~~~~~~~~~~~~~~ - -Joined table inheritance is defined as a subclass that defines its own -table:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - id = Column(Integer, ForeignKey('people.id'), primary_key=True) - primary_language = Column(String(50)) - -Note that above, the ``Engineer.id`` attribute, since it shares the -same attribute name as the ``Person.id`` attribute, will in fact -represent the ``people.id`` and ``engineers.id`` columns together, -with the "Engineer.id" column taking precedence if queried directly. -To provide the ``Engineer`` class with an attribute that represents -only the ``engineers.id`` column, give it a different attribute name:: - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - engineer_id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) - primary_language = Column(String(50)) - - -.. versionchanged:: 0.7 joined table inheritance favors the subclass - column over that of the superclass, such as querying above - for ``Engineer.id``. Prior to 0.7 this was the reverse. - -.. _declarative_single_table: - -Single Table Inheritance -~~~~~~~~~~~~~~~~~~~~~~~~ - -Single table inheritance is defined as a subclass that does not have -its own table; you just leave out the ``__table__`` and ``__tablename__`` -attributes:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - primary_language = Column(String(50)) - -When the above mappers are configured, the ``Person`` class is mapped -to the ``people`` table *before* the ``primary_language`` column is -defined, and this column will not be included in its own mapping. -When ``Engineer`` then defines the ``primary_language`` column, the -column is added to the ``people`` table so that it is included in the -mapping for ``Engineer`` and is also part of the table's full set of -columns. Columns which are not mapped to ``Person`` are also excluded -from any other single or joined inheriting classes using the -``exclude_properties`` mapper argument. Below, ``Manager`` will have -all the attributes of ``Person`` and ``Manager`` but *not* the -``primary_language`` attribute of ``Engineer``:: - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - golf_swing = Column(String(50)) - -The attribute exclusion logic is provided by the -``exclude_properties`` mapper argument, and declarative's default -behavior can be disabled by passing an explicit ``exclude_properties`` -collection (empty or otherwise) to the ``__mapper_args__``. - -Resolving Column Conflicts -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Note above that the ``primary_language`` and ``golf_swing`` columns -are "moved up" to be applied to ``Person.__table__``, as a result of their -declaration on a subclass that has no table of its own. A tricky case -comes up when two subclasses want to specify *the same* column, as below:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - start_date = Column(DateTime) - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - start_date = Column(DateTime) - -Above, the ``start_date`` column declared on both ``Engineer`` and ``Manager`` -will result in an error:: - - sqlalchemy.exc.ArgumentError: Column 'start_date' on class - <class '__main__.Manager'> conflicts with existing - column 'people.start_date' - -In a situation like this, Declarative can't be sure -of the intent, especially if the ``start_date`` columns had, for example, -different types. A situation like this can be resolved by using -:class:`.declared_attr` to define the :class:`.Column` conditionally, taking -care to return the **existing column** via the parent ``__table__`` if it -already exists:: - - from sqlalchemy.ext.declarative import declared_attr - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - - @declared_attr - def start_date(cls): - "Start date column, if not present already." - return Person.__table__.c.get('start_date', Column(DateTime)) - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - - @declared_attr - def start_date(cls): - "Start date column, if not present already." - return Person.__table__.c.get('start_date', Column(DateTime)) - -Above, when ``Manager`` is mapped, the ``start_date`` column is -already present on the ``Person`` class. Declarative lets us return -that :class:`.Column` as a result in this case, where it knows to skip -re-assigning the same column. If the mapping is mis-configured such -that the ``start_date`` column is accidentally re-assigned to a -different table (such as, if we changed ``Manager`` to be joined -inheritance without fixing ``start_date``), an error is raised which -indicates an existing :class:`.Column` is trying to be re-assigned to -a different owning :class:`.Table`. - -.. versionadded:: 0.8 :class:`.declared_attr` can be used on a non-mixin - class, and the returned :class:`.Column` or other mapped attribute - will be applied to the mapping as any other attribute. Previously, - the resulting attribute would be ignored, and also result in a warning - being emitted when a subclass was created. - -.. versionadded:: 0.8 :class:`.declared_attr`, when used either with a - mixin or non-mixin declarative class, can return an existing - :class:`.Column` already assigned to the parent :class:`.Table`, - to indicate that the re-assignment of the :class:`.Column` should be - skipped, however should still be mapped on the target class, - in order to resolve duplicate column conflicts. - -The same concept can be used with mixin classes (see -:ref:`declarative_mixins`):: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class HasStartDate(object): - @declared_attr - def start_date(cls): - return cls.__table__.c.get('start_date', Column(DateTime)) - - class Engineer(HasStartDate, Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - - class Manager(HasStartDate, Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} - -The above mixin checks the local ``__table__`` attribute for the column. -Because we're using single table inheritance, we're sure that in this case, -``cls.__table__`` refers to ``People.__table__``. If we were mixing joined- -and single-table inheritance, we might want our mixin to check more carefully -if ``cls.__table__`` is really the :class:`.Table` we're looking for. - -Concrete Table Inheritance -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Concrete is defined as a subclass which has its own table and sets the -``concrete`` keyword argument to ``True``:: - - class Person(Base): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'concrete':True} - id = Column(Integer, primary_key=True) - primary_language = Column(String(50)) - name = Column(String(50)) - -Usage of an abstract base class is a little less straightforward as it -requires usage of :func:`~sqlalchemy.orm.util.polymorphic_union`, -which needs to be created with the :class:`.Table` objects -before the class is built:: - - engineers = Table('engineers', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('primary_language', String(50)) - ) - managers = Table('managers', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('golf_swing', String(50)) - ) - - punion = polymorphic_union({ - 'engineer':engineers, - 'manager':managers - }, 'type', 'punion') - - class Person(Base): - __table__ = punion - __mapper_args__ = {'polymorphic_on':punion.c.type} - - class Engineer(Person): - __table__ = engineers - __mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True} - - class Manager(Person): - __table__ = managers - __mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True} - -.. _declarative_concrete_helpers: - -Using the Concrete Helpers -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Helper classes provides a simpler pattern for concrete inheritance. -With these objects, the ``__declare_first__`` helper is used to configure the -"polymorphic" loader for the mapper after all subclasses have been declared. - -.. versionadded:: 0.7.3 - -An abstract base can be declared using the -:class:`.AbstractConcreteBase` class:: - - from sqlalchemy.ext.declarative import AbstractConcreteBase - - class Employee(AbstractConcreteBase, Base): - pass - -To have a concrete ``employee`` table, use :class:`.ConcreteBase` instead:: - - from sqlalchemy.ext.declarative import ConcreteBase - - class Employee(ConcreteBase, Base): - __tablename__ = 'employee' - employee_id = Column(Integer, primary_key=True) - name = Column(String(50)) - __mapper_args__ = { - 'polymorphic_identity':'employee', - 'concrete':True} - - -Either ``Employee`` base can be used in the normal fashion:: - - class Manager(Employee): - __tablename__ = 'manager' - employee_id = Column(Integer, primary_key=True) - name = Column(String(50)) - manager_data = Column(String(40)) - __mapper_args__ = { - 'polymorphic_identity':'manager', - 'concrete':True} - - class Engineer(Employee): - __tablename__ = 'engineer' - employee_id = Column(Integer, primary_key=True) - name = Column(String(50)) - engineer_info = Column(String(40)) - __mapper_args__ = {'polymorphic_identity':'engineer', - 'concrete':True} - - -The :class:`.AbstractConcreteBase` class is itself mapped, and can be -used as a target of relationships:: - - class Company(Base): - __tablename__ = 'company' - - id = Column(Integer, primary_key=True) - employees = relationship("Employee", - primaryjoin="Company.id == Employee.company_id") - - -.. versionchanged:: 0.9.3 Support for use of :class:`.AbstractConcreteBase` - as the target of a :func:`.relationship` has been improved. - -It can also be queried directly:: - - for employee in session.query(Employee).filter(Employee.name == 'qbert'): - print(employee) - - -.. _declarative_mixins: - -Mixin and Custom Base Classes -============================== - -A common need when using :mod:`~sqlalchemy.ext.declarative` is to -share some functionality, such as a set of common columns, some common -table options, or other mapped properties, across many -classes. The standard Python idioms for this is to have the classes -inherit from a base which includes these common features. - -When using :mod:`~sqlalchemy.ext.declarative`, this idiom is allowed -via the usage of a custom declarative base class, as well as a "mixin" class -which is inherited from in addition to the primary base. Declarative -includes several helper features to make this work in terms of how -mappings are declared. An example of some commonly mixed-in -idioms is below:: - - from sqlalchemy.ext.declarative import declared_attr - - class MyMixin(object): - - @declared_attr - def __tablename__(cls): - return cls.__name__.lower() - - __table_args__ = {'mysql_engine': 'InnoDB'} - __mapper_args__= {'always_refresh': True} - - id = Column(Integer, primary_key=True) - - class MyModel(MyMixin, Base): - name = Column(String(1000)) - -Where above, the class ``MyModel`` will contain an "id" column -as the primary key, a ``__tablename__`` attribute that derives -from the name of the class itself, as well as ``__table_args__`` -and ``__mapper_args__`` defined by the ``MyMixin`` mixin class. - -There's no fixed convention over whether ``MyMixin`` precedes -``Base`` or not. Normal Python method resolution rules apply, and -the above example would work just as well with:: - - class MyModel(Base, MyMixin): - name = Column(String(1000)) - -This works because ``Base`` here doesn't define any of the -variables that ``MyMixin`` defines, i.e. ``__tablename__``, -``__table_args__``, ``id``, etc. If the ``Base`` did define -an attribute of the same name, the class placed first in the -inherits list would determine which attribute is used on the -newly defined class. - -Augmenting the Base -~~~~~~~~~~~~~~~~~~~ - -In addition to using a pure mixin, most of the techniques in this -section can also be applied to the base class itself, for patterns that -should apply to all classes derived from a particular base. This is achieved -using the ``cls`` argument of the :func:`.declarative_base` function:: - - from sqlalchemy.ext.declarative import declared_attr - - class Base(object): - @declared_attr - def __tablename__(cls): - return cls.__name__.lower() - - __table_args__ = {'mysql_engine': 'InnoDB'} - - id = Column(Integer, primary_key=True) - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base(cls=Base) - - class MyModel(Base): - name = Column(String(1000)) - -Where above, ``MyModel`` and all other classes that derive from ``Base`` will -have a table name derived from the class name, an ``id`` primary key column, -as well as the "InnoDB" engine for MySQL. - -Mixing in Columns -~~~~~~~~~~~~~~~~~ - -The most basic way to specify a column on a mixin is by simple -declaration:: - - class TimestampMixin(object): - created_at = Column(DateTime, default=func.now()) - - class MyModel(TimestampMixin, Base): - __tablename__ = 'test' - - id = Column(Integer, primary_key=True) - name = Column(String(1000)) - -Where above, all declarative classes that include ``TimestampMixin`` -will also have a column ``created_at`` that applies a timestamp to -all row insertions. - -Those familiar with the SQLAlchemy expression language know that -the object identity of clause elements defines their role in a schema. -Two ``Table`` objects ``a`` and ``b`` may both have a column called -``id``, but the way these are differentiated is that ``a.c.id`` -and ``b.c.id`` are two distinct Python objects, referencing their -parent tables ``a`` and ``b`` respectively. - -In the case of the mixin column, it seems that only one -:class:`.Column` object is explicitly created, yet the ultimate -``created_at`` column above must exist as a distinct Python object -for each separate destination class. To accomplish this, the declarative -extension creates a **copy** of each :class:`.Column` object encountered on -a class that is detected as a mixin. - -This copy mechanism is limited to simple columns that have no foreign -keys, as a :class:`.ForeignKey` itself contains references to columns -which can't be properly recreated at this level. For columns that -have foreign keys, as well as for the variety of mapper-level constructs -that require destination-explicit context, the -:class:`~.declared_attr` decorator is provided so that -patterns common to many classes can be defined as callables:: - - from sqlalchemy.ext.declarative import declared_attr - - class ReferenceAddressMixin(object): - @declared_attr - def address_id(cls): - return Column(Integer, ForeignKey('address.id')) - - class User(ReferenceAddressMixin, Base): - __tablename__ = 'user' - id = Column(Integer, primary_key=True) - -Where above, the ``address_id`` class-level callable is executed at the -point at which the ``User`` class is constructed, and the declarative -extension can use the resulting :class:`.Column` object as returned by -the method without the need to copy it. - -.. versionchanged:: > 0.6.5 - Rename 0.6.5 ``sqlalchemy.util.classproperty`` - into :class:`~.declared_attr`. - -Columns generated by :class:`~.declared_attr` can also be -referenced by ``__mapper_args__`` to a limited degree, currently -by ``polymorphic_on`` and ``version_id_col``; the declarative extension -will resolve them at class construction time:: - - class MyMixin: - @declared_attr - def type_(cls): - return Column(String(50)) - - __mapper_args__= {'polymorphic_on':type_} - - class MyModel(MyMixin, Base): - __tablename__='test' - id = Column(Integer, primary_key=True) - - -Mixing in Relationships -~~~~~~~~~~~~~~~~~~~~~~~ - -Relationships created by :func:`~sqlalchemy.orm.relationship` are provided -with declarative mixin classes exclusively using the -:class:`.declared_attr` approach, eliminating any ambiguity -which could arise when copying a relationship and its possibly column-bound -contents. Below is an example which combines a foreign key column and a -relationship so that two classes ``Foo`` and ``Bar`` can both be configured to -reference a common target class via many-to-one:: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship("Target") - - class Foo(RefTargetMixin, Base): - __tablename__ = 'foo' - id = Column(Integer, primary_key=True) - - class Bar(RefTargetMixin, Base): - __tablename__ = 'bar' - id = Column(Integer, primary_key=True) - - class Target(Base): - __tablename__ = 'target' - id = Column(Integer, primary_key=True) - - -Using Advanced Relationship Arguments (e.g. ``primaryjoin``, etc.) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:func:`~sqlalchemy.orm.relationship` definitions which require explicit -primaryjoin, order_by etc. expressions should in all but the most -simplistic cases use **late bound** forms -for these arguments, meaning, using either the string form or a lambda. -The reason for this is that the related :class:`.Column` objects which are to -be configured using ``@declared_attr`` are not available to another -``@declared_attr`` attribute; while the methods will work and return new -:class:`.Column` objects, those are not the :class:`.Column` objects that -Declarative will be using as it calls the methods on its own, thus using -*different* :class:`.Column` objects. - -The canonical example is the primaryjoin condition that depends upon -another mixed-in column:: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship(Target, - primaryjoin=Target.id==cls.target_id # this is *incorrect* - ) - -Mapping a class using the above mixin, we will get an error like:: - - sqlalchemy.exc.InvalidRequestError: this ForeignKey's parent column is not - yet associated with a Table. - -This is because the ``target_id`` :class:`.Column` we've called upon in our -``target()`` method is not the same :class:`.Column` that declarative is -actually going to map to our table. - -The condition above is resolved using a lambda:: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship(Target, - primaryjoin=lambda: Target.id==cls.target_id - ) - -or alternatively, the string form (which ultimately generates a lambda):: - - class RefTargetMixin(object): - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) - - @declared_attr - def target(cls): - return relationship("Target", - primaryjoin="Target.id==%s.target_id" % cls.__name__ - ) - -Mixing in deferred(), column_property(), and other MapperProperty classes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Like :func:`~sqlalchemy.orm.relationship`, all -:class:`~sqlalchemy.orm.interfaces.MapperProperty` subclasses such as -:func:`~sqlalchemy.orm.deferred`, :func:`~sqlalchemy.orm.column_property`, -etc. ultimately involve references to columns, and therefore, when -used with declarative mixins, have the :class:`.declared_attr` -requirement so that no reliance on copying is needed:: - - class SomethingMixin(object): - - @declared_attr - def dprop(cls): - return deferred(Column(Integer)) - - class Something(SomethingMixin, Base): - __tablename__ = "something" - -The :func:`.column_property` or other construct may refer -to other columns from the mixin. These are copied ahead of time before -the :class:`.declared_attr` is invoked:: - - class SomethingMixin(object): - x = Column(Integer) - - y = Column(Integer) - - @declared_attr - def x_plus_y(cls): - return column_property(cls.x + cls.y) - - -.. versionchanged:: 1.0.0 mixin columns are copied to the final mapped class - so that :class:`.declared_attr` methods can access the actual column - that will be mapped. - -Mixing in Association Proxy and Other Attributes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Mixins can specify user-defined attributes as well as other extension -units such as :func:`.association_proxy`. The usage of -:class:`.declared_attr` is required in those cases where the attribute must -be tailored specifically to the target subclass. An example is when -constructing multiple :func:`.association_proxy` attributes which each -target a different type of child object. Below is an -:func:`.association_proxy` / mixin example which provides a scalar list of -string values to an implementing class:: - - from sqlalchemy import Column, Integer, ForeignKey, String - from sqlalchemy.orm import relationship - from sqlalchemy.ext.associationproxy import association_proxy - from sqlalchemy.ext.declarative import declarative_base, declared_attr - - Base = declarative_base() - - class HasStringCollection(object): - @declared_attr - def _strings(cls): - class StringAttribute(Base): - __tablename__ = cls.string_table_name - id = Column(Integer, primary_key=True) - value = Column(String(50), nullable=False) - parent_id = Column(Integer, - ForeignKey('%s.id' % cls.__tablename__), - nullable=False) - def __init__(self, value): - self.value = value - - return relationship(StringAttribute) - - @declared_attr - def strings(cls): - return association_proxy('_strings', 'value') - - class TypeA(HasStringCollection, Base): - __tablename__ = 'type_a' - string_table_name = 'type_a_strings' - id = Column(Integer(), primary_key=True) - - class TypeB(HasStringCollection, Base): - __tablename__ = 'type_b' - string_table_name = 'type_b_strings' - id = Column(Integer(), primary_key=True) - -Above, the ``HasStringCollection`` mixin produces a :func:`.relationship` -which refers to a newly generated class called ``StringAttribute``. The -``StringAttribute`` class is generated with its own :class:`.Table` -definition which is local to the parent class making usage of the -``HasStringCollection`` mixin. It also produces an :func:`.association_proxy` -object which proxies references to the ``strings`` attribute onto the ``value`` -attribute of each ``StringAttribute`` instance. - -``TypeA`` or ``TypeB`` can be instantiated given the constructor -argument ``strings``, a list of strings:: - - ta = TypeA(strings=['foo', 'bar']) - tb = TypeA(strings=['bat', 'bar']) - -This list will generate a collection -of ``StringAttribute`` objects, which are persisted into a table that's -local to either the ``type_a_strings`` or ``type_b_strings`` table:: - - >>> print ta._strings - [<__main__.StringAttribute object at 0x10151cd90>, - <__main__.StringAttribute object at 0x10151ce10>] - -When constructing the :func:`.association_proxy`, the -:class:`.declared_attr` decorator must be used so that a distinct -:func:`.association_proxy` object is created for each of the ``TypeA`` -and ``TypeB`` classes. - -.. versionadded:: 0.8 :class:`.declared_attr` is usable with non-mapped - attributes, including user-defined attributes as well as - :func:`.association_proxy`. - - -Controlling table inheritance with mixins -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The ``__tablename__`` attribute may be used to provide a function that -will determine the name of the table used for each class in an inheritance -hierarchy, as well as whether a class has its own distinct table. - -This is achieved using the :class:`.declared_attr` indicator in conjunction -with a method named ``__tablename__()``. Declarative will always -invoke :class:`.declared_attr` for the special names -``__tablename__``, ``__mapper_args__`` and ``__table_args__`` -function **for each mapped class in the hierarchy**. The function therefore -needs to expect to receive each class individually and to provide the -correct answer for each. - -For example, to create a mixin that gives every class a simple table -name based on class name:: - - from sqlalchemy.ext.declarative import declared_attr - - class Tablename: - @declared_attr - def __tablename__(cls): - return cls.__name__.lower() - - class Person(Tablename, Base): - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = None - __mapper_args__ = {'polymorphic_identity': 'engineer'} - primary_language = Column(String(50)) - -Alternatively, we can modify our ``__tablename__`` function to return -``None`` for subclasses, using :func:`.has_inherited_table`. This has -the effect of those subclasses being mapped with single table inheritance -agaisnt the parent:: - - from sqlalchemy.ext.declarative import declared_attr - from sqlalchemy.ext.declarative import has_inherited_table - - class Tablename(object): - @declared_attr - def __tablename__(cls): - if has_inherited_table(cls): - return None - return cls.__name__.lower() - - class Person(Tablename, Base): - id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - primary_language = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'engineer'} - -.. _mixin_inheritance_columns: - -Mixing in Columns in Inheritance Scenarios -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In constrast to how ``__tablename__`` and other special names are handled when -used with :class:`.declared_attr`, when we mix in columns and properties (e.g. -relationships, column properties, etc.), the function is -invoked for the **base class only** in the hierarchy. Below, only the -``Person`` class will receive a column -called ``id``; the mapping will fail on ``Engineer``, which is not given -a primary key:: - - class HasId(object): - @declared_attr - def id(cls): - return Column('id', Integer, primary_key=True) - - class Person(HasId, Base): - __tablename__ = 'person' - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = 'engineer' - primary_language = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'engineer'} - -It is usually the case in joined-table inheritance that we want distinctly -named columns on each subclass. However in this case, we may want to have -an ``id`` column on every table, and have them refer to each other via -foreign key. We can achieve this as a mixin by using the -:attr:`.declared_attr.cascading` modifier, which indicates that the -function should be invoked **for each class in the hierarchy**, just like -it does for ``__tablename__``:: - - class HasId(object): - @declared_attr.cascading - def id(cls): - if has_inherited_table(cls): - return Column('id', - Integer, - ForeignKey('person.id'), primary_key=True) - else: - return Column('id', Integer, primary_key=True) - - class Person(HasId, Base): - __tablename__ = 'person' - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} - - class Engineer(Person): - __tablename__ = 'engineer' - primary_language = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'engineer'} - - -.. versionadded:: 1.0.0 added :attr:`.declared_attr.cascading`. - -Combining Table/Mapper Arguments from Multiple Mixins -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In the case of ``__table_args__`` or ``__mapper_args__`` -specified with declarative mixins, you may want to combine -some parameters from several mixins with those you wish to -define on the class iteself. The -:class:`.declared_attr` decorator can be used -here to create user-defined collation routines that pull -from multiple collections:: - - from sqlalchemy.ext.declarative import declared_attr - - class MySQLSettings(object): - __table_args__ = {'mysql_engine':'InnoDB'} - - class MyOtherMixin(object): - __table_args__ = {'info':'foo'} - - class MyModel(MySQLSettings, MyOtherMixin, Base): - __tablename__='my_model' - - @declared_attr - def __table_args__(cls): - args = dict() - args.update(MySQLSettings.__table_args__) - args.update(MyOtherMixin.__table_args__) - return args - - id = Column(Integer, primary_key=True) - -Creating Indexes with Mixins -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To define a named, potentially multicolumn :class:`.Index` that applies to all -tables derived from a mixin, use the "inline" form of :class:`.Index` and -establish it as part of ``__table_args__``:: - - class MyMixin(object): - a = Column(Integer) - b = Column(Integer) - - @declared_attr - def __table_args__(cls): - return (Index('test_idx_%s' % cls.__tablename__, 'a', 'b'),) - - class MyModel(MyMixin, Base): - __tablename__ = 'atable' - c = Column(Integer,primary_key=True) - -Special Directives -================== - -``__declare_last__()`` -~~~~~~~~~~~~~~~~~~~~~~ - -The ``__declare_last__()`` hook allows definition of -a class level function that is automatically called by the -:meth:`.MapperEvents.after_configured` event, which occurs after mappings are -assumed to be completed and the 'configure' step has finished:: - - class MyClass(Base): - @classmethod - def __declare_last__(cls): - "" - # do something with mappings - -.. versionadded:: 0.7.3 - -``__declare_first__()`` -~~~~~~~~~~~~~~~~~~~~~~~ - -Like ``__declare_last__()``, but is called at the beginning of mapper -configuration via the :meth:`.MapperEvents.before_configured` event:: - - class MyClass(Base): - @classmethod - def __declare_first__(cls): - "" - # do something before mappings are configured - -.. versionadded:: 0.9.3 - -.. _declarative_abstract: - -``__abstract__`` -~~~~~~~~~~~~~~~~~~~ - -``__abstract__`` causes declarative to skip the production -of a table or mapper for the class entirely. A class can be added within a -hierarchy in the same way as mixin (see :ref:`declarative_mixins`), allowing -subclasses to extend just from the special class:: - - class SomeAbstractBase(Base): - __abstract__ = True - - def some_helpful_method(self): - "" - - @declared_attr - def __mapper_args__(cls): - return {"helpful mapper arguments":True} - - class MyMappedClass(SomeAbstractBase): - "" - -One possible use of ``__abstract__`` is to use a distinct -:class:`.MetaData` for different bases:: - - Base = declarative_base() - - class DefaultBase(Base): - __abstract__ = True - metadata = MetaData() - - class OtherBase(Base): - __abstract__ = True - metadata = MetaData() - -Above, classes which inherit from ``DefaultBase`` will use one -:class:`.MetaData` as the registry of tables, and those which inherit from -``OtherBase`` will use a different one. The tables themselves can then be -created perhaps within distinct databases:: - - DefaultBase.metadata.create_all(some_engine) - OtherBase.metadata_create_all(some_other_engine) - -.. versionadded:: 0.7.3 - -Class Constructor -================= - -As a convenience feature, the :func:`declarative_base` sets a default -constructor on classes which takes keyword arguments, and assigns them -to the named attributes:: - - e = Engineer(primary_language='python') - -Sessions -======== - -Note that ``declarative`` does nothing special with sessions, and is -only intended as an easier way to configure mappers and -:class:`~sqlalchemy.schema.Table` objects. A typical application -setup using :class:`~sqlalchemy.orm.scoping.scoped_session` might look like:: - - engine = create_engine('postgresql://scott:tiger@localhost/test') - Session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) - Base = declarative_base() - -Mapped instances then make usage of -:class:`~sqlalchemy.orm.session.Session` in the usual way. - -""" - from .api import declarative_base, synonym_for, comparable_using, \ instrument_declarative, ConcreteBase, AbstractConcreteBase, \ DeclarativeMeta, DeferredReflection, has_inherited_table,\ diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index 291608b6c..6735abf4c 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -278,7 +278,7 @@ class _MapperConfig(object): elif not isinstance(value, (Column, MapperProperty)): # using @declared_attr for some object that # isn't Column/MapperProperty; remove from the dict_ - # and place the evaulated value onto the class. + # and place the evaluated value onto the class. if not k.startswith('__'): dict_.pop(k) setattr(cls, k, value) diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 3ef63a5ae..d2a09d823 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -71,6 +71,8 @@ class _MultipleClassMarker(object): """ + __slots__ = 'on_remove', 'contents', '__weakref__' + def __init__(self, classes, on_remove=None): self.on_remove = on_remove self.contents = set([ @@ -127,6 +129,8 @@ class _ModuleMarker(object): """ + __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__' + def __init__(self, name, parent): self.parent = parent self.name = name @@ -172,6 +176,8 @@ class _ModuleMarker(object): class _ModNS(object): + __slots__ = '__parent', + def __init__(self, parent): self.__parent = parent @@ -193,6 +199,8 @@ class _ModNS(object): class _GetColumns(object): + __slots__ = 'cls', + def __init__(self, cls): self.cls = cls @@ -221,6 +229,8 @@ inspection._inspects(_GetColumns)( class _GetTable(object): + __slots__ = 'key', 'metadata' + def __init__(self, key, metadata): self.key = key self.metadata = metadata diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index e2739d1de..f72de6099 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -145,7 +145,7 @@ usage of the absolute value function:: return func.abs(cls.length) / 2 Above the Python function ``abs()`` is used for instance-level -operations, the SQL function ``ABS()`` is used via the :attr:`.func` +operations, the SQL function ``ABS()`` is used via the :data:`.func` object for class-level expressions:: >>> i1.radius @@ -660,7 +660,7 @@ HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY') """ -class hybrid_method(interfaces.InspectionAttr): +class hybrid_method(interfaces.InspectionAttrInfo): """A decorator which allows definition of a Python object method with both instance-level and class-level behavior. @@ -703,7 +703,7 @@ class hybrid_method(interfaces.InspectionAttr): return self -class hybrid_property(interfaces.InspectionAttr): +class hybrid_property(interfaces.InspectionAttrInfo): """A decorator which allows definition of a Python descriptor with both instance-level and class-level behavior. diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 2b4c3ec75..e9c8c511a 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -345,18 +345,16 @@ class Event(object): .. versionadded:: 0.9.0 - """ - - impl = None - """The :class:`.AttributeImpl` which is the current event initiator. - """ + :var impl: The :class:`.AttributeImpl` which is the current event + initiator. - op = None - """The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE` or :attr:`.OP_REPLACE`, - indicating the source operation. + :var op: The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE` or + :attr:`.OP_REPLACE`, indicating the source operation. """ + __slots__ = 'impl', 'op', 'parent_token' + def __init__(self, attribute_impl, op): self.impl = attribute_impl self.op = op @@ -455,6 +453,11 @@ class AttributeImpl(object): self.expire_missing = expire_missing + __slots__ = ( + 'class_', 'key', 'callable_', 'dispatch', 'trackparent', + 'parent_token', 'send_modified_events', 'is_equal', 'expire_missing' + ) + def __str__(self): return "%s.%s" % (self.class_.__name__, self.key) @@ -654,6 +657,23 @@ class ScalarAttributeImpl(AttributeImpl): supports_population = True collection = False + __slots__ = '_replace_token', '_append_token', '_remove_token' + + def __init__(self, *arg, **kw): + super(ScalarAttributeImpl, self).__init__(*arg, **kw) + self._replace_token = self._append_token = None + self._remove_token = None + + def _init_append_token(self): + self._replace_token = self._append_token = Event(self, OP_REPLACE) + return self._replace_token + + _init_append_or_replace_token = _init_append_token + + def _init_remove_token(self): + self._remove_token = Event(self, OP_REMOVE) + return self._remove_token + def delete(self, state, dict_): # TODO: catch key errors, convert to attributeerror? @@ -692,27 +712,18 @@ class ScalarAttributeImpl(AttributeImpl): state._modified_event(dict_, self, old) dict_[self.key] = value - @util.memoized_property - def _replace_token(self): - return Event(self, OP_REPLACE) - - @util.memoized_property - def _append_token(self): - return Event(self, OP_REPLACE) - - @util.memoized_property - def _remove_token(self): - return Event(self, OP_REMOVE) - def fire_replace_event(self, state, dict_, value, previous, initiator): for fn in self.dispatch.set: value = fn( - state, value, previous, initiator or self._replace_token) + state, value, previous, + initiator or self._replace_token or + self._init_append_or_replace_token()) return value def fire_remove_event(self, state, dict_, value, initiator): for fn in self.dispatch.remove: - fn(state, value, initiator or self._remove_token) + fn(state, value, + initiator or self._remove_token or self._init_remove_token()) @property def type(self): @@ -732,9 +743,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): supports_population = True collection = False + __slots__ = () + def delete(self, state, dict_): old = self.get(state, dict_) - self.fire_remove_event(state, dict_, old, self._remove_token) + self.fire_remove_event( + state, dict_, old, + self._remove_token or self._init_remove_token()) del dict_[self.key] def get_history(self, state, dict_, passive=PASSIVE_OFF): @@ -807,7 +822,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): self.sethasparent(instance_state(value), state, False) for fn in self.dispatch.remove: - fn(state, value, initiator or self._remove_token) + fn(state, value, initiator or + self._remove_token or self._init_remove_token()) state._modified_event(dict_, self, value) @@ -819,7 +835,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for fn in self.dispatch.set: value = fn( - state, value, previous, initiator or self._replace_token) + state, value, previous, initiator or + self._replace_token or self._init_append_or_replace_token()) state._modified_event(dict_, self, previous) @@ -846,6 +863,8 @@ class CollectionAttributeImpl(AttributeImpl): supports_population = True collection = True + __slots__ = 'copy', 'collection_factory', '_append_token', '_remove_token' + def __init__(self, class_, key, callable_, dispatch, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): @@ -862,6 +881,8 @@ class CollectionAttributeImpl(AttributeImpl): copy_function = self.__copy self.copy = copy_function self.collection_factory = typecallable + self._append_token = None + self._remove_token = None if getattr(self.collection_factory, "_sa_linker", None): @@ -873,6 +894,14 @@ class CollectionAttributeImpl(AttributeImpl): def unlink(target, collection, collection_adapter): collection._sa_linker(None) + def _init_append_token(self): + self._append_token = Event(self, OP_APPEND) + return self._append_token + + def _init_remove_token(self): + self._remove_token = Event(self, OP_REMOVE) + return self._remove_token + def __copy(self, item): return [y for y in collections.collection_adapter(item)] @@ -915,17 +944,11 @@ class CollectionAttributeImpl(AttributeImpl): return [(instance_state(o), o) for o in current] - @util.memoized_property - def _append_token(self): - return Event(self, OP_APPEND) - - @util.memoized_property - def _remove_token(self): - return Event(self, OP_REMOVE) - def fire_append_event(self, state, dict_, value, initiator): for fn in self.dispatch.append: - value = fn(state, value, initiator or self._append_token) + value = fn( + state, value, + initiator or self._append_token or self._init_append_token()) state._modified_event(dict_, self, NEVER_SET, True) @@ -942,7 +965,8 @@ class CollectionAttributeImpl(AttributeImpl): self.sethasparent(instance_state(value), state, False) for fn in self.dispatch.remove: - fn(state, value, initiator or self._remove_token) + fn(state, value, + initiator or self._remove_token or self._init_remove_token()) state._modified_event(dict_, self, NEVER_SET, True) @@ -1134,7 +1158,8 @@ def backref_listeners(attribute, key, uselist): impl.pop(old_state, old_dict, state.obj(), - parent_impl._append_token, + parent_impl._append_token or + parent_impl._init_append_token(), passive=PASSIVE_NO_FETCH) if child is not None: diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 3390ceec4..7bfafdc2b 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -437,6 +437,7 @@ class InspectionAttr(object): here intact for forwards-compatibility. """ + __slots__ = () is_selectable = False """Return True if this object is an instance of :class:`.Selectable`.""" @@ -488,6 +489,16 @@ class InspectionAttr(object): """ + +class InspectionAttrInfo(InspectionAttr): + """Adds the ``.info`` attribute to :class:`.InspectionAttr`. + + The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo` + is that the former is compatible as a mixin for classes that specify + ``__slots__``; this is essentially an implementation artifact. + + """ + @util.memoized_property def info(self): """Info dictionary associated with the object, allowing user-defined @@ -501,9 +512,10 @@ class InspectionAttr(object): .. versionadded:: 0.8 Added support for .info to all :class:`.MapperProperty` subclasses. - .. versionchanged:: 1.0.0 :attr:`.InspectionAttr.info` moved - from :class:`.MapperProperty` so that it can apply to a wider - variety of ORM and extension constructs. + .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also + available on extension types via the + :attr:`.InspectionAttrInfo.info` attribute, so that it can apply + to a wider variety of ORM and extension constructs. .. seealso:: @@ -520,3 +532,4 @@ class _MappedAttribute(object): attributes. """ + __slots__ = () diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 19ff71f73..e68ff1bea 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -143,6 +143,7 @@ class CompositeProperty(DescriptorProperty): class. **Deprecated.** Please see :class:`.AttributeEvents`. """ + super(CompositeProperty, self).__init__() self.attrs = attrs self.composite_class = class_ @@ -471,6 +472,7 @@ class ConcreteInheritedProperty(DescriptorProperty): return comparator_callable def __init__(self): + super(ConcreteInheritedProperty, self).__init__() def warn(): raise AttributeError("Concrete %s does not implement " "attribute %r at the instance level. Add " @@ -555,6 +557,7 @@ class SynonymProperty(DescriptorProperty): more complicated attribute-wrapping schemes than synonyms. """ + super(SynonymProperty, self).__init__() self.name = name self.map_column = map_column @@ -684,6 +687,7 @@ class ComparableProperty(DescriptorProperty): .. versionadded:: 1.0.0 """ + super(ComparableProperty, self).__init__() self.descriptor = descriptor self.comparator_factory = comparator_factory self.doc = doc or (descriptor and descriptor.__doc__) or None diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 9ea0dd834..4d888a350 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1479,8 +1479,9 @@ class AttributeEvents(event.Events): @staticmethod def _set_dispatch(cls, dispatch_cls): - event.Events._set_dispatch(cls, dispatch_cls) + dispatch = event.Events._set_dispatch(cls, dispatch_cls) dispatch_cls._active_history = False + return dispatch @classmethod def _accept_with(cls, target): diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index ad2452c1b..299ccaaaf 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -24,7 +24,8 @@ from .. import util from ..sql import operators from .base import (ONETOMANY, MANYTOONE, MANYTOMANY, EXT_CONTINUE, EXT_STOP, NOT_EXTENSION) -from .base import InspectionAttr, _MappedAttribute +from .base import (InspectionAttr, InspectionAttr, + InspectionAttrInfo, _MappedAttribute) import collections # imported later @@ -48,11 +49,8 @@ __all__ = ( ) -class MapperProperty(_MappedAttribute, InspectionAttr): - """Manage the relationship of a ``Mapper`` to a single class - attribute, as well as that attribute as it appears on individual - instances of the class, including attribute instrumentation, - attribute access, loading behavior, and dependency calculations. +class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): + """Represent a particular class attribute mapped by :class:`.Mapper`. The most common occurrences of :class:`.MapperProperty` are the mapped :class:`.Column`, which is represented in a mapping as @@ -63,6 +61,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr): """ + __slots__ = ( + '_configure_started', '_configure_finished', 'parent', 'key', + 'info' + ) + cascade = frozenset() """The set of 'cascade' attribute names. @@ -78,6 +81,32 @@ class MapperProperty(_MappedAttribute, InspectionAttr): """ + def _memoized_attr_info(self): + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`.relationship`, or :func:`.composite` + functions. + + .. versionadded:: 0.8 Added support for .info to all + :class:`.MapperProperty` subclasses. + + .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also + available on extension types via the + :attr:`.InspectionAttrInfo.info` attribute, so that it can apply + to a wider variety of ORM and extension constructs. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + return {} + def setup(self, context, entity, path, adapter, **kwargs): """Called by Query for the purposes of constructing a SQL statement. @@ -139,8 +168,9 @@ class MapperProperty(_MappedAttribute, InspectionAttr): """ - _configure_started = False - _configure_finished = False + def __init__(self): + self._configure_started = False + self._configure_finished = False def init(self): """Called after all mappers are created to assemble @@ -303,6 +333,8 @@ class PropComparator(operators.ColumnOperators): """ + __slots__ = 'prop', 'property', '_parentmapper', '_adapt_to_entity' + def __init__(self, prop, parentmapper, adapt_to_entity=None): self.prop = self.property = prop self._parentmapper = parentmapper @@ -331,7 +363,7 @@ class PropComparator(operators.ColumnOperators): else: return self._adapt_to_entity._adapt_element - @util.memoized_property + @property def info(self): return self.property.info @@ -420,6 +452,8 @@ class StrategizedProperty(MapperProperty): """ + __slots__ = '_strategies', 'strategy' + strategy_wildcard_key = None def _get_context_loader(self, context, path): @@ -483,14 +517,14 @@ class StrategizedProperty(MapperProperty): not mapper.class_manager._attr_has_impl(self.key): self.strategy.init_class_attribute(mapper) - _strategies = collections.defaultdict(dict) + _all_strategies = collections.defaultdict(dict) @classmethod def strategy_for(cls, **kw): def decorate(dec_cls): dec_cls._strategy_keys = [] key = tuple(sorted(kw.items())) - cls._strategies[cls][key] = dec_cls + cls._all_strategies[cls][key] = dec_cls dec_cls._strategy_keys.append(key) return dec_cls return decorate @@ -498,8 +532,8 @@ class StrategizedProperty(MapperProperty): @classmethod def _strategy_lookup(cls, *key): for prop_cls in cls.__mro__: - if prop_cls in cls._strategies: - strategies = cls._strategies[prop_cls] + if prop_cls in cls._all_strategies: + strategies = cls._all_strategies[prop_cls] try: return strategies[key] except KeyError: @@ -558,6 +592,8 @@ class LoaderStrategy(object): """ + __slots__ = 'parent_property', 'is_class_level', 'parent', 'key' + def __init__(self, parent): self.parent_property = parent self.is_class_level = False diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 380afcdc7..fdc787545 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -42,41 +42,45 @@ def instances(query, cursor, context): def filter_fn(row): return tuple(fn(x) for x, fn in zip(row, filter_fns)) - (process, labels) = \ - list(zip(*[ - query_entity.row_processor(query, - context, cursor) - for query_entity in query._entities - ])) - - if not single_entity: - keyed_tuple = util.lightweight_named_tuple('result', labels) - - while True: - context.partials = {} - - if query._yield_per: - fetch = cursor.fetchmany(query._yield_per) - if not fetch: - break - else: - fetch = cursor.fetchall() + try: + (process, labels) = \ + list(zip(*[ + query_entity.row_processor(query, + context, cursor) + for query_entity in query._entities + ])) + + if not single_entity: + keyed_tuple = util.lightweight_named_tuple('result', labels) + + while True: + context.partials = {} + + if query._yield_per: + fetch = cursor.fetchmany(query._yield_per) + if not fetch: + break + else: + fetch = cursor.fetchall() - if single_entity: - proc = process[0] - rows = [proc(row) for row in fetch] - else: - rows = [keyed_tuple([proc(row) for proc in process]) - for row in fetch] + if single_entity: + proc = process[0] + rows = [proc(row) for row in fetch] + else: + rows = [keyed_tuple([proc(row) for proc in process]) + for row in fetch] - if filtered: - rows = util.unique_list(rows, filter_fn) + if filtered: + rows = util.unique_list(rows, filter_fn) - for row in rows: - yield row + for row in rows: + yield row - if not query._yield_per: - break + if not query._yield_per: + break + except Exception as err: + cursor.close() + util.raise_from_cause(err) @util.dependencies("sqlalchemy.orm.query") diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 863dab5cb..eb5abbd4f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -974,6 +974,15 @@ class Mapper(InspectionAttr): self._all_tables = self.inherits._all_tables if self.polymorphic_identity is not None: + if self.polymorphic_identity in self.polymorphic_map: + util.warn( + "Reassigning polymorphic association for identity %r " + "from %r to %r: Check for duplicate use of %r as " + "value for polymorphic_identity." % + (self.polymorphic_identity, + self.polymorphic_map[self.polymorphic_identity], + self, self.polymorphic_identity) + ) self.polymorphic_map[self.polymorphic_identity] = self else: @@ -1248,7 +1257,7 @@ class Mapper(InspectionAttr): self._readonly_props = set( self._columntoproperty[col] for col in self._columntoproperty - if self._columntoproperty[col] not in self._primary_key_props and + if self._columntoproperty[col] not in self._identity_key_props and (not hasattr(col, 'table') or col.table not in self._cols_by_table)) @@ -2373,16 +2382,31 @@ class Mapper(InspectionAttr): manager[prop.key]. impl.get(state, dict_, attributes.PASSIVE_RETURN_NEVER_SET) - for prop in self._primary_key_props + for prop in self._identity_key_props ] @_memoized_configured_property - def _primary_key_props(self): - # TODO: this should really be called "identity key props", - # as it does not necessarily include primary key columns within - # individual tables + def _identity_key_props(self): return [self._columntoproperty[col] for col in self.primary_key] + @_memoized_configured_property + def _all_pk_props(self): + collection = set() + for table in self.tables: + collection.update(self._pks_by_table[table]) + return collection + + @_memoized_configured_property + def _should_undefer_in_wildcard(self): + cols = set(self.primary_key) + if self.polymorphic_on is not None: + cols.add(self.polymorphic_on) + return cols + + @_memoized_configured_property + def _primary_key_propkeys(self): + return set([prop.key for prop in self._all_pk_props]) + def _get_state_attr_by_column( self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET): @@ -2635,7 +2659,7 @@ def configure_mappers(): if not Mapper._new_mappers: return - Mapper.dispatch(Mapper).before_configured() + Mapper.dispatch._for_class(Mapper).before_configured() # initialize properties on all mappers # note that _mapper_registry is unordered, which # may randomly conceal/reveal issues related to @@ -2667,7 +2691,7 @@ def configure_mappers(): _already_compiling = False finally: _CONFIGURE_MUTEX.release() - Mapper.dispatch(Mapper).after_configured() + Mapper.dispatch._for_class(Mapper).after_configured() def reconstructor(fn): @@ -2779,6 +2803,8 @@ def _event_on_init(state, args, kwargs): class _ColumnMapping(dict): """Error reporting helper for mapper._columntoproperty.""" + __slots__ = 'mapper', + def __init__(self, mapper): self.mapper = mapper diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index d4dbf29a0..ec80c70cc 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -52,6 +52,9 @@ class PathRegistry(object): """ + is_token = False + is_root = False + def __eq__(self, other): return other is not None and \ self.path == other.path @@ -153,6 +156,8 @@ class RootRegistry(PathRegistry): """ path = () has_entity = False + is_aliased_class = False + is_root = True def __getitem__(self, entity): return entity._path_registry @@ -168,6 +173,15 @@ class TokenRegistry(PathRegistry): has_entity = False + is_token = True + + def generate_for_superclasses(self): + if not self.parent.is_aliased_class and not self.parent.is_root: + for ent in self.parent.mapper.iterate_to_root(): + yield TokenRegistry(self.parent.parent[ent], self.token) + else: + yield self + def __getitem__(self, entity): raise NotImplementedError() diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 6b8d5af14..c3b2d7bcb 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -15,7 +15,7 @@ in unitofwork.py. """ import operator -from itertools import groupby +from itertools import groupby, chain from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator from .base import state_str, _attr_as_key, _entity_descriptor @@ -23,7 +23,105 @@ from ..sql import expression from . import loading -def save_obj(base_mapper, states, uowtransaction, single=False): +def _bulk_insert( + mapper, mappings, session_transaction, isstates, return_defaults): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_insert()") + + if isstates: + if return_defaults: + states = [(state, state.dict) for state in mappings] + mappings = [dict_ for (state, dict_) in states] + else: + mappings = [state.dict for state in mappings] + else: + mappings = list(mappings) + + connection = session_transaction.connection(base_mapper) + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper): + continue + + records = ( + (None, state_dict, params, mapper, + connection, value_params, has_all_pks, has_all_defaults) + for + state, state_dict, params, mp, + conn, value_params, has_all_pks, + has_all_defaults in _collect_insert_commands(table, ( + (None, mapping, mapper, connection) + for mapping in mappings), + bulk=True, return_defaults=return_defaults + ) + ) + _emit_insert_statements(base_mapper, None, + cached_connections, + super_mapper, table, records, + bookkeeping=return_defaults) + + if return_defaults and isstates: + identity_cls = mapper._identity_class + identity_props = [p.key for p in mapper._identity_key_props] + for state, dict_ in states: + state.key = ( + identity_cls, + tuple([dict_[key] for key in identity_props]) + ) + + +def _bulk_update(mapper, mappings, session_transaction, + isstates, update_changed_only): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + def _changed_dict(mapper, state): + return dict( + (k, v) + for k, v in state.dict.items() if k in state.committed_state or k + in mapper._primary_key_propkeys + ) + + if isstates: + if update_changed_only: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = [state.dict for state in mappings] + else: + mappings = list(mappings) + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_update()") + + connection = session_transaction.connection(base_mapper) + + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper): + continue + + records = _collect_update_commands(None, table, ( + (None, mapping, mapper, connection, + (mapping[mapper._version_id_prop.key] + if mapper._version_id_prop else None)) + for mapping in mappings + ), bulk=True) + + _emit_update_statements(base_mapper, None, + cached_connections, + super_mapper, table, records, + bookkeeping=False) + + +def save_obj( + base_mapper, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -76,17 +174,16 @@ def save_obj(base_mapper, states, uowtransaction, single=False): _finalize_insert_update_commands( base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, False) - for state, state_dict, mapper, connection in states_to_insert - ) - ) - _finalize_insert_update_commands( - base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, True) - for state, state_dict, mapper, connection, - update_version_id in states_to_update + chain( + ( + (state, state_dict, mapper, connection, False) + for state, state_dict, mapper, connection in states_to_insert + ), + ( + (state, state_dict, mapper, connection, True) + for state, state_dict, mapper, connection, + update_version_id in states_to_update + ) ) ) @@ -261,7 +358,9 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): state, dict_, mapper, connection, update_version_id) -def _collect_insert_commands(table, states_to_insert): +def _collect_insert_commands( + table, states_to_insert, + bulk=False, return_defaults=False): """Identify sets of values to use in INSERT statements for a list of states. @@ -280,22 +379,26 @@ def _collect_insert_commands(table, states_to_insert): col = propkey_to_col[propkey] if value is None: continue - elif isinstance(value, sql.ClauseElement): + elif not bulk and isinstance(value, sql.ClauseElement): value_params[col.key] = value else: params[col.key] = value - for colkey in mapper._insert_cols_as_none[table].\ - difference(params).difference(value_params): - params[colkey] = None + if not bulk: + for colkey in mapper._insert_cols_as_none[table].\ + difference(params).difference(value_params): + params[colkey] = None - has_all_pks = mapper._pk_keys_by_table[table].issubset(params) + if not bulk or return_defaults: + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) - if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].\ - issubset(params) + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True else: - has_all_defaults = True + has_all_defaults = has_all_pks = True if mapper.version_id_generator is not False \ and mapper.version_id_col is not None and \ @@ -309,7 +412,9 @@ def _collect_insert_commands(table, states_to_insert): has_all_defaults) -def _collect_update_commands(uowtransaction, table, states_to_update): +def _collect_update_commands( + uowtransaction, table, states_to_update, + bulk=False): """Identify sets of values to use in UPDATE statements for a list of states. @@ -329,23 +434,32 @@ def _collect_update_commands(uowtransaction, table, states_to_update): pks = mapper._pks_by_table[table] - params = {} value_params = {} propkey_to_col = mapper._propkey_to_col[table] - for propkey in set(propkey_to_col).intersection(state.committed_state): - value = state_dict[propkey] - col = propkey_to_col[propkey] - - if not state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey]): - if isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value + if bulk: + params = dict( + (propkey_to_col[propkey].key, state_dict[propkey]) + for propkey in + set(propkey_to_col).intersection(state_dict) + ) + else: + params = {} + for propkey in set(propkey_to_col).intersection( + state.committed_state): + value = state_dict[propkey] + col = propkey_to_col[propkey] + + if not state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey]): + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value - if update_version_id is not None: + if update_version_id is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: col = mapper.version_id_col params[col._label] = update_version_id @@ -357,28 +471,37 @@ def _collect_update_commands(uowtransaction, table, states_to_update): if not (params or value_params): continue - pk_params = {} - for col in pks: - propkey = mapper._columntoproperty[col].key - history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF) - - if history.added: - if not history.deleted or \ - ("pk_cascaded", state, col) in \ - uowtransaction.attributes: - pk_params[col._label] = history.added[0] - params.pop(col.key, None) + if bulk: + pk_params = dict( + (propkey_to_col[propkey]._label, state_dict.get(propkey)) + for propkey in + set(propkey_to_col). + intersection(mapper._pk_keys_by_table[table]) + ) + else: + pk_params = {} + for col in pks: + propkey = mapper._columntoproperty[col].key + + history = state.manager[propkey].impl.get_history( + state, state_dict, attributes.PASSIVE_OFF) + + if history.added: + if not history.deleted or \ + ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + pk_params[col._label] = history.added[0] + params.pop(col.key, None) + else: + # else, use the old value to locate the row + pk_params[col._label] = history.deleted[0] + params[col.key] = history.added[0] else: - # else, use the old value to locate the row - pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] - else: - pk_params[col._label] = history.unchanged[0] - if pk_params[col._label] is None: - raise orm_exc.FlushError( - "Can't update table %s using NULL for primary " - "key value on column %s" % (table, col)) + pk_params[col._label] = history.unchanged[0] + if pk_params[col._label] is None: + raise orm_exc.FlushError( + "Can't update table %s using NULL for primary " + "key value on column %s" % (table, col)) if params or value_params: params.update(pk_params) @@ -446,18 +569,19 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, "key value on column %s" % (table, col)) if update_version_id is not None and \ - table.c.contains_column(mapper.version_id_col): + mapper.version_id_col in mapper._cols_by_table[table]: params[mapper.version_id_col.key] = update_version_id yield params, connection def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): + cached_connections, mapper, table, update, + bookkeeping=True): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def update_stmt(): clause = sql.and_() @@ -486,32 +610,42 @@ def _emit_update_statements(base_mapper, uowtransaction, records in groupby( update, lambda rec: ( - rec[4], - tuple(sorted(rec[2])), - bool(rec[5]))): + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]))): # whether or not we have "value" parameters rows = 0 records = list(records) + + # TODO: would be super-nice to not have to determine this boolean + # inside the loop here, in the 99.9999% of the time there's only + # one connection in use + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = not needs_version_id or assert_multirow + if hasvalue: for state, state_dict, params, mapper, \ connection, value_params in records: c = connection.execute( statement.values(value_params), params) - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) rows += c.rowcount + check_rowcount = True else: - if needs_version_id and \ - not connection.dialect.supports_sane_multi_rowcount and \ - connection.dialect.supports_sane_rowcount: + if not allow_multirow: + check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ connection, value_params in records: c = cached_connections[connection].\ @@ -528,6 +662,12 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount else: multiparams = [rec[2] for rec in records] + + check_rowcount = assert_multirow or ( + assert_singlerow and + len(multiparams) == 1 + ) + c = cached_connections[connection].\ execute(statement, multiparams) @@ -544,7 +684,7 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) - if connection.dialect.supports_sane_rowcount: + if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " @@ -558,20 +698,23 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert): + cached_connections, mapper, table, insert, + bookkeeping=True): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" statement = base_mapper._memo(('insert', table), table.insert) for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ - records in groupby(insert, - lambda rec: (rec[4], - tuple(sorted(rec[2].keys())), - bool(rec[5]), - rec[6], rec[7]) - ): - if \ + records in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7])): + if not bookkeeping or \ ( has_all_defaults or not base_mapper.eager_defaults @@ -584,19 +727,20 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + if bookkeeping: + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -657,7 +801,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # also group them into common (connection, cols) sets # to support executemany(). for key, grouper in groupby( - update, lambda rec: (rec[1], sorted(rec[0])) + update, lambda rec: ( + rec[1], # connection + set(rec[0]) # parameter keys + ) ): connection = key[0] multiparams = [params for params, conn in grouper] @@ -671,7 +818,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, by _collect_delete_commands().""" need_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def delete_stmt(): clause = sql.and_() @@ -693,12 +840,9 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, statement = base_mapper._memo(('delete', table), delete_stmt) for connection, recs in groupby( delete, - lambda rec: rec[1] + lambda rec: rec[1] # connection ): - del_objects = [ - params - for params, connection in recs - ] + del_objects = [params for params, connection in recs] connection = cached_connections[connection] @@ -775,9 +919,8 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): toload_now.extend(state._unloaded_non_object) elif mapper.version_id_col is not None and \ mapper.version_id_generator is False: - prop = mapper._columntoproperty[mapper.version_id_col] - if prop.key in state.unloaded: - toload_now.extend([prop.key]) + if mapper._version_id_prop.key in state.unloaded: + toload_now.extend([mapper._version_id_prop.key]) if toload_now: state.key = base_mapper._identity_key_from_state(state) @@ -794,7 +937,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params): + state, dict_, result, params, value_params, bulk=False): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -803,7 +946,8 @@ def _postfetch(mapper, uowtransaction, table, postfetch_cols = result.context.compiled.postfetch returning_cols = result.context.compiled.returning - if mapper.version_id_col is not None: + if mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] if returning_cols: @@ -829,10 +973,13 @@ def _postfetch(mapper, uowtransaction, table, # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + if state is None: + sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) + else: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) def _connections_for_states(base_mapper, uowtransaction, states): @@ -883,6 +1030,7 @@ class BulkUD(object): def __init__(self, query): self.query = query.enable_eagerloads(False) + self.mapper = self.query._bind_mapper() @property def session(self): @@ -977,6 +1125,7 @@ class BulkFetch(BulkUD): self.primary_table.primary_key) self.matched_rows = session.execute( select_stmt, + mapper=self.mapper, params=query._params).fetchall() @@ -987,7 +1136,6 @@ class BulkUpdate(BulkUD): super(BulkUpdate, self).__init__(query) self.query._no_select_modifiers("update") self.values = values - self.mapper = self.query._mapper_zero_or_none() @classmethod def factory(cls, query, synchronize_session, values): @@ -1033,7 +1181,8 @@ class BulkUpdate(BulkUD): self.context.whereclause, values) self.result = self.query.session.execute( - update_stmt, params=self.query._params) + update_stmt, params=self.query._params, + mapper=self.mapper) self.rowcount = self.result.rowcount def _do_post(self): @@ -1060,8 +1209,10 @@ class BulkDelete(BulkUD): delete_stmt = sql.delete(self.primary_table, self.context.whereclause) - self.result = self.query.session.execute(delete_stmt, - params=self.query._params) + self.result = self.query.session.execute( + delete_stmt, + params=self.query._params, + mapper=self.mapper) self.rowcount = self.result.rowcount def _do_post(self): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 62ea93fb3..d51b6920d 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -34,6 +34,13 @@ class ColumnProperty(StrategizedProperty): strategy_wildcard_key = 'column' + __slots__ = ( + '_orig_columns', 'columns', 'group', 'deferred', + 'instrument', 'comparator_factory', 'descriptor', 'extension', + 'active_history', 'expire_on_flush', 'info', 'doc', + 'strategy_class', '_creation_order', '_is_polymorphic_discriminator', + '_mapped_by_synonym') + def __init__(self, *columns, **kwargs): """Provide a column-level property for use with a Mapper. @@ -109,6 +116,7 @@ class ColumnProperty(StrategizedProperty): **Deprecated.** Please see :class:`.AttributeEvents`. """ + super(ColumnProperty, self).__init__() self._orig_columns = [expression._labeled(c) for c in columns] self.columns = [expression._labeled(_orm_full_deannotate(c)) for c in columns] @@ -206,7 +214,7 @@ class ColumnProperty(StrategizedProperty): elif dest_state.has_identity and self.key not in dest_dict: dest_state._expire_attributes(dest_dict, [self.key]) - class Comparator(PropComparator): + class Comparator(util.MemoizedSlots, PropComparator): """Produce boolean, comparison, and other operators for :class:`.ColumnProperty` attributes. @@ -224,8 +232,10 @@ class ColumnProperty(StrategizedProperty): :attr:`.TypeEngine.comparator_factory` """ - @util.memoized_instancemethod - def __clause_element__(self): + + __slots__ = '__clause_element__', 'info' + + def _memoized_method___clause_element__(self): if self.adapter: return self.adapter(self.prop.columns[0]) else: @@ -233,15 +243,14 @@ class ColumnProperty(StrategizedProperty): "parententity": self._parentmapper, "parentmapper": self._parentmapper}) - @util.memoized_property - def info(self): + def _memoized_attr_info(self): ce = self.__clause_element__() try: return ce.info except AttributeError: return self.prop.info - def __getattr__(self, key): + def _fallback_getattr(self, key): """proxy attribute access down to the mapped column. this allows user-defined comparison methods to be accessed. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 790686288..60a637952 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -75,6 +75,7 @@ class Query(object): _having = None _distinct = False _prefixes = None + _suffixes = None _offset = None _limit = None _for_update_arg = None @@ -159,7 +160,6 @@ class Query(object): for from_obj in obj: info = inspect(from_obj) - if hasattr(info, 'mapper') and \ (info.is_mapper or info.is_aliased_class): self._select_from_entity = from_obj @@ -285,8 +285,9 @@ class Query(object): return self._entities[0] def _mapper_zero(self): - return self._select_from_entity or \ - self._entity_zero().entity_zero + return self._select_from_entity \ + if self._select_from_entity is not None \ + else self._entity_zero().entity_zero @property def _mapper_entities(self): @@ -300,11 +301,14 @@ class Query(object): self._mapper_zero() ) - def _mapper_zero_or_none(self): - if self._primary_entity: - return self._primary_entity.mapper - else: - return None + def _bind_mapper(self): + ezero = self._mapper_zero() + if ezero is not None: + insp = inspect(ezero) + if hasattr(insp, 'mapper'): + return insp.mapper + + return None def _only_mapper_zero(self, rationale=None): if len(self._entities) > 1: @@ -810,7 +814,7 @@ class Query(object): foreign-key-to-primary-key criterion, will also use an operation equivalent to :meth:`~.Query.get` in order to retrieve the target value from the local identity map - before querying the database. See :doc:`/orm/loading` + before querying the database. See :doc:`/orm/loading_relationships` for further details on relationship loading. :param ident: A scalar or tuple value representing @@ -987,6 +991,7 @@ class Query(object): statement.correlate(None) q = self._from_selectable(fromclause) q._enable_single_crit = False + q._select_from_entity = self._mapper_zero() if entities: q._set_entities(entities) return q @@ -1003,7 +1008,7 @@ class Query(object): '_limit', '_offset', '_joinpath', '_joinpoint', '_distinct', '_having', - '_prefixes', + '_prefixes', '_suffixes' ): self.__dict__.pop(attr, None) self._set_select_from([fromclause], True) @@ -1099,7 +1104,7 @@ class Query(object): Most supplied options regard changing how column- and relationship-mapped attributes are loaded. See the sections - :ref:`deferred` and :doc:`/orm/loading` for reference + :ref:`deferred` and :doc:`/orm/loading_relationships` for reference documentation. """ @@ -2359,12 +2364,38 @@ class Query(object): .. versionadded:: 0.7.7 + .. seealso:: + + :meth:`.HasPrefixes.prefix_with` + """ if self._prefixes: self._prefixes += prefixes else: self._prefixes = prefixes + @_generative() + def suffix_with(self, *suffixes): + """Apply the suffix to the query and return the newly resulting + ``Query``. + + :param \*suffixes: optional suffixes, typically strings, + not using any commas. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Query.prefix_with` + + :meth:`.HasSuffixes.suffix_with` + + """ + if self._suffixes: + self._suffixes += suffixes + else: + self._suffixes = suffixes + def all(self): """Return the results represented by this ``Query`` as a list. @@ -2499,7 +2530,7 @@ class Query(object): def _execute_and_instances(self, querycontext): conn = self._connection_from_session( - mapper=self._mapper_zero_or_none(), + mapper=self._bind_mapper(), clause=querycontext.statement, close_with_result=True) @@ -2601,6 +2632,7 @@ class Query(object): 'offset': self._offset, 'distinct': self._distinct, 'prefixes': self._prefixes, + 'suffixes': self._suffixes, 'group_by': self._group_by or None, 'having': self._having } @@ -2697,6 +2729,18 @@ class Query(object): Deletes rows matched by this query from the database. + E.g.:: + + sess.query(User).filter(User.age == 25).\\ + delete(synchronize_session=False) + + sess.query(User).filter(User.age == 25).\\ + delete(synchronize_session='evaluate') + + .. warning:: The :meth:`.Query.delete` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + :param synchronize_session: chooses the strategy for the removal of matched objects from the session. Valid values are: @@ -2715,8 +2759,7 @@ class Query(object): ``'evaluate'`` - Evaluate the query's criteria in Python straight on the objects in the session. If evaluation of the criteria isn't - implemented, an error is raised. In that case you probably - want to use the 'fetch' strategy as a fallback. + implemented, an error is raised. The expression evaluator currently doesn't account for differing string collations between the database and Python. @@ -2724,29 +2767,42 @@ class Query(object): :return: the count of rows matched as returned by the database's "row count" feature. - This method has several key caveats: - - * The method does **not** offer in-Python cascading of relationships - - it is assumed that ON DELETE CASCADE/SET NULL/etc. is configured - for any foreign key references which require it, otherwise the - database may emit an integrity violation if foreign key references - are being enforced. - - After the DELETE, dependent objects in the :class:`.Session` which - were impacted by an ON DELETE may not contain the current - state, or may have been deleted. This issue is resolved once the - :class:`.Session` is expired, - which normally occurs upon :meth:`.Session.commit` or can be forced - by using :meth:`.Session.expire_all`. Accessing an expired object - whose row has been deleted will invoke a SELECT to locate the - row; when the row is not found, an - :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. - - * The :meth:`.MapperEvents.before_delete` and - :meth:`.MapperEvents.after_delete` - events are **not** invoked from this method. Instead, the - :meth:`.SessionEvents.after_bulk_delete` method is provided to act - upon a mass DELETE of entity rows. + .. warning:: **Additional Caveats for bulk query deletes** + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON DELETE CASCADE/SET + NULL/etc. is configured for any foreign key references + which require it, otherwise the database may emit an + integrity violation if foreign key references are being + enforced. + + After the DELETE, dependent objects in the + :class:`.Session` which were impacted by an ON DELETE + may not contain the current state, or may have been + deleted. This issue is resolved once the + :class:`.Session` is expired, which normally occurs upon + :meth:`.Session.commit` or can be forced by using + :meth:`.Session.expire_all`. Accessing an expired + object whose row has been deleted will invoke a SELECT + to locate the row; when the row is not found, an + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is + raised. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaluate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The :meth:`.MapperEvents.before_delete` and + :meth:`.MapperEvents.after_delete` + events **are not invoked** from this method. Instead, the + :meth:`.SessionEvents.after_bulk_delete` method is provided to + act upon a mass DELETE of entity rows. .. seealso:: @@ -2769,17 +2825,21 @@ class Query(object): E.g.:: - sess.query(User).filter(User.age == 25).\ - update({User.age: User.age - 10}, synchronize_session='fetch') - + sess.query(User).filter(User.age == 25).\\ + update({User.age: User.age - 10}, synchronize_session=False) - sess.query(User).filter(User.age == 25).\ + sess.query(User).filter(User.age == 25).\\ update({"age": User.age - 10}, synchronize_session='evaluate') + .. warning:: The :meth:`.Query.update` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + + :param values: a dictionary with attributes names, or alternatively - mapped attributes or SQL expressions, as keys, and literal - values or sql expressions as values. + mapped attributes or SQL expressions, as keys, and literal + values or sql expressions as values. .. versionchanged:: 1.0.0 - string names in the values dictionary are now resolved against the mapped entity; previously, these @@ -2787,7 +2847,7 @@ class Query(object): translation. :param synchronize_session: chooses the strategy to update the - attributes on objects in the session. Valid values are: + attributes on objects in the session. Valid values are: ``False`` - don't synchronize the session. This option is the most efficient and is reliable once the session is expired, which @@ -2808,43 +2868,56 @@ class Query(object): string collations between the database and Python. :return: the count of rows matched as returned by the database's - "row count" feature. - - This method has several key caveats: - - * The method does **not** offer in-Python cascading of relationships - - it is assumed that ON UPDATE CASCADE is configured for any foreign - key references which require it, otherwise the database may emit an - integrity violation if foreign key references are being enforced. - - After the UPDATE, dependent objects in the :class:`.Session` which - were impacted by an ON UPDATE CASCADE may not contain the current - state; this issue is resolved once the :class:`.Session` is expired, - which normally occurs upon :meth:`.Session.commit` or can be forced - by using :meth:`.Session.expire_all`. - - * The method supports multiple table updates, as - detailed in :ref:`multi_table_updates`, and this behavior does - extend to support updates of joined-inheritance and other multiple - table mappings. However, the **join condition of an inheritance - mapper is currently not automatically rendered**. - Care must be taken in any multiple-table update to explicitly - include the joining condition between those tables, even in mappings - where this is normally automatic. - E.g. if a class ``Engineer`` subclasses ``Employee``, an UPDATE of - the ``Engineer`` local table using criteria against the ``Employee`` - local table might look like:: - - session.query(Engineer).\\ - filter(Engineer.id == Employee.id).\\ - filter(Employee.name == 'dilbert').\\ - update({"engineer_type": "programmer"}) - - * The :meth:`.MapperEvents.before_update` and - :meth:`.MapperEvents.after_update` - events are **not** invoked from this method. Instead, the - :meth:`.SessionEvents.after_bulk_update` method is provided to act - upon a mass UPDATE of entity rows. + "row count" feature. + + .. warning:: **Additional Caveats for bulk query updates** + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON UPDATE CASCADE is + configured for any foreign key references which require + it, otherwise the database may emit an integrity + violation if foreign key references are being enforced. + + After the UPDATE, dependent objects in the + :class:`.Session` which were impacted by an ON UPDATE + CASCADE may not contain the current state; this issue is + resolved once the :class:`.Session` is expired, which + normally occurs upon :meth:`.Session.commit` or can be + forced by using :meth:`.Session.expire_all`. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaluate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The method supports multiple table updates, as detailed + in :ref:`multi_table_updates`, and this behavior does + extend to support updates of joined-inheritance and + other multiple table mappings. However, the **join + condition of an inheritance mapper is not + automatically rendered**. Care must be taken in any + multiple-table update to explicitly include the joining + condition between those tables, even in mappings where + this is normally automatic. E.g. if a class ``Engineer`` + subclasses ``Employee``, an UPDATE of the ``Engineer`` + local table using criteria against the ``Employee`` + local table might look like:: + + session.query(Engineer).\\ + filter(Engineer.id == Employee.id).\\ + filter(Employee.name == 'dilbert').\\ + update({"engineer_type": "programmer"}) + + * The :meth:`.MapperEvents.before_update` and + :meth:`.MapperEvents.after_update` + events **are not invoked from this method**. Instead, the + :meth:`.SessionEvents.after_bulk_update` method is provided to + act upon a mass UPDATE of entity rows. .. seealso:: @@ -3473,26 +3546,26 @@ class _ColumnEntity(_QueryEntity): )): self._label_name = column.key column = column._query_clause_element() - else: - self._label_name = getattr(column, 'key', None) - - if not isinstance(column, expression.ColumnElement) and \ - hasattr(column, '_select_iterable'): - for c in column._select_iterable: - if c is column: - break - _ColumnEntity(query, c, namespace=column) - else: + if isinstance(column, Bundle): + _BundleEntity(query, column) return - elif isinstance(column, Bundle): - _BundleEntity(query, column) - return + elif not isinstance(column, sql.ColumnElement): + if hasattr(column, '_select_iterable'): + # break out an object like Table into + # individual columns + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c, namespace=column) + else: + return - if not isinstance(column, sql.ColumnElement): raise sa_exc.InvalidRequestError( "SQL expression, column, or mapped entity " "expected - got '%r'" % (column, ) ) + else: + self._label_name = getattr(column, 'key', None) self.type = type_ = column.type if type_.hashable: @@ -3523,15 +3596,26 @@ class _ColumnEntity(_QueryEntity): # leaking out their entities into the main select construct self.actual_froms = actual_froms = set(column._from_objects) - self.entities = util.OrderedSet( + all_elements = [ + elem for elem in visitors.iterate(column, {}) + if 'parententity' in elem._annotations + ] + + self.entities = util.unique_list( + elem._annotations['parententity'] + for elem in all_elements + if 'parententity' in elem._annotations + ) + + self._from_entities = set( elem._annotations['parententity'] - for elem in visitors.iterate(column, {}) + for elem in all_elements if 'parententity' in elem._annotations and actual_froms.intersection(elem._from_objects) ) if self.entities: - self.entity_zero = list(self.entities)[0] + self.entity_zero = self.entities[0] elif self.namespace is not None: self.entity_zero = self.namespace else: @@ -3557,7 +3641,9 @@ class _ColumnEntity(_QueryEntity): def setup_entity(self, ext_info, aliased_adapter): if 'selectable' not in self.__dict__: self.selectable = ext_info.selectable - self.froms.add(ext_info.selectable) + + if self.actual_froms.intersection(ext_info.selectable._from_objects): + self.froms.add(ext_info.selectable) def corresponds_to(self, entity): # TODO: just returning False here, diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 86f1b3f82..df2250a4c 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -528,7 +528,7 @@ class RelationshipProperty(StrategizedProperty): .. seealso:: - :doc:`/orm/loading` - Full documentation on relationship loader + :doc:`/orm/loading_relationships` - Full documentation on relationship loader configuration. :ref:`dynamic_relationship` - detail on the ``dynamic`` option. @@ -775,6 +775,7 @@ class RelationshipProperty(StrategizedProperty): """ + super(RelationshipProperty, self).__init__() self.uselist = uselist self.argument = argument diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index f23983cbc..0e272dc95 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -20,6 +20,8 @@ from .base import ( _class_to_mapper, _state_mapper, object_state, _none_set, state_str, instance_str ) +import itertools +from . import persistence from .unitofwork import UOWTransaction from . import state as statelib import sys @@ -433,11 +435,13 @@ class SessionTransaction(object): self.session.dispatch.after_rollback(self.session) - def close(self): + def close(self, invalidate=False): self.session.transaction = self._parent if self._parent is None: for connection, transaction, autoclose in \ set(self._connections.values()): + if invalidate: + connection.invalidate() if autoclose: connection.close() else: @@ -482,7 +486,8 @@ class Session(_SessionClassMethods): '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', 'close', 'commit', 'connection', 'delete', 'execute', 'expire', 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', - 'is_modified', + 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings', + 'bulk_update_mappings', 'merge', 'query', 'refresh', 'rollback', 'scalar') @@ -591,8 +596,8 @@ class Session(_SessionClassMethods): .. versionadded:: 0.9.0 :param query_cls: Class which should be used to create new Query - objects, as returned by the :meth:`~.Session.query` method. - Defaults to :class:`.Query`. + objects, as returned by the :meth:`~.Session.query` method. + Defaults to :class:`.Query`. :param twophase: When ``True``, all transactions will be started as a "two phase" transaction, i.e. using the "two phase" semantics @@ -997,10 +1002,46 @@ class Session(_SessionClassMethods): not use any connection resources until they are first needed. """ + self._close_impl(invalidate=False) + + def invalidate(self): + """Close this Session, using connection invalidation. + + This is a variant of :meth:`.Session.close` that will additionally + ensure that the :meth:`.Connection.invalidate` method will be called + on all :class:`.Connection` objects. This can be called when + the database is known to be in a state where the connections are + no longer safe to be used. + + E.g.:: + + try: + sess = Session() + sess.add(User()) + sess.commit() + except gevent.Timeout: + sess.invalidate() + raise + except: + sess.rollback() + raise + + This clears all items and ends any transaction in progress. + + If this session were created with ``autocommit=False``, a new + transaction is immediately begun. Note that this new transaction does + not use any connection resources until they are first needed. + + .. versionadded:: 0.9.9 + + """ + self._close_impl(invalidate=True) + + def _close_impl(self, invalidate): self.expunge_all() if self.transaction is not None: for transaction in self.transaction._iterate_parents(): - transaction.close() + transaction.close(invalidate) def expunge_all(self): """Remove all object instances from this ``Session``. @@ -2044,6 +2085,226 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) + def bulk_save_objects( + self, objects, return_defaults=False, update_changed_only=True): + """Perform a bulk save of the given list of objects. + + The bulk save feature allows mapped objects to be used as the + source of simple INSERT and UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations; the extraction of data from the objects is also performed + using a lower-latency process that ignores whether or not attributes + have actually been modified in the case of UPDATEs, and also ignores + SQL expressions. + + The objects as given are not added to the session and no additional + state is established on them, unless the ``return_defaults`` flag + is also set, in which case primary key attributes and server-side + default values will be populated. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk save feature allows for a lower-latency INSERT/UPDATE + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT/UPDATES of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param objects: a list of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. + + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` + + """ + for (mapper, isupdate), states in itertools.groupby( + (attributes.instance_state(obj) for obj in objects), + lambda state: (state.mapper, state.key is not None) + ): + self._bulk_save_mappings( + mapper, states, isupdate, True, + return_defaults, update_changed_only) + + def bulk_insert_mappings(self, mapper, mappings, return_defaults=False): + """Perform a bulk insert of the given list of mapping dictionaries. + + The bulk insert feature allows plain Python dictionaries to be used as + the source of simple INSERT operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when inserting + large numbers of simple rows. + + The values within the dictionaries as given are typically passed + without modification into Core :meth:`.Insert` constructs, after + organizing the values within them across the tables to which + the given mapper is mapped. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk insert feature allows for a lower-latency INSERT + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be inserted, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain + all keys to be populated into all tables. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary + key values ahead of time; however, + :paramref:`.Session.bulk_insert_mappings.return_defaults` + **greatly reduces the performance gains** of the method overall. + If the rows + to be inserted only refer to a single table, then there is no + reason this flag should be set as the returned default information + is not used. + + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + """ + self._bulk_save_mappings( + mapper, mappings, False, False, return_defaults, False) + + def bulk_update_mappings(self, mapper, mappings): + """Perform a bulk update of the given list of mapping dictionaries. + + The bulk update feature allows plain Python dictionaries to be used as + the source of simple UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when updating + large numbers of simple rows. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk update feature allows for a lower-latency UPDATE + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + UPDATES of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary may contain + keys corresponding to all tables. All those keys which are present + and are not part of the primary key are applied to the SET clause + of the UPDATE statement; the primary key values, which are required, + are applied to the WHERE clause. + + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + """ + self._bulk_save_mappings(mapper, mappings, True, False, False, False) + + def _bulk_save_mappings( + self, mapper, mappings, isupdate, isstates, + return_defaults, update_changed_only): + mapper = _class_to_mapper(mapper) + self._flushing = True + + transaction = self.begin( + subtransactions=True) + try: + if isupdate: + persistence._bulk_update( + mapper, mappings, transaction, + isstates, update_changed_only) + else: + persistence._bulk_insert( + mapper, mappings, transaction, isstates, return_defaults) + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + finally: + self._flushing = False + def is_modified(self, instance, include_collections=True, passive=True): """Return ``True`` if the given instance has locally diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d95f17f64..8a4c8e731 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -105,6 +105,8 @@ class UninstrumentedColumnLoader(LoaderStrategy): if the argument is against the with_polymorphic selectable. """ + __slots__ = 'columns', + def __init__(self, parent): super(UninstrumentedColumnLoader, self).__init__(parent) self.columns = self.parent_property.columns @@ -128,6 +130,8 @@ class UninstrumentedColumnLoader(LoaderStrategy): class ColumnLoader(LoaderStrategy): """Provide loading behavior for a :class:`.ColumnProperty`.""" + __slots__ = 'columns', 'is_composite' + def __init__(self, parent): super(ColumnLoader, self).__init__(parent) self.columns = self.parent_property.columns @@ -176,6 +180,8 @@ class ColumnLoader(LoaderStrategy): class DeferredColumnLoader(LoaderStrategy): """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" + __slots__ = 'columns', 'group' + def __init__(self, parent): super(DeferredColumnLoader, self).__init__(parent) if hasattr(self.parent_property, 'composite_class'): @@ -225,7 +231,8 @@ class DeferredColumnLoader(LoaderStrategy): ( loadopt and 'undefer_pks' in loadopt.local_opts and - set(self.columns).intersection(self.parent.primary_key) + set(self.columns).intersection( + self.parent._should_undefer_in_wildcard) ) or ( @@ -300,6 +307,8 @@ class LoadDeferredColumns(object): class AbstractRelationshipLoader(LoaderStrategy): """LoaderStratgies which deal with related objects.""" + __slots__ = 'mapper', 'target', 'uselist' + def __init__(self, parent): super(AbstractRelationshipLoader, self).__init__(parent) self.mapper = self.parent_property.mapper @@ -316,6 +325,8 @@ class NoLoader(AbstractRelationshipLoader): """ + __slots__ = () + def init_class_attribute(self, mapper): self.is_class_level = True @@ -343,6 +354,10 @@ class LazyLoader(AbstractRelationshipLoader): """ + __slots__ = ( + '_lazywhere', '_rev_lazywhere', 'use_get', '_bind_to_col', + '_equated_columns', '_rev_bind_to_col', '_rev_equated_columns') + def __init__(self, parent): super(LazyLoader, self).__init__(parent) join_condition = self.parent_property._join_condition @@ -661,6 +676,8 @@ class LoadLazyAttribute(object): @properties.RelationshipProperty.strategy_for(lazy="immediate") class ImmediateLoader(AbstractRelationshipLoader): + __slots__ = () + def init_class_attribute(self, mapper): self.parent_property.\ _get_strategy_by_cls(LazyLoader).\ @@ -684,6 +701,8 @@ class ImmediateLoader(AbstractRelationshipLoader): @log.class_logger @properties.RelationshipProperty.strategy_for(lazy="subquery") class SubqueryLoader(AbstractRelationshipLoader): + __slots__ = 'join_depth', + def __init__(self, parent): super(SubqueryLoader, self).__init__(parent) self.join_depth = self.parent_property.join_depth @@ -1069,6 +1088,9 @@ class JoinedLoader(AbstractRelationshipLoader): using joined eager loading. """ + + __slots__ = 'join_depth', + def __init__(self, parent): super(JoinedLoader, self).__init__(parent) self.join_depth = self.parent_property.join_depth diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 276da2ae0..90e4e9661 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -364,6 +364,7 @@ class _UnboundLoad(Load): return None token = start_path[0] + if isinstance(token, util.string_types): entity = self._find_entity_basestring(query, token, raiseerr) elif isinstance(token, PropComparator): @@ -407,10 +408,18 @@ class _UnboundLoad(Load): # prioritize "first class" options over those # that were "links in the chain", e.g. "x" and "y" in # someload("x.y.z") versus someload("x") / someload("x.y") - if self._is_chain_link: - effective_path.setdefault(context, "loader", loader) + + if effective_path.is_token: + for path in effective_path.generate_for_superclasses(): + if self._is_chain_link: + path.setdefault(context, "loader", loader) + else: + path.set(context, "loader", loader) else: - effective_path.set(context, "loader", loader) + if self._is_chain_link: + effective_path.setdefault(context, "loader", loader) + else: + effective_path.set(context, "loader", loader) def _find_entity_prop_comparator(self, query, token, mapper, raiseerr): if _is_aliased_class(mapper): diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index e1ef85c1d..671c7c067 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -45,6 +45,23 @@ def populate(source, source_mapper, dest, dest_mapper, uowcommit.attributes[("pk_cascaded", dest, r)] = True +def bulk_populate_inherit_keys( + source_dict, source_mapper, synchronize_pairs): + # a simplified version of populate() used by bulk insert mode + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + value = source_dict[prop.key] + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r) + + try: + prop = source_mapper._columntoproperty[r] + source_dict[prop.key] = value + except exc.UnmappedColumnError: + _raise_col_to_prop(True, source_mapper, l, source_mapper, r) + + def clear(dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: if r.primary_key and \ diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 71e61827b..05265b13f 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -16,6 +16,7 @@ organizes them in order of dependency, and executes. from .. import util, event from ..util import topological from . import attributes, persistence, util as orm_util +import itertools def track_cascade_events(descriptor, prop): @@ -379,14 +380,19 @@ class UOWTransaction(object): execute() method has succeeded and the transaction has been committed. """ + if not self.states: + return + states = set(self.states) isdel = set( s for (s, (isdelete, listonly)) in self.states.items() if isdelete ) other = states.difference(isdel) - self.session._remove_newly_deleted(isdel) - self.session._register_newly_persistent(other) + if isdel: + self.session._remove_newly_deleted(isdel) + if other: + self.session._register_newly_persistent(other) class IterateMappersMixin(object): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4be8d19ff..ee629b034 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -30,6 +30,10 @@ class CascadeOptions(frozenset): 'all', 'none', 'delete-orphan']) _allowed_cascades = all_cascades + __slots__ = ( + 'save_update', 'delete', 'refresh_expire', 'merge', + 'expunge', 'delete_orphan') + def __new__(cls, value_list): if isinstance(value_list, str) or value_list is None: return cls.from_string(value_list) @@ -38,10 +42,7 @@ class CascadeOptions(frozenset): raise sa_exc.ArgumentError( "Invalid cascade option(s): %s" % ", ".join([repr(x) for x in - sorted( - values.difference(cls._allowed_cascades) - )]) - ) + sorted(values.difference(cls._allowed_cascades))])) if "all" in values: values.update(cls._add_w_all_cascades) @@ -76,6 +77,7 @@ class CascadeOptions(frozenset): ] return cls(values) + def _validator_events( desc, key, validator, include_removes, include_backrefs): """Runs a validation method on an attribute value to be set or diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index a174df784..253bd77b8 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -230,6 +230,7 @@ class Pool(log.Identified): % reset_on_return) self.echo = echo + if _dispatch: self.dispatch._update(_dispatch, only_propagate=False) if _dialect: @@ -917,9 +918,9 @@ class QueuePool(Pool): on returning a connection. Defaults to 30. :param \**kw: Other keyword arguments including - :paramref:`.Pool.recycle`, :paramref:`.Pool.echo`, - :paramref:`.Pool.reset_on_return` and others are passed to the - :class:`.Pool` constructor. + :paramref:`.Pool.recycle`, :paramref:`.Pool.echo`, + :paramref:`.Pool.reset_on_return` and others are passed to the + :class:`.Pool` constructor. """ Pool.__init__(self, creator, **kw) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 4b6ad1988..285ae579f 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -35,6 +35,7 @@ from .sql.schema import ( UniqueConstraint, _get_table_key, ColumnCollectionConstraint, + ColumnCollectionMixin ) @@ -58,5 +59,7 @@ from .sql.ddl import ( DDLBase, DDLElement, _CreateDropBase, - _DDLCompiles + _DDLCompiles, + sort_tables, + sort_tables_and_constraints ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 2d06109b9..0f6405309 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -449,10 +449,12 @@ class ColumnCollection(util.OrderedProperties): """ + __slots__ = '_all_col_set', '_all_columns' + def __init__(self, *columns): super(ColumnCollection, self).__init__() - self.__dict__['_all_col_set'] = util.column_set() - self.__dict__['_all_columns'] = [] + object.__setattr__(self, '_all_col_set', util.column_set()) + object.__setattr__(self, '_all_columns', []) for c in columns: self.add(c) @@ -576,13 +578,14 @@ class ColumnCollection(util.OrderedProperties): return util.OrderedProperties.__contains__(self, other) def __getstate__(self): - return {'_data': self.__dict__['_data'], - '_all_columns': self.__dict__['_all_columns']} + return {'_data': self._data, + '_all_columns': self._all_columns} def __setstate__(self, state): - self.__dict__['_data'] = state['_data'] - self.__dict__['_all_columns'] = state['_all_columns'] - self.__dict__['_all_col_set'] = util.column_set(state['_all_columns']) + object.__setattr__(self, '_data', state['_data']) + object.__setattr__(self, '_all_columns', state['_all_columns']) + object.__setattr__( + self, '_all_col_set', util.column_set(state['_all_columns'])) def contains_column(self, col): # this has to be done via set() membership @@ -596,8 +599,8 @@ class ColumnCollection(util.OrderedProperties): class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): def __init__(self, data, colset, all_columns): util.ImmutableProperties.__init__(self, data) - self.__dict__['_all_col_set'] = colset - self.__dict__['_all_columns'] = all_columns + object.__setattr__(self, '_all_col_set', colset) + object.__setattr__(self, '_all_columns', all_columns) extend = remove = util.ImmutableProperties._immutable diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b102f0240..da62b1434 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -82,6 +82,7 @@ OPERATORS = { operators.eq: ' = ', operators.concat_op: ' || ', operators.match_op: ' MATCH ', + operators.notmatch_op: ' NOT MATCH ', operators.in_op: ' IN ', operators.notin_op: ' NOT IN ', operators.comma_op: ', ', @@ -247,15 +248,16 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() -class TypeCompiler(object): - +class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" + ensure_kwarg = 'visit_\w+' + def __init__(self, dialect): self.dialect = dialect - def process(self, type_): - return type_._compiler_dispatch(self) + def process(self, type_, **kw): + return type_._compiler_dispatch(self, **kw) class _CompileLabel(visitors.Visitable): @@ -637,8 +639,9 @@ class SQLCompiler(Compiled): def visit_index(self, index, **kwargs): return index.name - def visit_typeclause(self, typeclause, **kwargs): - return self.dialect.type_compiler.process(typeclause.type) + def visit_typeclause(self, typeclause, **kw): + kw['type_expression'] = typeclause + return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): return text @@ -862,14 +865,18 @@ class SQLCompiler(Compiled): else: return "%s = 0" % self.process(element.element, **kw) - def visit_binary(self, binary, **kw): + def visit_notmatch_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op) + + def visit_binary(self, binary, override_operator=None, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ isinstance(binary.left, elements.BindParameter) and \ isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - operator_ = binary.operator + operator_ = override_operator or binary.operator disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) if disp: return disp(binary, operator_, **kw) @@ -1188,12 +1195,16 @@ class SQLCompiler(Compiled): self, asfrom=True, **kwargs ) + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs) + self.ctes[cte] = text if asfrom: if cte_alias_name: text = self.preparer.format_alias(cte, cte_alias_name) - text += " AS " + cte_name + text += self.get_render_as_alias_suffix(cte_name) else: return self.preparer.format_alias(cte, cte_name) return text @@ -1212,8 +1223,8 @@ class SQLCompiler(Compiled): elif asfrom: ret = alias.original._compiler_dispatch(self, asfrom=True, **kwargs) + \ - " AS " + \ - self.preparer.format_alias(alias, alias_name) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name)) if fromhints and alias in fromhints: ret = self.format_from_hint_text(ret, alias, @@ -1223,6 +1234,9 @@ class SQLCompiler(Compiled): else: return alias.original._compiler_dispatch(self, **kwargs) + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + def _add_to_result_map(self, keyname, name, objects, type_): if not self.dialect.case_sensitive: keyname = keyname.lower() @@ -1549,6 +1563,10 @@ class SQLCompiler(Compiled): compound_index == 0 and toplevel: text = self._render_cte_clause() + text + if select._suffixes: + text += " " + self._generate_prefixes( + select, select._suffixes, **kwargs) + self.stack.pop(-1) if asfrom and parens: @@ -2086,7 +2104,9 @@ class DDLCompiler(Compiled): (table.description, column.name, ce.args[0]) )) - const = self.create_table_constraints(table) + const = self.create_table_constraints( + table, _include_foreign_key_constraints= + create.include_foreign_key_constraints) if const: text += ", \n\t" + const @@ -2110,7 +2130,9 @@ class DDLCompiler(Compiled): return text - def create_table_constraints(self, table): + def create_table_constraints( + self, table, + _include_foreign_key_constraints=None): # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) @@ -2118,8 +2140,15 @@ class DDLCompiler(Compiled): if table.primary_key: constraints.append(table.primary_key) + all_fkcs = table.foreign_key_constraints + if _include_foreign_key_constraints is not None: + omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints) + else: + omit_fkcs = set() + constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key]) + if c is not table.primary_key and + c not in omit_fkcs]) return ", \n\t".join( p for p in @@ -2214,15 +2243,26 @@ class DDLCompiler(Compiled): self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): + constraint = drop.element + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + else: + formatted_name = None + + if formatted_name is None: + raise exc.CompileError( + "Can't emit DROP CONSTRAINT for constraint %r; " + "it has no name" % drop.element) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), - self.preparer.format_constraint(drop.element), + formatted_name, drop.cascade and " CASCADE" or "" ) def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2346,13 +2386,13 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): return "FLOAT" - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): return "REAL" - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -2363,7 +2403,7 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -2374,31 +2414,31 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): return "INTEGER" - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): return "SMALLINT" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return 'TIMESTAMP' - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME" - def visit_CLOB(self, type_): + def visit_CLOB(self, type_, **kw): return "CLOB" - def visit_NCLOB(self, type_): + def visit_NCLOB(self, type_, **kw): return "NCLOB" def _render_string_type(self, type_, name): @@ -2410,91 +2450,91 @@ class GenericTypeCompiler(TypeCompiler): text += ' COLLATE "%s"' % type_.collation return text - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._render_string_type(type_, "CHAR") - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._render_string_type(type_, "NCHAR") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._render_string_type(type_, "VARCHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._render_string_type(type_, "NVARCHAR") - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._render_string_type(type_, "TEXT") - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB" - def visit_BINARY(self, type_): + def visit_BINARY(self, type_, **kw): return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_): + def visit_BOOLEAN(self, type_, **kw): return "BOOLEAN" - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_): - return self.visit_BOOLEAN(type_) + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_): - return self.visit_TIME(type_) + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_): - return self.visit_DATETIME(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_): - return self.visit_DATE(type_) + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_BIGINT(type_) + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_): - return self.visit_SMALLINT(type_) + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_): - return self.visit_INTEGER(type_) + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_): - return self.visit_REAL(type_) + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_numeric(self, type_): - return self.visit_NUMERIC(type_) + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_): - return self.visit_VARCHAR(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_): - return self.visit_VARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_): - return self.visit_TEXT(type_) + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): - return self.visit_TEXT(type_) + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_): - return self.visit_VARCHAR(type_) + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_null(self, type_): + def visit_null(self, type_, **kw): raise exc.CompileError("Can't generate DDL for %r; " "did you forget to specify a " "type on this Column?" % type_) - def visit_type_decorator(self, type_): - return self.process(type_.type_engine(self.dialect)) + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_): - return type_.get_col_spec() + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) class IdentifierPreparer(object): diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 831d05be1..2961f579f 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -116,11 +116,12 @@ def _get_crud_params(compiler, stmt, **kw): def _create_bind_param( - compiler, col, value, process=True, required=False, name=None): + compiler, col, value, process=True, + required=False, name=None): if name is None: name = col.key - bindparam = elements.BindParameter(name, value, - type_=col.type, required=required) + bindparam = elements.BindParameter( + name, value, type_=col.type, required=required) bindparam._is_crud = True if process: bindparam = bindparam._compiler_dispatch(compiler) @@ -300,13 +301,45 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.returning.append(c) else: values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) + else: compiler.returning.append(c) +def _create_prefetch_bind_param(compiler, c, process=True, name=None): + param = _create_bind_param(compiler, c, None, process=process, name=name) + compiler.prefetch.append(c) + return param + + +class _multiparam_column(elements.ColumnElement): + def __init__(self, original, index): + self.key = "%s_%d" % (original.key, index + 1) + self.original = original + self.default = original.default + + def __eq__(self, other): + return isinstance(other, _multiparam_column) and \ + other.key == self.key and \ + other.original == self.original + + +def _process_multiparam_default_bind(compiler, c, index, kw): + + if not c.default: + raise exc.CompileError( + "INSERT value for column %s is explicitly rendered as a bound" + "parameter in the VALUES clause; " + "a Python-side value or SQL expression is required" % c) + elif c.default.is_clause_element: + return compiler.process(c.default.arg.self_group(), **kw) + else: + col = _multiparam_column(c, index) + return _create_prefetch_bind_param(compiler, col) + + def _append_param_insert_pk(compiler, stmt, c, values, kw): if ( (c.default is not None and @@ -318,11 +351,9 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): preexecute_autoincrement_sequences) ): values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) - def _append_param_insert_hasdefault( compiler, stmt, c, implicit_return_defaults, values, kw): @@ -350,9 +381,8 @@ def _append_param_insert_hasdefault( compiler.postfetch.append(c) else: values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) def _append_param_insert_select_hasdefault( @@ -369,9 +399,8 @@ def _append_param_insert_select_hasdefault( values.append((c, proc)) else: values.append( - (c, _create_bind_param(compiler, c, None, process=False)) + (c, _create_prefetch_bind_param(compiler, c, process=False)) ) - compiler.prefetch.append(c) def _append_param_update( @@ -390,9 +419,8 @@ def _append_param_update( compiler.postfetch.append(c) else: values.append( - (c, _create_bind_param(compiler, c, None)) + (c, _create_prefetch_bind_param(compiler, c)) ) - compiler.prefetch.append(c) elif c.server_onupdate is not None: if implicit_return_defaults and \ c in implicit_return_defaults: @@ -445,12 +473,9 @@ def _get_multitable_params( compiler.postfetch.append(c) else: values.append( - (c, _create_bind_param( - compiler, c, None, name=_col_bind_name(c) - ) - ) + (c, _create_prefetch_bind_param( + compiler, c, name=_col_bind_name(c))) ) - compiler.prefetch.append(c) elif c.server_onupdate is not None: compiler.postfetch.append(c) @@ -469,7 +494,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): ) if elements._is_literal(row[c.key]) else compiler.process( row[c.key].self_group(), **kw)) - if c.key in row else param + if c.key in row else + _process_multiparam_default_bind(compiler, c, i, kw) ) for (c, param) in values_0 ] diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 1f2c448ea..7a1c7fef6 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -12,7 +12,6 @@ to invoke them for a create/drop call. from .. import util from .elements import ClauseElement -from .visitors import traverse from .base import Executable, _generative, SchemaVisitor, _bind_or_error from ..util import topological from .. import event @@ -370,7 +369,7 @@ class DDL(DDLElement): :class:`.DDLEvents` - :mod:`sqlalchemy.event` + :ref:`event_toplevel` """ @@ -464,19 +463,28 @@ class CreateTable(_CreateDropBase): __visit_name__ = "create_table" - def __init__(self, element, on=None, bind=None): + def __init__( + self, element, on=None, bind=None, + include_foreign_key_constraints=None): """Create a :class:`.CreateTable` construct. :param element: a :class:`.Table` that's the subject of the CREATE :param on: See the description for 'on' in :class:`.DDL`. :param bind: See the description for 'bind' in :class:`.DDL`. + :param include_foreign_key_constraints: optional sequence of + :class:`.ForeignKeyConstraint` objects that will be included + inline within the CREATE construct; if omitted, all foreign key + constraints that do not specify use_alter=True are included. + + .. versionadded:: 1.0.0 """ super(CreateTable, self).__init__(element, on=on, bind=bind) self.columns = [CreateColumn(column) for column in element.columns ] + self.include_foreign_key_constraints = include_foreign_key_constraints class _DropView(_CreateDropBase): @@ -696,8 +704,10 @@ class SchemaGenerator(DDLBase): tables = self.tables else: tables = list(metadata.tables.values()) - collection = [t for t in sort_tables(tables) - if self._can_create_table(t)] + + collection = sort_tables_and_constraints( + [t for t in tables if self._can_create_table(t)]) + seq_coll = [s for s in metadata._sequences.values() if s.column is None and self._can_create_sequence(s)] @@ -709,15 +719,23 @@ class SchemaGenerator(DDLBase): for seq in seq_coll: self.traverse_single(seq, create_ok=True) - for table in collection: - self.traverse_single(table, create_ok=True) + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, create_ok=True, + include_foreign_key_constraints=fkcs) + else: + for fkc in fkcs: + self.traverse_single(fkc) metadata.dispatch.after_create(metadata, self.connection, tables=collection, checkfirst=self.checkfirst, _ddl_runner=self) - def visit_table(self, table, create_ok=False): + def visit_table( + self, table, create_ok=False, + include_foreign_key_constraints=None): if not create_ok and not self._can_create_table(table): return @@ -729,7 +747,15 @@ class SchemaGenerator(DDLBase): if column.default is not None: self.traverse_single(column.default) - self.connection.execute(CreateTable(table)) + if not self.dialect.supports_alter: + # e.g., don't omit any foreign key constraints + include_foreign_key_constraints = None + + self.connection.execute( + CreateTable( + table, + include_foreign_key_constraints=include_foreign_key_constraints + )) if hasattr(table, 'indexes'): for index in table.indexes: @@ -739,6 +765,11 @@ class SchemaGenerator(DDLBase): checkfirst=self.checkfirst, _ddl_runner=self) + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + self.connection.execute(AddConstraint(constraint)) + def visit_sequence(self, sequence, create_ok=False): if not create_ok and not self._can_create_sequence(sequence): return @@ -765,11 +796,33 @@ class SchemaDropper(DDLBase): else: tables = list(metadata.tables.values()) - collection = [ - t - for t in reversed(sort_tables(tables)) - if self._can_drop_table(t) - ] + try: + collection = reversed( + sort_tables_and_constraints( + [t for t in tables if self._can_drop_table(t)], + filter_fn= + lambda constraint: True if not self.dialect.supports_alter + else False if constraint.name is None + else None + ) + ) + except exc.CircularDependencyError as err2: + util.raise_from_cause( + exc.CircularDependencyError( + err2.args[0], + err2.cycles, err2.edges, + msg="Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s. Please ensure " + "that the ForeignKey and ForeignKeyConstraint objects " + "involved in the cycle have " + "names so that they can be dropped using DROP CONSTRAINT." + % ( + ", ".join(sorted([t.fullname for t in err2.cycles])) + ) + + ) + ) seq_coll = [ s @@ -781,8 +834,13 @@ class SchemaDropper(DDLBase): metadata, self.connection, tables=collection, checkfirst=self.checkfirst, _ddl_runner=self) - for table in collection: - self.traverse_single(table, drop_ok=True) + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, drop_ok=True) + else: + for fkc in fkcs: + self.traverse_single(fkc) for seq in seq_coll: self.traverse_single(seq, drop_ok=True) @@ -830,6 +888,11 @@ class SchemaDropper(DDLBase): checkfirst=self.checkfirst, _ddl_runner=self) + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + self.connection.execute(DropConstraint(constraint)) + def visit_sequence(self, sequence, drop_ok=False): if not drop_ok and not self._can_drop_sequence(sequence): return @@ -837,32 +900,159 @@ class SchemaDropper(DDLBase): def sort_tables(tables, skip_fn=None, extra_dependencies=None): - """sort a collection of Table objects in order of - their foreign-key dependency.""" + """sort a collection of :class:`.Table` objects based on dependency. - tables = list(tables) - tuples = [] - if extra_dependencies is not None: - tuples.extend(extra_dependencies) + This is a dependency-ordered sort which will emit :class:`.Table` + objects such that they will follow their dependent :class:`.Table` objects. + Tables are dependent on another based on the presence of + :class:`.ForeignKeyConstraint` objects as well as explicit dependencies + added by :meth:`.Table.add_is_dependent_on`. - def visit_foreign_key(fkey): - if fkey.use_alter: - return - elif skip_fn and skip_fn(fkey): - return - parent_table = fkey.column.table - if parent_table in tables: - child_table = fkey.parent.table - if parent_table is not child_table: - tuples.append((parent_table, child_table)) + .. warning:: + + The :func:`.sort_tables` function cannot by itself accommodate + automatic resolution of dependency cycles between tables, which + are usually caused by mutually dependent foreign key constraints. + To resolve these cycles, either the + :paramref:`.ForeignKeyConstraint.use_alter` parameter may be appled + to those constraints, or use the + :func:`.sql.sort_tables_and_constraints` function which will break + out foreign key constraints involved in cycles separately. + + :param tables: a sequence of :class:`.Table` objects. + :param skip_fn: optional callable which will be passed a + :class:`.ForeignKey` object; if it returns True, this + constraint will not be considered as a dependency. Note this is + **different** from the same parameter in + :func:`.sort_tables_and_constraints`, which is + instead passed the owning :class:`.ForeignKeyConstraint` object. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. seealso:: + + :func:`.sort_tables_and_constraints` + + :meth:`.MetaData.sorted_tables` - uses this function to sort + + + """ + + if skip_fn is not None: + def _skip_fn(fkc): + for fk in fkc.elements: + if skip_fn(fk): + return True + else: + return None + else: + _skip_fn = None + + return [ + t for (t, fkcs) in + sort_tables_and_constraints( + tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies) + if t is not None + ] + + +def sort_tables_and_constraints( + tables, filter_fn=None, extra_dependencies=None): + """sort a collection of :class:`.Table` / :class:`.ForeignKeyConstraint` + objects. + + This is a dependency-ordered sort which will emit tuples of + ``(Table, [ForeignKeyConstraint, ...])`` such that each + :class:`.Table` follows its dependent :class:`.Table` objects. + Remaining :class:`.ForeignKeyConstraint` objects that are separate due to + dependency rules not satisifed by the sort are emitted afterwards + as ``(None, [ForeignKeyConstraint ...])``. + + Tables are dependent on another based on the presence of + :class:`.ForeignKeyConstraint` objects, explicit dependencies + added by :meth:`.Table.add_is_dependent_on`, as well as dependencies + stated here using the :paramref:`~.sort_tables_and_constraints.skip_fn` + and/or :paramref:`~.sort_tables_and_constraints.extra_dependencies` + parameters. + + :param tables: a sequence of :class:`.Table` objects. + + :param filter_fn: optional callable which will be passed a + :class:`.ForeignKeyConstraint` object, and returns a value based on + whether this constraint should definitely be included or excluded as + an inline constraint, or neither. If it returns False, the constraint + will definitely be included as a dependency that cannot be subject + to ALTER; if True, it will **only** be included as an ALTER result at + the end. Returning None means the constraint is included in the + table-based result unless it is detected as part of a dependency cycle. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :func:`.sort_tables` + + + """ + + fixed_dependencies = set() + mutable_dependencies = set() + + if extra_dependencies is not None: + fixed_dependencies.update(extra_dependencies) + + remaining_fkcs = set() for table in tables: - traverse(table, - {'schema_visitor': True}, - {'foreign_key': visit_foreign_key}) + for fkc in table.foreign_key_constraints: + if fkc.use_alter is True: + remaining_fkcs.add(fkc) + continue + + if filter_fn: + filtered = filter_fn(fkc) + + if filtered is True: + remaining_fkcs.add(fkc) + continue - tuples.extend( - [parent, table] for parent in table._extra_dependencies + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.add((dependent_on, table)) + + fixed_dependencies.update( + (parent, table) for parent in table._extra_dependencies + ) + + try: + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), tables + ) + ) + except exc.CircularDependencyError as err: + for edge in err.edges: + if edge in mutable_dependencies: + table = edge[1] + can_remove = [ + fkc for fkc in table.foreign_key_constraints + if filter_fn is None or filter_fn(fkc) is not False] + remaining_fkcs.update(can_remove) + for fkc in can_remove: + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.discard((dependent_on, table)) + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), tables + ) ) - return list(topological.sort(tuples, tables)) + return [ + (table, table.foreign_key_constraints.difference(remaining_fkcs)) + for table in candidate_sort + ] + [(None, list(remaining_fkcs))] diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 4f53e2979..bb9e53aae 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -9,8 +9,8 @@ """ from .. import exc, util -from . import operators from . import type_api +from . import operators from .elements import BindParameter, True_, False_, BinaryExpression, \ Null, _const_expr, _clause_element_as_expr, \ ClauseList, ColumnElement, TextClause, UnaryExpression, \ @@ -18,294 +18,270 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \ from .selectable import SelectBase, Alias, Selectable, ScalarSelect -class _DefaultColumnComparator(operators.ColumnOperators): - """Defines comparison and math operations. - - See :class:`.ColumnOperators` and :class:`.Operators` for descriptions - of all operations. - - """ - - @util.memoized_property - def type(self): - return self.expr.type - - def operate(self, op, *other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, self.expr, op, *(other + o[1:]), **kwargs) - - def reverse_operate(self, op, other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, self.expr, op, other, - reverse=True, *o[1:], **kwargs) - - def _adapt_expression(self, op, other_comparator): - """evaluate the return type of <self> <op> <othertype>, - and apply any adaptations to the given operator. - - This method determines the type of a resulting binary expression - given two source types and an operator. For example, two - :class:`.Column` objects, both of the type :class:`.Integer`, will - produce a :class:`.BinaryExpression` that also has the type - :class:`.Integer` when compared via the addition (``+``) operator. - However, using the addition operator with an :class:`.Integer` - and a :class:`.Date` object will produce a :class:`.Date`, assuming - "days delta" behavior by the database (in reality, most databases - other than Postgresql don't accept this particular operation). - - The method returns a tuple of the form <operator>, <type>. - The resulting operator and type will be those applied to the - resulting :class:`.BinaryExpression` as the final operator and the - right-hand side of the expression. - - Note that only a subset of operators make usage of - :meth:`._adapt_expression`, - including math operators and user-defined operators, but not - boolean comparison or special SQL keywords like MATCH or BETWEEN. - - """ - return op, other_comparator.type - - def _boolean_compare(self, expr, op, obj, negate=None, reverse=False, - _python_is_types=(util.NoneType, bool), - **kwargs): - - if isinstance(obj, _python_is_types + (Null, True_, False_)): - - # allow x ==/!= True/False to be treated as a literal. - # this comes out to "== / != true/false" or "1/0" if those - # constants aren't supported and works on all platforms - if op in (operators.eq, operators.ne) and \ - isinstance(obj, (bool, True_, False_)): - return BinaryExpression(expr, - _literal_as_text(obj), - op, - type_=type_api.BOOLEANTYPE, - negate=negate, modifiers=kwargs) - else: - # all other None/True/False uses IS, IS NOT - if op in (operators.eq, operators.is_): - return BinaryExpression(expr, _const_expr(obj), - operators.is_, - negate=operators.isnot) - elif op in (operators.ne, operators.isnot): - return BinaryExpression(expr, _const_expr(obj), - operators.isnot, - negate=operators.is_) - else: - raise exc.ArgumentError( - "Only '=', '!=', 'is_()', 'isnot()' operators can " - "be used with None/True/False") - else: - obj = self._check_literal(expr, op, obj) +def _boolean_compare(expr, op, obj, negate=None, reverse=False, + _python_is_types=(util.NoneType, bool), + result_type = None, + **kwargs): - if reverse: - return BinaryExpression(obj, - expr, - op, - type_=type_api.BOOLEANTYPE, - negate=negate, modifiers=kwargs) - else: + if result_type is None: + result_type = type_api.BOOLEANTYPE + + if isinstance(obj, _python_is_types + (Null, True_, False_)): + + # allow x ==/!= True/False to be treated as a literal. + # this comes out to "== / != true/false" or "1/0" if those + # constants aren't supported and works on all platforms + if op in (operators.eq, operators.ne) and \ + isinstance(obj, (bool, True_, False_)): return BinaryExpression(expr, - obj, + _literal_as_text(obj), op, - type_=type_api.BOOLEANTYPE, + type_=result_type, negate=negate, modifiers=kwargs) - - def _binary_operate(self, expr, op, obj, reverse=False, result_type=None, - **kw): - obj = self._check_literal(expr, op, obj) - - if reverse: - left, right = obj, expr - else: - left, right = expr, obj - - if result_type is None: - op, result_type = left.comparator._adapt_expression( - op, right.comparator) - - return BinaryExpression(left, right, op, type_=result_type) - - def _conjunction_operate(self, expr, op, other, **kw): - if op is operators.and_: - return and_(expr, other) - elif op is operators.or_: - return or_(expr, other) else: - raise NotImplementedError() - - def _scalar(self, expr, op, fn, **kw): - return fn(expr) - - def _in_impl(self, expr, op, seq_or_selectable, negate_op, **kw): - seq_or_selectable = _clause_element_as_expr(seq_or_selectable) - - if isinstance(seq_or_selectable, ScalarSelect): - return self._boolean_compare(expr, op, seq_or_selectable, - negate=negate_op) - elif isinstance(seq_or_selectable, SelectBase): - - # TODO: if we ever want to support (x, y, z) IN (select x, - # y, z from table), we would need a multi-column version of - # as_scalar() to produce a multi- column selectable that - # does not export itself as a FROM clause - - return self._boolean_compare( - expr, op, seq_or_selectable.as_scalar(), - negate=negate_op, **kw) - elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return self._boolean_compare(expr, op, seq_or_selectable, - negate=negate_op, **kw) - elif isinstance(seq_or_selectable, ClauseElement): - raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions ' - 'or a selectable: %r' % seq_or_selectable) - - # Handle non selectable arguments as sequences - args = [] - for o in seq_or_selectable: - if not _is_literal(o): - if not isinstance(o, operators.ColumnOperators): - raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions ' - 'or a selectable: %r' % o) - elif o is None: - o = Null() + # all other None/True/False uses IS, IS NOT + if op in (operators.eq, operators.is_): + return BinaryExpression(expr, _const_expr(obj), + operators.is_, + negate=operators.isnot) + elif op in (operators.ne, operators.isnot): + return BinaryExpression(expr, _const_expr(obj), + operators.isnot, + negate=operators.is_) else: - o = expr._bind_param(op, o) - args.append(o) - if len(args) == 0: - - # Special case handling for empty IN's, behave like - # comparison against zero row selectable. We use != to - # build the contradiction as it handles NULL values - # appropriately, i.e. "not (x IN ())" should not return NULL - # values for x. - - util.warn('The IN-predicate on "%s" was invoked with an ' - 'empty sequence. This results in a ' - 'contradiction, which nonetheless can be ' - 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.' % expr) - if op is operators.in_op: - return expr != expr - else: - return expr == expr - - return self._boolean_compare(expr, op, - ClauseList(*args).self_group(against=op), - negate=negate_op) - - def _unsupported_impl(self, expr, op, *arg, **kw): - raise NotImplementedError("Operator '%s' is not supported on " - "this expression" % op.__name__) - - def _inv_impl(self, expr, op, **kw): - """See :meth:`.ColumnOperators.__inv__`.""" - if hasattr(expr, 'negation_clause'): - return expr.negation_clause + raise exc.ArgumentError( + "Only '=', '!=', 'is_()', 'isnot()' operators can " + "be used with None/True/False") + else: + obj = _check_literal(expr, op, obj) + + if reverse: + return BinaryExpression(obj, + expr, + op, + type_=result_type, + negate=negate, modifiers=kwargs) + else: + return BinaryExpression(expr, + obj, + op, + type_=result_type, + negate=negate, modifiers=kwargs) + + +def _binary_operate(expr, op, obj, reverse=False, result_type=None, + **kw): + obj = _check_literal(expr, op, obj) + + if reverse: + left, right = obj, expr + else: + left, right = expr, obj + + if result_type is None: + op, result_type = left.comparator._adapt_expression( + op, right.comparator) + + return BinaryExpression( + left, right, op, type_=result_type, modifiers=kw) + + +def _conjunction_operate(expr, op, other, **kw): + if op is operators.and_: + return and_(expr, other) + elif op is operators.or_: + return or_(expr, other) + else: + raise NotImplementedError() + + +def _scalar(expr, op, fn, **kw): + return fn(expr) + + +def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): + seq_or_selectable = _clause_element_as_expr(seq_or_selectable) + + if isinstance(seq_or_selectable, ScalarSelect): + return _boolean_compare(expr, op, seq_or_selectable, + negate=negate_op) + elif isinstance(seq_or_selectable, SelectBase): + + # TODO: if we ever want to support (x, y, z) IN (select x, + # y, z from table), we would need a multi-column version of + # as_scalar() to produce a multi- column selectable that + # does not export itself as a FROM clause + + return _boolean_compare( + expr, op, seq_or_selectable.as_scalar(), + negate=negate_op, **kw) + elif isinstance(seq_or_selectable, (Selectable, TextClause)): + return _boolean_compare(expr, op, seq_or_selectable, + negate=negate_op, **kw) + elif isinstance(seq_or_selectable, ClauseElement): + raise exc.InvalidRequestError( + 'in_() accepts' + ' either a list of expressions ' + 'or a selectable: %r' % seq_or_selectable) + + # Handle non selectable arguments as sequences + args = [] + for o in seq_or_selectable: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + raise exc.InvalidRequestError( + 'in_() accepts' + ' either a list of expressions ' + 'or a selectable: %r' % o) + elif o is None: + o = Null() else: - return expr._negate() - - def _neg_impl(self, expr, op, **kw): - """See :meth:`.ColumnOperators.__neg__`.""" - return UnaryExpression(expr, operator=operators.neg) - - def _match_impl(self, expr, op, other, **kw): - """See :meth:`.ColumnOperators.match`.""" - return self._boolean_compare( - expr, operators.match_op, - self._check_literal( - expr, operators.match_op, other), - **kw) - - def _distinct_impl(self, expr, op, **kw): - """See :meth:`.ColumnOperators.distinct`.""" - return UnaryExpression(expr, operator=operators.distinct_op, - type_=expr.type) - - def _between_impl(self, expr, op, cleft, cright, **kw): - """See :meth:`.ColumnOperators.between`.""" - return BinaryExpression( - expr, - ClauseList( - self._check_literal(expr, operators.and_, cleft), - self._check_literal(expr, operators.and_, cright), - operator=operators.and_, - group=False, group_contents=False), - op, - negate=operators.notbetween_op - if op is operators.between_op - else operators.between_op, - modifiers=kw) - - def _collate_impl(self, expr, op, other, **kw): - return collate(expr, other) - - # a mapping of operators with the method they use, along with - # their negated operator for comparison operators - operators = { - "and_": (_conjunction_operate,), - "or_": (_conjunction_operate,), - "inv": (_inv_impl,), - "add": (_binary_operate,), - "mul": (_binary_operate,), - "sub": (_binary_operate,), - "div": (_binary_operate,), - "mod": (_binary_operate,), - "truediv": (_binary_operate,), - "custom_op": (_binary_operate,), - "concat_op": (_binary_operate,), - "lt": (_boolean_compare, operators.ge), - "le": (_boolean_compare, operators.gt), - "ne": (_boolean_compare, operators.eq), - "gt": (_boolean_compare, operators.le), - "ge": (_boolean_compare, operators.lt), - "eq": (_boolean_compare, operators.ne), - "like_op": (_boolean_compare, operators.notlike_op), - "ilike_op": (_boolean_compare, operators.notilike_op), - "notlike_op": (_boolean_compare, operators.like_op), - "notilike_op": (_boolean_compare, operators.ilike_op), - "contains_op": (_boolean_compare, operators.notcontains_op), - "startswith_op": (_boolean_compare, operators.notstartswith_op), - "endswith_op": (_boolean_compare, operators.notendswith_op), - "desc_op": (_scalar, UnaryExpression._create_desc), - "asc_op": (_scalar, UnaryExpression._create_asc), - "nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst), - "nullslast_op": (_scalar, UnaryExpression._create_nullslast), - "in_op": (_in_impl, operators.notin_op), - "notin_op": (_in_impl, operators.in_op), - "is_": (_boolean_compare, operators.is_), - "isnot": (_boolean_compare, operators.isnot), - "collate": (_collate_impl,), - "match_op": (_match_impl,), - "distinct_op": (_distinct_impl,), - "between_op": (_between_impl, ), - "notbetween_op": (_between_impl, ), - "neg": (_neg_impl,), - "getitem": (_unsupported_impl,), - "lshift": (_unsupported_impl,), - "rshift": (_unsupported_impl,), - } - - def _check_literal(self, expr, operator, other): - if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and \ - other.type._isnull: - other = other._clone() - other.type = expr.type - return other - elif hasattr(other, '__clause_element__'): - other = other.__clause_element__() - elif isinstance(other, type_api.TypeEngine.Comparator): - other = other.expr - - if isinstance(other, (SelectBase, Alias)): - return other.as_scalar() - elif not isinstance(other, (ColumnElement, TextClause)): - return expr._bind_param(operator, other) + o = expr._bind_param(op, o) + args.append(o) + if len(args) == 0: + + # Special case handling for empty IN's, behave like + # comparison against zero row selectable. We use != to + # build the contradiction as it handles NULL values + # appropriately, i.e. "not (x IN ())" should not return NULL + # values for x. + + util.warn('The IN-predicate on "%s" was invoked with an ' + 'empty sequence. This results in a ' + 'contradiction, which nonetheless can be ' + 'expensive to evaluate. Consider alternative ' + 'strategies for improved performance.' % expr) + if op is operators.in_op: + return expr != expr else: - return other + return expr == expr + + return _boolean_compare(expr, op, + ClauseList(*args).self_group(against=op), + negate=negate_op) + + +def _unsupported_impl(expr, op, *arg, **kw): + raise NotImplementedError("Operator '%s' is not supported on " + "this expression" % op.__name__) + + +def _inv_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.__inv__`.""" + if hasattr(expr, 'negation_clause'): + return expr.negation_clause + else: + return expr._negate() + + +def _neg_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.__neg__`.""" + return UnaryExpression(expr, operator=operators.neg) + + +def _match_impl(expr, op, other, **kw): + """See :meth:`.ColumnOperators.match`.""" + + return _boolean_compare( + expr, operators.match_op, + _check_literal( + expr, operators.match_op, other), + result_type=type_api.MATCHTYPE, + negate=operators.notmatch_op + if op is operators.match_op else operators.match_op, + **kw + ) + + +def _distinct_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.distinct`.""" + return UnaryExpression(expr, operator=operators.distinct_op, + type_=expr.type) + + +def _between_impl(expr, op, cleft, cright, **kw): + """See :meth:`.ColumnOperators.between`.""" + return BinaryExpression( + expr, + ClauseList( + _check_literal(expr, operators.and_, cleft), + _check_literal(expr, operators.and_, cright), + operator=operators.and_, + group=False, group_contents=False), + op, + negate=operators.notbetween_op + if op is operators.between_op + else operators.between_op, + modifiers=kw) + + +def _collate_impl(expr, op, other, **kw): + return collate(expr, other) + +# a mapping of operators with the method they use, along with +# their negated operator for comparison operators +operator_lookup = { + "and_": (_conjunction_operate,), + "or_": (_conjunction_operate,), + "inv": (_inv_impl,), + "add": (_binary_operate,), + "mul": (_binary_operate,), + "sub": (_binary_operate,), + "div": (_binary_operate,), + "mod": (_binary_operate,), + "truediv": (_binary_operate,), + "custom_op": (_binary_operate,), + "concat_op": (_binary_operate,), + "lt": (_boolean_compare, operators.ge), + "le": (_boolean_compare, operators.gt), + "ne": (_boolean_compare, operators.eq), + "gt": (_boolean_compare, operators.le), + "ge": (_boolean_compare, operators.lt), + "eq": (_boolean_compare, operators.ne), + "like_op": (_boolean_compare, operators.notlike_op), + "ilike_op": (_boolean_compare, operators.notilike_op), + "notlike_op": (_boolean_compare, operators.like_op), + "notilike_op": (_boolean_compare, operators.ilike_op), + "contains_op": (_boolean_compare, operators.notcontains_op), + "startswith_op": (_boolean_compare, operators.notstartswith_op), + "endswith_op": (_boolean_compare, operators.notendswith_op), + "desc_op": (_scalar, UnaryExpression._create_desc), + "asc_op": (_scalar, UnaryExpression._create_asc), + "nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst), + "nullslast_op": (_scalar, UnaryExpression._create_nullslast), + "in_op": (_in_impl, operators.notin_op), + "notin_op": (_in_impl, operators.in_op), + "is_": (_boolean_compare, operators.is_), + "isnot": (_boolean_compare, operators.isnot), + "collate": (_collate_impl,), + "match_op": (_match_impl,), + "notmatch_op": (_match_impl,), + "distinct_op": (_distinct_impl,), + "between_op": (_between_impl, ), + "notbetween_op": (_between_impl, ), + "neg": (_neg_impl,), + "getitem": (_unsupported_impl,), + "lshift": (_unsupported_impl,), + "rshift": (_unsupported_impl,), +} + + +def _check_literal(expr, operator, other): + if isinstance(other, (ColumnElement, TextClause)): + if isinstance(other, BindParameter) and \ + other.type._isnull: + other = other._clone() + other.type = expr.type + return other + elif hasattr(other, '__clause_element__'): + other = other.__clause_element__() + elif isinstance(other, type_api.TypeEngine.Comparator): + other = other.expr + + if isinstance(other, (SelectBase, Alias)): + return other.as_scalar() + elif not isinstance(other, (ColumnElement, TextClause)): + return expr._bind_param(operator, other) + else: + return other + diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 9f2ce7ce3..38b3b8c44 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -277,6 +277,12 @@ class ValuesBase(UpdateBase): deals with an arbitrary number of rows, so the :attr:`.ResultProxy.inserted_primary_key` accessor does not apply. + .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports + columns with Python side default values and callables in the + same way as that of an "executemany" style of invocation; the + callable is invoked for each row. See :ref:`bug_3288` + for other details. + .. seealso:: :ref:`inserts_and_updates` - SQL Expression @@ -387,7 +393,7 @@ class ValuesBase(UpdateBase): :func:`.mapper`. :param cols: optional list of column key names or :class:`.Column` - objects. If omitted, all column expressions evaulated on the server + objects. If omitted, all column expressions evaluated on the server are added to the returning list. .. versionadded:: 0.9.0 diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 734f78632..fa4f14fb9 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1279,7 +1279,7 @@ class TextClause(Executable, ClauseElement): E.g.:: - fom sqlalchemy import text + from sqlalchemy import text t = text("SELECT * FROM users") result = connection.execute(t) @@ -2146,7 +2146,7 @@ class Case(ColumnElement): result of the ``CASE`` construct if all expressions within :paramref:`.case.whens` evaluate to false. When omitted, most databases will produce a result of NULL if none of the "when" - expressions evaulate to true. + expressions evaluate to true. """ @@ -2763,7 +2763,7 @@ class BinaryExpression(ColumnElement): self.right, self.negate, negate=self.operator, - type_=type_api.BOOLEANTYPE, + type_=self.type, modifiers=self.modifiers) else: return super(BinaryExpression, self)._negate() @@ -3387,7 +3387,7 @@ class ReleaseSavepointClause(_IdentifiedClause): __visit_name__ = 'release_savepoint' -class quoted_name(util.text_type): +class quoted_name(util.MemoizedSlots, util.text_type): """Represent a SQL identifier combined with quoting preferences. :class:`.quoted_name` is a Python unicode/str subclass which @@ -3431,6 +3431,8 @@ class quoted_name(util.text_type): """ + __slots__ = 'quote', 'lower', 'upper' + def __new__(cls, value, quote): if value is None: return None @@ -3450,15 +3452,13 @@ class quoted_name(util.text_type): def __reduce__(self): return quoted_name, (util.text_type(self), self.quote) - @util.memoized_instancemethod - def lower(self): + def _memoized_method_lower(self): if self.quote: return self else: return util.text_type(self).lower() - @util.memoized_instancemethod - def upper(self): + def _memoized_method_upper(self): if self.quote: return self else: @@ -3475,6 +3475,8 @@ class _truncated_label(quoted_name): """A unicode subclass used to identify symbolic " "names that may require truncation.""" + __slots__ = () + def __new__(cls, value, quote=None): quote = getattr(value, "quote", quote) # return super(_truncated_label, cls).__new__(cls, value, quote, True) @@ -3531,6 +3533,7 @@ class conv(_truncated_label): :ref:`constraint_naming_conventions` """ + __slots__ = () class _defer_name(_truncated_label): @@ -3538,6 +3541,8 @@ class _defer_name(_truncated_label): generation. """ + __slots__ = () + def __new__(cls, value): if value is None: return _NONE_NAME @@ -3552,6 +3557,7 @@ class _defer_name(_truncated_label): class _defer_none_name(_defer_name): """indicate a 'deferred' name that was ultimately the value None.""" + __slots__ = () _NONE_NAME = _defer_none_name("_unnamed_") @@ -3566,6 +3572,8 @@ class _anonymous_label(_truncated_label): """A unicode subclass used to identify anonymously generated names.""" + __slots__ = () + def __add__(self, other): return _anonymous_label( quoted_name( @@ -3732,7 +3740,8 @@ def _literal_as_text(element, warn=False): return _const_expr(element) else: raise exc.ArgumentError( - "SQL expression object or string expected." + "SQL expression object or string expected, got object of type %r " + "instead" % type(element) ) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2ffc5468c..2218bd660 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -47,7 +47,7 @@ from .base import ColumnCollection, Generative, Executable, \ from .selectable import Alias, Join, Select, Selectable, TableClause, \ CompoundSelect, CTE, FromClause, FromGrouping, SelectBase, \ alias, GenerativeSelect, \ - subquery, HasPrefixes, Exists, ScalarSelect, TextAsFrom + subquery, HasPrefixes, HasSuffixes, Exists, ScalarSelect, TextAsFrom from .dml import Insert, Update, Delete, UpdateBase, ValuesBase diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 945356328..f71cba913 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -38,6 +38,7 @@ class Operators(object): :class:`.ColumnOperators`. """ + __slots__ = () def __and__(self, other): """Implement the ``&`` operator. @@ -137,7 +138,7 @@ class Operators(object): .. versionadded:: 0.8 - added the 'precedence' argument. :param is_comparison: if True, the operator will be considered as a - "comparison" operator, that is which evaulates to a boolean + "comparison" operator, that is which evaluates to a boolean true/false value, like ``==``, ``>``, etc. This flag should be set so that ORM relationships can establish that the operator is a comparison operator when used in a custom join condition. @@ -267,6 +268,8 @@ class ColumnOperators(Operators): """ + __slots__ = () + timetuple = None """Hack, allows datetime objects to be compared on the LHS.""" @@ -767,6 +770,10 @@ def match_op(a, b, **kw): return a.match(b, **kw) +def notmatch_op(a, b, **kw): + return a.notmatch(b, **kw) + + def comma_op(a, b): raise NotImplementedError() @@ -834,6 +841,7 @@ _PRECEDENCE = { concat_op: 6, match_op: 6, + notmatch_op: 6, ilike_op: 6, notilike_op: 6, diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index b90f7fc53..65a1da877 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -516,6 +516,19 @@ class Table(DialectKWArgs, SchemaItem, TableClause): """ return sorted(self.constraints, key=lambda c: c._creation_order) + @property + def foreign_key_constraints(self): + """:class:`.ForeignKeyConstraint` objects referred to by this + :class:`.Table`. + + This list is produced from the collection of :class:`.ForeignKey` + objects currently associated. + + .. versionadded:: 1.0.0 + + """ + return set(fkc.constraint for fkc in self.foreign_keys) + def _init_existing(self, *args, **kwargs): autoload_with = kwargs.pop('autoload_with', None) autoload = kwargs.pop('autoload', autoload_with is not None) @@ -1463,7 +1476,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): :param use_alter: passed to the underlying :class:`.ForeignKeyConstraint` to indicate the constraint should be generated/dropped externally from the CREATE TABLE/ DROP TABLE - statement. See that classes' constructor for details. + statement. See :paramref:`.ForeignKeyConstraint.use_alter` + for further description. + + .. seealso:: + + :paramref:`.ForeignKeyConstraint.use_alter` + + :ref:`use_alter` :param match: Optional string. If set, emit MATCH <value> when issuing DDL for this constraint. Typical values include SIMPLE, PARTIAL @@ -2345,6 +2365,14 @@ def _to_schema_column_or_string(element): class ColumnCollectionMixin(object): + columns = None + """A :class:`.ColumnCollection` of :class:`.Column` objects. + + This collection represents the columns which are referred to by + this object. + + """ + def __init__(self, *columns): self.columns = ColumnCollection() self._pending_colargs = [_to_schema_column_or_string(c) @@ -2545,11 +2573,23 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): part of the CREATE TABLE definition. Instead, generate it via an ALTER TABLE statement issued after the full collection of tables have been created, and drop it via an ALTER TABLE statement before - the full collection of tables are dropped. This is shorthand for the - usage of :class:`.AddConstraint` and :class:`.DropConstraint` - applied as "after-create" and "before-drop" events on the MetaData - object. This is normally used to generate/drop constraints on - objects that are mutually dependent on each other. + the full collection of tables are dropped. + + The use of :paramref:`.ForeignKeyConstraint.use_alter` is + particularly geared towards the case where two or more tables + are established within a mutually-dependent foreign key constraint + relationship; however, the :meth:`.MetaData.create_all` and + :meth:`.MetaData.drop_all` methods will perform this resolution + automatically, so the flag is normally not needed. + + .. versionchanged:: 1.0.0 Automatic resolution of foreign key + cycles has been added, removing the need to use the + :paramref:`.ForeignKeyConstraint.use_alter` in typical use + cases. + + .. seealso:: + + :ref:`use_alter` :param match: Optional string. If set, emit MATCH <value> when issuing DDL for this constraint. Typical values include SIMPLE, PARTIAL @@ -2575,8 +2615,6 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): self.onupdate = onupdate self.ondelete = ondelete self.link_to_name = link_to_name - if self.name is None and use_alter: - raise exc.ArgumentError("Alterable Constraint requires a name") self.use_alter = use_alter self.match = match @@ -2624,6 +2662,20 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): else: return None + @property + def referred_table(self): + """The :class:`.Table` object to which this + :class:`.ForeignKeyConstraint` references. + + This is a dynamically calculated attribute which may not be available + if the constraint and/or parent table is not yet associated with + a metadata collection that contains the referred table. + + .. versionadded:: 1.0.0 + + """ + return self.elements[0].column.table + def _validate_dest_table(self, table): table_keys = set([elem._table_key() for elem in self.elements]) @@ -2681,15 +2733,6 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): self._validate_dest_table(table) - if self.use_alter: - def supports_alter(ddl, event, schema_item, bind, **kw): - return table in set(kw['tables']) and \ - bind.dialect.supports_alter - - event.listen(table.metadata, "after_create", - ddl.AddConstraint(self, on=supports_alter)) - event.listen(table.metadata, "before_drop", - ddl.DropConstraint(self, on=supports_alter)) def copy(self, schema=None, target_table=None, **kw): fkc = ForeignKeyConstraint( @@ -3333,12 +3376,30 @@ class MetaData(SchemaItem): order in which they can be created. To get the order in which the tables would be dropped, use the ``reversed()`` Python built-in. + .. warning:: + + The :attr:`.sorted_tables` accessor cannot by itself accommodate + automatic resolution of dependency cycles between tables, which + are usually caused by mutually dependent foreign key constraints. + To resolve these cycles, either the + :paramref:`.ForeignKeyConstraint.use_alter` parameter may be appled + to those constraints, or use the + :func:`.schema.sort_tables_and_constraints` function which will break + out foreign key constraints involved in cycles separately. + .. seealso:: + :func:`.schema.sort_tables` + + :func:`.schema.sort_tables_and_constraints` + :attr:`.MetaData.tables` :meth:`.Inspector.get_table_names` + :meth:`.Inspector.get_sorted_table_and_fkc_names` + + """ return ddl.sort_tables(self.tables.values()) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8198a6733..87029ec2b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -171,6 +171,79 @@ class Selectable(ClauseElement): return self +class HasPrefixes(object): + _prefixes = () + + @_generative + def prefix_with(self, *expr, **kw): + """Add one or more expressions following the statement keyword, i.e. + SELECT, INSERT, UPDATE, or DELETE. Generative. + + This is used to support backend-specific prefix keywords such as those + provided by MySQL. + + E.g.:: + + stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") + + Multiple prefixes can be specified by multiple calls + to :meth:`.prefix_with`. + + :param \*expr: textual or :class:`.ClauseElement` construct which + will be rendered following the INSERT, UPDATE, or DELETE + keyword. + :param \**kw: A single keyword 'dialect' is accepted. This is an + optional string dialect name which will + limit rendering of this prefix to only that dialect. + + """ + dialect = kw.pop('dialect', None) + if kw: + raise exc.ArgumentError("Unsupported argument(s): %s" % + ",".join(kw)) + self._setup_prefixes(expr, dialect) + + def _setup_prefixes(self, prefixes, dialect=None): + self._prefixes = self._prefixes + tuple( + [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) + + +class HasSuffixes(object): + _suffixes = () + + @_generative + def suffix_with(self, *expr, **kw): + """Add one or more expressions following the statement as a whole. + + This is used to support backend-specific suffix keywords on + certain constructs. + + E.g.:: + + stmt = select([col1, col2]).cte().suffix_with( + "cycle empno set y_cycle to 1 default 0", dialect="oracle") + + Multiple prefixes can be specified by multiple calls + to :meth:`.suffix_with`. + + :param \*expr: textual or :class:`.ClauseElement` construct which + will be rendered following the target clause. + :param \**kw: A single keyword 'dialect' is accepted. This is an + optional string dialect name which will + limit rendering of this suffix to only that dialect. + + """ + dialect = kw.pop('dialect', None) + if kw: + raise exc.ArgumentError("Unsupported argument(s): %s" % + ",".join(kw)) + self._setup_suffixes(expr, dialect) + + def _setup_suffixes(self, suffixes, dialect=None): + self._suffixes = self._suffixes + tuple( + [(_literal_as_text(p, warn=False), dialect) for p in suffixes]) + + class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -1088,7 +1161,7 @@ class Alias(FromClause): return self.element.bind -class CTE(Alias): +class CTE(Generative, HasSuffixes, Alias): """Represent a Common Table Expression. The :class:`.CTE` object is obtained using the @@ -1104,10 +1177,13 @@ class CTE(Alias): name=None, recursive=False, _cte_alias=None, - _restates=frozenset()): + _restates=frozenset(), + _suffixes=None): self.recursive = recursive self._cte_alias = _cte_alias self._restates = _restates + if _suffixes: + self._suffixes = _suffixes super(CTE, self).__init__(selectable, name=name) def alias(self, name=None, flat=False): @@ -1116,6 +1192,7 @@ class CTE(Alias): name=name, recursive=self.recursive, _cte_alias=self, + _suffixes=self._suffixes ) def union(self, other): @@ -1123,7 +1200,8 @@ class CTE(Alias): self.original.union(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]) + _restates=self._restates.union([self]), + _suffixes=self._suffixes ) def union_all(self, other): @@ -1131,7 +1209,8 @@ class CTE(Alias): self.original.union_all(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]) + _restates=self._restates.union([self]), + _suffixes=self._suffixes ) @@ -2118,44 +2197,7 @@ class CompoundSelect(GenerativeSelect): bind = property(bind, _set_bind) -class HasPrefixes(object): - _prefixes = () - - @_generative - def prefix_with(self, *expr, **kw): - """Add one or more expressions following the statement keyword, i.e. - SELECT, INSERT, UPDATE, or DELETE. Generative. - - This is used to support backend-specific prefix keywords such as those - provided by MySQL. - - E.g.:: - - stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") - - Multiple prefixes can be specified by multiple calls - to :meth:`.prefix_with`. - - :param \*expr: textual or :class:`.ClauseElement` construct which - will be rendered following the INSERT, UPDATE, or DELETE - keyword. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will - limit rendering of this prefix to only that dialect. - - """ - dialect = kw.pop('dialect', None) - if kw: - raise exc.ArgumentError("Unsupported argument(s): %s" % - ",".join(kw)) - self._setup_prefixes(expr, dialect) - - def _setup_prefixes(self, prefixes, dialect=None): - self._prefixes = self._prefixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) - - -class Select(HasPrefixes, GenerativeSelect): +class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """Represents a ``SELECT`` statement. """ @@ -2163,6 +2205,7 @@ class Select(HasPrefixes, GenerativeSelect): __visit_name__ = 'select' _prefixes = () + _suffixes = () _hints = util.immutabledict() _statement_hints = () _distinct = False @@ -2180,6 +2223,7 @@ class Select(HasPrefixes, GenerativeSelect): having=None, correlate=True, prefixes=None, + suffixes=None, **kwargs): """Construct a new :class:`.Select`. @@ -2425,6 +2469,9 @@ class Select(HasPrefixes, GenerativeSelect): if prefixes: self._setup_prefixes(prefixes) + if suffixes: + self._setup_suffixes(suffixes) + GenerativeSelect.__init__(self, **kwargs) @property diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 7bf2f337c..bd1914da3 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -14,7 +14,6 @@ import codecs from .type_api import TypeEngine, TypeDecorator, to_instance from .elements import quoted_name, type_coerce, _defer_name -from .default_comparator import _DefaultColumnComparator from .. import exc, util, processors from .base import _bind_or_error, SchemaEventTarget from . import operators @@ -894,7 +893,7 @@ class LargeBinary(_Binary): :param length: optional, a length for the column for use in DDL statements, for those BLOB types that accept a length - (i.e. MySQL). It does *not* produce a small BINARY/VARBINARY + (i.e. MySQL). It does *not* produce a *lengthed* BINARY/VARBINARY type - use the BINARY/VARBINARY types specifically for those. May be safely omitted if no ``CREATE TABLE`` will be issued. Certain databases may require a @@ -1147,6 +1146,7 @@ class Enum(String, SchemaType): def __repr__(self): return util.generic_repr(self, + additional_kw=[('native_enum', True)], to_inspect=[Enum, SchemaType], ) @@ -1654,10 +1654,26 @@ class NullType(TypeEngine): comparator_factory = Comparator +class MatchType(Boolean): + """Refers to the return type of the MATCH operator. + + As the :meth:`.ColumnOperators.match` is probably the most open-ended + operator in generic SQLAlchemy Core, we can't assume the return type + at SQL evaluation time, as MySQL returns a floating point, not a boolean, + and other backends might do something different. So this type + acts as a placeholder, currently subclassing :class:`.Boolean`. + The type allows dialects to inject result-processing functionality + if needed, and on MySQL will return floating-point values. + + .. versionadded:: 1.0.0 + + """ + NULLTYPE = NullType() BOOLEANTYPE = Boolean() STRINGTYPE = String() INTEGERTYPE = Integer() +MATCHTYPE = MatchType() _type_map = { int: Integer(), @@ -1685,21 +1701,7 @@ type_api.BOOLEANTYPE = BOOLEANTYPE type_api.STRINGTYPE = STRINGTYPE type_api.INTEGERTYPE = INTEGERTYPE type_api.NULLTYPE = NULLTYPE +type_api.MATCHTYPE = MATCHTYPE type_api._type_map = _type_map -# this one, there's all kinds of ways to play it, but at the EOD -# there's just a giant dependency cycle between the typing system and -# the expression element system, as you might expect. We can use -# importlaters or whatnot, but the typing system just necessarily has -# to have some kind of connection like this. right now we're injecting the -# _DefaultColumnComparator implementation into the TypeEngine.Comparator -# interface. Alternatively TypeEngine.Comparator could have an "impl" -# injected, though just injecting the base is simpler, error free, and more -# performant. - - -class Comparator(_DefaultColumnComparator): - BOOLEANTYPE = BOOLEANTYPE - -TypeEngine.Comparator.__bases__ = ( - Comparator, ) + TypeEngine.Comparator.__bases__ +TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 77c6e1b1e..19398ae96 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -12,13 +12,14 @@ from .. import exc, util from . import operators -from .visitors import Visitable +from .visitors import Visitable, VisitableType # these are back-assigned by sqltypes. BOOLEANTYPE = None INTEGERTYPE = None NULLTYPE = None STRINGTYPE = None +MATCHTYPE = None class TypeEngine(Visitable): @@ -45,9 +46,51 @@ class TypeEngine(Visitable): """ + __slots__ = 'expr', 'type' + + default_comparator = None def __init__(self, expr): self.expr = expr + self.type = expr.type + + @util.dependencies('sqlalchemy.sql.default_comparator') + def operate(self, default_comparator, op, *other, **kwargs): + o = default_comparator.operator_lookup[op.__name__] + return o[0](self.expr, op, *(other + o[1:]), **kwargs) + + @util.dependencies('sqlalchemy.sql.default_comparator') + def reverse_operate(self, default_comparator, op, other, **kwargs): + o = default_comparator.operator_lookup[op.__name__] + return o[0](self.expr, op, other, + reverse=True, *o[1:], **kwargs) + + def _adapt_expression(self, op, other_comparator): + """evaluate the return type of <self> <op> <othertype>, + and apply any adaptations to the given operator. + + This method determines the type of a resulting binary expression + given two source types and an operator. For example, two + :class:`.Column` objects, both of the type :class:`.Integer`, will + produce a :class:`.BinaryExpression` that also has the type + :class:`.Integer` when compared via the addition (``+``) operator. + However, using the addition operator with an :class:`.Integer` + and a :class:`.Date` object will produce a :class:`.Date`, assuming + "days delta" behavior by the database (in reality, most databases + other than Postgresql don't accept this particular operation). + + The method returns a tuple of the form <operator>, <type>. + The resulting operator and type will be those applied to the + resulting :class:`.BinaryExpression` as the final operator and the + right-hand side of the expression. + + Note that only a subset of operators make usage of + :meth:`._adapt_expression`, + including math operators and user-defined operators, but not + boolean comparison or special SQL keywords like MATCH or BETWEEN. + + """ + return op, other_comparator.type def __reduce__(self): return _reconstitute_comparator, (self.expr, ) @@ -252,7 +295,7 @@ class TypeEngine(Visitable): The construction of :meth:`.TypeEngine.with_variant` is always from the "fallback" type to that which is dialect specific. The returned type is an instance of :class:`.Variant`, which - itself provides a :meth:`~sqlalchemy.types.Variant.with_variant` + itself provides a :meth:`.Variant.with_variant` that can be called repeatedly. :param type_: a :class:`.TypeEngine` that will be selected @@ -417,7 +460,11 @@ class TypeEngine(Visitable): return util.generic_repr(self) -class UserDefinedType(TypeEngine): +class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType): + pass + + +class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)): """Base for user defined types. This should be the base of new types. Note that @@ -430,7 +477,7 @@ class UserDefinedType(TypeEngine): def __init__(self, precision = 8): self.precision = precision - def get_col_spec(self): + def get_col_spec(self, **kw): return "MYTYPE(%s)" % self.precision def bind_processor(self, dialect): @@ -450,10 +497,26 @@ class UserDefinedType(TypeEngine): Column('data', MyType(16)) ) + The ``get_col_spec()`` method will in most cases receive a keyword + argument ``type_expression`` which refers to the owning expression + of the type as being compiled, such as a :class:`.Column` or + :func:`.cast` construct. This keyword is only sent if the method + accepts keyword arguments (e.g. ``**kw``) in its argument signature; + introspection is used to check for this in order to support legacy + forms of this function. + + .. versionadded:: 1.0.0 the owning expression is passed to + the ``get_col_spec()`` method via the keyword argument + ``type_expression``, if it receives ``**kw`` in its signature. + """ __visit_name__ = "user_defined" + ensure_kwarg = 'get_col_spec' + class Comparator(TypeEngine.Comparator): + __slots__ = () + def _adapt_expression(self, op, other_comparator): if hasattr(self.type, 'adapt_operator'): util.warn_deprecated( @@ -617,6 +680,7 @@ class TypeDecorator(TypeEngine): """ class Comparator(TypeEngine.Comparator): + __slots__ = () def operate(self, op, *other, **kwargs): kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types @@ -630,9 +694,13 @@ class TypeDecorator(TypeEngine): @property def comparator_factory(self): - return type("TDComparator", - (TypeDecorator.Comparator, self.impl.comparator_factory), - {}) + if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: + return self.impl.comparator_factory + else: + return type("TDComparator", + (TypeDecorator.Comparator, + self.impl.comparator_factory), + {}) def _gen_dialect_impl(self, dialect): """ diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index bb525744a..d09b82148 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -51,6 +51,7 @@ class VisitableType(type): Classes having no __visit_name__ attribute will remain unaffected. """ + def __init__(cls, clsname, bases, clsdict): if clsname != 'Visitable' and \ hasattr(cls, '__visit_name__'): diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 1f37b4b45..2375a13a9 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -23,7 +23,8 @@ from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ assert_raises_message, AssertsCompiledSQL, ComparesTables, \ AssertsExecutionResults, expect_deprecated, expect_warnings -from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict +from .util import run_as_contextmanager, rowset, fail, \ + provide_metadata, adict, force_drop_names crashes = skip diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bf7c27a89..635f6c539 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -229,6 +229,7 @@ class AssertsCompiledSQL(object): def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, checkpositional=None, + check_prefetch=None, use_default_dialect=False, allow_dialect_select=False, literal_binds=False): @@ -289,6 +290,8 @@ class AssertsCompiledSQL(object): if checkpositional is not None: p = c.construct_params(params) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + if check_prefetch is not None: + eq_(c.prefetch, check_prefetch) class ComparesTables(object): @@ -405,29 +408,27 @@ class AssertsExecutionResults(object): cls.__name__, repr(expected_item))) return True + def sql_execution_asserter(self, db=None): + if db is None: + from . import db as db + + return assertsql.assert_engine(db) + def assert_sql_execution(self, db, callable_, *rules): - assertsql.asserter.add_rules(rules) - try: + with self.sql_execution_asserter(db) as asserter: callable_() - assertsql.asserter.statement_complete() - finally: - assertsql.asserter.clear_rules() + asserter.assert_(*rules) - def assert_sql(self, db, callable_, list_, with_sequences=None): - if (with_sequences is not None and - config.db.dialect.supports_sequences): - rules = with_sequences - else: - rules = list_ + def assert_sql(self, db, callable_, rules): newrules = [] for rule in rules: if isinstance(rule, dict): newrule = assertsql.AllOf(*[ - assertsql.ExactSQL(k, v) for k, v in rule.items() + assertsql.CompiledSQL(k, v) for k, v in rule.items() ]) else: - newrule = assertsql.ExactSQL(*rule) + newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) self.assert_sql_execution(db, callable_, *newrules) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index bcc999fe3..5c746e8f1 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -8,84 +8,141 @@ from ..engine.default import DefaultDialect from .. import util import re +import collections +import contextlib +from .. import event +from sqlalchemy.schema import _DDLCompiles +from sqlalchemy.engine.util import _distill_params class AssertRule(object): - def process_execute(self, clauseelement, *multiparams, **params): - pass + is_consumed = False + errormessage = None + consume_statement = True - def process_cursor_execute(self, statement, parameters, context, - executemany): + def process_statement(self, execute_observed): pass - def is_consumed(self): - """Return True if this rule has been consumed, False if not. - - Should raise an AssertionError if this rule's condition has - definitely failed. - - """ - - raise NotImplementedError() + def no_more_statements(self): + assert False, 'All statements are complete, but pending '\ + 'assertion rules remain' - def rule_passed(self): - """Return True if the last test of this rule passed, False if - failed, None if no test was applied.""" - raise NotImplementedError() - - def consume_final(self): - """Return True if this rule has been consumed. - - Should raise an AssertionError if this rule's condition has not - been consumed or has failed. +class SQLMatchRule(AssertRule): + pass - """ - if self._result is None: - assert False, 'Rule has not been consumed' - return self.is_consumed() +class CursorSQL(SQLMatchRule): + consume_statement = False + def __init__(self, statement, params=None): + self.statement = statement + self.params = params -class SQLMatchRule(AssertRule): - def __init__(self): - self._result = None - self._errmsg = "" + def process_statement(self, execute_observed): + stmt = execute_observed.statements[0] + if self.statement != stmt.statement or ( + self.params is not None and self.params != stmt.parameters): + self.errormessage = \ + "Testing for exact SQL %s parameters %s received %s %s" % ( + self.statement, self.params, + stmt.statement, stmt.parameters + ) + else: + execute_observed.statements.pop(0) + self.is_consumed = True + if not execute_observed.statements: + self.consume_statement = True - def rule_passed(self): - return self._result - def is_consumed(self): - if self._result is None: - return False +class CompiledSQL(SQLMatchRule): - assert self._result, self._errmsg + def __init__(self, statement, params=None): + self.statement = statement + self.params = params - return True + def _compare_sql(self, execute_observed, received_statement): + stmt = re.sub(r'[\n\t]', '', self.statement) + return received_statement == stmt + def _compile_dialect(self, execute_observed): + return DefaultDialect() -class ExactSQL(SQLMatchRule): + def _received_statement(self, execute_observed): + """reconstruct the statement and params in terms + of a target dialect, which for CompiledSQL is just DefaultDialect.""" - def __init__(self, sql, params=None): - SQLMatchRule.__init__(self) - self.sql = sql - self.params = params + context = execute_observed.context + compare_dialect = self._compile_dialect(execute_observed) + if isinstance(context.compiled.statement, _DDLCompiles): + compiled = \ + context.compiled.statement.compile(dialect=compare_dialect) + else: + compiled = ( + context.compiled.statement.compile( + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + inline=context.compiled.inline) + ) + _received_statement = re.sub(r'[\n\t]', '', str(compiled)) + parameters = execute_observed.parameters - def process_cursor_execute(self, statement, parameters, context, - executemany): - if not context: - return - _received_statement = \ - _process_engine_statement(context.unicode_statement, - context) - _received_parameters = context.compiled_parameters + if not parameters: + _received_parameters = [compiled.construct_params()] + else: + _received_parameters = [ + compiled.construct_params(m) for m in parameters] + + return _received_statement, _received_parameters + + def process_statement(self, execute_observed): + context = execute_observed.context + + _received_statement, _received_parameters = \ + self._received_statement(execute_observed) + params = self._all_params(context) + + equivalent = self._compare_sql(execute_observed, _received_statement) + + if equivalent: + if params is not None: + all_params = list(params) + all_received = list(_received_parameters) + while all_params and all_received: + param = dict(all_params.pop(0)) + + for idx, received in enumerate(list(all_received)): + # do a positive compare only + for param_key in param: + # a key in param did not match current + # 'received' + if param_key not in received or \ + received[param_key] != param[param_key]: + break + else: + # all keys in param matched 'received'; + # onto next param + del all_received[idx] + break + else: + # param did not match any entry + # in all_received + equivalent = False + break + if all_params or all_received: + equivalent = False - # TODO: remove this step once all unit tests are migrated, as - # ExactSQL should really be *exact* SQL + if equivalent: + self.is_consumed = True + self.errormessage = None + else: + self.errormessage = self._failure_message(params) % { + 'received_statement': _received_statement, + 'received_parameters': _received_parameters + } - sql = _process_assertion_statement(self.sql, context) - equivalent = _received_statement == sql + def _all_params(self, context): if self.params: if util.callable(self.params): params = self.params(context) @@ -93,127 +150,77 @@ class ExactSQL(SQLMatchRule): params = self.params if not isinstance(params, list): params = [params] - equivalent = equivalent and params \ - == context.compiled_parameters + return params else: - params = {} - self._result = equivalent - if not self._result: - self._errmsg = ( - 'Testing for exact statement %r exact params %r, ' - 'received %r with params %r' % - (sql, params, _received_statement, _received_parameters)) - + return None + + def _failure_message(self, expected_params): + return ( + 'Testing for compiled statement %r partial params %r, ' + 'received %%(received_statement)r with params ' + '%%(received_parameters)r' % ( + self.statement, expected_params + ) + ) -class RegexSQL(SQLMatchRule): +class RegexSQL(CompiledSQL): def __init__(self, regex, params=None): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params - def process_cursor_execute(self, statement, parameters, context, - executemany): - if not context: - return - _received_statement = \ - _process_engine_statement(context.unicode_statement, - context) - _received_parameters = context.compiled_parameters - equivalent = bool(self.regex.match(_received_statement)) - if self.params: - if util.callable(self.params): - params = self.params(context) - else: - params = self.params - if not isinstance(params, list): - params = [params] - - # do a positive compare only - - for param, received in zip(params, _received_parameters): - for k, v in param.items(): - if k not in received or received[k] != v: - equivalent = False - break - else: - params = {} - self._result = equivalent - if not self._result: - self._errmsg = \ - 'Testing for regex %r partial params %r, received %r '\ - 'with params %r' % (self.orig_regex, params, - _received_statement, - _received_parameters) - + def _failure_message(self, expected_params): + return ( + 'Testing for compiled statement ~%r partial params %r, ' + 'received %%(received_statement)r with params ' + '%%(received_parameters)r' % ( + self.orig_regex, expected_params + ) + ) -class CompiledSQL(SQLMatchRule): + def _compare_sql(self, execute_observed, received_statement): + return bool(self.regex.match(received_statement)) - def __init__(self, statement, params=None): - SQLMatchRule.__init__(self) - self.statement = statement - self.params = params - def process_cursor_execute(self, statement, parameters, context, - executemany): - if not context: - return - from sqlalchemy.schema import _DDLCompiles - _received_parameters = list(context.compiled_parameters) - - # recompile from the context, using the default dialect +class DialectSQL(CompiledSQL): + def _compile_dialect(self, execute_observed): + return execute_observed.context.dialect - if isinstance(context.compiled.statement, _DDLCompiles): - compiled = \ - context.compiled.statement.compile(dialect=DefaultDialect()) + def _received_statement(self, execute_observed): + received_stmt, received_params = super(DialectSQL, self).\ + _received_statement(execute_observed) + for real_stmt in execute_observed.statements: + if real_stmt.statement == received_stmt: + break else: - compiled = ( - context.compiled.statement.compile( - dialect=DefaultDialect(), - column_keys=context.compiled.column_keys) - ) - _received_statement = re.sub(r'[\n\t]', '', str(compiled)) - equivalent = self.statement == _received_statement - if self.params: - if util.callable(self.params): - params = self.params(context) - else: - params = self.params - if not isinstance(params, list): - params = [params] - else: - params = list(params) - all_params = list(params) - all_received = list(_received_parameters) - while params: - param = dict(params.pop(0)) - for k, v in context.compiled.params.items(): - param.setdefault(k, v) - if param not in _received_parameters: - equivalent = False - break - else: - _received_parameters.remove(param) - if _received_parameters: - equivalent = False + raise AssertionError( + "Can't locate compiled statement %r in list of " + "statements actually invoked" % received_stmt) + return received_stmt, execute_observed.context.compiled_parameters + + def _compare_sql(self, execute_observed, received_statement): + stmt = re.sub(r'[\n\t]', '', self.statement) + + # convert our comparison statement to have the + # paramstyle of the received + paramstyle = execute_observed.context.dialect.paramstyle + if paramstyle == 'pyformat': + stmt = re.sub( + r':([\w_]+)', r"%(\1)s", stmt) else: - params = {} - all_params = {} - all_received = [] - self._result = equivalent - if not self._result: - print('Testing for compiled statement %r partial params ' - '%r, received %r with params %r' % - (self.statement, all_params, - _received_statement, all_received)) - self._errmsg = ( - 'Testing for compiled statement %r partial params %r, ' - 'received %r with params %r' % - (self.statement, all_params, - _received_statement, all_received)) - - # print self._errmsg + # positional params + repl = None + if paramstyle == 'qmark': + repl = "?" + elif paramstyle == 'format': + repl = r"%s" + elif paramstyle == 'numeric': + repl = None + stmt = re.sub(r':([\w_]+)', repl, stmt) + + return received_statement == stmt class CountStatements(AssertRule): @@ -222,21 +229,13 @@ class CountStatements(AssertRule): self.count = count self._statement_count = 0 - def process_execute(self, clauseelement, *multiparams, **params): + def process_statement(self, execute_observed): self._statement_count += 1 - def process_cursor_execute(self, statement, parameters, context, - executemany): - pass - - def is_consumed(self): - return False - - def consume_final(self): - assert self.count == self._statement_count, \ - 'desired statement count %d does not match %d' \ - % (self.count, self._statement_count) - return True + def no_more_statements(self): + if self.count != self._statement_count: + assert False, 'desired statement count %d does not match %d' \ + % (self.count, self._statement_count) class AllOf(AssertRule): @@ -244,116 +243,113 @@ class AllOf(AssertRule): def __init__(self, *rules): self.rules = set(rules) - def process_execute(self, clauseelement, *multiparams, **params): - for rule in self.rules: - rule.process_execute(clauseelement, *multiparams, **params) - - def process_cursor_execute(self, statement, parameters, context, - executemany): - for rule in self.rules: - rule.process_cursor_execute(statement, parameters, context, - executemany) - - def is_consumed(self): - if not self.rules: - return True + def process_statement(self, execute_observed): for rule in list(self.rules): - if rule.rule_passed(): # a rule passed, move on - self.rules.remove(rule) - return len(self.rules) == 0 - return False + rule.errormessage = None + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.discard(rule) + if not self.rules: + self.is_consumed = True + break + elif not rule.errormessage: + # rule is not done yet + self.errormessage = None + break + else: + self.errormessage = list(self.rules)[0].errormessage - def rule_passed(self): - return self.is_consumed() - def consume_final(self): - return len(self.rules) == 0 +class Or(AllOf): + def process_statement(self, execute_observed): + for rule in self.rules: + rule.process_statement(execute_observed) + if rule.is_consumed: + self.is_consumed = True + break + else: + self.errormessage = list(self.rules)[0].errormessage -class Or(AllOf): - def __init__(self, *rules): - self.rules = set(rules) - self._consume_final = False - def is_consumed(self): - if not self.rules: - return True - for rule in list(self.rules): - if rule.rule_passed(): # a rule passed - self._consume_final = True - return True - return False +class SQLExecuteObserved(object): + def __init__(self, context, clauseelement, multiparams, params): + self.context = context + self.clauseelement = clauseelement + self.parameters = _distill_params(multiparams, params) + self.statements = [] - def consume_final(self): - assert self._consume_final, "Unsatisified rules remain" +class SQLCursorExecuteObserved( + collections.namedtuple( + "SQLCursorExecuteObserved", + ["statement", "parameters", "context", "executemany"]) +): + pass -def _process_engine_statement(query, context): - if util.jython: - # oracle+zxjdbc passes a PyStatement when returning into +class SQLAsserter(object): + def __init__(self): + self.accumulated = [] - query = str(query) - if context.engine.name == 'mssql' \ - and query.endswith('; select scope_identity()'): - query = query[:-25] - query = re.sub(r'\n', '', query) - return query + def _close(self): + self._final = self.accumulated + del self.accumulated + def assert_(self, *rules): + rules = list(rules) + observed = list(self._final) -def _process_assertion_statement(query, context): - paramstyle = context.dialect.paramstyle - if paramstyle == 'named': - pass - elif paramstyle == 'pyformat': - query = re.sub(r':([\w_]+)', r"%(\1)s", query) - else: - # positional params - repl = None - if paramstyle == 'qmark': - repl = "?" - elif paramstyle == 'format': - repl = r"%s" - elif paramstyle == 'numeric': - repl = None - query = re.sub(r':([\w_]+)', repl, query) + while observed and rules: + rule = rules[0] + rule.process_statement(observed[0]) + if rule.is_consumed: + rules.pop(0) + elif rule.errormessage: + assert False, rule.errormessage - return query + if rule.consume_statement: + observed.pop(0) + if not observed and rules: + rules[0].no_more_statements() + elif not rules and observed: + assert False, "Additional SQL statements remain" -class SQLAssert(object): - rules = None +@contextlib.contextmanager +def assert_engine(engine): + asserter = SQLAsserter() - def add_rules(self, rules): - self.rules = list(rules) + orig = [] - def statement_complete(self): - for rule in self.rules: - if not rule.consume_final(): - assert False, \ - 'All statements are complete, but pending '\ - 'assertion rules remain' - - def clear_rules(self): - del self.rules - - def execute(self, conn, clauseelement, multiparams, params, result): - if self.rules is not None: - if not self.rules: - assert False, \ - 'All rules have been exhausted, but further '\ - 'statements remain' - rule = self.rules[0] - rule.process_execute(clauseelement, *multiparams, **params) - if rule.is_consumed(): - self.rules.pop(0) - - def cursor_execute(self, conn, cursor, statement, parameters, - context, executemany): - if self.rules: - rule = self.rules[0] - rule.process_cursor_execute(statement, parameters, context, - executemany) + @event.listens_for(engine, "before_execute") + def connection_execute(conn, clauseelement, multiparams, params): + # grab the original statement + params before any cursor + # execution + orig[:] = clauseelement, multiparams, params -asserter = SQLAssert() + @event.listens_for(engine, "after_cursor_execute") + def cursor_execute(conn, cursor, statement, parameters, + context, executemany): + if not context: + return + # then grab real cursor statements and associate them all + # around a single context + if asserter.accumulated and \ + asserter.accumulated[-1].context is context: + obs = asserter.accumulated[-1] + else: + obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) + asserter.accumulated.append(obs) + obs.statements.append( + SQLCursorExecuteObserved( + statement, parameters, context, executemany) + ) + + try: + yield asserter + finally: + event.remove(engine, "after_cursor_execute", cursor_execute) + event.remove(engine, "before_execute", connection_execute) + asserter._close() diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 0f6f59401..444a79b70 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -204,7 +204,6 @@ def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" from sqlalchemy import create_engine - from .assertsql import asserter if not options: use_reaper = True @@ -216,11 +215,12 @@ def testing_engine(url=None, options=None): options = config.db_opts engine = create_engine(url, **options) + engine._has_events = True # enable event blocks, helps with + # profiling + if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 - event.listen(engine, 'after_execute', asserter.execute) - event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) if use_reaper: event.listen(engine.pool, 'connect', testing_reaper.connect) event.listen(engine.pool, 'checkout', testing_reaper.checkout) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index f94724608..0aff43ae1 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -425,7 +425,7 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) + OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)]) ) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index d86049da7..48d4d9c9b 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -192,9 +192,8 @@ class TablesTest(TestBase): def sql_count_(self, count, fn): self.assert_sql_count(self.bind, fn, count) - def sql_eq_(self, callable_, statements, with_sequences=None): - self.assert_sql(self.bind, - callable_, statements, with_sequences) + def sql_eq_(self, callable_, statements): + self.assert_sql(self.bind, callable_, statements) @classmethod def _load_fixtures(cls): diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 614a12133..b0188aa5a 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -294,7 +294,7 @@ def _setup_requirements(argument): @post def _prep_testing_database(options, file_config): - from sqlalchemy.testing import config + from sqlalchemy.testing import config, util from sqlalchemy.testing.exclusions import against from sqlalchemy import schema, inspect @@ -325,19 +325,10 @@ def _prep_testing_database(options, file_config): schema="test_schema") )) - for tname in reversed(inspector.get_table_names( - order_by="foreign_key")): - e.execute(schema.DropTable( - schema.Table(tname, schema.MetaData()) - )) + util.drop_all_tables(e, inspector) if config.requirements.schemas.enabled_for_config(cfg): - for tname in reversed(inspector.get_table_names( - order_by="foreign_key", schema="test_schema")): - e.execute(schema.DropTable( - schema.Table(tname, schema.MetaData(), - schema="test_schema") - )) + util.drop_all_tables(e, inspector, schema=cfg.test_schema) if against(cfg, "postgresql"): from sqlalchemy.dialects import postgresql diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 4bbc8ed9a..fbab4966c 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -84,7 +84,8 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict(list) items[:] = [ item for item in - items if isinstance(item.parent, pytest.Instance)] + items if isinstance(item.parent, pytest.Instance) + and not item.parent.parent.name.startswith("_")] test_classes = set(item.parent for item in items) for test_class in test_classes: for sub_cls in plugin_base.generate_sub_tests( diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 671bbe32d..57308925e 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -226,6 +226,7 @@ def count_functions(variance=0.05): callcount = stats.total_calls expected = _profile_stats.result(callcount) + if expected is None: expected_count = None else: @@ -249,10 +250,11 @@ def count_functions(variance=0.05): else: raise AssertionError( "Adjusted function call count %s not within %s%% " - "of expected %s. Rerun with --write-profiles to " + "of expected %s, platform %s. Rerun with " + "--write-profiles to " "regenerate this callcount." % ( callcount, (variance * 100), - expected_count)) + expected_count, _profile_stats.platform_key)) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index da3e3128a..5744431cb 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -323,6 +323,11 @@ class SuiteRequirements(Requirements): return exclusions.closed() @property + def temporary_tables(self): + """target database supports temporary tables""" + return exclusions.open() + + @property def temporary_views(self): """target database supports temporary views""" return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 08b858b47..3edbdeb8c 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -128,6 +128,10 @@ class ComponentReflectionTest(fixtures.TablesTest): DDL("create temporary view user_tmp_v as " "select * from user_tmp") ) + event.listen( + user_tmp, "before_drop", + DDL("drop view user_tmp_v") + ) @classmethod def define_index(cls, metadata, users): @@ -511,6 +515,8 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_temp_table_indexes(self): insp = inspect(self.metadata.bind) indexes = insp.get_indexes('user_tmp') + for ind in indexes: + ind.pop('dialect_options', None) eq_( # TODO: we need to add better filtering for indexes/uq constraints # that are doubled up diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 7b3f721a6..8230f923a 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -147,6 +147,10 @@ def run_as_contextmanager(ctx, fn, *arg, **kw): simulating the behavior of 'with' to support older Python versions. + This is not necessary anymore as we have placed 2.6 + as minimum Python version, however some tests are still using + this structure. + """ obj = ctx.__enter__() @@ -194,6 +198,25 @@ def provide_metadata(fn, *args, **kw): self.metadata = prev_meta +def force_drop_names(*names): + """Force the given table names to be dropped after test complete, + isolating for foreign key cycles + + """ + from . import config + from sqlalchemy import inspect + + @decorator + def go(fn, *args, **kw): + + try: + return fn(*args, **kw) + finally: + drop_all_tables( + config.db, inspect(config.db), include_names=names) + return go + + class adict(dict): """Dict keys available as attributes. Shadows.""" @@ -207,3 +230,39 @@ class adict(dict): return tuple([self[key] for key in keys]) get_all = __call__ + + +def drop_all_tables(engine, inspector, schema=None, include_names=None): + from sqlalchemy import Column, Table, Integer, MetaData, \ + ForeignKeyConstraint + from sqlalchemy.schema import DropTable, DropConstraint + + if include_names is not None: + include_names = set(include_names) + + with engine.connect() as conn: + for tname, fkcs in reversed( + inspector.get_sorted_table_and_fkc_names(schema=schema)): + if tname: + if include_names is not None and tname not in include_names: + continue + conn.execute(DropTable( + Table(tname, MetaData()) + )) + elif fkcs: + if not engine.dialect.supports_alter: + continue + for tname, fkc in fkcs: + if include_names is not None and \ + tname not in include_names: + continue + tb = Table( + tname, MetaData(), + Column('x', Integer), + Column('y', Integer), + schema=schema + ) + conn.execute(DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fkc) + )) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index b49e389ac..1215bd790 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -51,6 +51,7 @@ from .sql.sqltypes import ( Integer, Interval, LargeBinary, + MatchType, NCHAR, NVARCHAR, NullType, diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index dfed5b90a..ceee18d86 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \ generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \ safe_reraise,\ get_callable_argspec, only_once, attrsetter, ellipses_string, \ - warn_limited + warn_limited, map_bits, MemoizedSlots, EnsureKWArgType from .deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation, inject_docstring_text diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index d36852698..a49848d08 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -165,8 +165,13 @@ class immutabledict(ImmutableContainer, dict): return immutabledict, (dict(self), ) def union(self, d): - if not self: - return immutabledict(d) + if not d: + return self + elif not self: + if isinstance(d, immutabledict): + return d + else: + return immutabledict(d) else: d2 = immutabledict(self) dict.update(d2, d) @@ -179,8 +184,10 @@ class immutabledict(ImmutableContainer, dict): class Properties(object): """Provide a __getattr__/__setattr__ interface over a dict.""" + __slots__ = '_data', + def __init__(self, data): - self.__dict__['_data'] = data + object.__setattr__(self, '_data', data) def __len__(self): return len(self._data) @@ -200,8 +207,8 @@ class Properties(object): def __delitem__(self, key): del self._data[key] - def __setattr__(self, key, object): - self._data[key] = object + def __setattr__(self, key, obj): + self._data[key] = obj def __getstate__(self): return {'_data': self.__dict__['_data']} @@ -252,6 +259,8 @@ class OrderedProperties(Properties): """Provide a __getattr__/__setattr__ interface with an OrderedDict as backing store.""" + __slots__ = () + def __init__(self): Properties.__init__(self, OrderedDict()) @@ -259,10 +268,17 @@ class OrderedProperties(Properties): class ImmutableProperties(ImmutableContainer, Properties): """Provide immutable dict/object attribute to an underlying dictionary.""" + __slots__ = () + class OrderedDict(dict): """A dict that returns keys/values/items in the order they were added.""" + __slots__ = '_list', + + def __reduce__(self): + return OrderedDict, (self.items(),) + def __init__(self, ____sequence=None, **kwargs): self._list = [] if ____sequence is None: diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 5c17bea88..5a938501a 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -92,6 +92,15 @@ def _unique_symbols(used, *bases): raise NameError("exhausted namespace for symbol base %s" % base) +def map_bits(fn, n): + """Call the given function given each nonzero bit from n.""" + + while n: + b = n & (~n + 1) + yield fn(b) + n ^= b + + def decorator(target): """A signature-matching decorator factory.""" @@ -513,6 +522,15 @@ class portable_instancemethod(object): """ + __slots__ = 'target', 'name', '__weakref__' + + def __getstate__(self): + return {'target': self.target, 'name': self.name} + + def __setstate__(self, state): + self.target = state['target'] + self.name = state['name'] + def __init__(self, meth): self.target = meth.__self__ self.name = meth.__name__ @@ -791,6 +809,40 @@ class group_expirable_memoized_property(object): return memoized_instancemethod(fn) +class MemoizedSlots(object): + """Apply memoized items to an object using a __getattr__ scheme. + + This allows the functionality of memoized_property and + memoized_instancemethod to be available to a class using __slots__. + + """ + + def _fallback_getattr(self, key): + raise AttributeError(key) + + def __getattr__(self, key): + if key.startswith('_memoized'): + raise AttributeError(key) + elif hasattr(self, '_memoized_attr_%s' % key): + value = getattr(self, '_memoized_attr_%s' % key)() + setattr(self, key, value) + return value + elif hasattr(self, '_memoized_method_%s' % key): + fn = getattr(self, '_memoized_method_%s' % key) + + def oneshot(*args, **kw): + result = fn(*args, **kw) + memo = lambda *a, **kw: result + memo.__name__ = fn.__name__ + memo.__doc__ = fn.__doc__ + setattr(self, key, memo) + return result + oneshot.__doc__ = fn.__doc__ + return oneshot + else: + return self._fallback_getattr(key) + + def dependency_for(modulename): def decorate(obj): # TODO: would be nice to improve on this import silliness, @@ -936,7 +988,7 @@ def asbool(obj): def bool_or_str(*text): - """Return a callable that will evaulate a string as + """Return a callable that will evaluate a string as boolean, or one of a set of "alternate" string values. """ @@ -969,7 +1021,7 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): kw[key] = type_(kw[key]) -def constructor_copy(obj, cls, **kw): +def constructor_copy(obj, cls, *args, **kw): """Instantiate cls using the __dict__ of obj as constructor arguments. Uses inspect to match the named arguments of ``cls``. @@ -978,7 +1030,7 @@ def constructor_copy(obj, cls, **kw): names = get_cls_kwargs(cls) kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__) - return cls(**kw) + return cls(*args, **kw) def counter(): @@ -1296,6 +1348,7 @@ def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): NoneType = type(None) + def attrsetter(attrname): code = \ "def set(obj, value):"\ @@ -1303,3 +1356,29 @@ def attrsetter(attrname): env = locals().copy() exec(code, env) return env['set'] + + +class EnsureKWArgType(type): + """Apply translation of functions to accept **kw arguments if they + don't already. + + """ + def __init__(cls, clsname, bases, clsdict): + fn_reg = cls.ensure_kwarg + if fn_reg: + for key in clsdict: + m = re.match(fn_reg, key) + if m: + fn = clsdict[key] + spec = inspect.getargspec(fn) + if not spec.keywords: + clsdict[key] = wrapped = cls._wrap_w_kw(fn) + setattr(cls, key, wrapped) + super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict) + + def _wrap_w_kw(self, fn): + + def wrap(*arg, **kw): + return fn(*arg) + return update_wrapper(wrap, fn) + |