diff options
| author | Gaëtan de Menten <gdementen@gmail.com> | 2010-04-11 20:58:37 +0200 |
|---|---|---|
| committer | Gaëtan de Menten <gdementen@gmail.com> | 2010-04-11 20:58:37 +0200 |
| commit | d3bc6c08d4484538171456cd4253bb3d47fa99e0 (patch) | |
| tree | 6139becf2c84144d00eb9a359ee305c6b796e62c | |
| parent | aa9506a0da35b78b2c4b1c62b1700fe33f5bea2e (diff) | |
| parent | 4f2b05ff2cb28d6741d66c7302c3349f064a9857 (diff) | |
| download | sqlalchemy-d3bc6c08d4484538171456cd4253bb3d47fa99e0.tar.gz | |
merge branch
38 files changed, 1019 insertions, 503 deletions
@@ -3,3 +3,4 @@ syntax:regexp ^doc/build/output .pyc$ .orig$ +.egg-info
\ No newline at end of file @@ -34,6 +34,31 @@ CHANGES - id(obj) is no longer used internally within topological.py, as the sorting functions now require hashable objects only. [ticket:1756] + + - The ORM will set the docstring of all generated descriptors + to None by default. This can be overridden using 'doc' + (or if using Sphinx, attribute docstrings work too). + + - Added kw argument 'doc' to all mapper property callables + as well as Column(). Will assemble the string 'doc' as + the '__doc__' attribute on the descriptor. + + - Usage of version_id_col on a backend that supports + cursor.rowcount for execute() but not executemany() now works + when a delete is issued (already worked for saves, since those + don't use executemany()). For a backend that doesn't support + cursor.rowcount at all, a warning is emitted the same + as with saves. [ticket:1761] + + - The ORM now short-term caches the "compiled" form of + insert() and update() constructs when flushing lists of + objects of all the same class, thereby avoiding redundant + compilation per individual INSERT/UPDATE within an + individual flush() call. + + - internal getattr(), setattr(), getcommitted() methods + on ColumnProperty, CompositeProperty, RelationshipProperty + have been underscored, signature has changed. - engines - The C extension now also works with DBAPIs which use custom @@ -46,6 +71,31 @@ CHANGES errors if column._label is used as a bind name during an UPDATE. Test coverage which wasn't present in 0.5 has been added. [ticket:1755] + + - somejoin.select(fold_equivalents=True) is no longer + deprecated, and will eventually be rolled into a more + comprehensive version of the feature for [ticket:1729]. + + - the Numeric type raises an *enormous* warning when expected + to convert floats to Decimal from a DBAPI that returns floats. + This includes SQLite, Oracle, Sybase, MS-SQL. + [ticket:1759] + + - Fixed an error in expression typing which caused an endless + loop for expressions with two NULL types. + + - Fixed bug in execution_options() feature whereby the existing + Transaction and other state information from the parent + connection would not be propagated to the sub-connection. + + - Added new 'compiled_cache' execution option. A dictionary + where Compiled objects will be cached when the Connection + compiles a clause expression into a dialect- and parameter- + specific Compiled object. It is the user's responsibility to + manage the size of this dictionary, which will have keys + corresponding to the dialect, clause element, the column + names within the VALUES or SET clause of an INSERT or UPDATE, + as well as the "batch" mode for an INSERT or UPDATE statement. - ext - the compiler extension now allows @compiles decorators @@ -57,6 +107,34 @@ CHANGES if a non-mapped class attribute is referenced in the string-based relationship() arguments. + - Further reworked the "mixin" logic in declarative to + additionally allow __mapper_args__ as a @classproperty + on a mixin, such as to dynamically assign polymorphic_identity. + +- postgresql + - Postgresql now reflects sequence names associated with + SERIAL columns correctly, after the name of of the sequence + has been changed. Thanks to Kumar McMillan for the patch. + [ticket:1071] + + - Repaired missing import in psycopg2._PGNumeric type when + unknown numeric is received. + + - psycopg2/pg8000 dialects now aware of REAL[], FLOAT[], + DOUBLE_PRECISION[], NUMERIC[] return types without + raising an exception. + +- oracle + - Now using cx_oracle output converters so that the + DBAPI returns natively the kinds of values we prefer: + - NUMBER values with positive precision + scale convert + to cx_oracle.STRING and then to Decimal. This + allows perfect precision for the Numeric type when + using cx_oracle. [ticket:1759] + - STRING/FIXED_CHAR now convert to unicode natively. + SQLAlchemy's String types then don't need to + apply any kind of conversions. + - examples - Updated attribute_shard.py example to use a more robust method of searching a Query for binary expressions which diff --git a/README.unittests b/README.unittests index dee34a106..dd2f6ab1b 100644 --- a/README.unittests +++ b/README.unittests @@ -13,6 +13,13 @@ http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html SQLAlchemy implements a nose plugin that must be present when tests are run. This plugin is available when SQLAlchemy is installed via setuptools. +NB: You will need to manually install nose, it is unlikely to be pulled + down as a dependency of installing SQLAlchemy. + + Nose can be installed with: + + $ easy_install nose + INSTANT TEST RUNNER ------------------- diff --git a/doc/build/mappers.rst b/doc/build/mappers.rst index 7e320c26a..81d67f217 100644 --- a/doc/build/mappers.rst +++ b/doc/build/mappers.rst @@ -477,7 +477,9 @@ It also accepts a second argument ``selectable`` which replaces the automatic jo # custom selectable query.with_polymorphic([Engineer, Manager], employees.outerjoin(managers).outerjoin(engineers)) -:func:`~sqlalchemy.orm.query.Query.with_polymorphic` is also needed when you wish to add filter criterion that is specific to one or more subclasses, so that those columns are available to the WHERE clause: +:func:`~sqlalchemy.orm.query.Query.with_polymorphic` is also needed +when you wish to add filter criteria that are specific to one or more +subclasses; It makes the subclasses' columns available to the WHERE clause: .. sourcecode:: python+sql diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 91af6620b..7502ed1d5 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -6,6 +6,9 @@ Driver The Oracle dialect uses the cx_oracle driver, available at http://cx-oracle.sourceforge.net/ . The dialect has several behaviors which are specifically tailored towards compatibility with this module. +Version 5.0 or greater is **strongly** recommended, as SQLAlchemy makes +extensive use of the cx_oracle output converters for numeric and +string conversions. Connecting ---------- @@ -38,33 +41,21 @@ URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: Unicode ------- -As of cx_oracle 5, Python unicode objects can be bound directly to statements, -and it appears that cx_oracle can handle these even without NLS_LANG being set. -SQLAlchemy tests for version 5 and will pass unicode objects straight to cx_oracle -if this is the case. For older versions of cx_oracle, SQLAlchemy will encode bind -parameters normally using dialect.encoding as the encoding. +cx_oracle 5 fully supports Python unicode objects. SQLAlchemy will pass +all unicode strings directly to cx_oracle, and additionally uses an output +handler so that all string based result values are returned as unicode as well. LOB Objects ----------- -cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set -is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, -SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, -the LOB object requires an active cursor association, meaning if you were to fetch many rows -at once such that cx_oracle had to go back to the database and fetch a new batch of rows, -the LOB objects in the already-fetched rows are now unreadable and will raise an error. -SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. -The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy -defaults to 50 (cx_oracle normally defaults this to one). - -Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to -"normalize" the results to look more like that of other DBAPIs. - -The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place -for all statement executions, even plain string-based statements for which SQLA has no awareness -of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases -without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" -of LOB objects, can be disabled using auto_convert_lobs=False. +cx_oracle returns oracle LOBs using the cx_oracle.LOB object. SQLAlchemy converts +these to strings so that the interface of the Binary type is consistent with that of +other backends, and so that the linkage to a live cursor is not needed in scenarios +like result.fetchmany() and result.fetchall(). This means that by default, LOB +objects are fully fetched unconditionally by SQLAlchemy, and the linkage to a live +cursor is broken. + +To disable this processing, pass ``auto_convert_lobs=False`` to :func:`create_engine()`. Two Phase Transaction Support ----------------------------- @@ -78,16 +69,33 @@ from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, \ RESERVED_WORDS, OracleExecutionContext from sqlalchemy.dialects.oracle import base as oracle from sqlalchemy.engine import base -from sqlalchemy import types as sqltypes, util, exc +from sqlalchemy import types as sqltypes, util, exc, processors from datetime import datetime import random +from decimal import Decimal class _OracleNumeric(sqltypes.Numeric): - # cx_oracle accepts Decimal objects, but returns - # floats def bind_processor(self, dialect): + # cx_oracle accepts Decimal objects and floats return None - + + def result_processor(self, dialect, coltype): + # we apply a connection output handler that + # returns Decimal for positive precision + scale NUMBER + # types + if dialect.supports_native_decimal: + if self.asdecimal and self.scale is None: + processors.to_decimal_processor_factory(Decimal) + elif not self.asdecimal and self.scale > 0: + return processors.to_float + else: + return None + else: + # cx_oracle 4 behavior, will assume + # floats + return super(_OracleNumeric, self).\ + result_processor(dialect, coltype) + class _OracleDate(sqltypes.Date): def bind_processor(self, dialect): return None @@ -127,17 +135,9 @@ class _NativeUnicodeMixin(object): return super(_NativeUnicodeMixin, self).bind_processor(dialect) # end Py2K - def result_processor(self, dialect, coltype): - # if we know cx_Oracle will return unicode, - # don't process results - if dialect._cx_oracle_with_unicode: - return None - elif self.convert_unicode != 'force' and \ - dialect._cx_oracle_native_nvarchar and \ - coltype in dialect._cx_oracle_unicode_types: - return None - else: - return super(_NativeUnicodeMixin, self).result_processor(dialect, coltype) + # we apply a connection output handler that returns + # unicode in all cases, so the "native_unicode" flag + # will be set for the default String.result_processor. class _OracleChar(_NativeUnicodeMixin, sqltypes.CHAR): def get_dbapi_type(self, dbapi): @@ -163,7 +163,7 @@ class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText): if lob_processor is None: return None - string_processor = _NativeUnicodeMixin.result_processor(self, dialect, coltype) + string_processor = sqltypes.UnicodeText.result_processor(self, dialect, coltype) if string_processor is None: return lob_processor @@ -253,6 +253,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): c = self._connection.connection.cursor() if self.dialect.arraysize: c.arraysize = self.dialect.arraysize + return c def get_result_proxy(self): @@ -345,6 +346,7 @@ class ReturningResultProxy(base.FullyBufferedResultProxy): class OracleDialect_cx_oracle(OracleDialect): execution_ctx_cls = OracleExecutionContext_cx_oracle statement_compiler = OracleCompiler_cx_oracle + driver = "cx_oracle" colspecs = colspecs = { @@ -361,7 +363,6 @@ class OracleDialect_cx_oracle(OracleDialect): sqltypes.CHAR : _OracleChar, sqltypes.Integer : _OracleInteger, # this is only needed for OUT parameters. # it would be nice if we could not use it otherwise. - oracle.NUMBER : oracle.NUMBER, # don't let this get converted oracle.RAW: _OracleRaw, sqltypes.Unicode: _OracleNVarChar, sqltypes.NVARCHAR : _OracleNVarChar, @@ -388,7 +389,7 @@ class OracleDialect_cx_oracle(OracleDialect): cx_oracle_ver = tuple([int(x) for x in self.dbapi.version.split('.')]) else: cx_oracle_ver = (0, 0, 0) - + def types(*names): return set([ getattr(self.dbapi, name, None) for name in names @@ -398,6 +399,7 @@ class OracleDialect_cx_oracle(OracleDialect): self._cx_oracle_unicode_types = types("UNICODE", "NCLOB") self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB") self.supports_unicode_binds = cx_oracle_ver >= (5, 0) + self.supports_native_decimal = cx_oracle_ver >= (5, 0) self._cx_oracle_native_nvarchar = cx_oracle_ver >= (5, 0) if cx_oracle_ver is None: @@ -446,6 +448,26 @@ class OracleDialect_cx_oracle(OracleDialect): import cx_Oracle return cx_Oracle + def on_connect(self): + cx_Oracle = self.dbapi + def output_type_handler(cursor, name, defaultType, size, precision, scale): + # convert all NUMBER with precision + positive scale to Decimal. + # this effectively allows "native decimal" mode. + if defaultType == cx_Oracle.NUMBER and precision and scale > 0: + return cursor.var( + cx_Oracle.STRING, + 255, + outconverter=Decimal, + arraysize=cursor.arraysize) + # allow all strings to come back natively as Unicode + elif defaultType in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR): + return cursor.var(unicode, size, cursor.arraysize) + + def on_connect(conn): + conn.outputtypehandler = output_type_handler + + return on_connect + def create_connect_args(self, url): dialect_opts = dict(url.query) for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index bef2f1c61..312ae9aa8 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -80,6 +80,10 @@ from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ CHAR, TEXT, FLOAT, NUMERIC, \ DATE, BOOLEAN +_DECIMAL_TYPES = (1700, 1231) +_FLOAT_TYPES = (700, 701, 1021, 1022) + + class REAL(sqltypes.Float): __visit_name__ = "REAL" @@ -814,7 +818,8 @@ class PGDialect(default.DefaultDialect): result = connection.execute( sql.text(u"SELECT relname FROM pg_class c " "WHERE relkind = 'r' " - "AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) " % + "AND '%s' = (select nspname from pg_namespace n " + "where n.oid = c.relnamespace) " % current_schema, typemap = {'relname':sqltypes.Unicode} ) @@ -832,7 +837,8 @@ class PGDialect(default.DefaultDialect): SELECT relname FROM pg_class c WHERE relkind = 'v' - AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) + AND '%(schema)s' = (select nspname from pg_namespace n + where n.oid = c.relnamespace) """ % dict(schema=current_schema) # Py3K #view_names = [row[0] for row in connection.execute(s)] @@ -870,7 +876,8 @@ class PGDialect(default.DefaultDialect): SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), - (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d + (SELECT substring(pg_catalog.pg_get_expr(d.adbin, d.adrelid) for 128) + FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) AS DEFAULT, a.attnotnull, a.attnum, a.attrelid as table_oid diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index a620daac6..862c915aa 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -26,23 +26,24 @@ from sqlalchemy import util, exc from sqlalchemy import processors from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql.base import PGDialect, \ - PGCompiler, PGIdentifierPreparer, PGExecutionContext + PGCompiler, PGIdentifierPreparer, PGExecutionContext,\ + _DECIMAL_TYPES, _FLOAT_TYPES class _PGNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: - if coltype in (700, 701): + if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory(decimal.Decimal) - elif coltype == 1700: + elif coltype in _DECIMAL_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) else: - if coltype in (700, 701): + if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 return None - elif coltype == 1700: + elif coltype in _DECIMAL_TYPES: return processors.to_float else: raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index f21c9a558..2a51a7239 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -60,7 +60,7 @@ import re import decimal import logging -from sqlalchemy import util +from sqlalchemy import util, exc from sqlalchemy import processors from sqlalchemy.engine import base, default from sqlalchemy.sql import expression @@ -68,7 +68,7 @@ from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \ PGIdentifierPreparer, PGExecutionContext, \ - ENUM, ARRAY + ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES logger = logging.getLogger('sqlalchemy.dialects.postgresql') @@ -80,18 +80,18 @@ class _PGNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: - if coltype in (700, 701): + if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory(decimal.Decimal) - elif coltype == 1700: + elif coltype in _DECIMAL_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) else: - if coltype in (700, 701): + if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 return None - elif coltype == 1700: + elif coltype in _DECIMAL_TYPES: return processors.to_float else: raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dc42ed957..4c5a6a82b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -794,6 +794,14 @@ class Connection(Connectable): """ return self.engine.Connection(self.engine, self.__connection, _branch=True) + + def _clone(self): + """Create a shallow copy of this Connection. + + """ + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + return c def execution_options(self, **opt): """ Set non-SQL options for the connection which take effect during execution. @@ -811,9 +819,9 @@ class Connection(Connectable): :meth:`sqlalchemy.sql.expression.Executable.execution_options`. """ - return self.engine.Connection( - self.engine, self.__connection, - _branch=self.__branch, _execution_options=opt) + c = self._clone() + c._execution_options = c._execution_options.union(opt) + return c @property def dialect(self): @@ -1142,10 +1150,22 @@ class Connection(Connectable): else: keys = [] + if 'compiled_cache' in self._execution_options: + key = self.dialect, elem, tuple(keys), len(params) > 1 + if key in self._execution_options['compiled_cache']: + compiled_sql = self._execution_options['compiled_cache'][key] + else: + compiled_sql = elem.compile( + dialect=self.dialect, column_keys=keys, + inline=len(params) > 1) + self._execution_options['compiled_cache'][key] = compiled_sql + else: + compiled_sql = elem.compile( + dialect=self.dialect, column_keys=keys, + inline=len(params) > 1) + context = self.__create_execution_context( - compiled_sql=elem.compile( - dialect=self.dialect, column_keys=keys, - inline=len(params) > 1), + compiled_sql=compiled_sql, parameters=params ) return self.__execute_context(context) diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index dde49e232..68c434fd9 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -119,6 +119,23 @@ overriding routine and cause an endless loop. Such as, to add "prefix" to all The above compiler will prefix all INSERT statements with "some prefix" when compiled. +Changing Compilation of Types +============================= + +``compiler`` works for types, too, such as below where we implement the MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: + + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') + def compile_varchar(element, compiler, **kw): + if element.length == 'max': + return "VARCHAR('max')" + else: + return compiler.visit_VARCHAR(element, **kw) + + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) + Subclassing Guidelines ====================== @@ -147,8 +164,23 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression function or stored procedure type of call. Since most databases support statements along the line of "SELECT FROM <some function>" ``FunctionElement`` adds in the ability to be used in the FROM clause of a - ``select()`` construct. - + ``select()`` construct:: + + from sqlalchemy.sql.expression import FunctionElement + + class coalesce(FunctionElement): + name = 'coalesce' + + @compiles(coalesce) + def compile(element, compiler, **kw): + return "coalesce(%s)" % compiler.process(element.clauses) + + @compiles(coalesce, 'oracle') + def compile(element, compiler, **kw): + if len(element.clauses) > 2: + raise TypeError("coalesce only supports two arguments on Oracle") + return "nvl(%s)" % compiler.process(element.clauses) + * :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions, like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement`` subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``. @@ -160,7 +192,7 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression used with any expression class that represents a "standalone" SQL statement that can be passed directly to an ``execute()`` method. It is already implicit within ``DDLElement`` and ``FunctionElement``. - + """ def compiles(class_, *specs): diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 407de1004..94dba77fa 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -535,43 +535,46 @@ def _as_declarative(cls, classname, dict_): dict_ = dict(dict_) column_copies = dict() - unmapped_mixins = False - for base in cls.__bases__: - names = dir(base) - if not _is_mapped_class(base): - unmapped_mixins = True - for name in names: - obj = getattr(base,name, None) - if isinstance(obj, Column): - if obj.foreign_keys: - raise exceptions.InvalidRequestError( - "Columns with foreign keys to other columns " - "are not allowed on declarative mixins at this time." - ) - dict_[name]=column_copies[obj]=obj.copy() - elif isinstance(obj, RelationshipProperty): - raise exceptions.InvalidRequestError( - "relationships are not allowed on " - "declarative mixins at this time.") - - # doing it this way enables these attributes to be descriptors - get_mapper_args = '__mapper_args__' in dict_ - get_table_args = '__table_args__' in dict_ - if unmapped_mixins: - get_mapper_args = get_mapper_args or getattr(cls,'__mapper_args__',None) - get_table_args = get_table_args or getattr(cls,'__table_args__',None) - tablename = getattr(cls,'__tablename__',None) - if tablename: - # subtle: if tablename is a descriptor here, we actually - # put the wrong value in, but it serves as a marker to get - # the right value value... - dict_['__tablename__']=tablename - - # now that we know whether or not to get these, get them from the class - # if we should, enabling them to be decorators - mapper_args = get_mapper_args and cls.__mapper_args__ or {} - table_args = get_table_args and cls.__table_args__ or None + mixin_table_args = None + mapper_args = {} + table_args = None + def _is_mixin(klass): + return not _is_mapped_class(klass) and klass is not cls + + for base in cls.__mro__: + if _is_mixin(base): + for name in dir(base): + if name == '__mapper_args__': + if not mapper_args: + mapper_args = cls.__mapper_args__ + elif name == '__table_args__': + if not table_args: + table_args = mixin_table_args = cls.__table_args__ + elif name == '__tablename__': + if '__tablename__' not in dict_: + dict_['__tablename__'] = cls.__tablename__ + else: + obj = getattr(base,name, None) + if isinstance(obj, Column): + if obj.foreign_keys: + raise exceptions.InvalidRequestError( + "Columns with foreign keys to other columns " + "are not allowed on declarative mixins at this time." + ) + dict_[name]=column_copies[obj]=obj.copy() + elif isinstance(obj, RelationshipProperty): + raise exceptions.InvalidRequestError( + "relationships are not allowed on " + "declarative mixins at this time.") + elif base is cls: + if '__mapper_args__' in dict_: + mapper_args = cls.__mapper_args__ + if '__table_args__' in dict_: + table_args = cls.__table_args__ + if '__tablename__' in dict_: + dict_['__tablename__'] = cls.__tablename__ + # make sure that column copies are used rather than the original columns # from any mixins for k, v in mapper_args.iteritems(): @@ -681,7 +684,7 @@ def _as_declarative(cls, classname, dict_): if table is None: # single table inheritance. # ensure no table args - if table_args is not None: + if table_args is not None and table_args is not mixin_table_args: raise exceptions.ArgumentError( "Can't place __table_args__ on an inherited class with no table." ) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 206c8d0c2..c2f6337bc 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -266,6 +266,9 @@ def relationship(argument, secondary=None, **kwargs): a class which extends :class:`RelationshipProperty.Comparator` which provides custom SQL clause generation for comparison operations. + :param doc: + docstring which will be applied to the resulting descriptor. + :param extension: an :class:`AttributeExtension` instance, or list of extensions, which will be prepended to the list of attribute listeners for @@ -469,7 +472,7 @@ def relation(*arg, **kw): def dynamic_loader(argument, secondary=None, primaryjoin=None, secondaryjoin=None, foreign_keys=None, backref=None, post_update=False, cascade=False, remote_side=None, - enable_typechecks=True, passive_deletes=False, + enable_typechecks=True, passive_deletes=False, doc=None, order_by=None, comparator_factory=None, query_class=None): """Construct a dynamically-loading mapper property. @@ -508,7 +511,7 @@ def dynamic_loader(argument, secondary=None, primaryjoin=None, secondaryjoin=secondaryjoin, foreign_keys=foreign_keys, backref=backref, post_update=post_update, cascade=cascade, remote_side=remote_side, enable_typechecks=enable_typechecks, passive_deletes=passive_deletes, - order_by=order_by, comparator_factory=comparator_factory, + order_by=order_by, comparator_factory=comparator_factory,doc=doc, strategy_class=DynaLoader, query_class=query_class) def column_property(*args, **kwargs): @@ -538,7 +541,11 @@ def column_property(*args, **kwargs): it does not load immediately, and is instead loaded when the attribute is first accessed on an instance. See also :func:`~sqlalchemy.orm.deferred`. - + + doc + optional string that will be applied as the doc on the + class-bound descriptor. + extension an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, or list of extensions, which will be prepended to the list of @@ -612,6 +619,10 @@ def composite(class_, *cols, **kwargs): a class which extends ``sqlalchemy.orm.properties.CompositeProperty.Comparator`` which provides custom SQL clause generation for comparison operations. + doc + optional string that will be applied as the doc on the + class-bound descriptor. + extension an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, or list of extensions, which will be prepended to the list of @@ -813,7 +824,7 @@ def mapper(class_, local_table=None, *args, **params): """ return Mapper(class_, local_table, *args, **params) -def synonym(name, map_column=False, descriptor=None, comparator_factory=None): +def synonym(name, map_column=False, descriptor=None, comparator_factory=None, doc=None): """Set up `name` as a synonym to another mapped property. Used with the ``properties`` dictionary sent to :func:`~sqlalchemy.orm.mapper`. @@ -851,7 +862,10 @@ def synonym(name, map_column=False, descriptor=None, comparator_factory=None): proxy access to the column-based attribute. """ - return SynonymProperty(name, map_column=map_column, descriptor=descriptor, comparator_factory=comparator_factory) + return SynonymProperty(name, map_column=map_column, + descriptor=descriptor, + comparator_factory=comparator_factory, + doc=doc) def comparable_property(comparator_factory, descriptor=None): """Provide query semantics for an unmanaged attribute. diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 887d9a9c1..b631ea2c9 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1366,12 +1366,12 @@ def unregister_class(class_): instrumentation_registry.unregister(class_) def register_attribute(class_, key, **kw): - proxy_property = kw.pop('proxy_property', None) comparator = kw.pop('comparator', None) parententity = kw.pop('parententity', None) - register_descriptor(class_, key, proxy_property, comparator, parententity) + doc = kw.pop('doc', None) + register_descriptor(class_, key, proxy_property, comparator, parententity, doc=doc) if not proxy_property: register_attribute_impl(class_, key, **kw) @@ -1405,7 +1405,8 @@ def register_attribute_impl(class_, key, manager.post_configure_attribute(key) -def register_descriptor(class_, key, proxy_property=None, comparator=None, parententity=None, property_=None): +def register_descriptor(class_, key, proxy_property=None, comparator=None, + parententity=None, property_=None, doc=None): manager = manager_of_class(class_) if proxy_property: @@ -1413,7 +1414,9 @@ def register_descriptor(class_, key, proxy_property=None, comparator=None, paren descriptor = proxy_type(key, proxy_property, comparator, parententity) else: descriptor = InstrumentedAttribute(key, comparator=comparator, parententity=parententity) - + + descriptor.__doc__ = doc + manager.instrument_attribute(key, descriptor) def unregister_attribute(class_, key): diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 616f2510a..65a24843d 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -134,13 +134,13 @@ def column_mapped_collection(mapping_spec): def keyfunc(value): state = instance_state(value) m = _state_mapper(state) - return m._get_state_attr_by_column(state, cols[0]) + return m._get_state_attr_by_column(state, state.dict, cols[0]) else: mapping_spec = tuple(cols) def keyfunc(value): state = instance_state(value) m = _state_mapper(state) - return tuple(m._get_state_attr_by_column(state, c) + return tuple(m._get_state_attr_by_column(state, state.dict, c) for c in mapping_spec) return lambda: MappedCollection(keyfunc) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index cbbfb0883..1ec22127c 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -6,10 +6,13 @@ """Relationship dependencies. -Bridges the ``PropertyLoader`` (i.e. a ``relationship()``) and the +Bridges the ``RelationshipLoader`` (i.e. a ``relationship()``) and the ``UOWTransaction`` together to allow processing of relationship()-based dependencies at flush time. +A large portion of this module will be reworked in an +upcoming release. See [ticket:1742] for details. + """ from sqlalchemy import sql, util @@ -326,22 +329,25 @@ class DetectKeySwitch(DependencyProcessor): def _process_key_switches(self, deplist, uowcommit): switchers = set(s for s in deplist if self._pks_changed(uowcommit, s)) if switchers: - # yes, we're doing a linear search right now through the UOW. only - # takes effect when primary key values have actually changed. - # a possible optimization might be to enhance the "hasparents" capability of - # attributes to actually store all parent references, but this introduces - # more complicated attribute accounting. - for s in [elem for elem in uowcommit.session.identity_map.all_states() - if issubclass(elem.class_, self.parent.class_) and - self.key in elem.dict and - elem.dict[self.key] is not None and - attributes.instance_state(elem.dict[self.key]) in switchers - ]: - uowcommit.register_object(s) - sync.populate( - attributes.instance_state(s.dict[self.key]), - self.mapper, s, self.parent, self.prop.synchronize_pairs, - uowcommit, self.passive_updates) + # if primary key values have actually changed somewhere, perform + # a linear search through the UOW in search of a parent. + # possible optimizations here include additional accounting within + # the attribute system, or allowing a one-to-many attr to circumvent + # the need for the search in this direction. + for state in uowcommit.session.identity_map.all_states(): + if not issubclass(state.class_, self.parent.class_): + continue + dict_ = state.dict + related = dict_.get(self.key) + if related is not None: + related_state = attributes.instance_state(dict_[self.key]) + if related_state in switchers: + uowcommit.register_object(state) + sync.populate( + related_state, + self.mapper, state, + self.parent, self.prop.synchronize_pairs, + uowcommit, self.passive_updates) def _pks_changed(self, uowcommit, state): return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs) @@ -543,11 +549,8 @@ class ManyToManyDP(DependencyProcessor): class MapperStub(object): """Represent a many-to-many dependency within a flush context. - - The UOWTransaction corresponds dependencies to mappers. - MapperStub takes the place of the "association table" - so that a depedendency can be corresponded to it. - + + This object is deprecated. """ def __init__(self, parent, mapper, key): diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 7fbb0862d..1a82b96b1 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -510,6 +510,8 @@ class MapperProperty(object): Establishes a topological dependency between two mappers which will affect the order in which mappers persist data. + This method is deprecated. + """ pass @@ -518,10 +520,11 @@ class MapperProperty(object): """Called by the ``Mapper`` in response to the UnitOfWork calling the ``Mapper``'s register_processors operation. Establishes a processor object between two mappers which - will link data and state between parent/child objects. + will synchronize state between parent/child objects. + + This method is deprecated. """ - pass def is_primary(self): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8f0f2128b..6d0a51a52 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1116,7 +1116,8 @@ class Mapper(object): return self._primary_key_from_state(state) def _primary_key_from_state(self, state): - return [self._get_state_attr_by_column(state, column) for column in self.primary_key] + dict_ = state.dict + return [self._get_state_attr_by_column(state, dict_, column) for column in self.primary_key] def _get_col_to_prop(self, column): try: @@ -1129,18 +1130,19 @@ class Mapper(object): raise orm_exc.UnmappedColumnError("No column %s is configured on mapper %s..." % (column, self)) # TODO: improve names? - def _get_state_attr_by_column(self, state, column): - return self._get_col_to_prop(column).getattr(state, column) + def _get_state_attr_by_column(self, state, dict_, column): + return self._get_col_to_prop(column)._getattr(state, dict_, column) - def _set_state_attr_by_column(self, state, column, value): - return self._get_col_to_prop(column).setattr(state, value, column) + def _set_state_attr_by_column(self, state, dict_, column, value): + return self._get_col_to_prop(column)._setattr(state, dict_, value, column) def _get_committed_attr_by_column(self, obj, column): state = attributes.instance_state(obj) - return self._get_committed_state_attr_by_column(state, column) + dict_ = attributes.instance_dict(obj) + return self._get_committed_state_attr_by_column(state, dict_, column) - def _get_committed_state_attr_by_column(self, state, column, passive=False): - return self._get_col_to_prop(column).getcommitted(state, column, passive=passive) + def _get_committed_state_attr_by_column(self, state, dict_, column, passive=False): + return self._get_col_to_prop(column)._getcommitted(state, dict_, column, passive=passive) def _optimized_get_statement(self, state, attribute_names): """assemble a WHERE clause which retrieves a given state by primary key, using a minimized set of tables. @@ -1171,12 +1173,12 @@ class Mapper(object): return if leftcol.table not in tables: - leftval = self._get_committed_state_attr_by_column(state, leftcol, passive=True) + leftval = self._get_committed_state_attr_by_column(state, state.dict, leftcol, passive=True) if leftval is attributes.PASSIVE_NO_RESULT: raise ColumnsNotAvailable() binary.left = sql.bindparam(None, leftval, type_=binary.right.type) elif rightcol.table not in tables: - rightval = self._get_committed_state_attr_by_column(state, rightcol, passive=True) + rightval = self._get_committed_state_attr_by_column(state, state.dict, rightcol, passive=True) if rightval is attributes.PASSIVE_NO_RESULT: raise ColumnsNotAvailable() binary.right = sql.bindparam(None, rightval, type_=binary.right.type) @@ -1224,11 +1226,13 @@ class Mapper(object): try: if item_type == 'property': prop = iterator.next() - visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None)) + visitables.append((prop.cascade_iterator(type_, parent_state, + visited_instances, halt_on), 'mapper', None)) elif item_type == 'mapper': instance, instance_mapper, corresponding_state = iterator.next() yield (instance, instance_mapper) - visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state)) + visitables.append((instance_mapper._props.itervalues(), + 'property', corresponding_state)) except StopIteration: visitables.pop() @@ -1263,55 +1267,46 @@ class Mapper(object): # if batch=false, call _save_obj separately for each object if not single and not self.batch: for state in _sort_states(states): - self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + self._save_obj([state], + uowtransaction, + postupdate=postupdate, + post_update_cols=post_update_cols, + single=True) return - + # if session has a connection callable, - # organize individual states with the connection to use for insert/update - tups = [] + # organize individual states with the connection + # to use for insert/update if 'connection_callable' in uowtransaction.mapper_flush_opts: - connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - for state in _sort_states(states): - m = _state_mapper(state) - tups.append( - ( - state, - m, - connection_callable(self, state.obj()), - _state_has_identity(state), - state.key or m._identity_key_from_state(state) - ) - ) + connection_callable = \ + uowtransaction.mapper_flush_opts['connection_callable'] else: connection = uowtransaction.transaction.connection(self) - for state in _sort_states(states): - m = _state_mapper(state) - tups.append( - ( - state, - m, - connection, - _state_has_identity(state), - state.key or m._identity_key_from_state(state) - ) - ) + connection_callable = None - if not postupdate: - # call before_XXX extensions - for state, mapper, connection, has_identity, instance_key in tups: + tups = [] + for state in _sort_states(states): + conn = connection_callable and \ + connection_callable(self, state.obj()) or \ + connection + + has_identity = _state_has_identity(state) + mapper = _state_mapper(state) + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = None + if not postupdate: + # call before_XXX extensions if not has_identity: if 'before_insert' in mapper.extension: - mapper.extension.before_insert(mapper, connection, state.obj()) + mapper.extension.before_insert(mapper, conn, state.obj()) else: if 'before_update' in mapper.extension: - mapper.extension.before_update(mapper, connection, state.obj()) + mapper.extension.before_update(mapper, conn, state.obj()) - row_switches = {} - if not postupdate: - for state, mapper, connection, has_identity, instance_key in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. convert to an - # UPDATE if so. + # and another instance with the same identity key already exists as persistent. + # convert to an UPDATE if so. if not has_identity and instance_key in uowtransaction.session.identity_map: instance = uowtransaction.session.identity_map[instance_key] existing = attributes.instance_state(instance) @@ -1320,28 +1315,43 @@ class Mapper(object): "New instance %s with identity key %s conflicts " "with persistent instance %s" % (state_str(state), instance_key, state_str(existing))) - + self._log_debug( - "detected row switch for identity %s. will update %s, remove %s from " - "transaction", instance_key, state_str(state), state_str(existing)) - + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + # remove the "delete" flag from the existing element uowtransaction.set_row_switch(existing) - row_switches[state] = existing - + row_switch = existing + + tups.append( + (state, + state.dict, + mapper, + conn, + has_identity, + instance_key, + row_switch) + ) + table_to_mapper = self._sorted_tables - for table in table_to_mapper.iterkeys(): + for table in table_to_mapper: insert = [] update = [] - for state, mapper, connection, has_identity, instance_key in tups: + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in tups: if table not in mapper._pks_by_table: continue pks = mapper._pks_by_table[table] - isinsert = not has_identity and not postupdate and state not in row_switches + isinsert = not has_identity and \ + not postupdate and \ + not row_switch params = {} value_params = {} @@ -1359,11 +1369,11 @@ class Mapper(object): value is not None): params[col.key] = value elif col in pks: - value = mapper._get_state_attr_by_column(state, col) + value = mapper._get_state_attr_by_column(state, state_dict, col) if value is not None: params[col.key] = value else: - value = mapper._get_state_attr_by_column(state, col) + value = mapper._get_state_attr_by_column(state, state_dict, col) if ((col.default is None and col.server_default is None) or value is not None): @@ -1371,23 +1381,37 @@ class Mapper(object): value_params[col] = value else: params[col.key] = value - insert.append((state, params, mapper, connection, value_params)) + insert.append((state, state_dict, params, mapper, + connection, value_params)) else: for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: - params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col) - params[col.key] = mapper.version_id_generator(params[col._label]) + params[col._label] = \ + mapper._get_state_attr_by_column( + row_switch or state, + row_switch and row_switch.dict or state_dict, + col) + params[col.key] = \ + mapper.version_id_generator(params[col._label]) + + # HACK: check for history, in case the history is only + # in a different table than the one where the version_id_col + # is. for prop in mapper._columntoproperty.itervalues(): - history = attributes.get_state_history(state, prop.key, passive=True) + history = attributes.get_state_history( + state, prop.key, passive=True) if history.added: hasdata = True elif mapper.polymorphic_on is not None and \ - mapper.polymorphic_on.shares_lineage(col) and col not in pks: + mapper.polymorphic_on.shares_lineage(col) and \ + col not in pks: pass else: - if post_update_cols is not None and col not in post_update_cols: + if post_update_cols is not None and \ + col not in post_update_cols: if col in pks: - params[col._label] = mapper._get_state_attr_by_column(state, col) + params[col._label] = \ + mapper._get_state_attr_by_column(state, state_dict, col) continue prop = mapper._columntoproperty[col] @@ -1422,63 +1446,94 @@ class Mapper(object): else: hasdata = True elif col in pks: - params[col._label] = mapper._get_state_attr_by_column(state, col) + params[col._label] = mapper._get_state_attr_by_column(state, state_dict, col) if hasdata: - update.append((state, params, mapper, connection, value_params)) + update.append((state, state_dict, params, mapper, + connection, value_params)) + if update: mapper = table_to_mapper[table] clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) + clause.clauses.append( + col == + sql.bindparam(col._label, type_=col.type) + ) - if mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col): - + needs_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + if needs_version_id: clause.clauses.append(mapper.version_id_col ==\ sql.bindparam(mapper.version_id_col._label, type_=col.type)) statement = table.update(clause) - + + if len(update) > 1: + compiled_cache = {} + else: + compiled_cache = None + rows = 0 - for state, params, mapper, connection, value_params in update: - c = connection.execute(statement.values(value_params), params) - mapper._postfetch(uowtransaction, connection, table, - state, c, c.last_updated_params(), value_params) + for state, state_dict, params, mapper, connection, value_params in update: + if not value_params and compiled_cache is not None: + c = connection.\ + execution_options( + compiled_cache=compiled_cache).\ + execute(statement, params) + else: + c = connection.execute(statement.values(value_params), params) + + mapper._postfetch(uowtransaction, table, + state, state_dict, c, c.last_updated_params(), value_params) rows += c.rowcount if connection.dialect.supports_sane_rowcount: if rows != len(update): raise orm_exc.ConcurrentModificationError( - "Updated rowcount %d does not match number of objects updated %d" % + "Updated rowcount %d does not match number " + "of objects updated %d" % (rows, len(update))) - - elif mapper.version_id_col is not None: + + elif needs_version_id: util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % c.dialect.dialect_description, - stacklevel=12) + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) if insert: statement = table.insert() - for state, params, mapper, connection, value_params in insert: - c = connection.execute(statement.values(value_params), params) + if len(insert) > 1: + compiled_cache = {} + else: + compiled_cache = None + + for state, state_dict, params, mapper, connection, value_params in insert: + if not value_params and compiled_cache is not None: + c = connection.\ + execution_options( + compiled_cache=compiled_cache).\ + execute(statement, params) + else: + c = connection.execute(statement.values(value_params), params) primary_key = c.inserted_primary_key if primary_key is not None: # set primary key attributes for i, col in enumerate(mapper._pks_by_table[table]): - if mapper._get_state_attr_by_column(state, col) is None and \ + if mapper._get_state_attr_by_column(state, state_dict, col) is None and \ len(primary_key) > i: - mapper._set_state_attr_by_column(state, col, primary_key[i]) + mapper._set_state_attr_by_column(state, state_dict, col, primary_key[i]) - mapper._postfetch(uowtransaction, connection, table, - state, c, c.last_inserted_params(), value_params) + mapper._postfetch(uowtransaction, table, + state, state_dict, c, c.last_inserted_params(), value_params) - if not postupdate: - for state, mapper, connection, has_identity, instance_key in tups: + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in tups: # expire readonly attributes readonly = state.unmodified.intersection( @@ -1488,8 +1543,8 @@ class Mapper(object): if readonly: _expire_state(state, state.dict, readonly) - # if specified, eagerly refresh whatever has - # been expired. + # if eager_defaults option is enabled, + # refresh whatever has been expired. if self.eager_defaults and state.unloaded: state.key = self._identity_key_from_state(state) uowtransaction.session.query(self)._get( @@ -1504,8 +1559,8 @@ class Mapper(object): if 'after_update' in mapper.extension: mapper.extension.after_update(mapper, connection, state.obj()) - def _postfetch(self, uowtransaction, connection, table, - state, resultproxy, params, value_params): + def _postfetch(self, uowtransaction, table, + state, dict_, resultproxy, params, value_params): """Expire attributes in need of newly persisted database state.""" postfetch_cols = resultproxy.postfetch_cols() @@ -1521,25 +1576,39 @@ class Mapper(object): for c in generated_cols: if c.key in params and c in self._columntoproperty: - self._set_state_attr_by_column(state, c, params[c.key]) + self._set_state_attr_by_column(state, dict_, c, params[c.key]) - deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] - - if deferred_props: - _expire_state(state, state.dict, deferred_props) + if postfetch_cols: + _expire_state(state, state.dict, + [self._columntoproperty[c].key + for c in postfetch_cols] + ) # synchronize newly inserted ids from one table to the next # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here - cols = set(table.c) - for m in self.iterate_to_root(): - if m._inherits_equated_pairs and \ - cols.intersection([l for l, r in m._inherits_equated_pairs]): - sync.populate(state, m, state, m, - m._inherits_equated_pairs, - uowtransaction, - self.passive_updates) - + for m, equated_pairs in self._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + self.passive_updates) + + @util.memoized_property + def _table_to_equated(self): + """memoized map of tables to collections of columns to be + synchronized upwards to the base mapper.""" + + result = util.defaultdict(list) + + for table in self._sorted_tables: + cols = set(table.c) + for m in self.iterate_to_root(): + if m._inherits_equated_pairs and \ + cols.intersection([l for l, r in m._inherits_equated_pairs]): + result[table].append((m, m._inherits_equated_pairs)) + + return result + def _delete_obj(self, states, uowtransaction): """Issue ``DELETE`` statements for a list of objects. @@ -1548,50 +1617,96 @@ class Mapper(object): """ if 'connection_callable' in uowtransaction.mapper_flush_opts: - connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)] + connection_callable = \ + uowtransaction.mapper_flush_opts['connection_callable'] else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)] - - for state, mapper, connection in tups: + connection_callable = None + + tups = [] + for state in _sort_states(states): + mapper = _state_mapper(state) + + conn = connection_callable and \ + connection_callable(self, state.obj()) or \ + connection + if 'before_delete' in mapper.extension: - mapper.extension.before_delete(mapper, connection, state.obj()) + mapper.extension.before_delete(mapper, conn, state.obj()) + + tups.append((state, + state.dict, + _state_mapper(state), + _state_has_identity(state), + conn)) table_to_mapper = self._sorted_tables for table in reversed(table_to_mapper.keys()): - delete = {} - for state, mapper, connection in tups: - if table not in mapper._pks_by_table: + delete = util.defaultdict(list) + for state, state_dict, mapper, has_identity, connection in tups: + if not has_identity or table not in mapper._pks_by_table: continue params = {} - if not _state_has_identity(state): - continue - else: - delete.setdefault(connection, []).append(params) + delete[connection].append(params) for col in mapper._pks_by_table[table]: - params[col.key] = mapper._get_state_attr_by_column(state, col) - if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) + params[col.key] = mapper._get_state_attr_by_column(state, state_dict, col) + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = \ + mapper._get_state_attr_by_column(state, state_dict, mapper.version_id_col) for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) - if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): + + need_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + if need_version_id: clause.clauses.append( mapper.version_id_col == - sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type)) + sql.bindparam( + mapper.version_id_col.key, + type_=mapper.version_id_col.type + ) + ) + statement = table.delete(clause) - c = connection.execute(statement, del_objects) - if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): - raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match " - "number of objects deleted %d" % (c.rowcount, len(del_objects))) + rows = -1 + + if need_version_id and \ + not connection.dialect.supports_sane_multi_rowcount: + # TODO: need test coverage for this [ticket:1761] + if connection.dialect.supports_sane_rowcount: + rows = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute(statement, params) + rows += c.rowcount + else: + util.warn("Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + connection.execute(statement, del_objects) + else: + c = connection.execute(statement, del_objects) + if connection.dialect.supports_sane_multi_rowcount: + rows = c.rowcount + + if rows != -1 and rows != len(del_objects): + raise orm_exc.ConcurrentModificationError( + "Deleted rowcount %d does not match " + "number of objects deleted %d" % + (c.rowcount, len(del_objects)) + ) - for state, mapper, connection in tups: + for state, state_dict, mapper, has_identity, connection in tups: if 'after_delete' in mapper.extension: mapper.extension.after_delete(mapper, connection, state.obj()) @@ -1723,12 +1838,13 @@ class Mapper(object): context.version_check and \ self._get_state_attr_by_column( state, + dict_, self.version_id_col) != row[version_id_col]: raise orm_exc.ConcurrentModificationError( "Instance '%s' version of %s does not match %s" % (state_str(state), - self._get_state_attr_by_column(state, self.version_id_col), + self._get_state_attr_by_column(state, dict_, self.version_id_col), row[version_id_col])) elif refresh_state: # out of band refresh_state detected (i.e. its not in the session.identity_map) @@ -1916,7 +2032,7 @@ def _event_on_resurrect(state, instance): # of the dict based on the mapping. instrumenting_mapper = state.manager.info[_INSTRUMENTOR] for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): - instrumenting_mapper._set_state_attr_by_column(state, col, val) + instrumenting_mapper._set_state_attr_by_column(state, state.dict, col, val) def _sort_states(states): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 2a5e92c1a..6d5cc0524 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -58,6 +58,8 @@ class ColumnProperty(StrategizedProperty): self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator) self.descriptor = kwargs.pop('descriptor', None) self.extension = kwargs.pop('extension', None) + self.doc = kwargs.pop('doc', getattr(columns[0], 'doc', None)) + if kwargs: raise TypeError( "%s received unexpected keyword argument(s): %s" % ( @@ -80,7 +82,8 @@ class ColumnProperty(StrategizedProperty): self.key, comparator=self.comparator_factory(self, mapper), parententity=mapper, - property_=self + property_=self, + doc=self.doc ) def do_init(self): @@ -96,14 +99,14 @@ class ColumnProperty(StrategizedProperty): def copy(self): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) - def getattr(self, state, column): - return state.get_impl(self.key).get(state, state.dict) + def _getattr(self, state, dict_, column): + return state.get_impl(self.key).get(state, dict_) - def getcommitted(self, state, column, passive=False): - return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) + def _getcommitted(self, state, dict_, column, passive=False): + return state.get_impl(self.key).get_committed_value(state, dict_, passive=passive) - def setattr(self, state, value, column): - state.get_impl(self.key).set(state, state.dict, value, None) + def _setattr(self, state, dict_, value, column): + state.get_impl(self.key).set(state, dict_, value, None) def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): if self.key in source_dict: @@ -161,18 +164,18 @@ class CompositeProperty(ColumnProperty): # which issues assertions that do not apply to CompositeColumnProperty super(ColumnProperty, self).do_init() - def getattr(self, state, column): - obj = state.get_impl(self.key).get(state, state.dict) + def _getattr(self, state, dict_, column): + obj = state.get_impl(self.key).get(state, dict_) return self.get_col_value(column, obj) - def getcommitted(self, state, column, passive=False): + def _getcommitted(self, state, dict_, column, passive=False): # TODO: no coverage here - obj = state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) + obj = state.get_impl(self.key).get_committed_value(state, dict_, passive=passive) return self.get_col_value(column, obj) - def setattr(self, state, value, column): + def _setattr(self, state, dict_, value, column): - obj = state.get_impl(self.key).get(state, state.dict) + obj = state.get_impl(self.key).get(state, dict_) if obj is None: obj = self.composite_class(*[None for c in self.columns]) state.get_impl(self.key).set(state, state.dict, obj, None) @@ -259,11 +262,12 @@ class SynonymProperty(MapperProperty): extension = None - def __init__(self, name, map_column=None, descriptor=None, comparator_factory=None): + def __init__(self, name, map_column=None, descriptor=None, comparator_factory=None, doc=None): self.name = name self.map_column = map_column self.descriptor = descriptor self.comparator_factory = comparator_factory + self.doc = doc or (descriptor and descriptor.__doc__) or None util.set_creation_order(self) def setup(self, context, entity, path, adapter, **kwargs): @@ -303,7 +307,8 @@ class SynonymProperty(MapperProperty): comparator=comparator_callable(self, mapper), parententity=mapper, property_=self, - proxy_property=self.descriptor + proxy_property=self.descriptor, + doc=self.doc ) def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): @@ -316,9 +321,10 @@ class ComparableProperty(MapperProperty): extension = None - def __init__(self, comparator_factory, descriptor=None): + def __init__(self, comparator_factory, descriptor=None, doc=None): self.descriptor = descriptor self.comparator_factory = comparator_factory + self.doc = doc or (descriptor and descriptor.__doc__) or None util.set_creation_order(self) def instrument_class(self, mapper): @@ -330,7 +336,8 @@ class ComparableProperty(MapperProperty): comparator=self.comparator_factory(self, mapper), parententity=mapper, property_=self, - proxy_property=self.descriptor + proxy_property=self.descriptor, + doc=self.doc, ) def setup(self, context, entity, path, adapter, **kwargs): @@ -364,6 +371,7 @@ class RelationshipProperty(StrategizedProperty): enable_typechecks=True, join_depth=None, comparator_factory=None, single_parent=False, innerjoin=False, + doc=None, strategy_class=None, _local_remote_pairs=None, query_class=None): self.uselist = uselist @@ -384,7 +392,7 @@ class RelationshipProperty(StrategizedProperty): self.enable_typechecks = enable_typechecks self.query_class = query_class self.innerjoin = innerjoin - + self.doc = doc self.join_depth = join_depth self.local_remote_pairs = _local_remote_pairs self.extension = extension @@ -433,7 +441,8 @@ class RelationshipProperty(StrategizedProperty): self.key, comparator=self.comparator_factory(self, mapper), parententity=mapper, - property_=self + property_=self, + doc=self.doc, ) class Comparator(PropComparator): @@ -1149,7 +1158,7 @@ class RelationshipProperty(StrategizedProperty): parent = self.parent.primary_mapper() kwargs.setdefault('viewonly', self.viewonly) kwargs.setdefault('post_update', self.post_update) - + self.back_populates = backref_key relationship = RelationshipProperty( parent, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 93b1170f4..96aac7d3a 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -66,6 +66,7 @@ def _register_attribute(strategy, mapper, useobject, callable_=callable_, active_history=active_history, impl_class=impl_class, + doc=prop.doc, **kw ) @@ -591,6 +592,7 @@ class LoadLazyAttribute(object): val = instance_mapper.\ _get_committed_state_attr_by_column( state, + state.dict, strategy._equated_columns[primary_key], **kw) if val is attributes.PASSIVE_NO_RESULT: diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 30daacbdf..b9ddbb6e7 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -14,12 +14,12 @@ def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs, uowcommit, passive_updates): for l, r in synchronize_pairs: try: - value = source_mapper._get_state_attr_by_column(source, l) + value = source_mapper._get_state_attr_by_column(source, source.dict, l) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) try: - dest_mapper._set_state_attr_by_column(dest, r, value) + dest_mapper._set_state_attr_by_column(dest, dest.dict, r, value) except exc.UnmappedColumnError: _raise_col_to_prop(True, source_mapper, l, dest_mapper, r) @@ -41,7 +41,7 @@ def clear(dest, dest_mapper, synchronize_pairs): (r, mapperutil.state_str(dest)) ) try: - dest_mapper._set_state_attr_by_column(dest, r, None) + dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) except exc.UnmappedColumnError: _raise_col_to_prop(True, None, l, dest_mapper, r) @@ -49,7 +49,7 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs): for l, r in synchronize_pairs: try: oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l) - value = source_mapper._get_state_attr_by_column(source, l) + value = source_mapper._get_state_attr_by_column(source, source.dict, l) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) dest[r.key] = value @@ -58,7 +58,7 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs): def populate_dict(source, source_mapper, dict_, synchronize_pairs): for l, r in synchronize_pairs: try: - value = source_mapper._get_state_attr_by_column(source, l) + value = source_mapper._get_state_attr_by_column(source, source.dict, l) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 30b0b61e5..ea4e0b1da 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -4,19 +4,17 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""The internals for the Unit Of Work system. +"""The internals for the unit of work system. -Includes hooks into the attributes package enabling the routing of -change events to Unit Of Work objects, as well as the flush() -mechanism which creates a dependency structure that executes change -operations. +The session's flush() process passes objects to a contextual object +here, which assembles flush tasks based on mappers and their properties, +organizes them in order of dependency, and executes. + +Most of the code in this module is obsolete, and will +be replaced by a much simpler and more efficient +system in an upcoming release. See [ticket:1742] for +details. -A Unit of Work is essentially a system of maintaining a graph of -in-memory objects and their modified state. Objects are maintained as -unique against their primary key identity using an *identity map* -pattern. The Unit of Work then maintains lists of objects that are -new, dirty, or deleted and provides the capability to flush all those -changes at once. """ @@ -40,7 +38,8 @@ class UOWEventHandler(interfaces.AttributeExtension): self.key = key def append(self, state, item, initiator): - # process "save_update" cascade rules for when an instance is appended to the list of another instance + # process "save_update" cascade rules for when + # an instance is appended to the list of another instance sess = _state_session(state) if sess: prop = _state_mapper(state).get_property(self.key) @@ -74,15 +73,8 @@ class UOWEventHandler(interfaces.AttributeExtension): class UOWTransaction(object): - """Handles the details of organizing and executing transaction - tasks during a UnitOfWork object's flush() operation. - - The central operation is to form a graph of nodes represented by the - ``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object - that issues SQL and instance-synchronizing operations via the related - packages. - """ - + """Represent the state of a flush operation in progress.""" + def __init__(self, session): self.session = session self.mapper_flush_opts = session._mapper_flush_opts @@ -94,7 +86,8 @@ class UOWTransaction(object): # dictionary of mappers to UOWTasks self.tasks = {} - # dictionary used by external actors to store arbitrary state + # dictionary used by external actors + # to store arbitrary state # information. self.attributes = {} @@ -169,8 +162,6 @@ class UOWTransaction(object): def get_task_by_mapper(self, mapper, dontcreate=False): """return UOWTask element corresponding to the given mapper. - Will create a new UOWTask, including a UOWTask corresponding to the - "base" inherited mapper, if needed, unless the dontcreate flag is True. """ try: @@ -195,16 +186,13 @@ class UOWTransaction(object): return task def register_dependency(self, mapper, dependency): - """register a dependency between two mappers. - - Called by ``mapper.PropertyLoader`` to register the objects - handled by one mapper being dependent on the objects handled - by another. - - """ + """register a dependency between two mappers.""" + # correct for primary mapper - # also convert to the "base mapper", the parentmost task at the top of an inheritance chain - # dependency sorting is done via non-inheriting mappers only, dependencies between mappers + # also convert to the "base mapper", the parentmost + # task at the top of an inheritance chain + # dependency sorting is done via non-inheriting + # mappers only, dependencies between mappers # in the same inheritance chain is done at the per-object level mapper = mapper.primary_mapper().base_mapper dependency = dependency.primary_mapper().base_mapper @@ -226,14 +214,7 @@ class UOWTransaction(object): task.dependencies.add(up) def execute(self): - """Execute this UOWTransaction. - - This will organize all collected UOWTasks into a dependency-sorted - list which is then traversed using the traversal scheme - encoded in the UOWExecutor class. Operations to mappers and dependency - processors are fired off in order to issue SQL to the database and - synchronize instance attributes with database values and related - foreign key values.""" + """Execute this steps assembled into this UOWTransaction.""" # pre-execute dependency processors. this process may # result in new tasks, objects and/or dependency processors being added, @@ -281,7 +262,7 @@ class UOWTransaction(object): self.session._register_newly_persistent(elem.state) def _sort_dependencies(self): - nodes = topological.sort_with_cycles(self.dependencies, + nodes = topological._sort_with_cycles(self.dependencies, [t.mapper for t in self.tasks.itervalues() if t.base_task is t] ) @@ -302,7 +283,10 @@ class UOWTransaction(object): log.class_logger(UOWTransaction) class UOWTask(object): - """A collection of mapped states corresponding to a particular mapper.""" + """A collection of mapped states corresponding to a particular mapper. + + This object is deprecated. + """ def __init__(self, uowtransaction, mapper, base_task=None): self.uowtransaction = uowtransaction @@ -339,30 +323,6 @@ class UOWTask(object): """Return an iterator of UOWTask objects corresponding to the inheritance sequence of this UOWTask's mapper. - e.g. if mapper B and mapper C inherit from mapper A, and - mapper D inherits from B: - - mapperA -> mapperB -> mapperD - -> mapperC - - the inheritance sequence starting at mapper A is a depth-first - traversal: - - [mapperA, mapperB, mapperD, mapperC] - - this method will therefore return - - [UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD), - UOWTask(mapperC)] - - The concept of "polymporphic iteration" is adapted into - several property-based iterators which return object - instances, UOWTaskElements and UOWDependencyProcessors in an - order corresponding to this sequence of parent UOWTasks. This - is used to issue operations related to inheritance-chains of - mappers in the proper order based on dependencies between - those mappers. - """ for mapper in self.inheriting_mappers: t = self.base_task._inheriting_tasks.get(mapper, None) @@ -370,9 +330,9 @@ class UOWTask(object): yield t def is_empty(self): - """return True if this UOWTask is 'empty', meaning it has no child items. + """return True if this UOWTask is 'empty', + meaning it has no child items. - used only for debugging output. """ return not self._objects and not self.dependencies @@ -386,19 +346,17 @@ class UOWTask(object): rec.update(listonly, isdelete) def append_postupdate(self, state, post_update_cols): - """issue a 'post update' UPDATE statement via this object's mapper immediately. + """issue a 'post update' UPDATE statement via + this object's mapper immediately. - this operation is used only with relationships that specify the `post_update=True` - flag. """ - # postupdates are UPDATED immeditely (for now) - # convert post_update_cols list to a Set so that __hash__() is used to compare columns - # instead of __eq__() - self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=set(post_update_cols)) + self.mapper._save_obj([state], self.uowtransaction, + postupdate=True, post_update_cols=set(post_update_cols)) def __contains__(self, state): - """return True if the given object is contained within this UOWTask or inheriting tasks.""" + """return True if the given object is contained + within this UOWTask or inheriting tasks.""" for task in self.polymorphic_tasks: if state in task._objects: @@ -407,7 +365,8 @@ class UOWTask(object): return False def is_deleted(self, state): - """return True if the given object is marked as to be deleted within this UOWTask.""" + """return True if the given object is marked + as to be deleted within this UOWTask.""" try: return self._objects[state].isdelete @@ -477,10 +436,15 @@ class UOWTask(object): return self.cyclical_dependencies def _sort_circular_dependencies(self, trans, cycles): - """Topologically sort individual entities with row-level dependencies. + """sort row-level dependencies. - Builds a modified UOWTask structure, and is invoked when the - per-mapper topological structure is found to have cycles. + Note that this method is a total disaster, as it was + bolted onto the originally simple unit-of-work + system after more complex mappings revealed + the presence of inter-row rependencies - this occured + well within version 0.1 and despite many fixes + has remained the most legacy code within SQLAlchemy. + It is gone without a trace after [ticket:1742]. """ @@ -492,7 +456,8 @@ class UOWTask(object): if depprocessor not in tasks: tasks[depprocessor] = UOWDependencyProcessor( depprocessor.processor, - UOWTask(self.uowtransaction, depprocessor.targettask.mapper) + UOWTask(self.uowtransaction, + depprocessor.targettask.mapper) ) tasks[depprocessor].targettask.append(target_state, isdelete=isdelete) @@ -529,7 +494,8 @@ class UOWTask(object): isdelete = taskelement.isdelete # list of dependent objects from this object - (added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True) + (added, unchanged, deleted) = dep.get_object_dependencies( + state, trans, passive=True) if not added and not unchanged and not deleted: continue @@ -551,9 +517,11 @@ class UOWTask(object): tuples.append(whosdep) if whosdep[0] is state: - set_processor_for_state(whosdep[0], dep, whosdep[0], isdelete=isdelete) + set_processor_for_state(whosdep[0], dep, whosdep[0], + isdelete=isdelete) else: - set_processor_for_state(whosdep[0], dep, whosdep[1], isdelete=isdelete) + set_processor_for_state(whosdep[0], dep, whosdep[1], + isdelete=isdelete) else: # TODO: no test coverage here set_processor_for_state(state, dep, state, isdelete=isdelete) @@ -567,11 +535,7 @@ class UOWTask(object): # dependency - keep non-dependent objects # grouped together, so that insert ordering as determined # by session.add() is maintained. - # An alternative might be to represent the "insert order" - # as part of the topological sort itself, which would - # eliminate the need for this step (but may make the original - # topological sort more expensive) - head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys()) + head = topological._sort_as_tree(tuples, object_to_original_task.iterkeys()) if head is not None: original_to_tasks = {} stack = [(head, t)] @@ -588,7 +552,8 @@ class UOWTask(object): else: task = original_to_tasks[(parenttask, originating_task)] - task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) + task.append(state, originating_task._objects[state].listonly, + isdelete=originating_task._objects[state].isdelete) if state in dependencies: task.cyclical_dependencies.update(dependencies[state].itervalues()) @@ -614,10 +579,9 @@ class UOWTask(object): return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper)) class UOWTaskElement(object): - """Corresponds to a single InstanceState to be saved, deleted, - or otherwise marked as having dependencies. A collection of - UOWTaskElements are held by a UOWTask. - + """Represent a single state to be saved. + + This object is deprecated. """ def __init__(self, state): self.state = state @@ -642,11 +606,9 @@ class UOWTaskElement(object): ) class UOWDependencyProcessor(object): - """In between the saving and deleting of objects, process - dependent data, such as filling in a foreign key on a child item - from a new primary key, or deleting association rows before a - delete. This object acts as a proxy to a DependencyProcessor. + """Represent tasks in between inserts/updates/deletes. + This object is deprecated. """ def __init__(self, processor, targettask): self.processor = processor @@ -674,19 +636,8 @@ class UOWDependencyProcessor(object): return hash((self.processor, self.targettask)) def preexecute(self, trans): - """preprocess all objects contained within this ``UOWDependencyProcessor``s target task. - - This may locate additional objects which should be part of the - transaction, such as those affected deletes, orphans to be - deleted, etc. - - Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent - changes occur to the ``UOWTaskElement``, its processed flag is reset, and will require processing - again. - - Return True if any objects were preprocessed, or False if no - objects were preprocessed. If True is returned, the parent ``UOWTransaction`` will - ultimately call ``preexecute()`` again on all processors until no new objects are processed. + """preprocess all objects contained within + this ``UOWDependencyProcessor``s target task. """ def getobj(elem): @@ -710,7 +661,8 @@ class UOWDependencyProcessor(object): return ret def execute(self, trans, delete): - """process all objects contained within this ``UOWDependencyProcessor``s target task.""" + """process all objects contained within this + ``UOWDependencyProcessor``s target task.""" elements = [e for e in @@ -729,15 +681,18 @@ class UOWDependencyProcessor(object): def whose_dependent_on_who(self, state1, state2): """establish which object is operationally dependent amongst a parent/child using the semantics stated by the dependency processor. - - This method is used to establish a partial ordering (set of dependency tuples) - when toplogically sorting on a per-instance basis. - """ return self.processor.whose_dependent_on_who(state1, state2) class UOWExecutor(object): - """Encapsulates the execution traversal of a UOWTransaction structure.""" + """Encapsulates the execution traversal + of a UOWTransaction structure. + + This part of the approach is the core flaw that's + being removed with [ticket:1742], as it necessitates + deep levels of recursion. + + """ def execute(self, trans, tasks, isdelete=None): if isdelete is not True: diff --git a/lib/sqlalchemy/orm/uowdumper.py b/lib/sqlalchemy/orm/uowdumper.py index dd96b6b9a..7884cabd4 100644 --- a/lib/sqlalchemy/orm/uowdumper.py +++ b/lib/sqlalchemy/orm/uowdumper.py @@ -4,7 +4,12 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Dumps out a string representation of a UOWTask structure""" +"""Dumps out a string representation of a UOWTask structure. + +This module is deprecated and will be removed once +[ticket:1742] is complete. + +""" from sqlalchemy.orm import unitofwork from sqlalchemy.orm import util as mapperutil diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 8ffb68a4e..0e03be686 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -535,6 +535,10 @@ class Column(SchemaItem, expression.ColumnClause): Contrast this argument to ``server_default`` which creates a default generator on the database side. + :param doc: optional String that can be used by the ORM or similar + to document attributes. This attribute does not render SQL + comments (a future attribute 'comment' will achieve that). + :param key: An optional string identifier which will identify this ``Column`` object on the :class:`Table`. When a key is provided, this is the only identifier referencing the ``Column`` within the @@ -651,6 +655,7 @@ class Column(SchemaItem, expression.ColumnClause): self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) self.quote = kwargs.pop('quote', None) + self.doc = kwargs.pop('doc', None) self.onupdate = kwargs.pop('onupdate', None) self.autoincrement = kwargs.pop('autoincrement', True) self.constraints = set() diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 5958a0bc4..fc6b5ad97 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -2276,6 +2276,23 @@ class Executable(_Generative): of many DBAPIs. The flag is currently understood only by the psycopg2 dialect. + * compiled_cache - a dictionary where :class:`Compiled` objects + will be cached when the :class:`Connection` compiles a clause + expression into a dialect- and parameter-specific + :class:`Compiled` object. It is the user's responsibility to + manage the size of this dictionary, which will have keys + corresponding to the dialect, clause element, the column + names within the VALUES or SET clause of an INSERT or UPDATE, + as well as the "batch" mode for an INSERT or UPDATE statement. + The format of this dictionary is not guaranteed to stay the + same in future releases. + + This option is usually more appropriate + to use via the + :meth:`sqlalchemy.engine.base.Connection.execution_options()` + method of :class:`Connection`, rather than upon individual + statement objects, though the effect is the same. + See also: :meth:`sqlalchemy.engine.base.Connection.execution_options()` @@ -2875,18 +2892,13 @@ class Join(FromClause): select, for columns that are calculated to be "equivalent" based on the join criterion of this :class:`Join`. This will recursively apply to any joins directly nested by this one - as well. This flag is specific to a particular use case - by the ORM and is deprecated as of 0.6. + as well. :param \**kwargs: all other kwargs are sent to the underlying :func:`select()` function. """ if fold_equivalents: - global sql_util - if not sql_util: - from sqlalchemy.sql import util as sql_util - util.warn_deprecated("fold_equivalents is deprecated.") collist = sql_util.folded_equivalents(self) else: collist = [self.left, self.right] diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 5a439b099..ccbeea371 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -20,7 +20,7 @@ def sort_tables(tables): for table in tables: visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) - return topological.sort(tuples, tables) + return topological._sort(tuples, tables) def find_join_source(clauses, join_to): """Given a list of FROM clauses and a selectable, diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py index 771b8c90f..70ddc7ba2 100644 --- a/lib/sqlalchemy/test/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -546,6 +546,23 @@ def fixture(table, columns, *rows): for column_values in rows]) table.append_ddl_listener('after-create', onload) +def provide_metadata(fn): + """Provides a bound MetaData object for a single test, + drops it afterwards.""" + def maybe(*args, **kw): + metadata = schema.MetaData(db) + context = dict(fn.func_globals) + context['metadata'] = metadata + # jython bug #1034 + rebound = types.FunctionType( + fn.func_code, context, fn.func_name, fn.func_defaults, + fn.func_closure) + try: + return rebound(*args, **kw) + finally: + metadata.drop_all() + return function_named(maybe, fn.__name__) + def resolve_artifact_names(fn): """Decorator, augment function globals with tables and classes. diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index d061aec04..cfc07e861 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -6,57 +6,46 @@ """Topological sorting algorithms. -The topological sort is an algorithm that receives this list of -dependencies as a *partial ordering*, that is a list of pairs which -might say, *X is dependent on Y*, *Q is dependent on Z*, but does not -necessarily tell you anything about Q being dependent on X. Therefore, -its not a straight sort where every element can be compared to -another... only some of the elements have any sorting preference, and -then only towards just some of the other elements. For a particular -partial ordering, there can be many possible sorts that satisfy the -conditions. +All functions and classes in this module are currently deprecated, +and will be replaced by a much simpler and more efficient +system in an upcoming release. See [ticket:1742] for +details. """ from sqlalchemy.exc import CircularDependencyError from sqlalchemy import util -__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree'] - -def sort(tuples, allitems): +def _sort(tuples, allitems): """sort the given list of items by dependency. - 'tuples' is a list of tuples representing a partial ordering. + deprecated. a new sort with slightly different + behavior will replace this method in an upcoming release. """ - return [n.item for n in _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=True)] + return [n.item for n in _sort_impl(tuples, allitems, + allow_cycles=False, + ignore_self_cycles=True)] -def sort_with_cycles(tuples, allitems): +def _sort_with_cycles(tuples, allitems): """sort the given list of items by dependency, cutting out cycles. - - returns results as an iterable of 2-tuples, containing the item, - and a list containing items involved in a cycle with this item, if any. - - 'tuples' is a list of tuples representing a partial ordering. + + deprecated. a new approach to cycle detection will + be introduced in an upcoming release. """ - return [(n.item, [n.item for n in n.cycles or []]) for n in _sort(tuples, allitems, allow_cycles=True)] + return [(n.item, [n.item for n in n.cycles or []]) + for n in _sort_impl(tuples, allitems, allow_cycles=True)] -def sort_as_tree(tuples, allitems, with_cycles=False): +def _sort_as_tree(tuples, allitems, with_cycles=False): """sort the given list of items by dependency, and return results as a hierarchical tree structure. - returns results as an iterable of 3-tuples, containing the item, - a list containing items involved in a cycle with this item, if any, - and a list of child tuples. - - if with_cycles is False, the returned structure is of the same form - but the second element of each tuple, i.e. the 'cycles', is an empty list. - - 'tuples' is a list of tuples representing a partial ordering. + deprecated. a new approach to "grouped" topological sorting + will be introduced in an upcoming release. """ - return _organize_as_tree(_sort(tuples, allitems, allow_cycles=with_cycles)) + return _organize_as_tree(_sort_impl(tuples, allitems, allow_cycles=with_cycles)) class _Node(object): @@ -156,7 +145,7 @@ class _EdgeCollection(object): def __repr__(self): return repr(list(self)) -def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): +def _sort_impl(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): nodes = {} edges = _EdgeCollection() @@ -221,6 +210,8 @@ def _organize_as_tree(nodes): set as siblings to each other as possible. returns nodes as 3-tuples (item, cycles, children). + + this function is deprecated. """ if not nodes: @@ -263,6 +254,9 @@ def _organize_as_tree(nodes): return (head.item, [n.item for n in head.cycles or []], head.children) def _find_cycles(edges): + """ + this function is deprecated. + """ cycles = {} def traverse(node, cycle, goal): diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 16cd57f26..198835562 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -524,7 +524,7 @@ class NullType(TypeEngine): __visit_name__ = 'null' def _adapt_expression(self, op, othertype): - if othertype is NullType or not operators.is_commutative(op): + if othertype is NULLTYPE or not operators.is_commutative(op): return op, self else: return othertype._adapt_expression(op, self) @@ -939,6 +939,13 @@ class Numeric(_DateAffinity, TypeEngine): # we're a "numeric", DBAPI will give us Decimal directly return None else: + util.warn("Dialect %s+%s does *not* support Decimal objects natively, " + "and SQLAlchemy must convert from floating point - " + "rounding errors and other issues may occur. " + "Please consider storing Decimal numbers as strings or " + "integers on this platform for lossless storage." % + (dialect.name, dialect.driver)) + # we're a "numeric", DBAPI returns floats, convert. if self.scale is not None: return processors.to_decimal_processor_factory(_python_Decimal, self.scale) @@ -976,7 +983,8 @@ class Float(Numeric): :param precision: the numeric precision for use in DDL ``CREATE TABLE``. :param asdecimal: the same flag as that of :class:`Numeric`, but - defaults to ``False``. + defaults to ``False``. Note that setting this flag to ``True`` + results in floating point conversion. """ self.precision = precision diff --git a/test/aaa_profiling/test_zoomark_orm.py b/test/aaa_profiling/test_zoomark_orm.py index 5b962b695..9b31d82f5 100644 --- a/test/aaa_profiling/test_zoomark_orm.py +++ b/test/aaa_profiling/test_zoomark_orm.py @@ -291,7 +291,7 @@ class ZooMarkTest(TestBase): def test_profile_1_create_tables(self): self.test_baseline_1_create_tables() - @profiling.function_call_count(12178, {'2.4':12178}) + @profiling.function_call_count(9225) def test_profile_1a_populate(self): self.test_baseline_1a_populate() diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 890dd7607..bd7d81112 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -1,4 +1,4 @@ -import sqlalchemy.topological as topological +from sqlalchemy import topological from sqlalchemy.test import TestBase @@ -56,7 +56,7 @@ class DependencySortTest(TestBase): (node4, subnode3), (node4, subnode4) ] - head = topological.sort_as_tree(tuples, []) + head = topological._sort_as_tree(tuples, []) self.assert_sort(tuples, head) def testsort2(self): @@ -74,7 +74,7 @@ class DependencySortTest(TestBase): (node5, node6), (node6, node2) ] - head = topological.sort_as_tree(tuples, [node7]) + head = topological._sort_as_tree(tuples, [node7]) self.assert_sort(tuples, head, [node7]) def testsort3(self): @@ -87,9 +87,9 @@ class DependencySortTest(TestBase): (node3, node2), (node1,node3) ] - head1 = topological.sort_as_tree(tuples, [node1, node2, node3]) - head2 = topological.sort_as_tree(tuples, [node3, node1, node2]) - head3 = topological.sort_as_tree(tuples, [node3, node2, node1]) + head1 = topological._sort_as_tree(tuples, [node1, node2, node3]) + head2 = topological._sort_as_tree(tuples, [node3, node1, node2]) + head3 = topological._sort_as_tree(tuples, [node3, node2, node1]) # TODO: figure out a "node == node2" function #self.assert_(str(head1) == str(head2) == str(head3)) @@ -108,7 +108,7 @@ class DependencySortTest(TestBase): (node1, node3), (node3, node2) ] - head = topological.sort_as_tree(tuples, []) + head = topological._sort_as_tree(tuples, []) self.assert_sort(tuples, head) def testsort5(self): @@ -131,7 +131,7 @@ class DependencySortTest(TestBase): node3, node4 ] - head = topological.sort_as_tree(tuples, allitems, with_cycles=True) + head = topological._sort_as_tree(tuples, allitems, with_cycles=True) self.assert_sort(tuples, head) def testcircular(self): @@ -149,7 +149,7 @@ class DependencySortTest(TestBase): (node4, node1) ] allitems = [node1, node2, node3, node4] - head = topological.sort_as_tree(tuples, allitems, with_cycles=True) + head = topological._sort_as_tree(tuples, allitems, with_cycles=True) self.assert_sort(tuples, head) def testcircular2(self): @@ -166,7 +166,7 @@ class DependencySortTest(TestBase): (node3, node2), (node2, node3) ] - head = topological.sort_as_tree(tuples, [], with_cycles=True) + head = topological._sort_as_tree(tuples, [], with_cycles=True) self.assert_sort(tuples, head) def testcircular3(self): @@ -174,16 +174,16 @@ class DependencySortTest(TestBase): tuples = [(question, issue), (providerservice, issue), (provider, question), (question, provider), (providerservice, question), (provider, providerservice), (question, answer), (issue, question)] - head = topological.sort_as_tree(tuples, [], with_cycles=True) + head = topological._sort_as_tree(tuples, [], with_cycles=True) self.assert_sort(tuples, head) def testbigsort(self): tuples = [(i, i + 1) for i in range(0, 1500, 2)] - head = topological.sort_as_tree(tuples, []) + head = topological._sort_as_tree(tuples, []) def testids(self): # ticket:1380 regression: would raise a KeyError - topological.sort([(id(i), i) for i in range(3)], []) + topological._sort([(id(i), i) for i in range(3)], []) diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index bcb34b05c..31e95f57f 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -34,9 +34,9 @@ create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT numb def test_out_params(self): result = testing.db.execute(text("begin foo(:x_in, :x_out, :y_out, :z_out); end;", bindparams=[ - bindparam('x_in', Numeric), + bindparam('x_in', Float), outparam('x_out', Integer), - outparam('y_out', Numeric), + outparam('y_out', Float), outparam('z_out', String)]), x_in=5) eq_( @@ -671,7 +671,7 @@ class TypesTest(TestBase, AssertsCompiledSQL): (15.76, float), )): eq_(row[i], val) - assert isinstance(row[i], type_) + assert isinstance(row[i], type_), "%r is not %r" % (row[i], type_) finally: t1.drop() diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 1d39d5653..14814bc20 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -228,7 +228,21 @@ class FloatCoercionTest(TablesTest, AssertsExecutionResults): ).scalar() eq_(round_decimal(ret, 9), result) - + @testing.provide_metadata + def test_arrays(self): + t1 = Table('t', metadata, + Column('x', postgresql.ARRAY(Float)), + Column('y', postgresql.ARRAY(postgresql.REAL)), + Column('z', postgresql.ARRAY(postgresql.DOUBLE_PRECISION)), + Column('q', postgresql.ARRAY(Numeric)) + ) + metadata.create_all() + t1.insert().execute(x=[5], y=[5], z=[6], q=[6.4]) + row = t1.select().execute().first() + eq_( + row, + ([5], [5], [6], [decimal.Decimal("6.4")]) + ) class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): __only_on__ = 'postgresql' @@ -1069,6 +1083,35 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): finally: t.drop(checkfirst=True) + def test_renamed_sequence_reflection(self): + m1 = MetaData(testing.db) + t = Table('t', m1, + Column('id', Integer, primary_key=True) + ) + m1.create_all() + try: + m2 = MetaData(testing.db) + t2 = Table('t', m2, autoload=True, implicit_returning=False) + eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)") + + r = t2.insert().execute() + eq_(r.inserted_primary_key, [1]) + + testing.db.connect().\ + execution_options(autocommit=True).\ + execute("alter table t_id_seq rename to foobar_id_seq") + + m3 = MetaData(testing.db) + t3 = Table('t', m3, autoload=True, implicit_returning=False) + eq_(t3.c.id.server_default.arg.text, "nextval('foobar_id_seq'::regclass)") + + r = t3.insert().execute() + eq_(r.inserted_primary_key, [2]) + + finally: + m1.drop_all() + + def test_distinct_on(self): t = Table('mytable', MetaData(testing.db), Column('id', Integer, primary_key=True), @@ -1282,8 +1325,18 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): else: exception_cls = eng.dialect.dbapi.ProgrammingError assert_raises(exception_cls, eng.execute, "show transaction isolation level") - - + + @testing.fails_on('+zxjdbc', + "psycopg2/pg8000 specific assertion") + @testing.fails_on('pypostgresql', + "psycopg2/pg8000 specific assertion") + def test_numeric_raise(self): + stmt = text("select cast('hi' as char) as hi", typemap={'hi':Numeric}) + assert_raises( + exc.InvalidRequestError, + testing.db.execute, stmt + ) + class TimezoneTest(TestBase): """Test timezone-aware datetimes. diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 8fd5e7eb6..04d4a06d5 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -32,7 +32,9 @@ class ExecuteTest(TestBase): def teardown_class(cls): metadata.drop_all() - @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', '+pyodbc', '+mxodbc', '+zxjdbc', 'mysql+oursql') + @testing.fails_on_everything_except('firebird', 'maxdb', + 'sqlite', '+pyodbc', + '+mxodbc', '+zxjdbc', 'mysql+oursql') def test_raw_qmark(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) @@ -70,7 +72,8 @@ class ExecuteTest(TestBase): # pyformat is supported for mysql, but skipping because a few driver # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2) @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky') - @testing.fails_on_everything_except('postgresql+psycopg2', 'postgresql+pypostgresql', 'mysql+mysqlconnector') + @testing.fails_on_everything_except('postgresql+psycopg2', + 'postgresql+pypostgresql', 'mysql+mysqlconnector') def test_raw_python(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", @@ -111,6 +114,37 @@ class ExecuteTest(TestBase): (1, None) ]) +class CompiledCacheTest(TestBase): + @classmethod + def setup_class(cls): + global users, metadata + metadata = MetaData(testing.db) + users = Table('users', metadata, + Column('user_id', INT, primary_key=True, test_needs_autoincrement=True), + Column('user_name', VARCHAR(20)), + ) + metadata.create_all() + + @engines.close_first + def teardown(self): + testing.db.connect().execute(users.delete()) + + @classmethod + def teardown_class(cls): + metadata.drop_all() + + def test_cache(self): + conn = testing.db.connect() + cache = {} + cached_conn = conn.execution_options(compiled_cache=cache) + + ins = users.insert() + cached_conn.execute(ins, {'user_name':'u1'}) + cached_conn.execute(ins, {'user_name':'u2'}) + cached_conn.execute(ins, {'user_name':'u3'}) + assert len(cache) == 1 + eq_(conn.execute("select count(1) from users").scalar(), 3) + class LogTest(TestBase): def _test_logger(self, eng, eng_name, pool_name): buf = logging.handlers.BufferingHandler(100) @@ -238,8 +272,26 @@ class ProxyConnectionTest(TestBase): assert_stmts(compiled, stmts) assert_stmts(cursor, cursor_stmts) - - @testing.fails_on('mysql+oursql', 'oursql dialect has some extra steps here') + + def test_options(self): + track = [] + class TrackProxy(ConnectionProxy): + def __getattribute__(self, key): + fn = object.__getattribute__(self, key) + def go(*arg, **kw): + track.append(fn.__name__) + return fn(*arg, **kw) + return go + engine = engines.testing_engine(options={'proxy':TrackProxy()}) + conn = engine.connect() + c2 = conn.execution_options(foo='bar') + eq_(c2._execution_options, {'foo':'bar'}) + c2.execute(select([1])) + c3 = c2.execution_options(bar='bat') + eq_(c3._execution_options, {'foo':'bar', 'bar':'bat'}) + eq_(track, ['execute', 'cursor_execute']) + + def test_transactional(self): track = [] class TrackProxy(ConnectionProxy): diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index e8da89438..f6cb9a473 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -120,7 +120,18 @@ class TransactionTest(TestBase): finally: connection.close() - + def test_retains_through_options(self): + connection = testing.db.connect() + try: + transaction = connection.begin() + connection.execute(users.insert(), user_id=1, user_name='user1') + conn2 = connection.execution_options(dummy=True) + conn2.execute(users.insert(), user_id=2, user_name='user2') + transaction.rollback() + eq_(connection.scalar("select count(1) from query_users"), 0) + finally: + connection.close() + def test_nesting(self): connection = testing.db.connect() transaction = connection.begin() diff --git a/test/ext/test_declarative.py b/test/ext/test_declarative.py index d5d837da7..67e650c34 100644 --- a/test/ext/test_declarative.py +++ b/test/ext/test_declarative.py @@ -192,11 +192,6 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'users' id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) - addresses = relationship("Address", order_by="desc(Address.email)", - primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]", - backref=backref('user', primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]") - ) - class Bar(Base, ComparableEntity): __tablename__ = 'bar' @@ -2057,7 +2052,55 @@ class DeclarativeMixinTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) eq_(MyModel.__table__.kwargs,{'mysql_engine': 'InnoDB'}) + + def test_mapper_args_classproperty(self): + class ComputedMapperArgs: + @classproperty + def __mapper_args__(cls): + if cls.__name__=='Person': + return dict(polymorphic_on=cls.discriminator) + else: + return dict(polymorphic_identity=cls.__name__) + class Person(Base,ComputedMapperArgs): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + + class Engineer(Person): + pass + + compile_mappers() + + assert class_mapper(Person).polymorphic_on is Person.__table__.c.type + eq_(class_mapper(Engineer).polymorphic_identity, 'Engineer') + + def test_mapper_args_classproperty_two(self): + # same as test_mapper_args_classproperty, but + # we repeat ComputedMapperArgs on both classes + # for no apparent reason. + + class ComputedMapperArgs: + @classproperty + def __mapper_args__(cls): + if cls.__name__=='Person': + return dict(polymorphic_on=cls.discriminator) + else: + return dict(polymorphic_identity=cls.__name__) + + class Person(Base,ComputedMapperArgs): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + + class Engineer(Person, ComputedMapperArgs): + pass + + compile_mappers() + + assert class_mapper(Person).polymorphic_on is Person.__table__.c.type + eq_(class_mapper(Engineer).polymorphic_identity, 'Engineer') + def test_table_args_composite(self): class MyMixin1: diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 02be04edc..ef9214a8b 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -1079,7 +1079,50 @@ class MapperTest(_fixtures.FixtureTest): mapper(B, users) +class DocumentTest(testing.TestBase): + + def test_doc_propagate(self): + metadata = MetaData() + t1 = Table('t1', metadata, + Column('col1', Integer, primary_key=True, doc="primary key column"), + Column('col2', String, doc="data col"), + Column('col3', String, doc="data col 2"), + Column('col4', String, doc="data col 3"), + Column('col5', String), + ) + t2 = Table('t2', metadata, + Column('col1', Integer, primary_key=True, doc="primary key column"), + Column('col2', String, doc="data col"), + Column('col3', Integer, ForeignKey('t1.col1'), doc="foreign key to t1.col1") + ) + class Foo(object): + pass + + class Bar(object): + pass + + mapper(Foo, t1, properties={ + 'bars':relationship(Bar, + doc="bar relationship", + backref=backref('foo',doc='foo relationship') + ), + 'foober':column_property(t1.c.col3, doc='alternate data col'), + 'hoho':synonym(t1.c.col4, doc="syn of col4") + }) + mapper(Bar, t2) + compile_mappers() + eq_(Foo.col1.__doc__, "primary key column") + eq_(Foo.col2.__doc__, "data col") + eq_(Foo.col5.__doc__, None) + eq_(Foo.foober.__doc__, "alternate data col") + eq_(Foo.bars.__doc__, "bar relationship") + eq_(Foo.hoho.__doc__, "syn of col4") + eq_(Bar.col1.__doc__, "primary key column") + eq_(Bar.foo.__doc__, "foo relationship") + + + class OptionsTest(_fixtures.FixtureTest): @testing.fails_on('maxdb', 'FIXME: unknown') diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 07e545bd1..ccb1c0177 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -102,7 +102,7 @@ class VersioningTest(_base.MappedTest): s1.delete(f1) s1.delete(f2) - if testing.db.dialect.supports_sane_multi_rowcount: + if testing.db.dialect.supports_sane_rowcount: assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit) else: s1.commit() diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 089ef727d..fb9b3912a 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -275,8 +275,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults): """assert expected values for 'native unicode' mode""" if \ - (testing.against('mssql+pyodbc') and not testing.db.dialect.freetds) or \ - testing.against('oracle+cx_oracle'): + (testing.against('mssql+pyodbc') and not testing.db.dialect.freetds): assert testing.db.dialect.returns_unicode_strings == 'conditional' return @@ -296,6 +295,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults): ('mysql','mysqlconnector'), ('sqlite','pysqlite'), ('oracle','zxjdbc'), + ('oracle','cx_oracle'), )), \ "name: %s driver %s returns_unicode_strings=%s" % \ (testing.db.name, @@ -481,16 +481,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults): eq_(a, b) x = utf8_row['plain_varchar_no_coding_error'] - if testing.against('oracle+cx_oracle'): - # TODO: not sure yet what produces this exact string as of yet - # ('replace' does not AFAICT) - eq_( - x, - 'Alors vous imaginez ma surprise, au lever du jour, quand une ' - 'drole de petit voix m?a reveille. Elle disait: < S?il vous plait? ' - 'dessine-moi un mouton! >' - ) - elif testing.against('mssql+pyodbc') and not testing.db.dialect.freetds: + if testing.against('mssql+pyodbc') and not testing.db.dialect.freetds: # TODO: no clue what this is eq_( x, @@ -893,6 +884,9 @@ class ExpressionTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): expr = column('bar', Integer) - 3 eq_(expr.type._type_affinity, Integer) + + expr = bindparam('bar') + bindparam('foo') + eq_(expr.type, types.NULLTYPE) def test_distinct(self): s = select([distinct(test_table.c.avalue)]) @@ -903,6 +897,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): assert distinct(test_table.c.data).type == test_table.c.data.type assert test_table.c.data.distinct().type == test_table.c.data.type + class DateTest(TestBase, AssertsExecutionResults): @classmethod @@ -1057,6 +1052,7 @@ class NumericTest(TestBase): def teardown(self): metadata.drop_all() + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def _do_test(self, type_, input_, output, filter_ = None): t = Table('t', metadata, Column('x', type_)) t.create() @@ -1067,10 +1063,10 @@ class NumericTest(TestBase): if filter_: result = set(filter_(x) for x in result) output = set(filter_(x) for x in output) - print result - print output + #print result + #print output eq_(result, output) - + def test_numeric_as_decimal(self): self._do_test( Numeric(precision=8, scale=4), @@ -1171,7 +1167,6 @@ class NumericTest(TestBase): ) @testing.fails_on('sqlite', 'TODO') - @testing.fails_on('oracle', 'TODO') @testing.fails_on('postgresql+pg8000', 'TODO') @testing.fails_on("firebird", "Precision must be from 1 to 18") @testing.fails_on("sybase+pysybase", "TODO") |
