diff options
| author | jonathan vanasco <jonathan@2xlp.com> | 2015-12-16 11:04:25 -0500 |
|---|---|---|
| committer | jonathan vanasco <jonathan@2xlp.com> | 2015-12-16 11:04:25 -0500 |
| commit | ce25ac172d3b1be81025b7b541a9aa32b0286974 (patch) | |
| tree | 7920084df122b2df19a44b2946ab0e52d4fe5958 /lib/sqlalchemy | |
| parent | 0a5dcdc2c4112478d87e5cd68c187e302f586834 (diff) | |
| parent | 03ee22f342bbef9b15bfc989edda6a4ac3910508 (diff) | |
| download | sqlalchemy-ce25ac172d3b1be81025b7b541a9aa32b0286974.tar.gz | |
Merge branch 'master' of bitbucket.org:zzzeek/sqlalchemy
Diffstat (limited to 'lib/sqlalchemy')
80 files changed, 4377 insertions, 1799 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 093e90bbf..12d4e8d1c 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -8,7 +8,9 @@ from .sql import ( alias, + all_, and_, + any_, asc, between, bindparam, @@ -52,6 +54,7 @@ from .sql import ( ) from .types import ( + Array, BIGINT, BINARY, BLOB, @@ -120,7 +123,7 @@ from .schema import ( from .inspection import inspect from .engine import create_engine, engine_from_config -__version__ = '1.0.7' +__version__ = '1.1.0b1' def __go(lcls): diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index c34829cd3..acd419e85 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -648,7 +648,7 @@ class FBDialect(default.DefaultDialect): 'type': coltype, 'nullable': not bool(row['null_flag']), 'default': defvalue, - 'autoincrement': defvalue is None + 'autoincrement': 'auto', } if orig_colname.lower() == orig_colname: diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index bd41c19bf..1ee328e83 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -166,56 +166,6 @@ how SQLAlchemy handles this: This is an auxilliary use case suitable for testing and bulk insert scenarios. -.. _legacy_schema_rendering: - -Rendering of SQL statements that include schema qualifiers ---------------------------------------------------------- - -When using :class:`.Table` metadata that includes a "schema" qualifier, -such as:: - - account_table = Table( - 'account', metadata, - Column('id', Integer, primary_key=True), - Column('info', String(100)), - schema="customer_schema" - ) - -The SQL Server dialect has a long-standing behavior that it will attempt -to turn a schema-qualified table name into an alias, such as:: - - >>> eng = create_engine("mssql+pymssql://mydsn") - >>> print(account_table.select().compile(eng)) - SELECT account_1.id, account_1.info - FROM customer_schema.account AS account_1 - -This behavior is legacy, does not function correctly for many forms -of SQL statements, and will be disabled by default in the 1.1 series -of SQLAlchemy. As of 1.0.5, the above statement will produce the following -warning:: - - SAWarning: legacy_schema_aliasing flag is defaulted to True; - some schema-qualified queries may not function correctly. - Consider setting this flag to False for modern SQL Server versions; - this flag will default to False in version 1.1 - -This warning encourages the :class:`.Engine` to be created as follows:: - - >>> eng = create_engine("mssql+pymssql://mydsn", legacy_schema_aliasing=False) - -Where the above SELECT statement will produce:: - - >>> print(account_table.select().compile(eng)) - SELECT customer_schema.account.id, customer_schema.account.info - FROM customer_schema.account - -The warning will not emit if the ``legacy_schema_aliasing`` flag is set -to either True or False. - -.. versionadded:: 1.0.5 - Added the ``legacy_schema_aliasing`` flag to disable - the SQL Server dialect's legacy behavior with schema-qualified table - names. This flag will default to False in version 1.1. - Collation Support ----------------- @@ -236,7 +186,7 @@ CREATE TABLE statement for this column will yield:: LIMIT/OFFSET Support -------------------- -MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is +MSSQL has no support for the LIMIT or OFFSET keywords. LIMIT is supported directly through the ``TOP`` Transact SQL keyword:: select.limit @@ -322,6 +272,41 @@ behavior of this flag is as follows: .. versionadded:: 1.0.0 +.. _legacy_schema_rendering: + +Legacy Schema Mode +------------------ + +Very old versions of the MSSQL dialect introduced the behavior such that a +schema-qualified table would be auto-aliased when used in a +SELECT statement; given a table:: + + account_table = Table( + 'account', metadata, + Column('id', Integer, primary_key=True), + Column('info', String(100)), + schema="customer_schema" + ) + +this legacy mode of rendering would assume that "customer_schema.account" +would not be accepted by all parts of the SQL statement, as illustrated +below:: + + >>> eng = create_engine("mssql+pymssql://mydsn", legacy_schema_aliasing=True) + >>> print(account_table.select().compile(eng)) + SELECT account_1.id, account_1.info + FROM customer_schema.account AS account_1 + +This mode of behavior is now off by default, as it appears to have served +no purpose; however in the case that legacy applications rely upon it, +it is available using the ``legacy_schema_aliasing`` argument to +:func:`.create_engine` as illustrated above. + +.. versionchanged:: 1.1 the ``legacy_schema_aliasing`` flag introduced + in version 1.0.5 to allow disabling of legacy mode for schemas now + defaults to False. + + .. _mssql_indexes: Clustered Index Support @@ -548,9 +533,13 @@ class _MSDate(sqltypes.Date): if isinstance(value, datetime.datetime): return value.date() elif isinstance(value, util.string_types): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a date value" % (value, )) return datetime.date(*[ int(x or 0) - for x in self._reg.match(value).groups() + for x in m.groups() ]) else: return value @@ -582,9 +571,13 @@ class TIME(sqltypes.TIME): if isinstance(value, datetime.datetime): return value.time() elif isinstance(value, util.string_types): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a time value" % (value, )) return datetime.time(*[ int(x or 0) - for x in self._reg.match(value).groups()]) + for x in m.groups()]) else: return value return process @@ -774,21 +767,21 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "TINYINT" def visit_DATETIMEOFFSET(self, type_, **kw): - if type_.precision: + if type_.precision is not None: return "DATETIMEOFFSET(%s)" % type_.precision else: return "DATETIMEOFFSET" def visit_TIME(self, type_, **kw): precision = getattr(type_, 'precision', None) - if precision: + if precision is not None: return "TIME(%s)" % precision else: return "TIME" def visit_DATETIME2(self, type_, **kw): precision = getattr(type_, 'precision', None) - if precision: + if precision is not None: return "DATETIME2(%s)" % precision else: return "DATETIME2" @@ -1156,15 +1149,6 @@ class MSSQLCompiler(compiler.SQLCompiler): def _schema_aliased_table(self, table): if getattr(table, 'schema', None) is not None: - if self.dialect._warn_schema_aliasing and \ - table.schema.lower() != 'information_schema': - util.warn( - "legacy_schema_aliasing flag is defaulted to True; " - "some schema-qualified queries may not function " - "correctly. Consider setting this flag to False for " - "modern SQL Server versions; this flag will default to " - "False in version 1.1") - if table not in self.tablealiases: self.tablealiases[table] = table.alias() return self.tablealiases[table] @@ -1530,7 +1514,7 @@ class MSDialect(default.DefaultDialect): max_identifier_length=None, schema_name="dbo", deprecate_large_types=None, - legacy_schema_aliasing=None, **opts): + legacy_schema_aliasing=False, **opts): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name @@ -1538,13 +1522,7 @@ class MSDialect(default.DefaultDialect): self.max_identifier_length = int(max_identifier_length or 0) or \ self.max_identifier_length self.deprecate_large_types = deprecate_large_types - - if legacy_schema_aliasing is None: - self.legacy_schema_aliasing = True - self._warn_schema_aliasing = True - else: - self.legacy_schema_aliasing = legacy_schema_aliasing - self._warn_schema_aliasing = False + self.legacy_schema_aliasing = legacy_schema_aliasing super(MSDialect, self).__init__(**opts) @@ -1772,7 +1750,7 @@ class MSDialect(default.DefaultDialect): MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary): if charlen == -1: - charlen = 'max' + charlen = None kwargs['length'] = charlen if collation: kwargs['collation'] = collation diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 324b3770c..1d7635c7f 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -85,7 +85,8 @@ class MSDialect_pymssql(MSDialect): "message 20003", # connection timeout "Error 10054", "Not connected to any MS SQL server", - "Connection is closed" + "Connection is closed", + "message 20006", # Write to the server failed ): if msg in str(e): return True diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index fee05fd2d..988746403 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -32,6 +32,11 @@ the ``pool_recycle`` option which controls the maximum age of any connection:: engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) +.. seealso:: + + :ref:`pool_setting_recycle` - full description of the pool recycle feature. + + .. _mysql_storage_engines: CREATE TABLE arguments including Storage Engines @@ -1584,7 +1589,10 @@ class SET(_EnumeratedValues): def column_expression(self, colexpr): if self.retrieve_as_bitwise: - return colexpr + 0 + return sql.type_coerce( + sql.type_coerce(colexpr, sqltypes.Integer) + 0, + self + ) else: return colexpr @@ -1913,38 +1921,7 @@ class MySQLCompiler(compiler.SQLCompiler): return None -# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. -# Starting with MySQL 4.1.2, these indexes are created automatically. -# In older versions, the indexes must be created explicitly or the -# creation of foreign key constraints fails." - class MySQLDDLCompiler(compiler.DDLCompiler): - def create_table_constraints(self, table, **kw): - """Get table constraints.""" - constraint_string = super( - MySQLDDLCompiler, self).create_table_constraints(table, **kw) - - # why self.dialect.name and not 'mysql'? because of drizzle - is_innodb = 'engine' in table.dialect_options[self.dialect.name] and \ - table.dialect_options[self.dialect.name][ - 'engine'].lower() == 'innodb' - - auto_inc_column = table._autoincrement_column - - if is_innodb and \ - auto_inc_column is not None and \ - auto_inc_column is not list(table.primary_key)[0]: - if constraint_string: - constraint_string += ", \n\t" - constraint_string += "KEY %s (%s)" % ( - self.preparer.quote( - "idx_autoinc_%s" % auto_inc_column.name - ), - self.preparer.format_column(auto_inc_column) - ) - - return constraint_string - def get_column_specification(self, column, **kw): """Builds column DDL.""" @@ -3117,6 +3094,11 @@ class MySQLTableDefinitionParser(object): # Column type keyword options type_kw = {} + + if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)): + if type_args: + type_kw['fsp'] = type_args.pop(0) + for kw in ('unsigned', 'zerofill'): if spec.get(kw, False): type_kw[kw] = True diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index c605bd510..82ec72f2b 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -287,6 +287,7 @@ from sqlalchemy import util, sql from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler, visitors, expression from sqlalchemy.sql import operators as sql_operators +from sqlalchemy.sql.elements import quoted_name from sqlalchemy import types as sqltypes, schema as sa_schema from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ BLOB, CLOB, TIMESTAMP, FLOAT @@ -1032,6 +1033,8 @@ class OracleDialect(default.DefaultDialect): if name.upper() == name and not \ self.identifier_preparer._requires_quotes(name.lower()): return name.lower() + elif name.lower() == name: + return quoted_name(name, quote=True) else: return name @@ -1280,7 +1283,7 @@ class OracleDialect(default.DefaultDialect): 'type': coltype, 'nullable': nullable, 'default': default, - 'autoincrement': default is None + 'autoincrement': 'auto', } if orig_colname.lower() == orig_colname: cdict['quote'] = True diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 4aed45c14..dede3b21a 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -293,6 +293,7 @@ from .base import OracleCompiler, OracleDialect, OracleExecutionContext from . import base as oracle from ...engine import result as _result from sqlalchemy import types as sqltypes, util, exc, processors +from sqlalchemy import util import random import collections import decimal @@ -719,8 +720,10 @@ class OracleDialect_cx_oracle(OracleDialect): # this occurs in tests with mock DBAPIs self._cx_oracle_string_types = set() self._cx_oracle_with_unicode = False - elif self.cx_oracle_ver >= (5,) and not \ - hasattr(self.dbapi, 'UNICODE'): + elif util.py3k or ( + self.cx_oracle_ver >= (5,) and not \ + hasattr(self.dbapi, 'UNICODE') + ): # cx_Oracle WITH_UNICODE mode. *only* python # unicode objects accepted for anything self.supports_unicode_statements = True diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 98fe6f085..d67f2a07e 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -12,11 +12,13 @@ base.dialect = psycopg2.dialect from .base import \ INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ INET, CIDR, UUID, BIT, MACADDR, OID, DOUBLE_PRECISION, TIMESTAMP, TIME, \ - DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All, \ - TSVECTOR, DropEnumType -from .constraints import ExcludeConstraint + DATE, BYTEA, BOOLEAN, INTERVAL, ENUM, dialect, TSVECTOR, DropEnumType, \ + CreateEnumType from .hstore import HSTORE, hstore -from .json import JSON, JSONElement, JSONB +from .json import JSON, JSONB +from .array import array, ARRAY, Any, All +from .ext import aggregate_order_by, ExcludeConstraint, array_agg + from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ TSTZRANGE @@ -24,8 +26,9 @@ __all__ = ( 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'OID', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN', - 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE', + 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE', 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', - 'TSRANGE', 'TSTZRANGE', 'json', 'JSON', 'JSONB', 'JSONElement', - 'DropEnumType' + 'TSRANGE', 'TSTZRANGE', 'json', 'JSON', 'JSONB', 'Any', 'All', + 'DropEnumType', 'CreateEnumType', 'ExcludeConstraint', + 'aggregate_order_by', 'array_agg' ) diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py new file mode 100644 index 000000000..b88f139de --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -0,0 +1,306 @@ +# postgresql/array.py +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .base import ischema_names +from ...sql import expression, operators +from ...sql.base import SchemaEventTarget +from ... import types as sqltypes + +try: + from uuid import UUID as _python_UUID +except ImportError: + _python_UUID = None + + +def Any(other, arrexpr, operator=operators.eq): + """A synonym for the :meth:`.ARRAY.Comparator.any` method. + + This method is legacy and is here for backwards-compatiblity. + + .. seealso:: + + :func:`.expression.any_` + + """ + + return arrexpr.any(other, operator) + + +def All(other, arrexpr, operator=operators.eq): + """A synonym for the :meth:`.ARRAY.Comparator.all` method. + + This method is legacy and is here for backwards-compatiblity. + + .. seealso:: + + :func:`.expression.all_` + + """ + + return arrexpr.all(other, operator) + + +class array(expression.Tuple): + + """A Postgresql ARRAY literal. + + This is used to produce ARRAY literals in SQL expressions, e.g.:: + + from sqlalchemy.dialects.postgresql import array + from sqlalchemy.dialects import postgresql + from sqlalchemy import select, func + + stmt = select([ + array([1,2]) + array([3,4,5]) + ]) + + print stmt.compile(dialect=postgresql.dialect()) + + Produces the SQL:: + + SELECT ARRAY[%(param_1)s, %(param_2)s] || + ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 + + An instance of :class:`.array` will always have the datatype + :class:`.ARRAY`. The "inner" type of the array is inferred from + the values present, unless the ``type_`` keyword argument is passed:: + + array(['foo', 'bar'], type_=CHAR) + + .. versionadded:: 0.8 Added the :class:`~.postgresql.array` literal type. + + See also: + + :class:`.postgresql.ARRAY` + + """ + __visit_name__ = 'array' + + def __init__(self, clauses, **kw): + super(array, self).__init__(*clauses, **kw) + self.type = ARRAY(self.type) + + def _bind_param(self, operator, obj): + return array([ + expression.BindParameter(None, o, _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + for o in obj + ]) + + def self_group(self, against=None): + if (against in ( + operators.any_op, operators.all_op, operators.getitem)): + return expression.Grouping(self) + else: + return self + + +CONTAINS = operators.custom_op("@>", precedence=5) + +CONTAINED_BY = operators.custom_op("<@", precedence=5) + +OVERLAP = operators.custom_op("&&", precedence=5) + + +class ARRAY(SchemaEventTarget, sqltypes.Array): + + """Postgresql ARRAY type. + + .. versionchanged:: 1.1 The :class:`.postgresql.ARRAY` type is now + a subclass of the core :class:`.Array` type. + + The :class:`.postgresql.ARRAY` type is constructed in the same way + as the core :class:`.Array` type; a member type is required, and a + number of dimensions is recommended if the type is to be used for more + than one dimension:: + + from sqlalchemy.dialects import postgresql + + mytable = Table("mytable", metadata, + Column("data", postgresql.ARRAY(Integer, dimensions=2)) + ) + + The :class:`.postgresql.ARRAY` type provides all operations defined on the + core :class:`.Array` type, including support for "dimensions", indexed + access, and simple matching such as :meth:`.Array.Comparator.any` + and :meth:`.Array.Comparator.all`. :class:`.postgresql.ARRAY` class also + provides PostgreSQL-specific methods for containment operations, including + :meth:`.postgresql.ARRAY.Comparator.contains` + :meth:`.postgresql.ARRAY.Comparator.contained_by`, + and :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.:: + + mytable.c.data.contains([1, 2]) + + The :class:`.postgresql.ARRAY` type may not be supported on all + PostgreSQL DBAPIs; it is currently known to work on psycopg2 only. + + Additionally, the :class:`.postgresql.ARRAY` type does not work directly in + conjunction with the :class:`.ENUM` type. For a workaround, see the + special type at :ref:`postgresql_array_of_enum`. + + .. seealso:: + + :class:`.types.Array` - base array type + + :class:`.postgresql.array` - produces a literal array value. + + """ + + class Comparator(sqltypes.Array.Comparator): + + """Define comparison operations for :class:`.ARRAY`. + + Note that these operations are in addition to those provided + by the base :class:`.types.Array.Comparator` class, including + :meth:`.types.Array.Comparator.any` and + :meth:`.types.Array.Comparator.all`. + + """ + + def contains(self, other, **kwargs): + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean) + + def overlap(self, other): + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) + + comparator_factory = Comparator + + def __init__(self, item_type, as_tuple=False, dimensions=None, + zero_indexes=False): + """Construct an ARRAY. + + E.g.:: + + Column('myarray', ARRAY(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that + dimensionality is irrelevant here, so multi-dimensional arrays like + ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as + ``ARRAY(ARRAY(Integer))`` or such. + + :param as_tuple=False: Specify whether return results + should be converted to tuples from lists. DBAPIs such + as psycopg2 return lists by default. When tuples are + returned, the results are hashable. + + :param dimensions: if non-None, the ARRAY will assume a fixed + number of dimensions. This will cause the DDL emitted for this + ARRAY to include the exact number of bracket clauses ``[]``, + and will also optimize the performance of the type overall. + Note that PG arrays are always implicitly "non-dimensioned", + meaning they can store any number of dimensions no matter how + they were declared. + + :param zero_indexes=False: when True, index values will be converted + between Python zero-based and Postgresql one-based indexes, e.g. + a value of one will be added to all index values before passing + to the database. + + .. versionadded:: 0.9.5 + + + """ + if isinstance(item_type, ARRAY): + raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype") + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.as_tuple = as_tuple + self.dimensions = dimensions + self.zero_indexes = zero_indexes + + @property + def hashable(self): + return self.as_tuple + + @property + def python_type(self): + return list + + def compare_values(self, x, y): + return x == y + + def _set_parent(self, column): + """Support SchemaEentTarget""" + + if isinstance(self.item_type, SchemaEventTarget): + self.item_type._set_parent(column) + + def _set_parent_with_dispatch(self, parent): + """Support SchemaEentTarget""" + + if isinstance(self.item_type, SchemaEventTarget): + self.item_type._set_parent_with_dispatch(parent) + + def _proc_array(self, arr, itemproc, dim, collection): + if dim is None: + arr = list(arr) + if dim == 1 or dim is None and ( + # this has to be (list, tuple), or at least + # not hasattr('__iter__'), since Py3K strings + # etc. have __iter__ + not arr or not isinstance(arr[0], (list, tuple))): + if itemproc: + return collection(itemproc(x) for x in arr) + else: + return collection(arr) + else: + return collection( + self._proc_array( + x, itemproc, + dim - 1 if dim is not None else None, + collection) + for x in arr + ) + + def bind_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).\ + bind_processor(dialect) + + def process(value): + if value is None: + return value + else: + return self._proc_array( + value, + item_proc, + self.dimensions, + list) + return process + + def result_processor(self, dialect, coltype): + item_proc = self.item_type.dialect_impl(dialect).\ + result_processor(dialect, coltype) + + def process(value): + if value is None: + return value + else: + return self._proc_array( + value, + item_proc, + self.dimensions, + tuple if self.as_tuple else list) + return process + +ischema_names['_array'] = ARRAY diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 22c66dbbb..e9001f79a 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -102,7 +102,7 @@ via foreign key constraint, a decision must be made as to how the ``.schema`` is represented in those remote tables, in the case where that remote schema name is also a member of the current `Postgresql search path -<http://www.postgresql.org/docs/9.0/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_. +<http://www.postgresql.org/docs/current/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_. By default, the Postgresql dialect mimics the behavior encouraged by Postgresql's own ``pg_get_constraintdef()`` builtin procedure. This function @@ -506,7 +506,42 @@ dialect in conjunction with the :class:`.Table` construct: .. seealso:: `Postgresql CREATE TABLE options - <http://www.postgresql.org/docs/9.3/static/sql-createtable.html>`_ + <http://www.postgresql.org/docs/current/static/sql-createtable.html>`_ + +ARRAY Types +----------- + +The Postgresql dialect supports arrays, both as multidimensional column types +as well as array literals: + +* :class:`.postgresql.ARRAY` - ARRAY datatype + +* :class:`.postgresql.array` - array literal + +* :func:`.postgresql.array_agg` - ARRAY_AGG SQL function + +* :class:`.postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate + function syntax. + +JSON Types +---------- + +The Postgresql dialect supports both JSON and JSONB datatypes, including +psycopg2's native support and support for all of Postgresql's special +operators: + +* :class:`.postgresql.JSON` + +* :class:`.postgresql.JSONB` + +HSTORE Type +----------- + +The Postgresql HSTORE type as well as hstore literals are supported: + +* :class:`.postgresql.HSTORE` - HSTORE datatype + +* :class:`.postgresql.hstore` - hstore literal ENUM Types ---------- @@ -524,13 +559,56 @@ entity. The following sections should be consulted: * :meth:`.postgresql.ENUM.create` , :meth:`.postgresql.ENUM.drop` - individual CREATE and DROP commands for ENUM. +.. _postgresql_array_of_enum: + +Using ENUM with ARRAY +^^^^^^^^^^^^^^^^^^^^^ + +The combination of ENUM and ARRAY is not directly supported by backend +DBAPIs at this time. In order to send and receive an ARRAY of ENUM, +use the following workaround type:: + + class ArrayOfEnum(ARRAY): + + def bind_expression(self, bindvalue): + return sa.cast(bindvalue, self) + + def result_processor(self, dialect, coltype): + super_rp = super(ArrayOfEnum, self).result_processor( + dialect, coltype) + + def handle_raw_string(value): + inner = re.match(r"^{(.*)}$", value).group(1) + return inner.split(",") + + def process(value): + if value is None: + return None + return super_rp(handle_raw_string(value)) + return process + +E.g.:: + + Table( + 'mydata', metadata, + Column('id', Integer, primary_key=True), + Column('data', ArrayOfEnum(ENUM('a', 'b, 'c', name='myenum'))) + + ) + +This type is not included as a built-in type as it would be incompatible +with a DBAPI that suddenly decides to support ARRAY of ENUM directly in +a new version. + """ from collections import defaultdict import re +import datetime as dt + from ... import sql, schema, exc, util from ...engine import default, reflection -from ...sql import compiler, expression, operators, default_comparator +from ...sql import compiler, expression from ... import types as sqltypes try: @@ -633,6 +711,10 @@ class INTERVAL(sqltypes.TypeEngine): def _type_affinity(self): return sqltypes.Interval + @property + def python_type(self): + return dt.timedelta + PGInterval = INTERVAL @@ -722,407 +804,6 @@ class TSVECTOR(sqltypes.TypeEngine): __visit_name__ = 'TSVECTOR' -class _Slice(expression.ColumnElement): - __visit_name__ = 'slice' - type = sqltypes.NULLTYPE - - def __init__(self, slice_, source_comparator): - self.start = default_comparator._check_literal( - source_comparator.expr, - operators.getitem, slice_.start) - self.stop = default_comparator._check_literal( - source_comparator.expr, - operators.getitem, slice_.stop) - - -class Any(expression.ColumnElement): - - """Represent the clause ``left operator ANY (right)``. ``right`` must be - an array expression. - - .. seealso:: - - :class:`.postgresql.ARRAY` - - :meth:`.postgresql.ARRAY.Comparator.any` - ARRAY-bound method - - """ - __visit_name__ = 'any' - - def __init__(self, left, right, operator=operators.eq): - self.type = sqltypes.Boolean() - self.left = expression._literal_as_binds(left) - self.right = right - self.operator = operator - - -class All(expression.ColumnElement): - - """Represent the clause ``left operator ALL (right)``. ``right`` must be - an array expression. - - .. seealso:: - - :class:`.postgresql.ARRAY` - - :meth:`.postgresql.ARRAY.Comparator.all` - ARRAY-bound method - - """ - __visit_name__ = 'all' - - def __init__(self, left, right, operator=operators.eq): - self.type = sqltypes.Boolean() - self.left = expression._literal_as_binds(left) - self.right = right - self.operator = operator - - -class array(expression.Tuple): - - """A Postgresql ARRAY literal. - - This is used to produce ARRAY literals in SQL expressions, e.g.:: - - from sqlalchemy.dialects.postgresql import array - from sqlalchemy.dialects import postgresql - from sqlalchemy import select, func - - stmt = select([ - array([1,2]) + array([3,4,5]) - ]) - - print stmt.compile(dialect=postgresql.dialect()) - - Produces the SQL:: - - SELECT ARRAY[%(param_1)s, %(param_2)s] || - ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 - - An instance of :class:`.array` will always have the datatype - :class:`.ARRAY`. The "inner" type of the array is inferred from - the values present, unless the ``type_`` keyword argument is passed:: - - array(['foo', 'bar'], type_=CHAR) - - .. versionadded:: 0.8 Added the :class:`~.postgresql.array` literal type. - - See also: - - :class:`.postgresql.ARRAY` - - """ - __visit_name__ = 'array' - - def __init__(self, clauses, **kw): - super(array, self).__init__(*clauses, **kw) - self.type = ARRAY(self.type) - - def _bind_param(self, operator, obj): - return array([ - expression.BindParameter(None, o, _compared_to_operator=operator, - _compared_to_type=self.type, unique=True) - for o in obj - ]) - - def self_group(self, against=None): - return self - - -class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): - - """Postgresql ARRAY type. - - Represents values as Python lists. - - An :class:`.ARRAY` type is constructed given the "type" - of element:: - - mytable = Table("mytable", metadata, - Column("data", ARRAY(Integer)) - ) - - The above type represents an N-dimensional array, - meaning Postgresql will interpret values with any number - of dimensions automatically. To produce an INSERT - construct that passes in a 1-dimensional array of integers:: - - connection.execute( - mytable.insert(), - data=[1,2,3] - ) - - The :class:`.ARRAY` type can be constructed given a fixed number - of dimensions:: - - mytable = Table("mytable", metadata, - Column("data", ARRAY(Integer, dimensions=2)) - ) - - This has the effect of the :class:`.ARRAY` type - specifying that number of bracketed blocks when a :class:`.Table` - is used in a CREATE TABLE statement, or when the type is used - within a :func:`.expression.cast` construct; it also causes - the bind parameter and result set processing of the type - to optimize itself to expect exactly that number of dimensions. - Note that Postgresql itself still allows N dimensions with such a type. - - SQL expressions of type :class:`.ARRAY` have support for "index" and - "slice" behavior. The Python ``[]`` operator works normally here, given - integer indexes or slices. Note that Postgresql arrays default - to 1-based indexing. The operator produces binary expression - constructs which will produce the appropriate SQL, both for - SELECT statements:: - - select([mytable.c.data[5], mytable.c.data[2:7]]) - - as well as UPDATE statements when the :meth:`.Update.values` method - is used:: - - mytable.update().values({ - mytable.c.data[5]: 7, - mytable.c.data[2:7]: [1, 2, 3] - }) - - :class:`.ARRAY` provides special methods for containment operations, - e.g.:: - - mytable.c.data.contains([1, 2]) - - For a full list of special methods see :class:`.ARRAY.Comparator`. - - .. versionadded:: 0.8 Added support for index and slice operations - to the :class:`.ARRAY` type, including support for UPDATE - statements, and special array containment operations. - - The :class:`.ARRAY` type may not be supported on all DBAPIs. - It is known to work on psycopg2 and not pg8000. - - See also: - - :class:`.postgresql.array` - produce a literal array value. - - """ - __visit_name__ = 'ARRAY' - - class Comparator(sqltypes.Concatenable.Comparator): - - """Define comparison operations for :class:`.ARRAY`.""" - - def __getitem__(self, index): - shift_indexes = 1 if self.expr.type.zero_indexes else 0 - if isinstance(index, slice): - if shift_indexes: - index = slice( - index.start + shift_indexes, - index.stop + shift_indexes, - index.step - ) - index = _Slice(index, self) - return_type = self.type - else: - index += shift_indexes - return_type = self.type.item_type - - return default_comparator._binary_operate( - self.expr, operators.getitem, index, - result_type=return_type) - - def any(self, other, operator=operators.eq): - """Return ``other operator ANY (array)`` clause. - - Argument places are switched, because ANY requires array - expression to be on the right hand-side. - - E.g.:: - - from sqlalchemy.sql import operators - - conn.execute( - select([table.c.data]).where( - table.c.data.any(7, operator=operators.lt) - ) - ) - - :param other: expression to be compared - :param operator: an operator object from the - :mod:`sqlalchemy.sql.operators` - package, defaults to :func:`.operators.eq`. - - .. seealso:: - - :class:`.postgresql.Any` - - :meth:`.postgresql.ARRAY.Comparator.all` - - """ - return Any(other, self.expr, operator=operator) - - def all(self, other, operator=operators.eq): - """Return ``other operator ALL (array)`` clause. - - Argument places are switched, because ALL requires array - expression to be on the right hand-side. - - E.g.:: - - from sqlalchemy.sql import operators - - conn.execute( - select([table.c.data]).where( - table.c.data.all(7, operator=operators.lt) - ) - ) - - :param other: expression to be compared - :param operator: an operator object from the - :mod:`sqlalchemy.sql.operators` - package, defaults to :func:`.operators.eq`. - - .. seealso:: - - :class:`.postgresql.All` - - :meth:`.postgresql.ARRAY.Comparator.any` - - """ - return All(other, self.expr, operator=operator) - - def contains(self, other, **kwargs): - """Boolean expression. Test if elements are a superset of the - elements of the argument array expression. - """ - return self.expr.op('@>')(other) - - def contained_by(self, other): - """Boolean expression. Test if elements are a proper subset of the - elements of the argument array expression. - """ - return self.expr.op('<@')(other) - - def overlap(self, other): - """Boolean expression. Test if array has elements in common with - an argument array expression. - """ - return self.expr.op('&&')(other) - - def _adapt_expression(self, op, other_comparator): - if isinstance(op, operators.custom_op): - if op.opstring in ['@>', '<@', '&&']: - return op, sqltypes.Boolean - return sqltypes.Concatenable.Comparator.\ - _adapt_expression(self, op, other_comparator) - - comparator_factory = Comparator - - def __init__(self, item_type, as_tuple=False, dimensions=None, - zero_indexes=False): - """Construct an ARRAY. - - E.g.:: - - Column('myarray', ARRAY(Integer)) - - Arguments are: - - :param item_type: The data type of items of this array. Note that - dimensionality is irrelevant here, so multi-dimensional arrays like - ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as - ``ARRAY(ARRAY(Integer))`` or such. - - :param as_tuple=False: Specify whether return results - should be converted to tuples from lists. DBAPIs such - as psycopg2 return lists by default. When tuples are - returned, the results are hashable. - - :param dimensions: if non-None, the ARRAY will assume a fixed - number of dimensions. This will cause the DDL emitted for this - ARRAY to include the exact number of bracket clauses ``[]``, - and will also optimize the performance of the type overall. - Note that PG arrays are always implicitly "non-dimensioned", - meaning they can store any number of dimensions no matter how - they were declared. - - :param zero_indexes=False: when True, index values will be converted - between Python zero-based and Postgresql one-based indexes, e.g. - a value of one will be added to all index values before passing - to the database. - - .. versionadded:: 0.9.5 - - """ - if isinstance(item_type, ARRAY): - raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype") - if isinstance(item_type, type): - item_type = item_type() - self.item_type = item_type - self.as_tuple = as_tuple - self.dimensions = dimensions - self.zero_indexes = zero_indexes - - @property - def python_type(self): - return list - - def compare_values(self, x, y): - return x == y - - def _proc_array(self, arr, itemproc, dim, collection): - if dim is None: - arr = list(arr) - if dim == 1 or dim is None and ( - # this has to be (list, tuple), or at least - # not hasattr('__iter__'), since Py3K strings - # etc. have __iter__ - not arr or not isinstance(arr[0], (list, tuple))): - if itemproc: - return collection(itemproc(x) for x in arr) - else: - return collection(arr) - else: - return collection( - self._proc_array( - x, itemproc, - dim - 1 if dim is not None else None, - collection) - for x in arr - ) - - def bind_processor(self, dialect): - item_proc = self.item_type.\ - dialect_impl(dialect).\ - bind_processor(dialect) - - def process(value): - if value is None: - return value - else: - return self._proc_array( - value, - item_proc, - self.dimensions, - list) - return process - - def result_processor(self, dialect, coltype): - item_proc = self.item_type.\ - dialect_impl(dialect).\ - result_processor(dialect, coltype) - - def process(value): - if value is None: - return value - else: - return self._proc_array( - value, - item_proc, - self.dimensions, - tuple if self.as_tuple else list) - return process - -PGArray = ARRAY - - class ENUM(sqltypes.Enum): """Postgresql ENUM type. @@ -1375,26 +1056,18 @@ class PGCompiler(compiler.SQLCompiler): self.process(element.stop, **kw), ) - def visit_any(self, element, **kw): - return "%s%sANY (%s)" % ( - self.process(element.left, **kw), - compiler.OPERATORS[element.operator], - self.process(element.right, **kw) - ) - - def visit_all(self, element, **kw): - return "%s%sALL (%s)" % ( - self.process(element.left, **kw), - compiler.OPERATORS[element.operator], - self.process(element.right, **kw) - ) - def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), self.process(binary.right, **kw) ) + def visit_aggregate_order_by(self, element, **kw): + return "%s ORDER BY %s" % ( + self.process(element.target, **kw), + self.process(element.order_by, **kw) + ) + def visit_match_op_binary(self, binary, operator, **kw): if "postgresql_regconfig" in binary.modifiers: regconfig = self.render_literal_value( @@ -1485,7 +1158,7 @@ class PGCompiler(compiler.SQLCompiler): c.table if isinstance(c, expression.ColumnClause) else c for c in select._for_update_arg.of) tmp += " OF " + ", ".join( - self.process(table, ashint=True, **kw) + self.process(table, ashint=True, use_schema=False, **kw) for table in tables ) @@ -1537,8 +1210,8 @@ class PGDDLCompiler(compiler.DDLCompiler): else: colspec += " SERIAL" else: - colspec += " " + self.dialect.type_compiler.process(column.type, - type_expression=column) + colspec += " " + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2294,11 +1967,27 @@ class PGDialect(default.DefaultDialect): current_schema = schema else: current_schema = self.default_schema_name - s = """ - SELECT definition FROM pg_views - WHERE schemaname = :schema - AND viewname = :view_name - """ + + if self.server_version_info >= (9, 3): + s = """ + SELECT definition FROM pg_views + WHERE schemaname = :schema + AND viewname = :view_name + + UNION + + SELECT definition FROM pg_matviews + WHERE schemaname = :schema + AND matviewname = :view_name + + """ + else: + s = """ + SELECT definition FROM pg_views + WHERE schemaname = :schema + AND viewname = :view_name + """ + rp = connection.execute(sql.text(s), view_name=view_name, schema=current_schema) if rp: @@ -2438,7 +2127,7 @@ class PGDialect(default.DefaultDialect): if coltype: coltype = coltype(*args, **kwargs) if is_array: - coltype = ARRAY(coltype) + coltype = self.ischema_names['_array'](coltype) else: util.warn("Did not recognize type '%s' of column '%s'" % (attype, name)) @@ -2631,7 +2320,7 @@ class PGDialect(default.DefaultDialect): i.relname as relname, ix.indisunique, ix.indexprs, ix.indpred, a.attname, a.attnum, NULL, ix.indkey%s, - i.reloptions, am.amname + %s, am.amname FROM pg_class t join pg_index ix on t.oid = ix.indrelid @@ -2654,6 +2343,8 @@ class PGDialect(default.DefaultDialect): # cast does not work in PG 8.2.4, does work in 8.3.0. # nothing in PG changelogs regarding this. "::varchar" if self.server_version_info >= (8, 3) else "", + "i.reloptions" if self.server_version_info >= (8, 2) + else "NULL", self._pg_index_any("a.attnum", "ix.indkey") ) else: diff --git a/lib/sqlalchemy/dialects/postgresql/constraints.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 4cfc050de..1a443c2d7 100644 --- a/lib/sqlalchemy/dialects/postgresql/constraints.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -1,11 +1,69 @@ -# Copyright (C) 2013-2015 the SQLAlchemy authors and contributors +# postgresql/ext.py +# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from ...sql.schema import ColumnCollectionConstraint + from ...sql import expression -from ... import util +from ...sql import elements +from ...sql import functions +from ...sql.schema import ColumnCollectionConstraint +from .array import ARRAY + + +class aggregate_order_by(expression.ColumnElement): + """Represent a Postgresql aggregate order by expression. + + E.g.:: + + from sqlalchemy.dialects.postgresql import aggregate_order_by + expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) + stmt = select([expr]) + + would represent the expression:: + + SELECT array_agg(a ORDER BY b DESC) FROM table; + + Similarly:: + + expr = func.string_agg( + table.c.a, + aggregate_order_by(literal_column("','"), table.c.a) + ) + stmt = select([expr]) + + Would represent:: + + SELECT string_agg(a, ',' ORDER BY a) FROM table; + + .. versionadded:: 1.1 + + .. seealso:: + + :class:`.array_agg` + + """ + + __visit_name__ = 'aggregate_order_by' + + def __init__(self, target, order_by): + self.target = elements._literal_as_binds(target) + self.order_by = elements._literal_as_binds(order_by) + + def self_group(self, against=None): + return self + + def get_children(self, **kwargs): + return self.target, self.order_by + + def _copy_internals(self, clone=elements._clone, **kw): + self.target = clone(self.target, **kw) + self.order_by = clone(self.order_by, **kw) + + @property + def _from_objects(self): + return self.target._from_objects + self.order_by._from_objects class ExcludeConstraint(ColumnCollectionConstraint): @@ -84,7 +142,7 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE ) self.using = kw.get('using', 'gist') where = kw.get('where') - if where: + if where is not None: self.where = expression._literal_as_text(where) def copy(self, **kw): @@ -96,3 +154,15 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE initially=self.initially) c.dispatch._update(self.dispatch) return c + + +def array_agg(*arg, **kw): + """Postgresql-specific form of :class:`.array_agg`, ensures + return type is :class:`.postgresql.ARRAY` and not + the plain :class:`.types.Array`. + + .. versionadded:: 1.1 + + """ + kw['type_'] = ARRAY(functions._type_from_args(arg)) + return functions.func.array_agg(*arg, **kw) diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 9f369cb5b..b7b0fc007 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -7,110 +7,43 @@ import re -from .base import ARRAY, ischema_names +from .base import ischema_names +from .array import ARRAY from ... import types as sqltypes from ...sql import functions as sqlfunc +from ...sql import operators from ...sql.operators import custom_op from ... import util __all__ = ('HSTORE', 'hstore') -# My best guess at the parsing rules of hstore literals, since no formal -# grammar is given. This is mostly reverse engineered from PG's input parser -# behavior. -HSTORE_PAIR_RE = re.compile(r""" -( - "(?P<key> (\\ . | [^"])* )" # Quoted key -) -[ ]* => [ ]* # Pair operator, optional adjoining whitespace -( - (?P<value_null> NULL ) # NULL value - | "(?P<value> (\\ . | [^"])* )" # Quoted value -) -""", re.VERBOSE) - -HSTORE_DELIMITER_RE = re.compile(r""" -[ ]* , [ ]* -""", re.VERBOSE) - - -def _parse_error(hstore_str, pos): - """format an unmarshalling error.""" - - ctx = 20 - hslen = len(hstore_str) - - parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)] - residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)] - if len(parsed_tail) > ctx: - parsed_tail = '[...]' + parsed_tail[1:] - if len(residual) > ctx: - residual = residual[:-1] + '[...]' - - return "After %r, could not parse residual at position %d: %r" % ( - parsed_tail, pos, residual) - - -def _parse_hstore(hstore_str): - """Parse an hstore from its literal string representation. - - Attempts to approximate PG's hstore input parsing rules as closely as - possible. Although currently this is not strictly necessary, since the - current implementation of hstore's output syntax is stricter than what it - accepts as input, the documentation makes no guarantees that will always - be the case. - - - - """ - result = {} - pos = 0 - pair_match = HSTORE_PAIR_RE.match(hstore_str) - - while pair_match is not None: - key = pair_match.group('key').replace(r'\"', '"').replace( - "\\\\", "\\") - if pair_match.group('value_null'): - value = None - else: - value = pair_match.group('value').replace( - r'\"', '"').replace("\\\\", "\\") - result[key] = value - - pos += pair_match.end() - - delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:]) - if delim_match is not None: - pos += delim_match.end() - - pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:]) - - if pos != len(hstore_str): - raise ValueError(_parse_error(hstore_str, pos)) +INDEX = custom_op( + "->", precedence=5, natural_self_precedent=True +) - return result +HAS_KEY = operators.custom_op( + "?", precedence=5, natural_self_precedent=True +) +HAS_ALL = operators.custom_op( + "?&", precedence=5, natural_self_precedent=True +) -def _serialize_hstore(val): - """Serialize a dictionary into an hstore literal. Keys and values must - both be strings (except None for values). +HAS_ANY = operators.custom_op( + "?|", precedence=5, natural_self_precedent=True +) - """ - def esc(s, position): - if position == 'value' and s is None: - return 'NULL' - elif isinstance(s, util.string_types): - return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"') - else: - raise ValueError("%r in %s position is not a string." % - (s, position)) +CONTAINS = operators.custom_op( + "@>", precedence=5, natural_self_precedent=True +) - return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value')) - for k, v in val.items()) +CONTAINED_BY = operators.custom_op( + "<@", precedence=5, natural_self_precedent=True +) -class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine): +class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): """Represent the Postgresql HSTORE type. The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: @@ -185,51 +118,61 @@ class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine): __visit_name__ = 'HSTORE' hashable = False + text_type = sqltypes.Text() + + def __init__(self, text_type=None): + """Construct a new :class:`.HSTORE`. + + :param text_type: the type that should be used for indexed values. + Defaults to :class:`.types.Text`. + + .. versionadded:: 1.1.0 - class comparator_factory(sqltypes.Concatenable.Comparator): + """ + if text_type is not None: + self.text_type = text_type + + class Comparator( + sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator): """Define comparison operations for :class:`.HSTORE`.""" def has_key(self, other): """Boolean expression. Test for presence of a key. Note that the key may be a SQLA expression. """ - return self.expr.op('?')(other) + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) def has_all(self, other): - """Boolean expression. Test for presence of all keys in the PG - array. + """Boolean expression. Test for presence of all keys in jsonb """ - return self.expr.op('?&')(other) + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) def has_any(self, other): - """Boolean expression. Test for presence of any key in the PG - array. + """Boolean expression. Test for presence of any key in jsonb """ - return self.expr.op('?|')(other) - - def defined(self, key): - """Boolean expression. Test for presence of a non-NULL value for - the key. Note that the key may be a SQLA expression. - """ - return _HStoreDefinedFunction(self.expr, key) + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) def contains(self, other, **kwargs): - """Boolean expression. Test if keys are a superset of the keys of - the argument hstore expression. + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. """ - return self.expr.op('@>')(other) + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) def contained_by(self, other): """Boolean expression. Test if keys are a proper subset of the - keys of the argument hstore expression. + keys of the argument jsonb expression. """ - return self.expr.op('<@')(other) + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean) - def __getitem__(self, other): - """Text expression. Get the value at a given key. Note that the - key may be a SQLA expression. + def _setup_getitem(self, index): + return INDEX, index, self.type.text_type + + def defined(self, key): + """Boolean expression. Test for presence of a non-NULL value for + the key. Note that the key may be a SQLA expression. """ - return self.expr.op('->', precedence=5)(other) + return _HStoreDefinedFunction(self.expr, key) def delete(self, key): """HStore expression. Returns the contents of this hstore with the @@ -263,14 +206,7 @@ class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine): """Text array expression. Returns array of [key, value] pairs.""" return _HStoreMatrixFunction(self.expr) - def _adapt_expression(self, op, other_comparator): - if isinstance(op, custom_op): - if op.opstring in ['?', '?&', '?|', '@>', '<@']: - return op, sqltypes.Boolean - elif op.opstring == '->': - return op, sqltypes.Text - return sqltypes.Concatenable.Comparator.\ - _adapt_expression(self, op, other_comparator) + comparator_factory = Comparator def bind_processor(self, dialect): if util.py2k: @@ -374,3 +310,105 @@ class _HStoreArrayFunction(sqlfunc.GenericFunction): class _HStoreMatrixFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) name = 'hstore_to_matrix' + + +# +# parsing. note that none of this is used with the psycopg2 backend, +# which provides its own native extensions. +# + +# My best guess at the parsing rules of hstore literals, since no formal +# grammar is given. This is mostly reverse engineered from PG's input parser +# behavior. +HSTORE_PAIR_RE = re.compile(r""" +( + "(?P<key> (\\ . | [^"])* )" # Quoted key +) +[ ]* => [ ]* # Pair operator, optional adjoining whitespace +( + (?P<value_null> NULL ) # NULL value + | "(?P<value> (\\ . | [^"])* )" # Quoted value +) +""", re.VERBOSE) + +HSTORE_DELIMITER_RE = re.compile(r""" +[ ]* , [ ]* +""", re.VERBOSE) + + +def _parse_error(hstore_str, pos): + """format an unmarshalling error.""" + + ctx = 20 + hslen = len(hstore_str) + + parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)] + residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)] + + if len(parsed_tail) > ctx: + parsed_tail = '[...]' + parsed_tail[1:] + if len(residual) > ctx: + residual = residual[:-1] + '[...]' + + return "After %r, could not parse residual at position %d: %r" % ( + parsed_tail, pos, residual) + + +def _parse_hstore(hstore_str): + """Parse an hstore from its literal string representation. + + Attempts to approximate PG's hstore input parsing rules as closely as + possible. Although currently this is not strictly necessary, since the + current implementation of hstore's output syntax is stricter than what it + accepts as input, the documentation makes no guarantees that will always + be the case. + + + + """ + result = {} + pos = 0 + pair_match = HSTORE_PAIR_RE.match(hstore_str) + + while pair_match is not None: + key = pair_match.group('key').replace(r'\"', '"').replace( + "\\\\", "\\") + if pair_match.group('value_null'): + value = None + else: + value = pair_match.group('value').replace( + r'\"', '"').replace("\\\\", "\\") + result[key] = value + + pos += pair_match.end() + + delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:]) + if delim_match is not None: + pos += delim_match.end() + + pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:]) + + if pos != len(hstore_str): + raise ValueError(_parse_error(hstore_str, pos)) + + return result + + +def _serialize_hstore(val): + """Serialize a dictionary into an hstore literal. Keys and values must + both be strings (except None for values). + + """ + def esc(s, position): + if position == 'value' and s is None: + return 'NULL' + elif isinstance(s, util.string_types): + return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"') + else: + raise ValueError("%r in %s position is not a string." % + (s, position)) + + return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value')) + for k, v in val.items()) + + diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 13ebc4afe..8a50270f5 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -6,96 +6,60 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from __future__ import absolute_import +import collections import json from .base import ischema_names from ... import types as sqltypes -from ...sql.operators import custom_op -from ... import sql -from ...sql import elements, default_comparator +from ...sql import operators +from ...sql import elements from ... import util -__all__ = ('JSON', 'JSONElement', 'JSONB') +__all__ = ('JSON', 'JSONB') -class JSONElement(elements.BinaryExpression): - """Represents accessing an element of a :class:`.JSON` value. +# json : returns json +INDEX = operators.custom_op( + "->", precedence=5, natural_self_precedent=True +) - The :class:`.JSONElement` is produced whenever using the Python index - operator on an expression that has the type :class:`.JSON`:: +# path operator: returns json +PATHIDX = operators.custom_op( + "#>", precedence=5, natural_self_precedent=True +) - expr = mytable.c.json_data['some_key'] +# json + astext: returns text +ASTEXT = operators.custom_op( + "->>", precedence=5, natural_self_precedent=True +) - The expression typically compiles to a JSON access such as ``col -> key``. - Modifiers are then available for typing behavior, including - :meth:`.JSONElement.cast` and :attr:`.JSONElement.astext`. +# path operator + astext: returns text +ASTEXT_PATHIDX = operators.custom_op( + "#>>", precedence=5, natural_self_precedent=True +) - """ - - def __init__(self, left, right, astext=False, - opstring=None, result_type=None): - self._astext = astext - if opstring is None: - if hasattr(right, '__iter__') and \ - not isinstance(right, util.string_types): - opstring = "#>" - right = "{%s}" % ( - ", ".join(util.text_type(elem) for elem in right)) - else: - opstring = "->" - - self._json_opstring = opstring - operator = custom_op(opstring, precedence=5) - right = default_comparator._check_literal( - left, operator, right) - super(JSONElement, self).__init__( - left, right, operator, type_=result_type) - - @property - def astext(self): - """Convert this :class:`.JSONElement` to use the 'astext' operator - when evaluated. - - E.g.:: - - select([data_table.c.data['some key'].astext]) - - .. seealso:: - - :meth:`.JSONElement.cast` - - """ - if self._astext: - return self - else: - return JSONElement( - self.left, - self.right, - astext=True, - opstring=self._json_opstring + ">", - result_type=sqltypes.String(convert_unicode=True) - ) - - def cast(self, type_): - """Convert this :class:`.JSONElement` to apply both the 'astext' operator - as well as an explicit type cast when evaluated. - - E.g.:: +HAS_KEY = operators.custom_op( + "?", precedence=5, natural_self_precedent=True +) - select([data_table.c.data['some key'].cast(Integer)]) +HAS_ALL = operators.custom_op( + "?&", precedence=5, natural_self_precedent=True +) - .. seealso:: +HAS_ANY = operators.custom_op( + "?|", precedence=5, natural_self_precedent=True +) - :attr:`.JSONElement.astext` +CONTAINS = operators.custom_op( + "@>", precedence=5, natural_self_precedent=True +) - """ - if not self._astext: - return self.astext.cast(type_) - else: - return sql.cast(self, type_) +CONTAINED_BY = operators.custom_op( + "<@", precedence=5, natural_self_precedent=True +) -class JSON(sqltypes.TypeEngine): +class JSON(sqltypes.Indexable, sqltypes.TypeEngine): """Represent the Postgresql JSON type. The :class:`.JSON` type stores arbitrary JSON format data, e.g.:: @@ -113,31 +77,36 @@ class JSON(sqltypes.TypeEngine): :class:`.JSON` provides several operations: - * Index operations:: + * Index operations (the ``->`` operator):: data_table.c.data['some key'] - * Index operations returning text (required for text comparison):: + * Index operations returning text (the ``->>`` operator):: data_table.c.data['some key'].astext == 'some value' - * Index operations with a built-in CAST call:: + * Index operations with CAST + (equivalent to ``CAST(col ->> ['some key'] AS <type>)``):: - data_table.c.data['some key'].cast(Integer) == 5 + data_table.c.data['some key'].astext.cast(Integer) == 5 - * Path index operations:: + * Path index operations (the ``#>`` operator):: data_table.c.data[('key_1', 'key_2', ..., 'key_n')] - * Path index operations returning text (required for text comparison):: + * Path index operations returning text (the ``#>>`` operator):: + + data_table.c.data[('key_1', 'key_2', ..., 'key_n')].astext == \ +'some value' - data_table.c.data[('key_1', 'key_2', ..., 'key_n')].astext == \\ - 'some value' + .. versionchanged:: 1.1 The :meth:`.ColumnElement.cast` operator on + JSON objects now requires that the :attr:`.JSON.Comparator.astext` + modifier be called explicitly, if the cast works only from a textual + string. - Index operations return an instance of :class:`.JSONElement`, which - represents an expression such as ``column -> index``. This element then - defines methods such as :attr:`.JSONElement.astext` and - :meth:`.JSONElement.cast` for setting up type behavior. + Index operations return an expression object whose type defaults to + :class:`.JSON` by default, so that further JSON-oriented instructions + may be called upon the result type. The :class:`.JSON` type, when used with the SQLAlchemy ORM, does not detect in-place mutations to the structure. In order to detect these, the @@ -146,6 +115,29 @@ class JSON(sqltypes.TypeEngine): will be detected by the unit of work. See the example at :class:`.HSTORE` for a simple example involving a dictionary. + When working with NULL values, the :class:`.JSON` type recommends the + use of two specific constants in order to differentiate between a column + that evaluates to SQL NULL, e.g. no value, vs. the JSON-encoded string + of ``"null"``. To insert or select against a value that is SQL NULL, + use the constant :func:`.null`:: + + conn.execute(table.insert(), json_value=null()) + + To insert or select against a value that is JSON ``"null"``, use the + constant :attr:`.JSON.NULL`:: + + conn.execute(table.insert(), json_value=JSON.NULL) + + The :class:`.JSON` type supports a flag + :paramref:`.JSON.none_as_null` which when set to True will result + in the Python constant ``None`` evaluating to the value of SQL + NULL, and when set to False results in the Python constant + ``None`` evaluating to the value of JSON ``"null"``. The Python + value ``None`` may be used in conjunction with either + :attr:`.JSON.NULL` and :func:`.null` in order to indicate NULL + values, but care must be taken as to the value of the + :paramref:`.JSON.none_as_null` in these cases. + Custom serializers and deserializers are specified at the dialect level, that is using :func:`.create_engine`. The reason for this is that when using psycopg2, the DBAPI only allows serializers at the per-cursor @@ -161,11 +153,42 @@ class JSON(sqltypes.TypeEngine): .. versionadded:: 0.9 + .. seealso:: + + :class:`.JSONB` + """ __visit_name__ = 'JSON' - def __init__(self, none_as_null=False): + hashable = False + astext_type = sqltypes.Text() + + NULL = util.symbol('JSON_NULL') + """Describe the json value of NULL. + + This value is used to force the JSON value of ``"null"`` to be + used as the value. A value of Python ``None`` will be recognized + either as SQL NULL or JSON ``"null"``, based on the setting + of the :paramref:`.JSON.none_as_null` flag; the :attr:`.JSON.NULL` + constant can be used to always resolve to JSON ``"null"`` regardless + of this setting. This is in contrast to the :func:`.sql.null` construct, + which always resolves to SQL NULL. E.g.:: + + from sqlalchemy import null + from sqlalchemy.dialects.postgresql import JSON + + obj1 = MyObject(json_value=null()) # will *always* insert SQL NULL + obj2 = MyObject(json_value=JSON.NULL) # will *always* insert JSON string "null" + + session.add_all([obj1, obj2]) + session.commit() + + .. versionadded:: 1.1 + + """ + + def __init__(self, none_as_null=False, astext_type=None): """Construct a :class:`.JSON` type. :param none_as_null: if True, persist the value ``None`` as a @@ -179,58 +202,99 @@ class JSON(sqltypes.TypeEngine): .. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null` is now supported in order to persist a NULL value. + .. seealso:: + + :attr:`.JSON.NULL` + + :param astext_type: the type to use for the + :attr:`.JSON.Comparator.astext` + accessor on indexed attributes. Defaults to :class:`.types.Text`. + + .. versionadded:: 1.1.0 + """ self.none_as_null = none_as_null + if astext_type is not None: + self.astext_type = astext_type - class comparator_factory(sqltypes.Concatenable.Comparator): + class Comparator( + sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator): """Define comparison operations for :class:`.JSON`.""" - def __getitem__(self, other): - """Get the value at a given key.""" + @property + def astext(self): + """On an indexed expression, use the "astext" (e.g. "->>") + conversion when rendered in SQL. + + E.g.:: + + select([data_table.c.data['some key'].astext]) + + .. seealso:: + + :meth:`.ColumnElement.cast` + + """ + against = self.expr.operator + if against is PATHIDX: + against = ASTEXT_PATHIDX + else: + against = ASTEXT + + return self.expr.left.operate( + against, self.expr.right, result_type=self.type.astext_type) + + def _setup_getitem(self, index): + if not isinstance(index, util.string_types): + assert isinstance(index, collections.Sequence) + tokens = [util.text_type(elem) for elem in index] + index = "{%s}" % (", ".join(tokens)) + operator = PATHIDX + else: + operator = INDEX - return JSONElement(self.expr, other) + return operator, index, self.type - def _adapt_expression(self, op, other_comparator): - if isinstance(op, custom_op): - if op.opstring == '->': - return op, sqltypes.Text - return sqltypes.Concatenable.Comparator.\ - _adapt_expression(self, op, other_comparator) + comparator_factory = Comparator + + @property + def should_evaluate_none(self): + return not self.none_as_null def bind_processor(self, dialect): json_serializer = dialect._json_serializer or json.dumps if util.py2k: encoding = dialect.encoding - - def process(value): - if isinstance(value, elements.Null) or ( - value is None and self.none_as_null - ): - return None - return json_serializer(value).encode(encoding) else: - def process(value): - if isinstance(value, elements.Null) or ( - value is None and self.none_as_null - ): - return None + encoding = None + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, elements.Null) or ( + value is None and self.none_as_null + ): + return None + if encoding: + return json_serializer(value).encode(encoding) + else: return json_serializer(value) + return process def result_processor(self, dialect, coltype): json_deserializer = dialect._json_deserializer or json.loads if util.py2k: encoding = dialect.encoding - - def process(value): - if value is None: - return None - return json_deserializer(value.decode(encoding)) else: - def process(value): - if value is None: - return None - return json_deserializer(value) + encoding = None + + def process(value): + if value is None: + return None + if encoding: + value = value.decode(encoding) + return json_deserializer(value) return process @@ -253,106 +317,68 @@ class JSONB(JSON): data = {"key1": "value1", "key2": "value2"} ) - :class:`.JSONB` provides several operations: - - * Index operations:: - - data_table.c.data['some key'] - - * Index operations returning text (required for text comparison):: + The :class:`.JSONB` type includes all operations provided by + :class:`.JSON`, including the same behaviors for indexing operations. + It also adds additional operators specific to JSONB, including + :meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`, + :meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`, + and :meth:`.JSONB.Comparator.contained_by`. + + Like the :class:`.JSON` type, the :class:`.JSONB` type does not detect + in-place changes when used with the ORM, unless the + :mod:`sqlalchemy.ext.mutable` extension is used. + + Custom serializers and deserializers + are shared with the :class:`.JSON` class, using the ``json_serializer`` + and ``json_deserializer`` keyword arguments. These must be specified + at the dialect level using :func:`.create_engine`. When using + psycopg2, the serializers are associated with the jsonb type using + ``psycopg2.extras.register_default_jsonb`` on a per-connection basis, + in the same way that ``psycopg2.extras.register_default_json`` is used + to register these handlers with the json type. - data_table.c.data['some key'].astext == 'some value' - - * Index operations with a built-in CAST call:: - - data_table.c.data['some key'].cast(Integer) == 5 - - * Path index operations:: - - data_table.c.data[('key_1', 'key_2', ..., 'key_n')] - - * Path index operations returning text (required for text comparison):: - - data_table.c.data[('key_1', 'key_2', ..., 'key_n')].astext == \\ - 'some value' - - Index operations return an instance of :class:`.JSONElement`, which - represents an expression such as ``column -> index``. This element then - defines methods such as :attr:`.JSONElement.astext` and - :meth:`.JSONElement.cast` for setting up type behavior. - - The :class:`.JSON` type, when used with the SQLAlchemy ORM, does not - detect in-place mutations to the structure. In order to detect these, the - :mod:`sqlalchemy.ext.mutable` extension must be used. This extension will - allow "in-place" changes to the datastructure to produce events which - will be detected by the unit of work. See the example at :class:`.HSTORE` - for a simple example involving a dictionary. - - Custom serializers and deserializers are specified at the dialect level, - that is using :func:`.create_engine`. The reason for this is that when - using psycopg2, the DBAPI only allows serializers at the per-cursor - or per-connection level. E.g.:: - - engine = create_engine("postgresql://scott:tiger@localhost/test", - json_serializer=my_serialize_fn, - json_deserializer=my_deserialize_fn - ) + .. versionadded:: 0.9.7 - When using the psycopg2 dialect, the json_deserializer is registered - against the database using ``psycopg2.extras.register_default_json``. + .. seealso:: - .. versionadded:: 0.9.7 + :class:`.JSON` """ __visit_name__ = 'JSONB' - hashable = False - class comparator_factory(sqltypes.Concatenable.Comparator): + class Comparator(JSON.Comparator): """Define comparison operations for :class:`.JSON`.""" - def __getitem__(self, other): - """Get the value at a given key.""" - - return JSONElement(self.expr, other) - - def _adapt_expression(self, op, other_comparator): - # How does one do equality?? jsonb also has "=" eg. - # '[1,2,3]'::jsonb = '[1,2,3]'::jsonb - if isinstance(op, custom_op): - if op.opstring in ['?', '?&', '?|', '@>', '<@']: - return op, sqltypes.Boolean - if op.opstring == '->': - return op, sqltypes.Text - return sqltypes.Concatenable.Comparator.\ - _adapt_expression(self, op, other_comparator) - def has_key(self, other): """Boolean expression. Test for presence of a key. Note that the key may be a SQLA expression. """ - return self.expr.op('?')(other) + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) def has_all(self, other): """Boolean expression. Test for presence of all keys in jsonb """ - return self.expr.op('?&')(other) + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) def has_any(self, other): """Boolean expression. Test for presence of any key in jsonb """ - return self.expr.op('?|')(other) + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) def contains(self, other, **kwargs): - """Boolean expression. Test if keys (or array) are a superset of/contained - the keys of the argument jsonb expression. + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. """ - return self.expr.op('@>')(other) + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) def contained_by(self, other): """Boolean expression. Test if keys are a proper subset of the keys of the argument jsonb expression. """ - return self.expr.op('<@')(other) + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean) + + comparator_factory = Comparator ischema_names['jsonb'] = JSONB diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 36a9d7bf7..d33554922 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -320,7 +320,7 @@ from ...sql import expression from ... import types as sqltypes from .base import PGDialect, PGCompiler, \ PGIdentifierPreparer, PGExecutionContext, \ - ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ + ENUM, _DECIMAL_TYPES, _FLOAT_TYPES,\ _INT_TYPES, UUID from .hstore import HSTORE from .json import JSON, JSONB diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index d9da46f4c..a1786d16c 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -853,12 +853,20 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if not column.nullable: colspec += " NOT NULL" - if (column.primary_key and - column.table.dialect_options['sqlite']['autoincrement'] and - len(column.table.primary_key.columns) == 1 and - issubclass(column.type._type_affinity, sqltypes.Integer) and - not column.foreign_keys): - colspec += " PRIMARY KEY AUTOINCREMENT" + if column.primary_key: + if ( + column.autoincrement is True and + len(column.table.primary_key.columns) != 1 + ): + raise exc.CompileError( + "SQLite does not support autoincrement for " + "composite primary keys") + + if (column.table.dialect_options['sqlite']['autoincrement'] and + len(column.table.primary_key.columns) == 1 and + issubclass(column.type._type_affinity, sqltypes.Integer) and + not column.foreign_keys): + colspec += " PRIMARY KEY AUTOINCREMENT" return colspec @@ -894,11 +902,25 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return preparer.format_table(table, use_schema=False) - def visit_create_index(self, create): + def visit_create_index(self, create, include_schema=False, + include_table_schema=True): index = create.element - - text = super(SQLiteDDLCompiler, self).visit_create_index( - create, include_table_schema=False) + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % ( + self._prepared_index_name(index, + include_schema=True), + preparer.format_table(index.table, + use_schema=False), + ', '.join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True) for + expr in index.expressions) + ) whereclause = index.dialect_options["sqlite"]["where"] if whereclause is not None: @@ -1095,6 +1117,13 @@ class SQLiteDialect(default.DefaultDialect): return None @reflection.cache + def get_schema_names(self, connection, **kw): + s = "PRAGMA database_list" + dl = connection.execute(s) + + return [db[1] for db in dl if db[1] != "temp"] + + @reflection.cache def get_table_names(self, connection, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) @@ -1190,7 +1219,7 @@ class SQLiteDialect(default.DefaultDialect): 'type': coltype, 'nullable': nullable, 'default': default, - 'autoincrement': default is None, + 'autoincrement': 'auto', 'primary_key': primary_key, } @@ -1283,7 +1312,7 @@ class SQLiteDialect(default.DefaultDialect): fk = fks[numerical_id] = { 'name': None, 'constrained_columns': [], - 'referred_schema': None, + 'referred_schema': schema, 'referred_table': rtbl, 'referred_columns': [], } @@ -1387,7 +1416,7 @@ class SQLiteDialect(default.DefaultDialect): unique_constraints = [] def parse_uqs(): - UNIQUE_PATTERN = '(?:CONSTRAINT (\w+) +)?UNIQUE *\((.+?)\)' + UNIQUE_PATTERN = '(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( '(?:(".+?")|([a-z0-9]+)) ' '+[a-z0-9_ ]+? +UNIQUE') diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index ae0473a3e..b3f8e307a 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -608,8 +608,8 @@ class SybaseDialect(default.DefaultDialect): FROM sysreferences r JOIN sysobjects o on r.tableid = o.id WHERE r.tableid = :table_id """) - referential_constraints = connection.execute(REFCONSTRAINT_SQL, - table_id=table_id) + referential_constraints = connection.execute( + REFCONSTRAINT_SQL, table_id=table_id).fetchall() REFTABLE_SQL = text(""" SELECT o.name AS name, u.name AS 'schema' @@ -740,10 +740,13 @@ class SybaseDialect(default.DefaultDialect): results.close() constrained_columns = [] - for i in range(1, pks["count"] + 1): - constrained_columns.append(pks["pk_%i" % (i,)]) - return {"constrained_columns": constrained_columns, - "name": pks["name"]} + if pks: + for i in range(1, pks["count"] + 1): + constrained_columns.append(pks["pk_%i" % (i,)]) + return {"constrained_columns": constrained_columns, + "name": pks["name"]} + else: + return {"constrained_columns": [], "name": None} @reflection.cache def get_schema_names(self, connection, **kw): diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index f1eacf6a6..0b0d50329 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -389,14 +389,33 @@ def create_engine(*args, **kwargs): def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): """Create a new Engine instance using a configuration dictionary. - The dictionary is typically produced from a config file where keys - are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The - 'prefix' argument indicates the prefix to be searched for. + The dictionary is typically produced from a config file. + + The keys of interest to ``engine_from_config()`` should be prefixed, e.g. + ``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument + indicates the prefix to be searched for. Each matching key (after the + prefix is stripped) is treated as though it were the corresponding keyword + argument to a :func:`.create_engine` call. + + The only required key is (assuming the default prefix) ``sqlalchemy.url``, + which provides the :ref:`database URL <database_urls>`. A select set of keyword arguments will be "coerced" to their expected type based on string values. The set of arguments is extensible per-dialect using the ``engine_config_types`` accessor. + :param configuration: A dictionary (typically produced from a config file, + but this is not a requirement). Items whose keys start with the value + of 'prefix' will have that prefix stripped, and will then be passed to + :ref:`create_engine`. + + :param prefix: Prefix to match and then strip from keys + in 'configuration'. + + :param kwargs: Each keyword argument to ``engine_from_config()`` itself + overrides the corresponding item taken from the 'configuration' + dictionary. Keyword arguments should *not* be prefixed. + """ options = dict((key[len(prefix):], configuration[key]) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index c5eabac0d..eaa435d45 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1531,9 +1531,13 @@ class Transaction(object): def __init__(self, connection, parent): self.connection = connection - self._parent = parent or self + self._actual_parent = parent self.is_active = True + @property + def _parent(self): + return self._actual_parent or self + def close(self): """Close this :class:`.Transaction`. diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 73a8b4635..3bad765df 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -252,7 +252,9 @@ class Dialect(object): sequence a dictionary of the form - {'name' : str, 'start' :int, 'increment': int} + {'name' : str, 'start' :int, 'increment': int, 'minvalue': int, + 'maxvalue': int, 'nominvalue': bool, 'nomaxvalue': bool, + 'cycle': bool} Additional column attributes may be present. """ @@ -1147,4 +1149,4 @@ class ExceptionContext(object): .. versionadded:: 1.0.3 - """
\ No newline at end of file + """ diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index b2b78dee8..7d1425c28 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -221,7 +221,7 @@ class ResultMetaData(object): in enumerate(result_columns) ] self.keys = [ - elem[1] for elem in result_columns + elem[0] for elem in result_columns ] else: # case 2 - raw string, or number of columns in result does @@ -236,7 +236,8 @@ class ResultMetaData(object): # that SQLAlchemy has used up through 0.9. if num_ctx_cols: - result_map = self._create_result_map(result_columns) + result_map = self._create_result_map( + result_columns, case_sensitive) raw = [] self.keys = [] @@ -329,10 +330,12 @@ class ResultMetaData(object): ]) @classmethod - def _create_result_map(cls, result_columns): + def _create_result_map(cls, result_columns, case_sensitive=True): d = {} for elem in result_columns: key, rec = elem[0], elem[1:] + if not case_sensitive: + key = key.lower() if key in d: # conflicting keyname, just double up the list # of objects. this will cause an "ambiguous name" @@ -492,10 +495,20 @@ class ResultProxy(object): self._init_metadata() def _getter(self, key): - return self._metadata._getter(key) + try: + getter = self._metadata._getter + except AttributeError: + return self._non_result(None) + else: + return getter(key) def _has_key(self, key): - return self._metadata._has_key(key) + try: + has_key = self._metadata._has_key + except AttributeError: + return self._non_result(None) + else: + return has_key(key) def _init_metadata(self): metadata = self._cursor_description() @@ -699,7 +712,7 @@ class ResultProxy(object): while True: row = self.fetchone() if row is None: - raise StopIteration + return else: yield row diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index a64c7d08d..8a88e40ef 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -51,7 +51,7 @@ class _ClsLevelDispatch(RefCollection): """Class-level events on :class:`._Dispatch` classes.""" __slots__ = ('name', 'arg_names', 'has_kw', - 'legacy_signatures', '_clslevel') + 'legacy_signatures', '_clslevel', '__weakref__') def __init__(self, parent_dispatch_cls, fn): self.name = fn.__name__ @@ -230,9 +230,7 @@ class _EmptyListener(_InstanceLevelDispatch): class _CompoundListener(_InstanceLevelDispatch): - _exec_once = False - - __slots__ = '_exec_once_mutex', + __slots__ = '_exec_once_mutex', '_exec_once' def _memoized_attr__exec_once_mutex(self): return threading.Lock() @@ -279,11 +277,14 @@ class _ListenerCollection(_CompoundListener): """ - __slots__ = 'parent_listeners', 'parent', 'name', 'listeners', 'propagate' + __slots__ = ( + 'parent_listeners', 'parent', 'name', 'listeners', + 'propagate', '__weakref__') def __init__(self, parent, target_cls): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) + self._exec_once = False self.parent_listeners = parent._clslevel[target_cls] self.parent = parent self.name = parent.name @@ -339,11 +340,10 @@ class _ListenerCollection(_CompoundListener): class _JoinedListener(_CompoundListener): - _exec_once = False - __slots__ = 'parent', 'name', 'local', 'parent_listeners' def __init__(self, parent, name, local): + self._exec_once = False self.parent = parent self.name = name self.local = local diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index f439d554f..0249b2623 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -819,6 +819,11 @@ class ConnectionEvents(event.Events): .. seealso:: + :ref:`pool_disconnects_pessimistic` - illustrates how to use + :meth:`.ConnectionEvents.engine_connect` + to transparently ensure pooled connections are connected to the + database. + :meth:`.PoolEvents.checkout` the lower-level pool checkout event for an individual DBAPI connection diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index d837aab52..31f16287d 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -94,7 +94,7 @@ class AssociationProxy(interfaces.InspectionAttrInfo): def __init__(self, target_collection, attr, creator=None, getset_factory=None, proxy_factory=None, - proxy_bulk_set=None): + proxy_bulk_set=None, info=None): """Construct a new :class:`.AssociationProxy`. The :func:`.association_proxy` function is provided as the usual @@ -138,6 +138,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo): :param proxy_bulk_set: Optional, use with proxy_factory. See the _set() method for details. + :param info: optional, will be assigned to + :attr:`.AssociationProxy.info` if present. + + .. versionadded:: 1.0.9 + """ self.target_collection = target_collection self.value_attr = attr @@ -150,6 +155,8 @@ class AssociationProxy(interfaces.InspectionAttrInfo): self.key = '_%s_%s_%s' % ( type(self).__name__, target_collection, id(self)) self.collection_class = None + if info: + self.info = info @property def remote_attr(self): @@ -596,7 +603,7 @@ class _AssociationList(_AssociationCollection): for member in self.col: yield self._get(member) - raise StopIteration + return def append(self, value): item = self._create(value) @@ -900,7 +907,7 @@ class _AssociationSet(_AssociationCollection): """ for member in self.col: yield self._get(member) - raise StopIteration + return def add(self, value): if value not in self: diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 330992e56..218ed64e1 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -111,7 +111,7 @@ explicit table declaration:: User, Address, Order = Base.classes.user, Base.classes.address,\ Base.classes.user_order -Specifying Classes Explcitly +Specifying Classes Explicitly ============================ The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index f01e0b348..d255b5ee4 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -272,16 +272,35 @@ class Result(object): Equivalent to :meth:`.Query.one`. """ + try: + ret = self.one_or_none() + except orm_exc.MultipleResultsFound: + raise orm_exc.MultipleResultsFound( + "Multiple rows were found for one()") + else: + if ret is None: + raise orm_exc.NoResultFound("No row was found for one()") + return ret + + def one_or_none(self): + """Return one or zero results, or raise an exception for multiple + rows. + + Equivalent to :meth:`.Query.one_or_none`. + + .. versionadded:: 1.0.9 + + """ ret = list(self) l = len(ret) if l == 1: return ret[0] elif l == 0: - raise orm_exc.NoResultFound("No row was found for one()") + return None else: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()") + "Multiple rows were found for one_or_none()") def all(self): """Return all rows. @@ -335,6 +354,12 @@ class Result(object): # (remember, we can map to an OUTER JOIN) bq = self.bq + # add the clause we got from mapper._get_clause to the cache + # key so that if a race causes multiple calls to _get_clause, + # we've cached on ours + bq = bq._clone() + bq._cache_key += (_get_clause, ) + bq = bq.with_criteria(setup, tuple(elem is None for elem in ident)) params = dict([ @@ -359,7 +384,6 @@ def bake_lazy_loaders(): Python overhead for these operations. """ - strategies.LazyLoader._strategy_keys[:] = [] BakedLazyLoader._strategy_keys[:] = [] properties.RelationshipProperty.strategy_for( @@ -369,6 +393,8 @@ def bake_lazy_loaders(): properties.RelationshipProperty.strategy_for( lazy="baked_select")(BakedLazyLoader) + strategies.LazyLoader._strategy_keys[:] = BakedLazyLoader._strategy_keys[:] + def unbake_lazy_loaders(): """Disable the use of baked queries for all lazyloaders systemwide. diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index 3d46bd4cb..dfc47ce95 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -7,7 +7,7 @@ """Public API functions and helpers for declarative.""" -from ...schema import Table, MetaData +from ...schema import Table, MetaData, Column from ...orm import synonym as _orm_synonym, \ comparable_property,\ interfaces, properties, attributes @@ -525,6 +525,17 @@ class AbstractConcreteBase(ConcreteBase): mappers.append(mn) pjoin = cls._create_polymorphic_union(mappers) + # For columns that were declared on the class, these + # are normally ignored with the "__no_table__" mapping, + # unless they have a different attribute key vs. col name + # and are in the properties argument. + # In that case, ensure we update the properties entry + # to the correct column from the pjoin target table. + declared_cols = set(to_map.declared_columns) + for k, v in list(to_map.properties.items()): + if v in declared_cols: + to_map.properties[k] = pjoin.c[v.key] + to_map.local_table = pjoin m_args = to_map.mapper_args_fn or dict diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index 57eb54f63..57305748c 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -463,7 +463,6 @@ class _MapperConfig(object): def _prepare_mapper_arguments(self): properties = self.properties - if self.mapper_args_fn: mapper_args = self.mapper_args_fn() else: diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index c3887d6cf..050923980 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -321,7 +321,8 @@ def _deferred_relationship(cls, prop): key, kwargs = prop.backref for attr in ('primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'remote_side', 'order_by'): - if attr in kwargs and isinstance(kwargs[attr], str): + if attr in kwargs and isinstance(kwargs[attr], + util.string_types): kwargs[attr] = resolve_arg(kwargs[attr]) return prop diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 9c6178264..0073494b8 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -46,7 +46,7 @@ as the class itself:: @hybrid_method def contains(self, point): - return (self.start <= point) & (point < self.end) + return (self.start <= point) & (point <= self.end) @hybrid_method def intersects(self, other): diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 501b18f39..88b653f60 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -658,6 +658,16 @@ class MutableDict(Mutable, dict): dict.update(self, *a, **kw) self.changed() + def pop(self, key): + result = dict.pop(self, key) + self.changed() + return result + + def popitem(self): + result = dict.popitem(self) + self.changed() + return result + def clear(self): dict.clear(self) self.changed() diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index e02a271e3..d9910a070 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -149,7 +149,12 @@ def backref(name, **kwargs): 'items':relationship( SomeItem, backref=backref('parent', lazy='subquery')) + .. seealso:: + + :ref:`relationships_backref` + """ + return (name, kwargs) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index a45c22394..8605df785 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -551,6 +551,11 @@ class AttributeImpl(object): def initialize(self, state, dict_): """Initialize the given state's attribute with an empty value.""" + # As of 1.0, we don't actually set a value in + # dict_. This is so that the state of the object does not get + # modified without emitting the appropriate events. + + return None def get(self, state, dict_, passive=PASSIVE_OFF): @@ -848,7 +853,10 @@ class CollectionAttributeImpl(AttributeImpl): supports_population = True collection = True - __slots__ = 'copy', 'collection_factory', '_append_token', '_remove_token' + __slots__ = ( + 'copy', 'collection_factory', '_append_token', '_remove_token', + '_duck_typed_as' + ) def __init__(self, class_, key, callable_, dispatch, typecallable=None, trackparent=False, extension=None, @@ -868,6 +876,8 @@ class CollectionAttributeImpl(AttributeImpl): self.collection_factory = typecallable self._append_token = None self._remove_token = None + self._duck_typed_as = util.duck_type_collection( + self.collection_factory()) if getattr(self.collection_factory, "_sa_linker", None): @@ -1011,38 +1021,46 @@ class CollectionAttributeImpl(AttributeImpl): except (ValueError, KeyError, IndexError): pass - def set(self, state, dict_, value, initiator, - passive=PASSIVE_OFF, pop=False): - """Set a value on the given object. - - """ - - self._set_iterable( - state, dict_, value, - lambda adapter, i: adapter.adapt_like_to_iterable(i)) - - def _set_iterable(self, state, dict_, iterable, adapter=None): - """Set a collection value from an iterable of state-bearers. + def set(self, state, dict_, value, initiator=None, + passive=PASSIVE_OFF, pop=False, _adapt=True): + iterable = orig_iterable = value - ``adapter`` is an optional callable invoked with a CollectionAdapter - and the iterable. Should return an iterable of state-bearing - instances suitable for appending via a CollectionAdapter. Can be used - for, e.g., adapting an incoming dictionary into an iterator of values - rather than keys. - - """ # pulling a new collection first so that an adaptation exception does # not trigger a lazy load of the old collection. new_collection, user_data = self._initialize_collection(state) - if adapter: - new_values = list(adapter(new_collection, iterable)) - else: - new_values = list(iterable) + if _adapt: + if new_collection._converter is not None: + iterable = new_collection._converter(iterable) + else: + setting_type = util.duck_type_collection(iterable) + receiving_type = self._duck_typed_as + + if setting_type is not receiving_type: + given = iterable is None and 'None' or \ + iterable.__class__.__name__ + wanted = self._duck_typed_as.__name__ + raise TypeError( + "Incompatible collection type: %s is not %s-like" % ( + given, wanted)) + + # If the object is an adapted collection, return the (iterable) + # adapter. + if hasattr(iterable, '_sa_iterator'): + iterable = iterable._sa_iterator() + elif setting_type is dict: + if util.py3k: + iterable = iterable.values() + else: + iterable = getattr( + iterable, 'itervalues', iterable.values)() + else: + iterable = iter(iterable) + new_values = list(iterable) old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT) if old is PASSIVE_NO_RESULT: old = self.initialize(state, dict_) - elif old is iterable: + elif old is orig_iterable: # ignore re-assignment of the current collection, as happens # implicitly with in-place operators (foo.collection |= other) return @@ -1054,7 +1072,8 @@ class CollectionAttributeImpl(AttributeImpl): dict_[self.key] = user_data - collections.bulk_replace(new_values, old_collection, new_collection) + collections.bulk_replace( + new_values, old_collection, new_collection) del old._sa_adapter self.dispatch.dispose_collection(state, old, old_collection) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 4f988a8d4..58a69227c 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -111,6 +111,7 @@ from ..sql import expression from .. import util, exc as sa_exc from . import base +from sqlalchemy.util.compat import inspect_getargspec __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', @@ -573,13 +574,18 @@ class CollectionAdapter(object): """ - invalidated = False + + __slots__ = ( + 'attr', '_key', '_data', 'owner_state', '_converter', 'invalidated') def __init__(self, attr, owner_state, data): + self.attr = attr self._key = attr.key self._data = weakref.ref(data) self.owner_state = owner_state data._sa_adapter = self + self._converter = data._sa_converter + self.invalidated = False def _warn_invalidated(self): util.warn("This collection has been invalidated.") @@ -599,53 +605,8 @@ class CollectionAdapter(object): """ return self.owner_state.dict[self._key] is self._data() - @util.memoized_property - def attr(self): - return self.owner_state.manager[self._key].impl - - def adapt_like_to_iterable(self, obj): - """Converts collection-compatible objects to an iterable of values. - - Can be passed any type of object, and if the underlying collection - determines that it can be adapted into a stream of values it can - use, returns an iterable of values suitable for append()ing. - - This method may raise TypeError or any other suitable exception - if adaptation fails. - - If a converter implementation is not supplied on the collection, - a default duck-typing-based implementation is used. - - """ - converter = self._data()._sa_converter - if converter is not None: - return converter(obj) - - setting_type = util.duck_type_collection(obj) - receiving_type = util.duck_type_collection(self._data()) - - if obj is None or setting_type != receiving_type: - given = obj is None and 'None' or obj.__class__.__name__ - if receiving_type is None: - wanted = self._data().__class__.__name__ - else: - wanted = receiving_type.__name__ - - raise TypeError( - "Incompatible collection type: %s is not %s-like" % ( - given, wanted)) - - # If the object is an adapted collection, return the (iterable) - # adapter. - if getattr(obj, '_sa_adapter', None) is not None: - return obj._sa_adapter - elif setting_type == dict: - if util.py3k: - return obj.values() - else: - return getattr(obj, 'itervalues', obj.values)() - else: - return iter(obj) + def bulk_appender(self): + return self._data()._sa_appender def append_with_event(self, item, initiator=None): """Add an entity to the collection, firing mutation events.""" @@ -662,6 +623,9 @@ class CollectionAdapter(object): for item in items: appender(item, _sa_initiator=False) + def bulk_remover(self): + return self._data()._sa_remover + def remove_with_event(self, item, initiator=None): """Remove an entity from the collection, firing mutation events.""" self._data()._sa_remover(item, _sa_initiator=initiator) @@ -776,8 +740,8 @@ def bulk_replace(values, existing_adapter, new_adapter): """ - if not isinstance(values, list): - values = list(values) + + assert isinstance(values, list) idset = util.IdentitySet existing_idset = idset(existing_adapter or ()) @@ -785,15 +749,18 @@ def bulk_replace(values, existing_adapter, new_adapter): additions = idset(values or ()).difference(constants) removals = existing_idset.difference(constants) + appender = new_adapter.bulk_appender() + for member in values or (): if member in additions: - new_adapter.append_with_event(member) + appender(member) elif member in constants: - new_adapter.append_without_event(member) + appender(member, _sa_initiator=False) if existing_adapter: + remover = existing_adapter.bulk_remover() for member in removals: - existing_adapter.remove_with_event(member) + remover(member) def prepare_instrumentation(factory): @@ -982,7 +949,7 @@ def _instrument_membership_mutator(method, before, argument, after): adapter.""" # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' if before: - fn_args = list(util.flatten_iterator(inspect.getargspec(method)[0])) + fn_args = list(util.flatten_iterator(inspect_getargspec(method)[0])) if isinstance(argument, int): pos_arg = argument named_arg = len(fn_args) > argument and fn_args[argument] or None diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index d8989939b..f3325203e 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -303,9 +303,9 @@ class DependencyProcessor(object): set ) - def _post_update(self, state, uowcommit, related): + def _post_update(self, state, uowcommit, related, is_m2o_delete=False): for x in related: - if x is not None: + if not is_m2o_delete or x is not None: uowcommit.issue_post_update( state, [r for l, r in self.prop.synchronize_pairs] @@ -740,7 +740,9 @@ class ManyToOneDP(DependencyProcessor): self.key, self._passive_delete_flag) if history: - self._post_update(state, uowcommit, history.sum()) + self._post_update( + state, uowcommit, history.sum(), + is_m2o_delete=True) def process_saves(self, uowcommit, states): for state in states: diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index aedd863f8..ca593765f 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -128,17 +128,16 @@ class DynamicAttributeImpl(attributes.AttributeImpl): dict_[self.key] = True return state.committed_state[self.key] - def set(self, state, dict_, value, initiator, + def set(self, state, dict_, value, initiator=None, passive=attributes.PASSIVE_OFF, - check_old=None, pop=False): + check_old=None, pop=False, _adapt=True): if initiator and initiator.parent_token is self.parent_token: return if pop and value is None: return - self._set_iterable(state, dict_, value) - def _set_iterable(self, state, dict_, iterable, adapter=None): + iterable = value new_values = list(iterable) if state.has_identity: old_collection = util.IdentitySet(self.get(state, dict_)) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 801701be9..5b0cbfdad 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -18,6 +18,7 @@ from .session import Session, sessionmaker from .scoping import scoped_session from .attributes import QueryableAttribute from .query import Query +from sqlalchemy.util.compat import inspect_getargspec class InstrumentationEvents(event.Events): """Events related to class instrumentation events. @@ -216,14 +217,41 @@ class InstanceEvents(event.Events): def first_init(self, manager, cls): """Called when the first instance of a particular mapping is called. + This event is called when the ``__init__`` method of a class + is called the first time for that particular class. The event + invokes before ``__init__`` actually proceeds as well as before + the :meth:`.InstanceEvents.init` event is invoked. + """ def init(self, target, args, kwargs): """Receive an instance when its constructor is called. This method is only called during a userland construction of - an object. It is not called when an object is loaded from the - database. + an object, in conjunction with the object's constructor, e.g. + its ``__init__`` method. It is not called when an object is + loaded from the database; see the :meth:`.InstanceEvents.load` + event in order to intercept a database load. + + The event is called before the actual ``__init__`` constructor + of the object is called. The ``kwargs`` dictionary may be + modified in-place in order to affect what is passed to + ``__init__``. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param args: positional arguments passed to the ``__init__`` method. + This is passed as a tuple and is currently immutable. + :param kwargs: keyword arguments passed to the ``__init__`` method. + This structure *can* be altered in place. + + .. seealso:: + + :meth:`.InstanceEvents.init_failure` + + :meth:`.InstanceEvents.load` """ @@ -232,8 +260,31 @@ class InstanceEvents(event.Events): and raised an exception. This method is only called during a userland construction of - an object. It is not called when an object is loaded from the - database. + an object, in conjunction with the object's constructor, e.g. + its ``__init__`` method. It is not called when an object is loaded + from the database. + + The event is invoked after an exception raised by the ``__init__`` + method is caught. After the event + is invoked, the original exception is re-raised outwards, so that + the construction of the object still raises an exception. The + actual exception and stack trace raised should be present in + ``sys.exc_info()``. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param args: positional arguments that were passed to the ``__init__`` + method. + :param kwargs: keyword arguments that were passed to the ``__init__`` + method. + + .. seealso:: + + :meth:`.InstanceEvents.init` + + :meth:`.InstanceEvents.load` """ @@ -260,12 +311,23 @@ class InstanceEvents(event.Events): ``None`` if the load does not correspond to a :class:`.Query`, such as during :meth:`.Session.merge`. + .. seealso:: + + :meth:`.InstanceEvents.init` + + :meth:`.InstanceEvents.refresh` + + :meth:`.SessionEvents.loaded_as_persistent` + """ def refresh(self, target, context, attrs): """Receive an object instance after one or more attributes have been refreshed from a query. + Contrast this to the :meth:`.InstanceEvents.load` method, which + is invoked when the object is first loaded from a query. + :param target: the mapped instance. If the event is configured with ``raw=True``, this will instead be the :class:`.InstanceState` state-management @@ -276,6 +338,10 @@ class InstanceEvents(event.Events): were populated, or None if all column-mapped, non-deferred attributes were populated. + .. seealso:: + + :meth:`.InstanceEvents.load` + """ def refresh_flush(self, target, flush_context, attrs): @@ -538,7 +604,7 @@ class MapperEvents(event.Events): meth = getattr(cls, identifier) try: target_index = \ - inspect.getargspec(meth)[0].index('target') - 1 + inspect_getargspec(meth)[0].index('target') - 1 except ValueError: target_index = None @@ -589,32 +655,67 @@ class MapperEvents(event.Events): """ def mapper_configured(self, mapper, class_): - """Called when the mapper for the class is fully configured. - - This event is the latest phase of mapper construction, and - is invoked when the mapped classes are first used, so that - relationships between mappers can be resolved. When the event is - called, the mapper should be in its final state. - - While the configuration event normally occurs automatically, - it can be forced to occur ahead of time, in the case where the event - is needed before any actual mapper usage, by using the - :func:`.configure_mappers` function. + """Called when a specific mapper has completed its own configuration + within the scope of the :func:`.configure_mappers` call. + + The :meth:`.MapperEvents.mapper_configured` event is invoked + for each mapper that is encountered when the + :func:`.orm.configure_mappers` function proceeds through the current + list of not-yet-configured mappers. + :func:`.orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + When the event is called, the mapper should be in its final + state, but **not including backrefs** that may be invoked from + other mappers; they might still be pending within the + configuration operation. Bidirectional relationships that + are instead configured via the + :paramref:`.orm.relationship.back_populates` argument + *will* be fully available, since this style of relationship does not + rely upon other possibly-not-configured mappers to know that they + exist. + For an event that is guaranteed to have **all** mappers ready + to go including backrefs that are defined only on other + mappings, use the :meth:`.MapperEvents.after_configured` + event; this event invokes only after all known mappings have been + fully configured. + + The :meth:`.MapperEvents.mapper_configured` event, unlike + :meth:`.MapperEvents.before_configured` or + :meth:`.MapperEvents.after_configured`, + is called for each mapper/class individually, and the mapper is + passed to the event itself. It also is called exactly once for + a particular mapper. The event is therefore useful for + configurational steps that benefit from being invoked just once + on a specific mapper basis, which don't require that "backref" + configurations are necessarily ready yet. :param mapper: the :class:`.Mapper` which is the target of this event. :param class\_: the mapped class. + .. seealso:: + + :meth:`.MapperEvents.before_configured` + + :meth:`.MapperEvents.after_configured` + """ # TODO: need coverage for this event def before_configured(self): """Called before a series of mappers have been configured. - This corresponds to the :func:`.orm.configure_mappers` call, which - note is usually called automatically as mappings are first - used. + The :meth:`.MapperEvents.before_configured` event is invoked + each time the :func:`.orm.configure_mappers` function is + invoked, before the function has done any of its work. + :func:`.orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. This event can **only** be applied to the :class:`.Mapper` class or :func:`.mapper` function, and not to individual mappings or @@ -626,11 +727,16 @@ class MapperEvents(event.Events): def go(): # ... + Constrast this event to :meth:`.MapperEvents.after_configured`, + which is invoked after the series of mappers has been configured, + as well as :meth:`.MapperEvents.mapper_configured`, which is invoked + on a per-mapper basis as each one is configured to the extent possible. + Theoretically this event is called once per application, but is actually called any time new mappers are to be affected by a :func:`.orm.configure_mappers` call. If new mappings are constructed after existing ones have - already been used, this event can be called again. To ensure + already been used, this event will likely be called again. To ensure that a particular event is only called once and no further, the ``once=True`` argument (new in 0.9.4) can be applied:: @@ -643,14 +749,33 @@ class MapperEvents(event.Events): .. versionadded:: 0.9.3 + + .. seealso:: + + :meth:`.MapperEvents.mapper_configured` + + :meth:`.MapperEvents.after_configured` + """ def after_configured(self): """Called after a series of mappers have been configured. - This corresponds to the :func:`.orm.configure_mappers` call, which - note is usually called automatically as mappings are first - used. + The :meth:`.MapperEvents.after_configured` event is invoked + each time the :func:`.orm.configure_mappers` function is + invoked, after the function has completed its work. + :func:`.orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + Contrast this event to the :meth:`.MapperEvents.mapper_configured` + event, which is called on a per-mapper basis while the configuration + operation proceeds; unlike that event, when this event is invoked, + all cross-configurations (e.g. backrefs) will also have been made + available for any mappers that were pending. + Also constrast to :meth:`.MapperEvents.before_configured`, + which is invoked before the series of mappers has been configured. This event can **only** be applied to the :class:`.Mapper` class or :func:`.mapper` function, and not to individual mappings or @@ -666,7 +791,7 @@ class MapperEvents(event.Events): application, but is actually called any time new mappers have been affected by a :func:`.orm.configure_mappers` call. If new mappings are constructed after existing ones have - already been used, this event can be called again. To ensure + already been used, this event will likely be called again. To ensure that a particular event is only called once and no further, the ``once=True`` argument (new in 0.9.4) can be applied:: @@ -676,6 +801,12 @@ class MapperEvents(event.Events): def go(): # ... + .. seealso:: + + :meth:`.MapperEvents.mapper_configured` + + :meth:`.MapperEvents.before_configured` + """ def before_insert(self, mapper, connection, target): @@ -697,30 +828,14 @@ class MapperEvents(event.Events): steps. .. warning:: - Mapper-level flush events are designed to operate **on attributes - local to the immediate object being handled - and via SQL operations with the given** - :class:`.Connection` **only.** Handlers here should **not** make - alterations to the state of the :class:`.Session` overall, and - in general should not affect any :func:`.relationship` -mapped - attributes, as session cascade rules will not function properly, - nor is it always known if the related class has already been - handled. Operations that **are not supported in mapper - events** include: - - * :meth:`.Session.add` - * :meth:`.Session.delete` - * Mapped collection append, add, remove, delete, discard, etc. - * Mapped relationship attribute set/del events, - i.e. ``someobject.related = someotherobject`` - - Operations which manipulate the state of the object - relative to other objects are better handled: - - * In the ``__init__()`` method of the mapped object itself, or - another method designed to establish some particular state. - * In a ``@validates`` handler, see :ref:`simple_validators` - * Within the :meth:`.SessionEvents.before_flush` event. + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. :param mapper: the :class:`.Mapper` which is the target of this event. @@ -734,6 +849,10 @@ class MapperEvents(event.Events): object associated with the instance. :return: No return value is supported by this event. + .. seealso:: + + :ref:`session_persistence_events` + """ def after_insert(self, mapper, connection, target): @@ -755,30 +874,14 @@ class MapperEvents(event.Events): event->persist->event steps. .. warning:: - Mapper-level flush events are designed to operate **on attributes - local to the immediate object being handled - and via SQL operations with the given** - :class:`.Connection` **only.** Handlers here should **not** make - alterations to the state of the :class:`.Session` overall, and in - general should not affect any :func:`.relationship` -mapped - attributes, as session cascade rules will not function properly, - nor is it always known if the related class has already been - handled. Operations that **are not supported in mapper - events** include: - - * :meth:`.Session.add` - * :meth:`.Session.delete` - * Mapped collection append, add, remove, delete, discard, etc. - * Mapped relationship attribute set/del events, - i.e. ``someobject.related = someotherobject`` - - Operations which manipulate the state of the object - relative to other objects are better handled: - - * In the ``__init__()`` method of the mapped object itself, - or another method designed to establish some particular state. - * In a ``@validates`` handler, see :ref:`simple_validators` - * Within the :meth:`.SessionEvents.before_flush` event. + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. :param mapper: the :class:`.Mapper` which is the target of this event. @@ -792,6 +895,10 @@ class MapperEvents(event.Events): object associated with the instance. :return: No return value is supported by this event. + .. seealso:: + + :ref:`session_persistence_events` + """ def before_update(self, mapper, connection, target): @@ -832,29 +939,14 @@ class MapperEvents(event.Events): steps. .. warning:: - Mapper-level flush events are designed to operate **on attributes - local to the immediate object being handled - and via SQL operations with the given** :class:`.Connection` - **only.** Handlers here should **not** make alterations to the - state of the :class:`.Session` overall, and in general should not - affect any :func:`.relationship` -mapped attributes, as - session cascade rules will not function properly, nor is it - always known if the related class has already been handled. - Operations that **are not supported in mapper events** include: - - * :meth:`.Session.add` - * :meth:`.Session.delete` - * Mapped collection append, add, remove, delete, discard, etc. - * Mapped relationship attribute set/del events, - i.e. ``someobject.related = someotherobject`` - - Operations which manipulate the state of the object - relative to other objects are better handled: - - * In the ``__init__()`` method of the mapped object itself, - or another method designed to establish some particular state. - * In a ``@validates`` handler, see :ref:`simple_validators` - * Within the :meth:`.SessionEvents.before_flush` event. + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. :param mapper: the :class:`.Mapper` which is the target of this event. @@ -867,6 +959,11 @@ class MapperEvents(event.Events): instead be the :class:`.InstanceState` state-management object associated with the instance. :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + """ def after_update(self, mapper, connection, target): @@ -906,29 +1003,14 @@ class MapperEvents(event.Events): steps. .. warning:: - Mapper-level flush events are designed to operate **on attributes - local to the immediate object being handled - and via SQL operations with the given** :class:`.Connection` - **only.** Handlers here should **not** make alterations to the - state of the :class:`.Session` overall, and in general should not - affect any :func:`.relationship` -mapped attributes, as - session cascade rules will not function properly, nor is it - always known if the related class has already been handled. - Operations that **are not supported in mapper events** include: - - * :meth:`.Session.add` - * :meth:`.Session.delete` - * Mapped collection append, add, remove, delete, discard, etc. - * Mapped relationship attribute set/del events, - i.e. ``someobject.related = someotherobject`` - - Operations which manipulate the state of the object - relative to other objects are better handled: - - * In the ``__init__()`` method of the mapped object itself, - or another method designed to establish some particular state. - * In a ``@validates`` handler, see :ref:`simple_validators` - * Within the :meth:`.SessionEvents.before_flush` event. + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. :param mapper: the :class:`.Mapper` which is the target of this event. @@ -942,6 +1024,10 @@ class MapperEvents(event.Events): object associated with the instance. :return: No return value is supported by this event. + .. seealso:: + + :ref:`session_persistence_events` + """ def before_delete(self, mapper, connection, target): @@ -957,29 +1043,14 @@ class MapperEvents(event.Events): once in a later step. .. warning:: - Mapper-level flush events are designed to operate **on attributes - local to the immediate object being handled - and via SQL operations with the given** :class:`.Connection` - **only.** Handlers here should **not** make alterations to the - state of the :class:`.Session` overall, and in general should not - affect any :func:`.relationship` -mapped attributes, as - session cascade rules will not function properly, nor is it - always known if the related class has already been handled. - Operations that **are not supported in mapper events** include: - - * :meth:`.Session.add` - * :meth:`.Session.delete` - * Mapped collection append, add, remove, delete, discard, etc. - * Mapped relationship attribute set/del events, - i.e. ``someobject.related = someotherobject`` - - Operations which manipulate the state of the object - relative to other objects are better handled: - - * In the ``__init__()`` method of the mapped object itself, - or another method designed to establish some particular state. - * In a ``@validates`` handler, see :ref:`simple_validators` - * Within the :meth:`.SessionEvents.before_flush` event. + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. :param mapper: the :class:`.Mapper` which is the target of this event. @@ -993,6 +1064,10 @@ class MapperEvents(event.Events): object associated with the instance. :return: No return value is supported by this event. + .. seealso:: + + :ref:`session_persistence_events` + """ def after_delete(self, mapper, connection, target): @@ -1008,29 +1083,14 @@ class MapperEvents(event.Events): once in a previous step. .. warning:: - Mapper-level flush events are designed to operate **on attributes - local to the immediate object being handled - and via SQL operations with the given** :class:`.Connection` - **only.** Handlers here should **not** make alterations to the - state of the :class:`.Session` overall, and in general should not - affect any :func:`.relationship` -mapped attributes, as - session cascade rules will not function properly, nor is it - always known if the related class has already been handled. - Operations that **are not supported in mapper events** include: - - * :meth:`.Session.add` - * :meth:`.Session.delete` - * Mapped collection append, add, remove, delete, discard, etc. - * Mapped relationship attribute set/del events, - i.e. ``someobject.related = someotherobject`` - - Operations which manipulate the state of the object - relative to other objects are better handled: - - * In the ``__init__()`` method of the mapped object itself, - or another method designed to establish some particular state. - * In a ``@validates`` handler, see :ref:`simple_validators` - * Within the :meth:`.SessionEvents.before_flush` event. + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. :param mapper: the :class:`.Mapper` which is the target of this event. @@ -1044,6 +1104,10 @@ class MapperEvents(event.Events): object associated with the instance. :return: No return value is supported by this event. + .. seealso:: + + :ref:`session_persistence_events` + """ @@ -1284,6 +1348,8 @@ class SessionEvents(event.Events): :meth:`~.SessionEvents.after_flush_postexec` + :ref:`session_persistence_events` + """ def after_flush(self, session, flush_context): @@ -1304,6 +1370,8 @@ class SessionEvents(event.Events): :meth:`~.SessionEvents.after_flush_postexec` + :ref:`session_persistence_events` + """ def after_flush_postexec(self, session, flush_context): @@ -1326,6 +1394,8 @@ class SessionEvents(event.Events): :meth:`~.SessionEvents.after_flush` + :ref:`session_persistence_events` + """ def after_begin(self, session, transaction, connection): @@ -1363,6 +1433,8 @@ class SessionEvents(event.Events): :meth:`~.SessionEvents.after_attach` + :ref:`session_lifecycle_events` + """ def after_attach(self, session, instance): @@ -1385,6 +1457,8 @@ class SessionEvents(event.Events): :meth:`~.SessionEvents.before_attach` + :ref:`session_lifecycle_events` + """ @event._legacy_signature("0.9", @@ -1439,6 +1513,244 @@ class SessionEvents(event.Events): """ + def transient_to_pending(self, session, instance): + """Intercept the "transient to pending" transition for a specific object. + + This event is a specialization of the + :meth:`.SessionEvents.after_attach` event which is only invoked + for this specific transition. It is invoked typically during the + :meth:`.Session.add` call. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def pending_to_transient(self, session, instance): + """Intercept the "pending to transient" transition for a specific object. + + This less common transition occurs when an pending object that has + not been flushed is evicted from the session; this can occur + when the :meth:`.Session.rollback` method rolls back the transaction, + or when the :meth:`.Session.expunge` method is used. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def persistent_to_transient(self, session, instance): + """Intercept the "persistent to transient" transition for a specific object. + + This less common transition occurs when an pending object that has + has been flushed is evicted from the session; this can occur + when the :meth:`.Session.rollback` method rolls back the transaction. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def pending_to_persistent(self, session, instance): + """Intercept the "pending to persistent"" transition for a specific object. + + This event is invoked within the flush process, and is + similar to scanning the :attr:`.Session.new` collection within + the :meth:`.SessionEvents.after_flush` event. However, in this + case the object has already been moved to the persistent state + when the event is called. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def detached_to_persistent(self, session, instance): + """Intercept the "detached to persistent" transition for a specific object. + + This event is a specialization of the + :meth:`.SessionEvents.after_attach` event which is only invoked + for this specific transition. It is invoked typically during the + :meth:`.Session.add` call, as well as during the + :meth:`.Session.delete` call if the object was not previously + associated with the + :class:`.Session` (note that an object marked as "deleted" remains + in the "persistent" state until the flush proceeds). + + .. note:: + + If the object becomes persistent as part of a call to + :meth:`.Session.delete`, the object is **not** yet marked as + deleted when this event is called. To detect deleted objects, + check the ``deleted`` flag sent to the + :meth:`.SessionEvents.persistent_to_detached` to event after the + flush proceeds, or check the :attr:`.Session.deleted` collection + within the :meth:`.SessionEvents.before_flush` event if deleted + objects need to be intercepted before the flush. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def loaded_as_persistent(self, session, instance): + """Intercept the "loaded as peristent" transition for a specific object. + + This event is invoked within the ORM loading process, and is invoked + very similarly to the :meth:`.InstanceEvents.load` event. However, + the event here is linkable to a :class:`.Session` class or instance, + rather than to a mapper or class hierarchy, and integrates + with the other session lifecycle events smoothly. The object + is guaranteed to be present in the session's identity map when + this event is called. + + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def persistent_to_deleted(self, session, instance): + """Intercept the "persistent to deleted" transition for a specific object. + + This event is invoked when a persistent object's identity + is deleted from the database within a flush, however the object + still remains associated with the :class:`.Session` until the + transaction completes. + + If the transaction is rolled back, the object moves again + to the persistent state, and the + :meth:`.SessionEvents.deleted_to_persistent` event is called. + If the transaction is committed, the object becomes detached, + which will emit the :meth:`.SessionEvents.deleted_to_detached` + event. + + Note that while the :meth:`.Session.delete` method is the primary + public interface to mark an object as deleted, many objects + get deleted due to cascade rules, which are not always determined + until flush time. Therefore, there's no way to catch + every object that will be deleted until the flush has proceeded. + the :meth:`.SessionEvents.persistent_to_deleted` event is therefore + invoked at the end of a flush. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def deleted_to_persistent(self, session, instance): + """Intercept the "deleted to persistent" transition for a specific object. + + This transition occurs only when an object that's been deleted + successfully in a flush is restored due to a call to + :meth:`.Session.rollback`. The event is not called under + any other circumstances. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def deleted_to_detached(self, session, instance): + """Intercept the "deleted to detached" transition for a specific object. + + This event is invoked when a deleted object is evicted + from the session. The typical case when this occurs is when + the transaction for a :class:`.Session` in which the object + was deleted is committed; the object moves from the deleted + state to the detached state. + + It is also invoked for objects that were deleted in a flush + when the :meth:`.Session.expunge_all` or :meth:`.Session.close` + events are called, as well as if the object is individually + expunged from its deleted state via :meth:`.Session.expunge`. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + def persistent_to_detached(self, session, instance): + """Intercept the "persistent to detached" transition for a specific object. + + This event is invoked when a persistent object is evicted + from the session. There are many conditions that cause this + to happen, including: + + * using a method such as :meth:`.Session.expunge` + or :meth:`.Session.close` + + * Calling the :meth:`.Session.rollback` method, when the object + was part of an INSERT statement for that session's transaction + + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + :param deleted: boolean. If True, indicates this object moved + to the detached state because it was marked as deleted and flushed. + + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + class AttributeEvents(event.Events): """Define events for object attributes. @@ -1638,7 +1950,7 @@ class AttributeEvents(event.Events): and also during replace operations:: - u1.addresess = [a2, a3] # <- new collection + u1.addresses = [a2, a3] # <- new collection :param target: the object instance receiving the event. If the listener is registered with ``raw=True``, this will @@ -1701,7 +2013,7 @@ class QueryEvents(event.Events): def no_deleted(query): for desc in query.column_descriptions: if desc['type'] is User: - entity = desc['expr'] + entity = desc['entity'] query = query.filter(entity.deleted == False) return query diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 46be2b719..2dfe3fd5c 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -8,7 +8,8 @@ import weakref from . import attributes from .. import util - +from .. import exc as sa_exc +from . import util as orm_util class IdentityMap(object): def __init__(self): @@ -126,16 +127,18 @@ class WeakInstanceDict(IdentityMap): if existing_state is not state: o = existing_state.obj() if o is not None: - raise AssertionError( - "A conflicting state is already " - "present in the identity map for key %r" - % (key, )) + raise sa_exc.InvalidRequestError( + "Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." % ( + orm_util.state_str(state), state.key)) else: - return + return False except KeyError: pass self._dict[key] = state self._manage_incoming_state(state) + return True def _add_unpresent(self, state, key): # inlined form of add() called by loading.py @@ -208,6 +211,18 @@ class WeakInstanceDict(IdentityMap): class StrongInstanceDict(IdentityMap): + """A 'strong-referencing' version of the identity map. + + .. deprecated 1.1:: + The strong + reference identity map is legacy. See the + recipe at :ref:`session_referencing_behavior` for + an event-based approach to maintaining strong identity + references. + + + """ + if util.py2k: def itervalues(self): return self._dict.itervalues() @@ -256,12 +271,16 @@ class StrongInstanceDict(IdentityMap): def add(self, state): if state.key in self: if attributes.instance_state(self._dict[state.key]) is not state: - raise AssertionError('A conflicting state is already ' - 'present in the identity map for key %r' - % (state.key, )) + raise sa_exc.InvalidRequestError( + "Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." % ( + orm_util.state_str(state), state.key)) + return False else: self._dict[state.key] = state.obj() self._manage_incoming_state(state) + return True def _add_unpresent(self, state, key): # inlined form of add() called by loading.py diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index cd4a0116d..ed8f27332 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -234,7 +234,7 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): """ def merge(self, session, source_state, source_dict, dest_state, - dest_dict, load, _recursive): + dest_dict, load, _recursive, _resolve_conflict_map): """Merge the attribute represented by this ``MapperProperty`` from source to destination object. diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index b81e98a58..b5a62d6b2 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -32,8 +32,7 @@ def instances(query, cursor, context): context.runid = _new_runid() - filter_fns = [ent.filter_fn for ent in query._entities] - filtered = id in filter_fns + filtered = query._has_mapper_entities single_entity = len(query._entities) == 1 and \ query._entities[0].supports_single_entity @@ -43,7 +42,12 @@ def instances(query, cursor, context): filter_fn = id else: def filter_fn(row): - return tuple(fn(x) for x, fn in zip(row, filter_fns)) + return tuple( + id(item) + if ent.use_id_for_hash + else item + for ent, item in zip(query._entities, row) + ) try: (process, labels) = \ @@ -104,7 +108,7 @@ def merge_result(querylib, query, iterator, load=True): result = [session._merge( attributes.instance_state(instance), attributes.instance_dict(instance), - load=load, _recursive={}) + load=load, _recursive={}, _resolve_conflict_map={}) for instance in iterator] else: result = list(iterator) @@ -121,7 +125,7 @@ def merge_result(querylib, query, iterator, load=True): newrow[i] = session._merge( attributes.instance_state(newrow[i]), attributes.instance_dict(newrow[i]), - load=load, _recursive={}) + load=load, _recursive={}, _resolve_conflict_map={}) result.append(keyed_tuple(newrow)) return iter(result) @@ -335,6 +339,9 @@ def _instance_processor( populate_existing = context.populate_existing or mapper.always_refresh load_evt = bool(mapper.class_manager.dispatch.load) refresh_evt = bool(mapper.class_manager.dispatch.refresh) + persistent_evt = bool(context.session.dispatch.loaded_as_persistent) + if persistent_evt: + loaded_as_persistent = context.session.dispatch.loaded_as_persistent instance_state = attributes.instance_state instance_dict = attributes.instance_dict session_id = context.session.hash_key @@ -428,8 +435,11 @@ def _instance_processor( loaded_instance, populate_existing, populators) if isnew: - if loaded_instance and load_evt: - state.manager.dispatch.load(state, context) + if loaded_instance: + if load_evt: + state.manager.dispatch.load(state, context) + if persistent_evt: + loaded_as_persistent(context.session, state.obj()) elif refresh_evt: state.manager.dispatch.refresh( state, context, only_load_props) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 48fbaae32..95aa14a26 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1915,6 +1915,19 @@ class Mapper(InspectionAttr): """ @_memoized_configured_property + def _insert_cols_evaluating_none(self): + return dict( + ( + table, + frozenset( + col.key for col in columns + if col.type.should_evaluate_none + ) + ) + for table, columns in self._cols_by_table.items() + ) + + @_memoized_configured_property def _insert_cols_as_none(self): return dict( ( @@ -1922,7 +1935,8 @@ class Mapper(InspectionAttr): frozenset( col.key for col in columns if not col.primary_key and - not col.server_default and not col.default) + not col.server_default and not col.default + and not col.type.should_evaluate_none) ) for table, columns in self._cols_by_table.items() ) @@ -1956,12 +1970,24 @@ class Mapper(InspectionAttr): ( table, frozenset([ - col for col in columns + col.key for col in columns if col.server_default is not None]) ) for table, columns in self._cols_by_table.items() ) + @_memoized_configured_property + def _server_onupdate_default_cols(self): + return dict( + ( + table, + frozenset([ + col.key for col in columns + if col.server_onupdate is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + @property def selectable(self): """The :func:`.select` construct this :class:`.Mapper` selects from @@ -2557,15 +2583,24 @@ class Mapper(InspectionAttr): for all relationships that meet the given cascade rule. :param type_: - The name of the cascade rule (i.e. save-update, delete, - etc.) + The name of the cascade rule (i.e. ``"save-update"``, ``"delete"``, + etc.). + + .. note:: the ``"all"`` cascade is not accepted here. For a generic + object traversal function, see :ref:`faq_walk_objects`. :param state: The lead InstanceState. child items will be processed per the relationships defined for this object's mapper. - the return value are object instances; this provides a strong - reference so that they don't fall out of scope immediately. + :return: the method yields individual object instances. + + .. seealso:: + + :ref:`unitofwork_cascades` + + :ref:`faq_walk_objects` - illustrates a generic function to + traverse all objects without relying on cascades. """ visited_states = set() @@ -2682,7 +2717,33 @@ def configure_mappers(): have been constructed thus far. This function can be called any number of times, but in - most cases is handled internally. + most cases is invoked automatically, the first time mappings are used, + as well as whenever mappings are used and additional not-yet-configured + mappers have been constructed. + + Points at which this occur include when a mapped class is instantiated + into an instance, as well as when the :meth:`.Session.query` method + is used. + + The :func:`.configure_mappers` function provides several event hooks + that can be used to augment its functionality. These methods include: + + * :meth:`.MapperEvents.before_configured` - called once before + :func:`.configure_mappers` does any work; this can be used to establish + additional options, properties, or related mappings before the operation + proceeds. + + * :meth:`.MapperEvents.mapper_configured` - called as each indivudal + :class:`.Mapper` is configured within the process; will include all + mapper state except for backrefs set up by other mappers that are still + to be configured. + + * :meth:`.MapperEvents.after_configured` - called once after + :func:`.configure_mappers` is complete; at this stage, all + :class:`.Mapper` objects that are known to SQLAlchemy will be fully + configured. Note that the calling application may still have other + mappings that haven't been produced yet, such as if they are in modules + as yet unimported. """ diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0bfee2ece..e6a2c0634 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -375,10 +375,12 @@ def _collect_insert_commands( propkey_to_col = mapper._propkey_to_col[table] + eval_none = mapper._insert_cols_evaluating_none[table] + for propkey in set(propkey_to_col).intersection(state_dict): value = state_dict[propkey] col = propkey_to_col[propkey] - if value is None: + if value is None and propkey not in eval_none: continue elif not bulk and isinstance(value, sql.ClauseElement): value_params[col.key] = value @@ -446,6 +448,7 @@ def _collect_update_commands( set(propkey_to_col).intersection(state_dict).difference( mapper._pk_keys_by_table[table]) ) + has_all_defaults = True else: params = {} for propkey in set(propkey_to_col).intersection( @@ -461,6 +464,12 @@ def _collect_update_commands( value, state.committed_state[propkey]) is not True: params[col.key] = value + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_onupdate_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True + if update_version_id is not None and \ mapper.version_id_col in mapper._cols_by_table[table]: @@ -483,7 +492,7 @@ def _collect_update_commands( col = mapper.version_id_col params[col._label] = update_version_id - if col.key not in params and \ + if (bulk or col.key not in params) and \ mapper.version_id_generator is not False: val = mapper.version_id_generator(update_version_id) params[col.key] = val @@ -527,7 +536,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params) + connection, value_params, has_all_defaults) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -617,37 +626,42 @@ def _emit_update_statements(base_mapper, uowtransaction, type_=mapper.version_id_col.type)) stmt = table.update(clause) - if mapper.base_mapper.eager_defaults: - stmt = stmt.return_defaults() - elif mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - return stmt - statement = base_mapper._memo(('update', table), update_stmt) + cached_stmt = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue), \ + for (connection, paramkeys, hasvalue, has_all_defaults), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys - bool(rec[5]))): # whether or not we have "value" parameters - + bool(rec[5]), # whether or not we have "value" parameters + rec[6] # has_all_defaults + ) + ): rows = 0 records = list(records) + statement = cached_stmt + # TODO: would be super-nice to not have to determine this boolean # inside the loop here, in the 99.9999% of the time there's only # one connection in use assert_singlerow = connection.dialect.supports_sane_rowcount assert_multirow = assert_singlerow and \ connection.dialect.supports_sane_multi_rowcount - allow_multirow = not needs_version_id or assert_multirow + allow_multirow = has_all_defaults and not needs_version_id + + if bookkeeping and not has_all_defaults and \ + mapper.base_mapper.eager_defaults: + statement = statement.return_defaults() + elif mapper.version_id_col is not None: + statement = statement.return_defaults(mapper.version_id_col) if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = connection.execute( statement.values(value_params), params) @@ -667,18 +681,21 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = cached_connections[connection].\ execute(statement, params) - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + + # TODO: why with bookkeeping=False? + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) rows += c.rowcount else: multiparams = [rec[2] for rec in records] @@ -692,17 +709,19 @@ def _emit_update_statements(base_mapper, uowtransaction, execute(statement, multiparams) rows += c.rowcount + for state, state_dict, params, mapper, \ - connection, value_params in records: - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + connection, value_params, has_all_defaults in records: + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) if check_rowcount: if rows != len(records): @@ -723,7 +742,7 @@ def _emit_insert_statements(base_mapper, uowtransaction, """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - statement = base_mapper._memo(('insert', table), table.insert) + cached_stmt = base_mapper._memo(('insert', table), table.insert) for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ records in groupby( @@ -734,6 +753,9 @@ def _emit_insert_statements(base_mapper, uowtransaction, bool(rec[5]), # whether we have "value" parameters rec[6], rec[7])): + + statement = cached_stmt + if not bookkeeping or \ ( has_all_defaults @@ -752,15 +774,18 @@ def _emit_insert_statements(base_mapper, uowtransaction, conn, value_params, has_all_pks, has_all_defaults), \ last_inserted_params in \ zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -789,15 +814,19 @@ def _emit_insert_statements(base_mapper, uowtransaction, prop = mapper_rec._columntoproperty[col] if state_dict.get(prop.key) is None: state_dict[prop.key] = pk - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - result.context.compiled_parameters[0], - value_params) + if bookkeeping: + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) def _emit_post_update_statements(base_mapper, uowtransaction, @@ -957,7 +986,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params, bulk=False): + state, dict_, result, params, value_params): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -1005,13 +1034,15 @@ def _postfetch(mapper, uowtransaction, table, # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: - if state is None: - sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) - else: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) + + +def _postfetch_bulk_save(mapper, dict_, table): + for m, equated_pairs in mapper._table_to_equated[table]: + sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) def _connections_for_states(base_mapper, uowtransaction, states): @@ -1242,10 +1273,16 @@ class BulkUpdate(BulkUD): "Invalid expression type: %r" % key) def _do_exec(self): - values = dict( + + values = [ (self._resolve_string_to_expr(k), v) - for k, v in self.values.items() - ) + for k, v in ( + self.values.items() if hasattr(self.values, 'items') + else self.values) + ] + if not self.update_kwargs.get('preserve_parameter_order', False): + values = dict(values) + update_stmt = sql.update(self.primary_table, self.context.whereclause, values, **self.update_kwargs) @@ -1295,7 +1332,9 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): def _additional_evaluators(self, evaluator_compiler): self.value_evaluators = {} - for key, value in self.values.items(): + values = (self.values.items() if hasattr(self.values, 'items') + else self.values) + for key, value in values: key = self._resolve_key_to_attrname(key) if key is not None: self.value_evaluators[key] = evaluator_compiler.process( diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 55e02984b..0d4e1b771 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -39,7 +39,7 @@ class ColumnProperty(StrategizedProperty): 'instrument', 'comparator_factory', 'descriptor', 'extension', 'active_history', 'expire_on_flush', 'info', 'doc', 'strategy_class', '_creation_order', '_is_polymorphic_discriminator', - '_mapped_by_synonym', '_deferred_loader') + '_mapped_by_synonym', '_deferred_column_loader') def __init__(self, *columns, **kwargs): """Provide a column-level property for use with a Mapper. @@ -206,7 +206,7 @@ class ColumnProperty(StrategizedProperty): get_committed_value(state, dict_, passive=passive) def merge(self, session, source_state, source_dict, dest_state, - dest_dict, load, _recursive): + dest_dict, load, _recursive, _resolve_conflict_map): if not self.instrument: return elif self.key in source_dict: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8b3df08e7..e1b920bbb 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -103,6 +103,7 @@ class Query(object): _orm_only_adapt = True _orm_only_from_obj_alias = True _current_path = _path_registry + _has_mapper_entities = False def __init__(self, entities, session=None): self.session = session @@ -114,6 +115,7 @@ class Query(object): entity_wrapper = _QueryEntity self._entities = [] self._primary_entity = None + self._has_mapper_entities = False for ent in util.to_list(entities): entity_wrapper(self, ent) @@ -287,6 +289,8 @@ class Query(object): return self._entities[0] def _mapper_zero(self): + # TODO: self._select_from_entity is not a mapper + # so this method is misnamed return self._select_from_entity \ if self._select_from_entity is not None \ else self._entity_zero().entity_zero @@ -608,6 +612,16 @@ class Query(object): When the `Query` actually issues SQL to load rows, it always uses column labeling. + .. note:: The :meth:`.Query.with_labels` method *only* applies + the output of :attr:`.Query.statement`, and *not* to any of + the result-row invoking systems of :class:`.Query` itself, e.g. + :meth:`.Query.first`, :meth:`.Query.all`, etc. To execute + a query using :meth:`.Query.with_labels`, invoke the + :attr:`.Query.statement` using :meth:`.Session.execute`:: + + result = session.execute(query.with_labels().statement) + + """ self._with_labels = True @@ -930,11 +944,13 @@ class Query(object): """ if property is None: + mapper_zero = inspect(self._mapper_zero()).mapper + mapper = object_mapper(instance) for prop in mapper.iterate_properties: if isinstance(prop, properties.RelationshipProperty) and \ - prop.mapper is self._mapper_zero(): + prop.mapper is mapper_zero: property = prop break else: @@ -972,8 +988,169 @@ class Query(object): """return a Query that selects from this Query's SELECT statement. - \*entities - optional list of entities which will replace - those being selected. + :meth:`.Query.from_self` essentially turns the SELECT statement + into a SELECT of itself. Given a query such as:: + + q = session.query(User).filter(User.name.like('e%')) + + Given the :meth:`.Query.from_self` version:: + + q = session.query(User).filter(User.name.like('e%')).from_self() + + This query renders as: + + .. sourcecode:: sql + + SELECT anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1) AS anon_1 + + There are lots of cases where :meth:`.Query.from_self` may be useful. + A simple one is where above, we may want to apply a row LIMIT to + the set of user objects we query against, and then apply additional + joins against that row-limited set:: + + q = session.query(User).filter(User.name.like('e%')).\\ + limit(5).from_self().\\ + join(User.addresses).filter(Address.email.like('q%')) + + The above query joins to the ``Address`` entity but only against the + first five results of the ``User`` query: + + .. sourcecode:: sql + + SELECT anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1 + LIMIT :param_1) AS anon_1 + JOIN address ON anon_1.user_id = address.user_id + WHERE address.email LIKE :email_1 + + **Automatic Aliasing** + + Another key behavior of :meth:`.Query.from_self` is that it applies + **automatic aliasing** to the entities inside the subquery, when + they are referenced on the outside. Above, if we continue to + refer to the ``User`` entity without any additional aliasing applied + to it, those references wil be in terms of the subquery:: + + q = session.query(User).filter(User.name.like('e%')).\\ + limit(5).from_self().\\ + join(User.addresses).filter(Address.email.like('q%')).\\ + order_by(User.name) + + The ORDER BY against ``User.name`` is aliased to be in terms of the + inner subquery: + + .. sourcecode:: sql + + SELECT anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1 + LIMIT :param_1) AS anon_1 + JOIN address ON anon_1.user_id = address.user_id + WHERE address.email LIKE :email_1 ORDER BY anon_1.user_name + + The automatic aliasing feature only works in a **limited** way, + for simple filters and orderings. More ambitious constructions + such as referring to the entity in joins should prefer to use + explicit subquery objects, typically making use of the + :meth:`.Query.subquery` method to produce an explicit subquery object. + Always test the structure of queries by viewing the SQL to ensure + a particular structure does what's expected! + + **Changing the Entities** + + :meth:`.Query.from_self` also includes the ability to modify what + columns are being queried. In our example, we want ``User.id`` + to be queried by the inner query, so that we can join to the + ``Address`` entity on the outside, but we only wanted the outer + query to return the ``Address.email`` column:: + + q = session.query(User).filter(User.name.like('e%')).\\ + limit(5).from_self(Address.email).\\ + join(User.addresses).filter(Address.email.like('q%')) + + yielding: + + .. sourcecode:: sql + + SELECT address.email AS address_email + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1 + LIMIT :param_1) AS anon_1 + JOIN address ON anon_1.user_id = address.user_id + WHERE address.email LIKE :email_1 + + **Looking out for Inner / Outer Columns** + + Keep in mind that when referring to columns that originate from + inside the subquery, we need to ensure they are present in the + columns clause of the subquery itself; this is an ordinary aspect of + SQL. For example, if we wanted to load from a joined entity inside + the subquery using :func:`.contains_eager`, we need to add those + columns. Below illustrates a join of ``Address`` to ``User``, + then a subquery, and then we'd like :func:`.contains_eager` to access + the ``User`` columns:: + + q = session.query(Address).join(Address.user).\\ + filter(User.name.like('e%')) + + q = q.add_entity(User).from_self().\\ + options(contains_eager(Address.user)) + + We use :meth:`.Query.add_entity` above **before** we call + :meth:`.Query.from_self` so that the ``User`` columns are present + in the inner subquery, so that they are available to the + :func:`.contains_eager` modifier we are using on the outside, + producing: + + .. sourcecode:: sql + + SELECT anon_1.address_id AS anon_1_address_id, + anon_1.address_email AS anon_1_address_email, + anon_1.address_user_id AS anon_1_address_user_id, + anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM ( + SELECT address.id AS address_id, + address.email AS address_email, + address.user_id AS address_user_id, + "user".id AS user_id, + "user".name AS user_name + FROM address JOIN "user" ON "user".id = address.user_id + WHERE "user".name LIKE :name_1) AS anon_1 + + If we didn't call ``add_entity(User)``, but still asked + :func:`.contains_eager` to load the ``User`` entity, it would be + forced to add the table on the outside without the correct + join criteria - note the ``anon1, "user"`` phrase at + the end: + + .. sourcecode:: sql + + -- incorrect query + SELECT anon_1.address_id AS anon_1_address_id, + anon_1.address_email AS anon_1_address_email, + anon_1.address_user_id AS anon_1_address_user_id, + "user".id AS user_id, + "user".name AS user_name + FROM ( + SELECT address.id AS address_id, + address.email AS address_email, + address.user_id AS address_user_id + FROM address JOIN "user" ON "user".id = address.user_id + WHERE "user".name LIKE :name_1) AS anon_1, "user" + + :param \*entities: optional list of entities which will replace + those being selected. """ fromclause = self.with_labels().enable_eagerloads(False).\ @@ -1280,7 +1457,9 @@ class Query(object): session.query(MyClass).filter(MyClass.name == 'some name') - Multiple criteria are joined together by AND:: + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: session.query(MyClass).\\ filter(MyClass.name == 'some name', MyClass.id > 5) @@ -1289,9 +1468,6 @@ class Query(object): WHERE clause of a select. String expressions are coerced into SQL expression constructs via the :func:`.text` construct. - .. versionchanged:: 0.7.5 - Multiple criteria joined by AND. - .. seealso:: :meth:`.Query.filter_by` - filter on keyword expressions. @@ -1315,7 +1491,9 @@ class Query(object): session.query(MyClass).filter_by(name = 'some name') - Multiple criteria are joined together by AND:: + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: session.query(MyClass).\\ filter_by(name = 'some name', id = 5) @@ -2323,6 +2501,19 @@ class Query(object): """Apply a ``DISTINCT`` to the query and return the newly resulting ``Query``. + + .. note:: + + The :meth:`.distinct` call includes logic that will automatically + add columns from the ORDER BY of the query to the columns + clause of the SELECT statement, to satisfy the common need + of the database backend that ORDER BY columns be part of the + SELECT list when DISTINCT is used. These columns *are not* + added to the list of columns actually fetched by the + :class:`.Query`, however, so would not affect results. + The columns are passed through when using the + :attr:`.Query.statement` accessor, however. + :param \*expr: optional column expressions. When present, the Postgresql dialect will render a ``DISTINCT ON (<expressions>>)`` construct. @@ -2436,7 +2627,13 @@ class Query(object): (note this may consist of multiple result rows if join-loaded collections are present). - Calling ``first()`` results in an execution of the underlying query. + Calling :meth:`.Query.first` results in an execution of the underlying query. + + .. seealso:: + + :meth:`.Query.one` + + :meth:`.Query.one_or_none` """ if self._statement is not None: @@ -2448,26 +2645,27 @@ class Query(object): else: return None - def one(self): - """Return exactly one result or raise an exception. + def one_or_none(self): + """Return at most one result or raise an exception. - Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + Returns ``None`` if the query selects no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` if multiple object identities are returned, or if multiple - rows are returned for a query that does not return object - identities. + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. + + Calling :meth:`.Query.one_or_none` results in an execution of the + underlying query. + + .. versionadded:: 1.0.9 - Note that an entity query, that is, one which selects one or - more mapped classes as opposed to individual column attributes, - may ultimately represent many rows but only one row of - unique entity or entities - this is a successful result for one(). + Added :meth:`.Query.one_or_none` - Calling ``one()`` results in an execution of the underlying query. + .. seealso:: + + :meth:`.Query.first` - .. versionchanged:: 0.6 - ``one()`` fully fetches all results instead of applying - any kind of limit, so that the "unique"-ing of entities does not - conceal multiple object identities. + :meth:`.Query.one` """ ret = list(self) @@ -2476,10 +2674,38 @@ class Query(object): if l == 1: return ret[0] elif l == 0: - raise orm_exc.NoResultFound("No row was found for one()") + return None else: raise orm_exc.MultipleResultsFound( + "Multiple rows were found for one_or_none()") + + def one(self): + """Return exactly one result or raise an exception. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. + + Calling :meth:`.one` results in an execution of the underlying query. + + .. seealso:: + + :meth:`.Query.first` + + :meth:`.Query.one_or_none` + + """ + try: + ret = self.one_or_none() + except orm_exc.MultipleResultsFound: + raise orm_exc.MultipleResultsFound( "Multiple rows were found for one()") + else: + if ret is None: + raise orm_exc.NoResultFound("No row was found for one()") + return ret def scalar(self): """Return the first element of the first result or None @@ -2849,7 +3075,12 @@ class Query(object): :param values: a dictionary with attributes names, or alternatively mapped attributes or SQL expressions, as keys, and literal - values or sql expressions as values. + values or sql expressions as values. If :ref:`parameter-ordered + mode <updates_order_parameters>` is desired, the values can be + passed as a list of 2-tuples; + this requires that the :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` + flag is passed to the :paramref:`.Query.update.update_args` dictionary + as well. .. versionchanged:: 1.0.0 - string names in the values dictionary are now resolved against the mapped entity; previously, these @@ -2880,7 +3111,8 @@ class Query(object): :param update_args: Optional dictionary, if present will be passed to the underlying :func:`.update` construct as the ``**kw`` for the object. May be used to pass dialect-specific arguments such - as ``mysql_limit``. + as ``mysql_limit``, as well as other special arguments such as + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`. .. versionadded:: 1.0.0 @@ -3181,12 +3413,14 @@ class _MapperEntity(_QueryEntity): if not query._primary_entity: query._primary_entity = self query._entities.append(self) - + query._has_mapper_entities = True self.entities = [entity] self.expr = entity supports_single_entity = True + use_id_for_hash = True + def setup_entity(self, ext_info, aliased_adapter): self.mapper = ext_info.mapper self.aliased_adapter = aliased_adapter @@ -3232,8 +3466,6 @@ class _MapperEntity(_QueryEntity): self.mapper, sql_util.ColumnAdapter( from_obj, self.mapper._equivalent_columns)) - filter_fn = id - @property def type(self): return self.mapper.class_ @@ -3462,6 +3694,8 @@ class Bundle(InspectionAttr): class _BundleEntity(_QueryEntity): + use_id_for_hash = False + def __init__(self, query, bundle, setup_entities=True): query._entities.append(self) self.bundle = self.expr = bundle @@ -3478,8 +3712,6 @@ class _BundleEntity(_QueryEntity): self.entities = () - self.filter_fn = lambda item: item - self.supports_single_entity = self.bundle.single_entity @property @@ -3582,11 +3814,7 @@ class _ColumnEntity(_QueryEntity): search_entities = True self.type = type_ = column.type - if type_.hashable: - self.filter_fn = lambda item: item - else: - counter = util.counter() - self.filter_fn = lambda item: counter() + self.use_id_for_hash = not type_.hashable # If the Column is unnamed, give it a # label() so that mutable column expressions @@ -3619,7 +3847,7 @@ class _ColumnEntity(_QueryEntity): self._from_entities = set(self.entities) else: all_elements = [ - elem for elem in visitors.iterate(column, {}) + elem for elem in sql_util.surface_column_elements(column) if 'parententity' in elem._annotations ] diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index da0730f46..f822071c4 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -275,15 +275,31 @@ class RelationshipProperty(StrategizedProperty): :paramref:`~.relationship.backref` - alternative form of backref specification. - :param bake_queries: - Use the :class:`.BakedQuery` cache to cache queries used in lazy - loads. True by default, as this typically improves performance - significantly. Set to False to reduce ORM memory use, or - if unresolved stability issues are observed with the baked query + :param bake_queries=True: + Use the :class:`.BakedQuery` cache to cache the construction of SQL + used in lazy loads, when the :func:`.bake_lazy_loaders` function has + first been called. Defaults to True and is intended to provide an + "opt out" flag per-relationship when the baked query cache system is + in use. + + .. warning:: + + This flag **only** has an effect when the application-wide + :func:`.bake_lazy_loaders` function has been called. It + defaults to True so is an "opt out" flag. + + Setting this flag to False when baked queries are otherwise in + use might be to reduce + ORM memory use for this :func:`.relationship`, or to work around + unresolved stability issues observed within the baked query cache system. .. versionadded:: 1.0.0 + .. seealso:: + + :ref:`baked_toplevel` + :param cascade: a comma-separated list of cascade rules which determines how Session operations should be "cascaded" from parent to child. @@ -604,30 +620,26 @@ class RelationshipProperty(StrategizedProperty): and examples. :param passive_updates=True: - Indicates loading and INSERT/UPDATE/DELETE behavior when the - source of a foreign key value changes (i.e. an "on update" - cascade), which are typically the primary key columns of the - source row. + Indicates the persistence behavior to take when a referenced + primary key value changes in place, indicating that the referencing + foreign key columns will also need their value changed. - When True, it is assumed that ON UPDATE CASCADE is configured on + When True, it is assumed that ``ON UPDATE CASCADE`` is configured on the foreign key in the database, and that the database will handle propagation of an UPDATE from a source column to - dependent rows. Note that with databases which enforce - referential integrity (i.e. PostgreSQL, MySQL with InnoDB tables), - ON UPDATE CASCADE is required for this operation. The - relationship() will update the value of the attribute on related - items which are locally present in the session during a flush. - - When False, it is assumed that the database does not enforce - referential integrity and will not be issuing its own CASCADE - operation for an update. The relationship() will issue the - appropriate UPDATE statements to the database in response to the - change of a referenced key, and items locally present in the - session during a flush will also be refreshed. - - This flag should probably be set to False if primary key changes - are expected and the database in use doesn't support CASCADE - (i.e. SQLite, MySQL MyISAM tables). + dependent rows. When False, the SQLAlchemy :func:`.relationship` + construct will attempt to emit its own UPDATE statements to + modify related targets. However note that SQLAlchemy **cannot** + emit an UPDATE for more than one level of cascade. Also, + setting this flag to False is not compatible in the case where + the database is in fact enforcing referential integrity, unless + those constraints are explicitly "deferred", if the target backend + supports it. + + It is highly advised that an application which is employing + mutable primary keys keeps ``passive_updates`` set to True, + and instead uses the referential integrity features of the database + itself in order to handle the change efficiently and fully. .. seealso:: @@ -1418,7 +1430,7 @@ class RelationshipProperty(StrategizedProperty): source_dict, dest_state, dest_dict, - load, _recursive): + load, _recursive, _resolve_conflict_map): if load: for r in self._reverse_property: @@ -1451,8 +1463,10 @@ class RelationshipProperty(StrategizedProperty): current_state = attributes.instance_state(current) current_dict = attributes.instance_dict(current) _recursive[(current_state, self)] = True - obj = session._merge(current_state, current_dict, - load=load, _recursive=_recursive) + obj = session._merge( + current_state, current_dict, + load=load, _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map) if obj is not None: dest_list.append(obj) @@ -1462,16 +1476,19 @@ class RelationshipProperty(StrategizedProperty): for c in dest_list: coll.append_without_event(c) else: - dest_state.get_impl(self.key)._set_iterable( - dest_state, dest_dict, dest_list) + dest_state.get_impl(self.key).set( + dest_state, dest_dict, dest_list, + _adapt=False) else: current = source_dict[self.key] if current is not None: current_state = attributes.instance_state(current) current_dict = attributes.instance_dict(current) _recursive[(current_state, self)] = True - obj = session._merge(current_state, current_dict, - load=load, _recursive=_recursive) + obj = session._merge( + current_state, current_dict, + load=load, _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map) else: obj = None diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 4619027e5..56513860a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -180,8 +180,7 @@ class SessionTransaction(object): if self.session._enable_transaction_accounting: self._take_snapshot() - if self.session.dispatch.after_transaction_create: - self.session.dispatch.after_transaction_create(self.session, self) + self.session.dispatch.after_transaction_create(self.session, self) @property def is_active(self): @@ -272,10 +271,9 @@ class SessionTransaction(object): def _restore_snapshot(self, dirty_only=False): assert self._is_transaction_boundary - for s in set(self._new).union(self.session._new): - self.session._expunge_state(s) - if s.key: - del s.key + self.session._expunge_states( + set(self._new).union(self.session._new), + to_transient=True) for s, (oldkey, newkey) in self._key_switches.items(): self.session.identity_map.safe_discard(s) @@ -283,10 +281,7 @@ class SessionTransaction(object): self.session.identity_map.replace(s) for s in set(self._deleted).union(self.session._deleted): - if s.deleted: - # assert s in self._deleted - del s.deleted - self.session._update_impl(s, discard_existing=True) + self.session._update_impl(s, revert_deletion=True) assert not self.session._deleted @@ -300,8 +295,9 @@ class SessionTransaction(object): if not self.nested and self.session.expire_on_commit: for s in self.session.identity_map.all_states(): s._expire(s.dict, self.session.identity_map._modified) - for s in list(self._deleted): - s._detach() + + statelib.InstanceState._detach_states( + list(self._deleted), self.session) self._deleted.clear() elif self.nested: self._parent._new.update(self._new) @@ -412,11 +408,23 @@ class SessionTransaction(object): for subtransaction in stx._iterate_parents(upto=self): subtransaction.close() + if _capture_exception: + captured_exception = sys.exc_info()[1] + boundary = self if self._state in (ACTIVE, PREPARED): for transaction in self._iterate_parents(): if transaction._parent is None or transaction.nested: - transaction._rollback_impl() + try: + transaction._rollback_impl() + except Exception: + if _capture_exception: + util.warn( + "An exception raised during a Session " + "persistence operation cannot be raised " + "due to an additional ROLLBACK exception; " + "the exception is: %s" % captured_exception) + raise transaction._state = DEACTIVE boundary = transaction break @@ -438,7 +446,7 @@ class SessionTransaction(object): self.close() if self._parent and _capture_exception: - self._parent._rollback_exception = sys.exc_info()[1] + self._parent._rollback_exception = captured_exception sess.dispatch.after_soft_rollback(sess, self) @@ -466,8 +474,7 @@ class SessionTransaction(object): transaction.close() self._state = CLOSED - if self.session.dispatch.after_transaction_end: - self.session.dispatch.after_transaction_end(self.session, self) + self.session.dispatch.after_transaction_end(self.session, self) if self._parent is None: if not self.session.autocommit: @@ -629,16 +636,23 @@ class Session(_SessionClassMethods): :param weak_identity_map: Defaults to ``True`` - when set to ``False``, objects placed in the :class:`.Session` will be strongly referenced until explicitly removed or the - :class:`.Session` is closed. **Deprecated** - this option - is obsolete. + :class:`.Session` is closed. **Deprecated** - The strong + reference identity map is legacy. See the + recipe at :ref:`session_referencing_behavior` for + an event-based approach to maintaining strong identity + references. """ if weak_identity_map: self._identity_cls = identity.WeakInstanceDict else: - util.warn_deprecated("weak_identity_map=False is deprecated. " - "This feature is not needed.") + util.warn_deprecated( + "weak_identity_map=False is deprecated. " + "See the documentation on 'Session Referencing Behavior' " + "for an event-based approach to maintaining strong identity " + "references.") + self._identity_cls = identity.StrongInstanceDict self.identity_map = self._identity_cls() @@ -680,7 +694,7 @@ class Session(_SessionClassMethods): def info(self): """A user-modifiable dictionary. - The initial value of this dictioanry can be populated using the + The initial value of this dictionary can be populated using the ``info`` argument to the :class:`.Session` constructor or :class:`.sessionmaker` constructor or factory methods. The dictionary here is always local to this :class:`.Session` and can be modified @@ -1086,16 +1100,15 @@ class Session(_SessionClassMethods): ``Session``. """ - for state in self.identity_map.all_states() + list(self._new): - state._detach() + all_states = self.identity_map.all_states() + list(self._new) self.identity_map = self._identity_cls() self._new = {} self._deleted = {} - # TODO: need much more test coverage for bind_mapper() and similar ! - # TODO: + crystallize + document resolution order - # vis. bind_mapper/bind_table + statelib.InstanceState._detach_states( + all_states, self + ) def _add_bind(self, key, bind): try: @@ -1437,7 +1450,7 @@ class Session(_SessionClassMethods): state._expire(state.dict, self.identity_map._modified) elif state in self._new: self._new.pop(state) - state._detach() + state._detach(self) @util.deprecated("0.7", "The non-weak-referencing identity map " "feature is no longer needed.") @@ -1472,23 +1485,26 @@ class Session(_SessionClassMethods): cascaded = list(state.manager.mapper.cascade_iterator( 'expunge', state)) - self._expunge_state(state) - for o, m, st_, dct_ in cascaded: - self._expunge_state(st_) + self._expunge_states( + [state] + [st_ for o, m, st_, dct_ in cascaded] + ) - def _expunge_state(self, state): - if state in self._new: - self._new.pop(state) - state._detach() - elif self.identity_map.contains_state(state): - self.identity_map.safe_discard(state) - self._deleted.pop(state, None) - state._detach() - elif self.transaction: - self.transaction._deleted.pop(state, None) - state._detach() + def _expunge_states(self, states, to_transient=False): + for state in states: + if state in self._new: + self._new.pop(state) + elif self.identity_map.contains_state(state): + self.identity_map.safe_discard(state) + self._deleted.pop(state, None) + elif self.transaction: + # state is "detached" from being deleted, but still present + # in the transaction snapshot + self.transaction._deleted.pop(state, None) + statelib.InstanceState._detach_states( + states, self, to_transient=to_transient) def _register_newly_persistent(self, states): + pending_to_persistent = self.dispatch.pending_to_persistent or None for state in states: mapper = _state_mapper(state) @@ -1535,6 +1551,11 @@ class Session(_SessionClassMethods): ) self._register_altered(states) + + if pending_to_persistent is not None: + for state in states: + pending_to_persistent(self, state.obj()) + # remove from new last, might be the last strong ref for state in set(states).intersection(self._new): self._new.pop(state) @@ -1548,13 +1569,19 @@ class Session(_SessionClassMethods): self.transaction._dirty[state] = True def _remove_newly_deleted(self, states): + persistent_to_deleted = self.dispatch.persistent_to_deleted or None for state in states: if self._enable_transaction_accounting and self.transaction: self.transaction._deleted[state] = True self.identity_map.safe_discard(state) self._deleted.pop(state, None) - state.deleted = True + state._deleted = True + # can't call state._detach() here, because this state + # is still in the transaction snapshot and needs to be + # tracked as part of that + if persistent_to_deleted is not None: + persistent_to_deleted(self, state.obj()) def add(self, instance, _warn=True): """Place an object in the ``Session``. @@ -1609,30 +1636,39 @@ class Session(_SessionClassMethods): except exc.NO_STATE: raise exc.UnmappedInstanceError(instance) + self._delete_impl(state, instance, head=True) + + def _delete_impl(self, state, obj, head): + if state.key is None: - raise sa_exc.InvalidRequestError( - "Instance '%s' is not persisted" % - state_str(state)) + if head: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % + state_str(state)) + else: + return + + to_attach = self._before_attach(state, obj) if state in self._deleted: return - # ensure object is attached to allow the - # cascade operation to load deferred attributes - # and collections - self._attach(state, include_before=True) + if to_attach: + self.identity_map.add(state) + self._after_attach(state, obj) - # grab the cascades before adding the item to the deleted list - # so that autoflush does not delete the item - # the strong reference to the instance itself is significant here - cascade_states = list(state.manager.mapper.cascade_iterator( - 'delete', state)) + if head: + # grab the cascades before adding the item to the deleted list + # so that autoflush does not delete the item + # the strong reference to the instance itself is significant here + cascade_states = list(state.manager.mapper.cascade_iterator( + 'delete', state)) - self._deleted[state] = state.obj() - self.identity_map.add(state) + self._deleted[state] = obj - for o, m, st_, dct_ in cascade_states: - self._delete_impl(st_) + if head: + for o, m, st_, dct_ in cascade_states: + self._delete_impl(st_, o, False) def merge(self, instance, load=True): """Copy the state of a given instance into a corresponding instance @@ -1653,6 +1689,10 @@ class Session(_SessionClassMethods): See :ref:`unitofwork_merging` for a detailed discussion of merging. + .. versionchanged:: 1.1 - :meth:`.Session.merge` will now reconcile + pending objects with overlapping primary keys in the same way + as persistent. See :ref:`change_3601` for discussion. + :param instance: Instance to be merged. :param load: Boolean, when False, :meth:`.merge` switches into a "high performance" mode which causes it to forego emitting history @@ -1677,12 +1717,14 @@ class Session(_SessionClassMethods): should be "clean" as well, else this suggests a mis-use of the method. + """ if self._warn_on_events: self._flush_warning("Session.merge()") _recursive = {} + _resolve_conflict_map = {} if load: # flush current contents if we expect to load data @@ -1695,11 +1737,13 @@ class Session(_SessionClassMethods): return self._merge( attributes.instance_state(instance), attributes.instance_dict(instance), - load=load, _recursive=_recursive) + load=load, _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map) finally: self.autoflush = autoflush - def _merge(self, state, state_dict, load=True, _recursive=None): + def _merge(self, state, state_dict, load=True, _recursive=None, + _resolve_conflict_map=None): mapper = _state_mapper(state) if state in _recursive: return _recursive[state] @@ -1715,9 +1759,14 @@ class Session(_SessionClassMethods): "all changes on mapped instances before merging with " "load=False.") key = mapper._identity_key_from_state(state) + key_is_persistent = attributes.NEVER_SET not in key[1] + else: + key_is_persistent = True if key in self.identity_map: merged = self.identity_map[key] + elif key_is_persistent and key in _resolve_conflict_map: + merged = _resolve_conflict_map[key] elif not load: if state.modified: @@ -1749,6 +1798,7 @@ class Session(_SessionClassMethods): merged_dict = attributes.instance_dict(merged) _recursive[state] = merged + _resolve_conflict_map[key] = merged # check that we didn't just pull the exact same # state out. @@ -1787,7 +1837,7 @@ class Session(_SessionClassMethods): for prop in mapper.iterate_properties: prop.merge(self, state, state_dict, merged_state, merged_dict, - load, _recursive) + load, _recursive, _resolve_conflict_map) if not load: # remove any history @@ -1809,35 +1859,47 @@ class Session(_SessionClassMethods): "Object '%s' already has an identity - " "it can't be registered as pending" % state_str(state)) - self._before_attach(state) + obj = state.obj() + to_attach = self._before_attach(state, obj) if state not in self._new: - self._new[state] = state.obj() + self._new[state] = obj state.insert_order = len(self._new) - self._attach(state) - - def _update_impl(self, state, discard_existing=False): - if (self.identity_map.contains_state(state) and - state not in self._deleted): - return + if to_attach: + self._after_attach(state, obj) + def _update_impl(self, state, revert_deletion=False): if state.key is None: raise sa_exc.InvalidRequestError( "Instance '%s' is not persisted" % state_str(state)) - if state.deleted: - raise sa_exc.InvalidRequestError( - "Instance '%s' has been deleted. Use the make_transient() " - "function to send this object back to the transient state." % - state_str(state) - ) - self._before_attach(state, check_identity_map=False) + if state._deleted: + if revert_deletion: + if not state._attached: + return + del state._deleted + else: + raise sa_exc.InvalidRequestError( + "Instance '%s' has been deleted. " + "Use the make_transient() " + "function to send this object back " + "to the transient state." % + state_str(state) + ) + + obj = state.obj() + to_attach = self._before_attach(state, obj) + self._deleted.pop(state, None) - if discard_existing: + if revert_deletion: self.identity_map.replace(state) else: self.identity_map.add(state) - self._attach(state) + + if to_attach: + self._after_attach(state, obj) + elif revert_deletion: + self.dispatch.deleted_to_persistent(self, obj) def _save_or_update_impl(self, state): if state.key is None: @@ -1845,17 +1907,6 @@ class Session(_SessionClassMethods): else: self._update_impl(state) - def _delete_impl(self, state): - if state in self._deleted: - return - - if state.key is None: - return - - self._attach(state, include_before=True) - self._deleted[state] = state.obj() - self.identity_map.add(state) - def enable_relationship_loading(self, obj): """Associate an object with this :class:`.Session` for related object loading. @@ -1908,40 +1959,35 @@ class Session(_SessionClassMethods): """ state = attributes.instance_state(obj) - self._attach(state, include_before=True) + to_attach = self._before_attach(state, obj) state._load_pending = True + if to_attach: + self._after_attach(state, obj) - def _before_attach(self, state, check_identity_map=True): - if state.session_id != self.hash_key and \ - self.dispatch.before_attach: - self.dispatch.before_attach(self, state.obj()) - - if check_identity_map and state.key and \ - state.key in self.identity_map and \ - not self.identity_map.contains_state(state): - raise sa_exc.InvalidRequestError( - "Can't attach instance " - "%s; another instance with key %s is already " - "present in this session." % (state_str(state), state.key)) + def _before_attach(self, state, obj): + if state.session_id == self.hash_key: + return False - if state.session_id and \ - state.session_id is not self.hash_key and \ - state.session_id in _sessions: + if state.session_id and state.session_id in _sessions: raise sa_exc.InvalidRequestError( "Object '%s' is already attached to session '%s' " "(this is '%s')" % (state_str(state), state.session_id, self.hash_key)) - def _attach(self, state, include_before=False): + self.dispatch.before_attach(self, obj) + + return True - if state.session_id != self.hash_key: - if include_before: - self._before_attach(state) - state.session_id = self.hash_key - if state.modified and state._strong_obj is None: - state._strong_obj = state.obj() - if self.dispatch.after_attach: - self.dispatch.after_attach(self, state.obj()) + def _after_attach(self, state, obj): + state.session_id = self.hash_key + if state.modified and state._strong_obj is None: + state._strong_obj = obj + self.dispatch.after_attach(self, obj) + + if state.key: + self.dispatch.detached_to_persistent(self, obj) + else: + self.dispatch.transient_to_pending(self, obj) def __contains__(self, instance): """Return True if the instance is associated with this session. @@ -1983,7 +2029,7 @@ class Session(_SessionClassMethods): For ``autocommit`` Sessions with no active manual transaction, flush() will create a transaction on the fly that surrounds the entire set of - operations int the flush. + operations into the flush. :param objects: Optional; restricts the flush operation to operate only on elements that are in the given collection. @@ -2700,7 +2746,7 @@ def make_transient(instance): state = attributes.instance_state(instance) s = _state_session(state) if s: - s._expunge_state(state) + s._expunge_states([state]) # remove expired state state.expired_attributes.clear() @@ -2711,8 +2757,8 @@ def make_transient(instance): if state.key: del state.key - if state.deleted: - del state.deleted + if state._deleted: + del state._deleted def make_transient_to_detached(instance): @@ -2744,8 +2790,8 @@ def make_transient_to_detached(instance): raise sa_exc.InvalidRequestError( "Given object must be transient") state.key = state.mapper._identity_key_from_state(state) - if state.deleted: - del state.deleted + if state._deleted: + del state._deleted state._commit_all(state.dict) state._expire_attributes(state.dict, state.unloaded) diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 6034e74de..b648ffa3b 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -14,6 +14,7 @@ defines a large part of the ORM's interactivity. import weakref from .. import util +from .. import inspection from . import exc as orm_exc, interfaces from .path_registry import PathRegistry from .base import PASSIVE_NO_RESULT, SQL_OK, NEVER_SET, ATTR_WAS_SET, \ @@ -21,6 +22,7 @@ from .base import PASSIVE_NO_RESULT, SQL_OK, NEVER_SET, ATTR_WAS_SET, \ from . import base +@inspection._self_inspects class InstanceState(interfaces.InspectionAttr): """tracks state information at the instance level. @@ -56,7 +58,7 @@ class InstanceState(interfaces.InspectionAttr): _strong_obj = None modified = False expired = False - deleted = False + _deleted = False _load_pending = False is_instance = True @@ -87,7 +89,6 @@ class InstanceState(interfaces.InspectionAttr): see also the ``unmodified`` collection which is intersected against this set when a refresh operation occurs.""" - @util.memoized_property def attrs(self): """Return a namespace representing each attribute on @@ -133,16 +134,80 @@ class InstanceState(interfaces.InspectionAttr): self._attached @property + def deleted(self): + """Return true if the object is :term:`deleted`. + + An object that is in the deleted state is guaranteed to + not be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`; however if the session's transaction is rolled + back, the object will be restored to the persistent state and + the identity map. + + .. note:: + + The :attr:`.InstanceState.deleted` attribute refers to a specific + state of the object that occurs between the "persistent" and + "detached" states; once the object is :term:`detached`, the + :attr:`.InstanceState.deleted` attribute **no longer returns + True**; in order to detect that a state was deleted, regardless + of whether or not the object is associated with a :class:`.Session`, + use the :attr:`.InstanceState.was_deleted` accessor. + + .. versionadded: 1.1 + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and \ + self._attached and self._deleted + + @property + def was_deleted(self): + """Return True if this object is or was previously in the + "deleted" state and has not been reverted to persistent. + + This flag returns True once the object was deleted in flush. + When the object is expunged from the session either explicitly + or via transaction commit and enters the "detached" state, + this flag will continue to report True. + + .. versionadded:: 1.1 - added a local method form of + :func:`.orm.util.was_deleted`. + + .. seealso:: + + :attr:`.InstanceState.deleted` - refers to the "deleted" state + + :func:`.orm.util.was_deleted` - standalone function + + :ref:`session_object_states` + + """ + return self._deleted + + @property def persistent(self): """Return true if the object is :term:`persistent`. + An object that is in the persistent state is guaranteed to + be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`. + + .. versionchanged:: 1.1 The :attr:`.InstanceState.persistent` + accessor no longer returns True for an object that was + "deleted" within a flush; use the :attr:`.InstanceState.deleted` + accessor to detect this state. This allows the "persistent" + state to guarantee membership in the identity map. + .. seealso:: :ref:`session_object_states` """ return self.key is not None and \ - self._attached + self._attached and not self._deleted @property def detached(self): @@ -153,8 +218,7 @@ class InstanceState(interfaces.InspectionAttr): :ref:`session_object_states` """ - return self.key is not None and \ - not self._attached + return self.key is not None and not self._attached @property @util.dependencies("sqlalchemy.orm.session") @@ -241,8 +305,44 @@ class InstanceState(interfaces.InspectionAttr): """ return bool(self.key) - def _detach(self): - self.session_id = self._strong_obj = None + @classmethod + def _detach_states(self, states, session, to_transient=False): + persistent_to_detached = \ + session.dispatch.persistent_to_detached or None + deleted_to_detached = \ + session.dispatch.deleted_to_detached or None + pending_to_transient = \ + session.dispatch.pending_to_transient or None + persistent_to_transient = \ + session.dispatch.persistent_to_transient or None + + for state in states: + deleted = state._deleted + pending = state.key is None + persistent = not pending and not deleted + + state.session_id = None + + if to_transient and state.key: + del state.key + if persistent: + if to_transient: + if persistent_to_transient is not None: + persistent_to_transient(session, state.obj()) + elif persistent_to_detached is not None: + persistent_to_detached(session, state.obj()) + elif deleted and deleted_to_detached is not None: + deleted_to_detached(session, state.obj()) + elif pending and pending_to_transient is not None: + pending_to_transient(session, state.obj()) + + state._strong_obj = None + + def _detach(self, session=None): + if session: + InstanceState._detach_states([self], session) + else: + self.session_id = self._strong_obj = None def _dispose(self): self._detach() @@ -294,7 +394,7 @@ class InstanceState(interfaces.InspectionAttr): return {} def _initialize_instance(*mixed, **kwargs): - self, instance, args = mixed[0], mixed[1], mixed[2:] + self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa manager = self.manager manager.dispatch.init(self, args, kwargs) @@ -374,12 +474,6 @@ class InstanceState(interfaces.InspectionAttr): state_dict['manager'](self, inst, state_dict) - def _initialize(self, key): - """Set this attribute to an empty value or collection, - based on the AttributeImpl in use.""" - - self.manager.get_impl(key).initialize(self, self.dict) - def _reset(self, dict_, key): """Remove the given attribute and any callables associated with it.""" diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 78e929345..b60e47bb3 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -346,7 +346,10 @@ class NoLoader(AbstractRelationshipLoader): self, context, path, loadopt, mapper, result, adapter, populators): def invoke_no_load(state, dict_, row): - state._initialize(self.key) + if self.uselist: + state.manager.get_impl(self.key).initialize(state, dict_) + else: + dict_[self.key] = None populators["new"].append((self.key, invoke_no_load)) @@ -361,7 +364,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): __slots__ = ( '_lazywhere', '_rev_lazywhere', 'use_get', '_bind_to_col', - '_equated_columns', '_rev_bind_to_col', '_rev_equated_columns') + '_equated_columns', '_rev_bind_to_col', '_rev_equated_columns', + '_simple_lazy_clause') def __init__(self, parent): super(LazyLoader, self).__init__(parent) @@ -1321,8 +1325,19 @@ class JoinedLoader(AbstractRelationshipLoader): if adapter: if getattr(adapter, 'aliased_class', None): + # joining from an adapted entity. The adapted entity + # might be a "with_polymorphic", so resolve that to our + # specific mapper's entity before looking for our attribute + # name on it. + efm = inspect(adapter.aliased_class).\ + _entity_for_mapper( + parentmapper + if parentmapper.isa(self.parent) else self.parent) + + # look for our attribute on the adapted entity, else fall back + # to our straight property onclause = getattr( - adapter.aliased_class, self.key, + efm.entity, self.key, self.parent_property) else: onclause = getattr( @@ -1363,8 +1378,7 @@ class JoinedLoader(AbstractRelationshipLoader): # send a hint to the Query as to where it may "splice" this join eagerjoin.stop_on = entity.selectable - if self.parent_property.secondary is None and \ - not parentmapper: + if not parentmapper: # for parentclause that is the non-eager end of the join, # ensure all the parent cols in the primaryjoin are actually # in the diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index cb7a5fef7..3467328e3 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -180,7 +180,7 @@ class Load(Generative, MapperOption): return path def __str__(self): - return "Load(strategy=%r)" % self.strategy + return "Load(strategy=%r)" % (self.strategy, ) def _coerce_strat(self, strategy): if strategy is not None: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 6d3869679..46183a47d 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -537,7 +537,11 @@ class AliasedInsp(InspectionAttr): def _entity_for_mapper(self, mapper): self_poly = self.with_polymorphic_mappers if mapper in self_poly: - return getattr(self.entity, mapper.class_.__name__)._aliased_insp + if mapper is self.mapper: + return self + else: + return getattr( + self.entity, mapper.class_.__name__)._aliased_insp elif mapper.isa(self.mapper): return self else: @@ -985,12 +989,19 @@ def was_deleted(object): """Return True if the given object was deleted within a session flush. + This is regardless of whether or not the object is + persistent or detached. + .. versionadded:: 0.8.0 + .. seealso:: + + :attr:`.InstanceState.was_deleted` + """ state = attributes.instance_state(object) - return state.deleted + return state.was_deleted def randomize_unitofwork(): diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index b38aefb3d..4dd954fc4 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -587,7 +587,12 @@ class _ConnectionRecord(object): if recycle: self.__close() self.info.clear() + + # ensure that if self.__connect() fails, + # we are not referring to the previous stale connection here + self.connection = None self.connection = self.__connect() + if self.__pool.dispatch.connect: self.__pool.dispatch.connect(self.connection, self) return self.connection diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index e8b70061d..fa2cf2399 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -21,6 +21,8 @@ from .expression import ( Update, alias, and_, + any_, + all_, asc, between, bindparam, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index e9c3d0efa..6766c99b7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -97,6 +97,8 @@ OPERATORS = { operators.exists: 'EXISTS ', operators.distinct_op: 'DISTINCT ', operators.inv: 'NOT ', + operators.any_op: 'ANY ', + operators.all_op: 'ALL ', # modifiers operators.desc_op: ' DESC', @@ -281,6 +283,8 @@ class _CompileLabel(visitors.Visitable): def type(self): return self.element.type + def self_group(self, **kw): + return self class SQLCompiler(Compiled): @@ -761,6 +765,9 @@ class SQLCompiler(Compiled): x += "END" return x + def visit_type_coerce(self, type_coerce, **kw): + return type_coerce.typed_expression._compiler_dispatch(self, **kw) + def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % \ (cast.clause._compiler_dispatch(self, **kwargs), @@ -768,7 +775,7 @@ class SQLCompiler(Compiled): def visit_over(self, over, **kwargs): return "%s OVER (%s)" % ( - over.func._compiler_dispatch(self, **kwargs), + over.element._compiler_dispatch(self, **kwargs), ' '.join( '%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs)) for word, clause in ( @@ -779,6 +786,12 @@ class SQLCompiler(Compiled): ) ) + def visit_withingroup(self, withingroup, **kwargs): + return "%s WITHIN GROUP (ORDER BY %s)" % ( + withingroup.element._compiler_dispatch(self, **kwargs), + withingroup.order_by._compiler_dispatch(self, **kwargs) + ) + def visit_funcfilter(self, funcfilter, **kwargs): return "%s FILTER (WHERE %s)" % ( funcfilter.func._compiler_dispatch(self, **kwargs), @@ -1270,9 +1283,6 @@ class SQLCompiler(Compiled): return " AS " + alias_name_text def _add_to_result_map(self, keyname, name, objects, type_): - if not self.dialect.case_sensitive: - keyname = keyname.lower() - self._result_columns.append((keyname, name, objects, type_)) def _label_select_column(self, select, column, @@ -1789,9 +1799,9 @@ class SQLCompiler(Compiled): return text def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, **kwargs): + fromhints=None, use_schema=True, **kwargs): if asfrom or ashint: - if getattr(table, "schema", None): + if use_schema and getattr(table, "schema", None): ret = self.preparer.quote_schema(table.schema) + \ "." + self.preparer.quote(table.name) else: @@ -1812,6 +1822,22 @@ class SQLCompiler(Compiled): join.onclause._compiler_dispatch(self, **kwargs) ) + def _setup_crud_hints(self, stmt, table_text): + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, + stmt.table, + dialect_hints[stmt.table], + True + ) + return dialect_hints, table_text + def visit_insert(self, insert_stmt, **kw): self.stack.append( {'correlate_froms': set(), @@ -1853,19 +1879,10 @@ class SQLCompiler(Compiled): table_text = preparer.format_table(insert_stmt.table) if insert_stmt._hints: - dialect_hints = dict([ - (table, hint_text) - for (table, dialect), hint_text in - insert_stmt._hints.items() - if dialect in ('*', self.dialect.name) - ]) - if insert_stmt.table in dialect_hints: - table_text = self.format_from_hint_text( - table_text, - insert_stmt.table, - dialect_hints[insert_stmt.table], - True - ) + dialect_hints, table_text = self._setup_crud_hints( + insert_stmt, table_text) + else: + dialect_hints = None text += table_text @@ -1957,19 +1974,8 @@ class SQLCompiler(Compiled): crud_params = crud._get_crud_params(self, update_stmt, **kw) if update_stmt._hints: - dialect_hints = dict([ - (table, hint_text) - for (table, dialect), hint_text in - update_stmt._hints.items() - if dialect in ('*', self.dialect.name) - ]) - if update_stmt.table in dialect_hints: - table_text = self.format_from_hint_text( - table_text, - update_stmt.table, - dialect_hints[update_stmt.table], - True - ) + dialect_hints, table_text = self._setup_crud_hints( + update_stmt, table_text) else: dialect_hints = None @@ -2038,22 +2044,8 @@ class SQLCompiler(Compiled): self, asfrom=True, iscrud=True) if delete_stmt._hints: - dialect_hints = dict([ - (table, hint_text) - for (table, dialect), hint_text in - delete_stmt._hints.items() - if dialect in ('*', self.dialect.name) - ]) - if delete_stmt.table in dialect_hints: - table_text = self.format_from_hint_text( - table_text, - delete_stmt.table, - dialect_hints[delete_stmt.table], - True - ) - - else: - dialect_hints = None + dialect_hints, table_text = self._setup_crud_hints( + delete_stmt, table_text) text += table_text @@ -2139,11 +2131,11 @@ class DDLCompiler(Compiled): table = create.element preparer = self.dialect.identifier_preparer - text = "\n" + " ".join(['CREATE'] + - table._prefixes + - ['TABLE', - preparer.format_table(table), - "("]) + text = "\nCREATE " + if table._prefixes: + text += " ".join(table._prefixes) + " " + text += "TABLE " + preparer.format_table(table) + " (" + separator = "\n" # if only one primary key, specify it along with the column @@ -2168,10 +2160,10 @@ class DDLCompiler(Compiled): )) const = self.create_table_constraints( - table, _include_foreign_key_constraints= - create.include_foreign_key_constraints) + table, _include_foreign_key_constraints= # noqa + create.include_foreign_key_constraints) if const: - text += ", \n\t" + const + text += separator + "\t" + const text += "\n)%s\n\n" % self.post_create_table(table) return text @@ -2223,7 +2215,7 @@ class DDLCompiler(Compiled): and ( not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False) - )) if p is not None + )) if p is not None ) def visit_drop_table(self, drop): @@ -2299,6 +2291,16 @@ class DDLCompiler(Compiled): text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: text += " START WITH %d" % create.element.start + if create.element.minvalue is not None: + text += " MINVALUE %d" % create.element.minvalue + if create.element.maxvalue is not None: + text += " MAXVALUE %d" % create.element.maxvalue + if create.element.nominvalue is not None: + text += " NO MINVALUE" + if create.element.nomaxvalue is not None: + text += " NO MAXVALUE" + if create.element.cycle is not None: + text += " CYCLE" return text def visit_drop_sequence(self, drop): @@ -2379,7 +2381,7 @@ class DDLCompiler(Compiled): text += "CONSTRAINT %s " % formatted_name text += "PRIMARY KEY " text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + for c in constraint.columns_autoinc_first) text += self.define_constraint_deferrability(constraint) return text diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 2e39f6b36..c5495ccde 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -196,8 +196,9 @@ def _scan_insert_from_select_cols( if add_select_cols: values.extend(add_select_cols) compiler._insert_from_select = compiler._insert_from_select._generate() - compiler._insert_from_select._raw_columns += tuple( - expr for col, expr in add_select_cols) + compiler._insert_from_select._raw_columns = \ + tuple(compiler._insert_from_select._raw_columns) + tuple( + expr for col, expr in add_select_cols) def _scan_cols( @@ -208,10 +209,22 @@ def _scan_cols( implicit_return_defaults, postfetch_lastrowid = \ _get_returning_modifiers(compiler, stmt) - cols = stmt.table.columns + if stmt._parameter_ordering: + parameter_ordering = [ + _column_as_key(key) for key in stmt._parameter_ordering + ] + ordered_keys = set(parameter_ordering) + cols = [ + stmt.table.c[key] for key in parameter_ordering + ] + [ + c for c in stmt.table.c if c.key not in ordered_keys + ] + else: + cols = stmt.table.columns for c in cols: col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: _append_param_parameter( @@ -248,6 +261,10 @@ def _scan_cols( elif implicit_return_defaults and \ c in implicit_return_defaults: compiler.returning.append(c) + elif c.primary_key and \ + c is not stmt.table._autoincrement_column and \ + not c.nullable: + _raise_pk_with_no_anticipated_value(c) elif compiler.isupdate: _append_param_update( @@ -285,6 +302,22 @@ def _append_param_parameter( def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement and + possibly a RETURNING clause for it. + + If the column has a Python-side default, we will create a bound + parameter for it and "pre-execute" the Python function. If + the column has a SQL expression default, or is a sequence, + we will add it directly into the INSERT statement and add a + RETURNING element to get the new value. If the column has a + server side default or is marked as the "autoincrement" column, + we will add a RETRUNING element to get at the value. + + If all the above tests fail, that indicates a primary key column with no + noted default generation capabilities that has no parameter passed; + raise an exception. + + """ if c.default is not None: if c.default.is_sequence: if compiler.dialect.supports_sequences and \ @@ -303,9 +336,12 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): values.append( (c, _create_prefetch_bind_param(compiler, c)) ) - - else: + elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) + elif not c.nullable: + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _raise_pk_with_no_anticipated_value(c) def _create_prefetch_bind_param(compiler, c, process=True, name=None): @@ -319,6 +355,7 @@ class _multiparam_column(elements.ColumnElement): self.key = "%s_%d" % (original.key, index + 1) self.original = original self.default = original.default + self.type = original.type def __eq__(self, other): return isinstance(other, _multiparam_column) and \ @@ -341,18 +378,46 @@ def _process_multiparam_default_bind(compiler, c, index, kw): def _append_param_insert_pk(compiler, stmt, c, values, kw): + """Create a bound parameter in the INSERT statement to receive a + 'prefetched' default value. + + The 'prefetched' value indicates that we are to invoke a Python-side + default function or expliclt SQL expression before the INSERT statement + proceeds, so that we have a primary key value available. + + if the column has no noted default generation capabilities, it has + no value passed in either; raise an exception. + + """ if ( - (c.default is not None and - (not c.default.is_sequence or - compiler.dialect.supports_sequences)) or - c is stmt.table._autoincrement_column and - (compiler.dialect.supports_sequences or - compiler.dialect. - preexecute_autoincrement_sequences) + ( + # column has a Python-side default + c.default is not None and + ( + # and it won't be a Sequence + not c.default.is_sequence or + compiler.dialect.supports_sequences + ) + ) + or + ( + # column is the "autoincrement column" + c is stmt.table._autoincrement_column and + ( + # and it's either a "sequence" or a + # pre-executable "autoincrement" sequence + compiler.dialect.supports_sequences or + compiler.dialect.preexecute_autoincrement_sequences + ) + ) ): values.append( (c, _create_prefetch_bind_param(compiler, c)) ) + elif c.default is None and c.server_default is None and not c.nullable: + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _raise_pk_with_no_anticipated_value(c) def _append_param_insert_hasdefault( @@ -428,6 +493,7 @@ def _append_param_update( else: compiler.postfetch.append(c) elif implicit_return_defaults and \ + stmt._return_defaults is not True and \ c in implicit_return_defaults: compiler.returning.append(c) @@ -554,3 +620,24 @@ def _get_returning_modifiers(compiler, stmt): return need_pks, implicit_returning, \ implicit_return_defaults, postfetch_lastrowid + + +def _raise_pk_with_no_anticipated_value(c): + msg = ( + "Column '%s.%s' is marked as a member of the " + "primary key for table '%s', " + "but has no Python-side or server-side default generator indicated, " + "nor does it indicate 'autoincrement=True' or 'nullable=True', " + "and no explicit value is passed. " + "Primary key columns typically may not store NULL." + % + (c.table.fullname, c.name, c.table.fullname)) + if len(c.table.primary_key.columns) > 1: + msg += ( + " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be " + "indicated explicitly for composite (e.g. multicolumn) primary " + "keys if AUTO_INCREMENT/SERIAL/IDENTITY " + "behavior is expected for one of the columns in the primary key. " + "CREATE TABLE statements are impacted by this change as well on " + "most backends.") + raise exc.CompileError(msg) diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index e77ad765c..68ea5624e 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -14,7 +14,8 @@ from . import operators from .elements import BindParameter, True_, False_, BinaryExpression, \ Null, _const_expr, _clause_element_as_expr, \ ClauseList, ColumnElement, TextClause, UnaryExpression, \ - collate, _is_literal, _literal_as_text, ClauseElement, and_, or_ + collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \ + Slice, Visitable, _literal_as_binds from .selectable import SelectBase, Alias, Selectable, ScalarSelect @@ -161,6 +162,34 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): negate=negate_op) +def _getitem_impl(expr, op, other, **kw): + if isinstance(expr.type, type_api.INDEXABLE): + if isinstance(other, slice): + if expr.type.zero_indexes: + other = slice( + other.start + 1, + other.stop + 1, + other.step + ) + other = Slice( + _literal_as_binds( + other.start, name=expr.key, type_=type_api.INTEGERTYPE), + _literal_as_binds( + other.stop, name=expr.key, type_=type_api.INTEGERTYPE), + _literal_as_binds( + other.step, name=expr.key, type_=type_api.INTEGERTYPE) + ) + else: + if expr.type.zero_indexes: + other += 1 + + other = _literal_as_binds( + other, name=expr.key, type_=type_api.INTEGERTYPE) + return _binary_operate(expr, op, other, **kw) + else: + _unsupported_impl(expr, op, other, **kw) + + def _unsupported_impl(expr, op, *arg, **kw): raise NotImplementedError("Operator '%s' is not supported on " "this expression" % op.__name__) @@ -260,7 +289,7 @@ operator_lookup = { "between_op": (_between_impl, ), "notbetween_op": (_between_impl, ), "neg": (_neg_impl,), - "getitem": (_unsupported_impl,), + "getitem": (_getitem_impl,), "lshift": (_unsupported_impl,), "rshift": (_unsupported_impl,), } @@ -280,7 +309,7 @@ def _check_literal(expr, operator, other): if isinstance(other, (SelectBase, Alias)): return other.as_scalar() - elif not isinstance(other, (ColumnElement, TextClause)): + elif not isinstance(other, Visitable): return expr._bind_param(operator, other) else: return other diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 6756f1554..22c534153 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -27,6 +27,7 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): _execution_options = \ Executable._execution_options.union({'autocommit': True}) _hints = util.immutabledict() + _parameter_ordering = None _prefixes = () def _process_colparams(self, parameters): @@ -39,6 +40,16 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): else: return p + if self._preserve_parameter_order and parameters is not None: + if not isinstance(parameters, list) or \ + (parameters and not isinstance(parameters[0], tuple)): + raise ValueError( + "When preserve_parameter_order is True, " + "values() only accepts a list of 2-tuples") + self._parameter_ordering = [key for key, value in parameters] + + return dict(parameters), False + if (isinstance(parameters, (list, tuple)) and parameters and isinstance(parameters[0], (list, tuple, dict))): @@ -178,6 +189,7 @@ class ValuesBase(UpdateBase): _supports_multi_parameters = False _has_multi_parameters = False + _preserve_parameter_order = False select = None def __init__(self, table, values, prefixes): @@ -214,23 +226,32 @@ class ValuesBase(UpdateBase): users.update().where(users.c.id==5).values(name="some name") - :param \*args: Alternatively, a dictionary, tuple or list - of dictionaries or tuples can be passed as a single positional - argument in order to form the VALUES or - SET clause of the statement. The single dictionary form - works the same as the kwargs form:: + :param \*args: As an alternative to passing key/value parameters, + a dictionary, tuple, or list of dictionaries or tuples can be passed + as a single positional argument in order to form the VALUES or + SET clause of the statement. The forms that are accepted vary + based on whether this is an :class:`.Insert` or an :class:`.Update` + construct. + + For either an :class:`.Insert` or :class:`.Update` construct, a + single dictionary can be passed, which works the same as that of + the kwargs form:: users.insert().values({"name": "some name"}) - If a tuple is passed, the tuple should contain the same number - of columns as the target :class:`.Table`:: + users.update().values({"name": "some new name"}) + + Also for either form but more typically for the :class:`.Insert` + construct, a tuple that contains an entry for every column in the + table is also accepted:: users.insert().values((5, "some name")) - The :class:`.Insert` construct also supports multiply-rendered VALUES - construct, for those backends which support this SQL syntax - (SQLite, Postgresql, MySQL). This mode is indicated by passing a - list of one or more dictionaries/tuples:: + The :class:`.Insert` construct also supports being passed a list + of dictionaries or full-table-tuples, which on the server will + render the less common SQL syntax of "multiple values" - this + syntax is supported on backends such as SQLite, Postgresql, MySQL, + but not necessarily others:: users.insert().values([ {"name": "some name"}, @@ -238,55 +259,61 @@ class ValuesBase(UpdateBase): {"name": "yet another name"}, ]) - In the case of an :class:`.Update` - construct, only the single dictionary/tuple form is accepted, - else an exception is raised. It is also an exception case to - attempt to mix the single-/multiple- value styles together, - either through multiple :meth:`.ValuesBase.values` calls - or by sending a list + kwargs at the same time. - - .. note:: - - Passing a multiple values list is *not* the same - as passing a multiple values list to the - :meth:`.Connection.execute` method. Passing a list of parameter - sets to :meth:`.ValuesBase.values` produces a construct of this - form:: - - INSERT INTO table (col1, col2, col3) VALUES - (col1_0, col2_0, col3_0), - (col1_1, col2_1, col3_1), - ... - - whereas a multiple list passed to :meth:`.Connection.execute` - has the effect of using the DBAPI - `executemany() <http://www.python.org/dev/peps/pep-0249/#id18>`_ - method, which provides a high-performance system of invoking - a single-row INSERT or single-criteria UPDATE or DELETE statement - many times against a series - of parameter sets. The "executemany" style is supported by - all database backends, and works equally well for INSERT, - UPDATE, and DELETE, as it does not depend on a special SQL - syntax. See :ref:`execute_multiple` for an introduction to - the traditional Core method of multiple parameter set invocation - using this system. - - .. versionadded:: 0.8 - Support for multiple-VALUES INSERT statements. - - .. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES - clause, even a list of length one, - implies that the :paramref:`.Insert.inline` flag is set to - True, indicating that the statement will not attempt to fetch - the "last inserted primary key" or other defaults. The statement - deals with an arbitrary number of rows, so the - :attr:`.ResultProxy.inserted_primary_key` accessor does not apply. - - .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports - columns with Python side default values and callables in the - same way as that of an "executemany" style of invocation; the - callable is invoked for each row. See :ref:`bug_3288` - for other details. + The above form would render a multiple VALUES statement similar to:: + + INSERT INTO users (name) VALUES + (:name_1), + (:name_2), + (:name_3) + + It is essential to note that **passing multiple values is + NOT the same as using traditional executemany() form**. The above + syntax is a **special** syntax not typically used. To emit an + INSERT statement against mutliple rows, the normal method is + to pass a mutiple values list to the :meth:`.Connection.execute` + method, which is supported by all database backends and is generally + more efficient for a very large number of parameters. + + .. seealso:: + + :ref:`execute_multiple` - an introduction to + the traditional Core method of multiple parameter set + invocation for INSERTs and other statements. + + .. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES + clause, even a list of length one, + implies that the :paramref:`.Insert.inline` flag is set to + True, indicating that the statement will not attempt to fetch + the "last inserted primary key" or other defaults. The + statement deals with an arbitrary number of rows, so the + :attr:`.ResultProxy.inserted_primary_key` accessor does not + apply. + + .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports + columns with Python side default values and callables in the + same way as that of an "executemany" style of invocation; the + callable is invoked for each row. See :ref:`bug_3288` + for other details. + + The :class:`.Update` construct supports a special form which is a + list of 2-tuples, which when provided must be passed in conjunction + with the + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` + parameter. + This form causes the UPDATE statement to render the SET clauses + using the order of parameters given to :meth:`.Update.values`, rather + than the ordering of columns given in the :class:`.Table`. + + .. versionadded:: 1.0.10 - added support for parameter-ordered + UPDATE statements via the + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` + flag. + + .. seealso:: + + :ref:`updates_order_parameters` - full example of the + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` + flag .. seealso:: @@ -582,6 +609,7 @@ class Update(ValuesBase): prefixes=None, returning=None, return_defaults=False, + preserve_parameter_order=False, **dialect_kw): """Construct an :class:`.Update` object. @@ -644,6 +672,19 @@ class Update(ValuesBase): be available in the dictionary returned from :meth:`.ResultProxy.last_updated_params`. + :param preserve_parameter_order: if True, the update statement is + expected to receive parameters **only** via the :meth:`.Update.values` + method, and they must be passed as a Python ``list`` of 2-tuples. + The rendered UPDATE statement will emit the SET clause for each + referenced column maintaining this order. + + .. versionadded:: 1.0.10 + + .. seealso:: + + :ref:`updates_order_parameters` - full example of the + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` flag + If both ``values`` and compile-time bind parameters are present, the compile-time bind parameters override the information specified within ``values`` on a per-key basis. @@ -685,6 +726,7 @@ class Update(ValuesBase): """ + self._preserve_parameter_order = preserve_parameter_order ValuesBase.__init__(self, table, values, prefixes) self._bind = bind self._returning = returning diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 27ecce2b0..70046c66b 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -124,67 +124,6 @@ def literal(value, type_=None): return BindParameter(None, value, type_=type_, unique=True) -def type_coerce(expression, type_): - """Associate a SQL expression with a particular type, without rendering - ``CAST``. - - E.g.:: - - from sqlalchemy import type_coerce - - stmt = select([type_coerce(log_table.date_string, StringDateTime())]) - - The above construct will produce SQL that is usually otherwise unaffected - by the :func:`.type_coerce` call:: - - SELECT date_string FROM log - - However, when result rows are fetched, the ``StringDateTime`` type - will be applied to result rows on behalf of the ``date_string`` column. - - A type that features bound-value handling will also have that behavior - take effect when literal values or :func:`.bindparam` constructs are - passed to :func:`.type_coerce` as targets. - For example, if a type implements the :meth:`.TypeEngine.bind_expression` - method or :meth:`.TypeEngine.bind_processor` method or equivalent, - these functions will take effect at statement compilation/execution time - when a literal value is passed, as in:: - - # bound-value handling of MyStringType will be applied to the - # literal value "some string" - stmt = select([type_coerce("some string", MyStringType)]) - - :func:`.type_coerce` is similar to the :func:`.cast` function, - except that it does not render the ``CAST`` expression in the resulting - statement. - - :param expression: A SQL expression, such as a :class:`.ColumnElement` - expression or a Python string which will be coerced into a bound literal - value. - - :param type_: A :class:`.TypeEngine` class or instance indicating - the type to which the expression is coerced. - - .. seealso:: - - :func:`.cast` - - """ - type_ = type_api.to_instance(type_) - - if hasattr(expression, '__clause_element__'): - return type_coerce(expression.__clause_element__(), type_) - elif isinstance(expression, BindParameter): - bp = expression._clone() - bp.type = type_ - return bp - elif not isinstance(expression, Visitable): - if expression is None: - return Null() - else: - return literal(expression, type_=type_) - else: - return Label(None, expression, type_=type_) def outparam(key, type_=None): @@ -700,6 +639,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity): return AsBoolean(self, operators.istrue, operators.isfalse) + elif (against in (operators.any_op, operators.all_op)): + return Grouping(self) else: return self @@ -715,7 +656,14 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): @util.memoized_property def comparator(self): - return self.type.comparator_factory(self) + try: + comparator_factory = self.type.comparator_factory + except AttributeError: + raise TypeError( + "Object %r associated with '.type' attribute " + "is not a TypeEngine class or object" % self.type) + else: + return comparator_factory(self) def __getattr__(self, key): try: @@ -837,6 +785,16 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): else: return False + def cast(self, type_): + """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``. + + This is a shortcut to the :func:`~.expression.cast` function. + + .. versionadded:: 1.0.7 + + """ + return Cast(self, type_) + def label(self, name): """Produce a column label, i.e. ``<columnname> AS <name>``. @@ -1128,8 +1086,7 @@ class BindParameter(ColumnElement): _compared_to_type.coerce_compared_value( _compared_to_operator, value) else: - self.type = type_api._type_map.get(type(value), - type_api.NULLTYPE) + self.type = type_api._resolve_value_to_type(value) elif isinstance(type_, type): self.type = type_() else: @@ -1144,8 +1101,7 @@ class BindParameter(ColumnElement): cloned.callable = None cloned.required = False if cloned.type is type_api.NULLTYPE: - cloned.type = type_api._type_map.get(type(value), - type_api.NULLTYPE) + cloned.type = type_api._resolve_value_to_type(value) return cloned @property @@ -1840,9 +1796,12 @@ class BooleanClauseList(ClauseList, ColumnElement): def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): convert_clauses = [] - clauses = util.coerce_generator_arg(clauses) + clauses = [ + _expression_literal_as_text(clause) + for clause in + util.coerce_generator_arg(clauses) + ] for clause in clauses: - clause = _expression_literal_as_text(clause) if isinstance(clause, continue_on): continue @@ -2327,6 +2286,109 @@ class Cast(ColumnElement): return self.clause._from_objects +class TypeCoerce(ColumnElement): + """Represent a Python-side type-coercion wrapper. + + :class:`.TypeCoerce` supplies the :func:`.expression.type_coerce` + function; see that function for usage details. + + .. versionchanged:: 1.1 The :func:`.type_coerce` function now produces + a persistent :class:`.TypeCoerce` wrapper object rather than + translating the given object in place. + + .. seealso:: + + :func:`.expression.type_coerce` + + """ + + __visit_name__ = 'type_coerce' + + def __init__(self, expression, type_): + """Associate a SQL expression with a particular type, without rendering + ``CAST``. + + E.g.:: + + from sqlalchemy import type_coerce + + stmt = select([ + type_coerce(log_table.date_string, StringDateTime()) + ]) + + The above construct will produce a :class:`.TypeCoerce` object, which + renders SQL that labels the expression, but otherwise does not + modify its value on the SQL side:: + + SELECT date_string AS anon_1 FROM log + + When result rows are fetched, the ``StringDateTime`` type + will be applied to result rows on behalf of the ``date_string`` column. + The rationale for the "anon_1" label is so that the type-coerced + column remains separate in the list of result columns vs. other + type-coerced or direct values of the target column. In order to + provide a named label for the expression, use + :meth:`.ColumnElement.label`:: + + stmt = select([ + type_coerce( + log_table.date_string, StringDateTime()).label('date') + ]) + + + A type that features bound-value handling will also have that behavior + take effect when literal values or :func:`.bindparam` constructs are + passed to :func:`.type_coerce` as targets. + For example, if a type implements the + :meth:`.TypeEngine.bind_expression` + method or :meth:`.TypeEngine.bind_processor` method or equivalent, + these functions will take effect at statement compilation/execution + time when a literal value is passed, as in:: + + # bound-value handling of MyStringType will be applied to the + # literal value "some string" + stmt = select([type_coerce("some string", MyStringType)]) + + :func:`.type_coerce` is similar to the :func:`.cast` function, + except that it does not render the ``CAST`` expression in the resulting + statement. + + :param expression: A SQL expression, such as a :class:`.ColumnElement` + expression or a Python string which will be coerced into a bound + literal value. + + :param type_: A :class:`.TypeEngine` class or instance indicating + the type to which the expression is coerced. + + .. seealso:: + + :func:`.cast` + + """ + self.type = type_api.to_instance(type_) + self.clause = _literal_as_binds(expression, type_=self.type) + + def _copy_internals(self, clone=_clone, **kw): + self.clause = clone(self.clause, **kw) + self.__dict__.pop('typed_expression', None) + + def get_children(self, **kwargs): + return self.clause, + + @property + def _from_objects(self): + return self.clause._from_objects + + @util.memoized_property + def typed_expression(self): + if isinstance(self.clause, BindParameter): + bp = self.clause._clone() + bp.type = self.type + return bp + else: + return self.clause + + class Extract(ColumnElement): """Represent a SQL EXTRACT clause, ``extract(field FROM expr)``.""" @@ -2668,6 +2730,91 @@ class UnaryExpression(ColumnElement): return self +class CollectionAggregate(UnaryExpression): + """Forms the basis for right-hand collection operator modifiers + ANY and ALL. + + The ANY and ALL keywords are available in different ways on different + backends. On Postgresql, they only work for an ARRAY type. On + MySQL, they only work for subqueries. + + """ + @classmethod + def _create_any(cls, expr): + """Produce an ANY expression. + + This may apply to an array type for some dialects (e.g. postgresql), + or to a subquery for others (e.g. mysql). e.g.:: + + # postgresql '5 = ANY (somearray)' + expr = 5 == any_(mytable.c.somearray) + + # mysql '5 = ANY (SELECT value FROM table)' + expr = 5 == any_(select([table.c.value])) + + .. versionadded:: 1.1 + + .. seealso:: + + :func:`.expression.all_` + + """ + + expr = _literal_as_binds(expr) + + if expr.is_selectable and hasattr(expr, 'as_scalar'): + expr = expr.as_scalar() + expr = expr.self_group() + return CollectionAggregate( + expr, operator=operators.any_op, + type_=type_api.NULLTYPE, wraps_column_expression=False) + + @classmethod + def _create_all(cls, expr): + """Produce an ALL expression. + + This may apply to an array type for some dialects (e.g. postgresql), + or to a subquery for others (e.g. mysql). e.g.:: + + # postgresql '5 = ALL (somearray)' + expr = 5 == all_(mytable.c.somearray) + + # mysql '5 = ALL (SELECT value FROM table)' + expr = 5 == all_(select([table.c.value])) + + .. versionadded:: 1.1 + + .. seealso:: + + :func:`.expression.any_` + + """ + + expr = _literal_as_binds(expr) + if expr.is_selectable and hasattr(expr, 'as_scalar'): + expr = expr.as_scalar() + expr = expr.self_group() + return CollectionAggregate( + expr, operator=operators.all_op, + type_=type_api.NULLTYPE, wraps_column_expression=False) + + # operate and reverse_operate are hardwired to + # dispatch onto the type comparator directly, so that we can + # ensure "reversed" behavior. + def operate(self, op, *other, **kwargs): + if not operators.is_comparison(op): + raise exc.ArgumentError( + "Only comparison operators may be used with ANY/ALL") + kwargs['reverse'] = True + return self.comparator.operate(operators.mirror(op), *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + # comparison operators should never call reverse_operate + assert not operators.is_comparison(op) + raise exc.ArgumentError( + "Only comparison operators may be used with ANY/ALL") + + class AsBoolean(UnaryExpression): def __init__(self, element, operator, negate): @@ -2779,6 +2926,32 @@ class BinaryExpression(ColumnElement): return super(BinaryExpression, self)._negate() +class Slice(ColumnElement): + """Represent SQL for a Python array-slice object. + + This is not a specific SQL construct at this level, but + may be interpreted by specific dialects, e.g. Postgresql. + + """ + __visit_name__ = 'slice' + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = type_api.NULLTYPE + + def self_group(self, against=None): + assert against is operator.getitem + return self + + +class IndexExpression(BinaryExpression): + """Represent the class of expressions that are like an "index" operation. + """ + pass + + class Grouping(ColumnElement): """Represent a grouping within a column expression""" @@ -2839,21 +3012,21 @@ class Over(ColumnElement): order_by = None partition_by = None - def __init__(self, func, partition_by=None, order_by=None): + def __init__(self, element, partition_by=None, order_by=None): """Produce an :class:`.Over` object against a function. Used against aggregate or so-called "window" functions, for database backends that support window functions. - E.g.:: + :func:`~.expression.over` is usually called using + the :meth:`.FunctionElement.over` method, e.g.:: - from sqlalchemy import over - over(func.row_number(), order_by='x') + func.row_number().over(order_by='x') - Would produce "ROW_NUMBER() OVER(ORDER BY x)". + Would produce ``ROW_NUMBER() OVER(ORDER BY x)``. - :param func: a :class:`.FunctionElement` construct, typically - generated by :data:`~.expression.func`. + :param element: a :class:`.FunctionElement`, :class:`.WithinGroup`, + or other compatible construct. :param partition_by: a column element or string, or a list of such, that will be used as the PARTITION BY clause of the OVER construct. @@ -2866,8 +3039,14 @@ class Over(ColumnElement): .. versionadded:: 0.7 + .. seealso:: + + :data:`.expression.func` + + :func:`.expression.within_group` + """ - self.func = func + self.element = element if order_by is not None: self.order_by = ClauseList( *util.to_list(order_by), @@ -2877,17 +3056,29 @@ class Over(ColumnElement): *util.to_list(partition_by), _literal_as_text=_literal_as_label_reference) + @property + def func(self): + """the element referred to by this :class:`.Over` + clause. + + .. deprecated:: 1.1 the ``func`` element has been renamed to + ``.element``. The two attributes are synonymous though + ``.func`` is read-only. + + """ + return self.element + @util.memoized_property def type(self): - return self.func.type + return self.element.type def get_children(self, **kwargs): return [c for c in - (self.func, self.partition_by, self.order_by) + (self.element, self.partition_by, self.order_by) if c is not None] def _copy_internals(self, clone=_clone, **kw): - self.func = clone(self.func, **kw) + self.element = clone(self.element, **kw) if self.partition_by is not None: self.partition_by = clone(self.partition_by, **kw) if self.order_by is not None: @@ -2897,7 +3088,106 @@ class Over(ColumnElement): def _from_objects(self): return list(itertools.chain( *[c._from_objects for c in - (self.func, self.partition_by, self.order_by) + (self.element, self.partition_by, self.order_by) + if c is not None] + )) + + +class WithinGroup(ColumnElement): + """Represent a WITHIN GROUP (ORDER BY) clause. + + This is a special operator against so-called + so-called "ordered set aggregate" and "hypothetical + set aggregate" functions, including ``percentile_cont()``, + ``rank()``, ``dense_rank()``, etc. + + It's supported only by certain database backends, such as PostgreSQL, + Oracle and MS SQL Server. + + The :class:`.WithinGroup` consturct extracts its type from the + method :meth:`.FunctionElement.within_group_type`. If this returns + ``None``, the function's ``.type`` is used. + + """ + __visit_name__ = 'withingroup' + + order_by = None + + def __init__(self, element, *order_by): + """Produce a :class:`.WithinGroup` object against a function. + + Used against so-called "ordered set aggregate" and "hypothetical + set aggregate" functions, including :class:`.percentile_cont`, + :class:`.rank`, :class:`.dense_rank`, etc. + + :func:`~.expression.within_group` is usually called using + the :meth:`.FunctionElement.within_group` method, e.g.:: + + from sqlalchemy import within_group + stmt = select([ + department.c.id, + func.percentile_cont(0.5).within_group( + department.c.salary.desc() + ) + ]) + + The above statement would produce SQL similar to + ``SELECT department.id, percentile_cont(0.5) + WITHIN GROUP (ORDER BY department.salary DESC)``. + + :param element: a :class:`.FunctionElement` construct, typically + generated by :data:`~.expression.func`. + :param \*order_by: one or more column elements that will be used + as the ORDER BY clause of the WITHIN GROUP construct. + + .. versionadded:: 1.1 + + .. seealso:: + + :data:`.expression.func` + + :func:`.expression.over` + + """ + self.element = element + if order_by is not None: + self.order_by = ClauseList( + *util.to_list(order_by), + _literal_as_text=_literal_as_label_reference) + + def over(self, partition_by=None, order_by=None): + """Produce an OVER clause against this :class:`.WithinGroup` + construct. + + This function has the same signature as that of + :meth:`.FunctionElement.over`. + + """ + return Over(self, partition_by=partition_by, order_by=order_by) + + @util.memoized_property + def type(self): + wgt = self.element.within_group_type(self) + if wgt is not None: + return wgt + else: + return self.element.type + + def get_children(self, **kwargs): + return [c for c in + (self.func, self.order_by) + if c is not None] + + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) + if self.order_by is not None: + self.order_by = clone(self.order_by, **kw) + + @property + def _from_objects(self): + return list(itertools.chain( + *[c._from_objects for c in + (self.element, self.order_by) if c is not None] )) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 74b827d7e..27fae8ca4 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -15,7 +15,7 @@ class. """ __all__ = [ - 'Alias', 'ClauseElement', 'ColumnCollection', 'ColumnElement', + 'Alias', 'Any', 'All', 'ClauseElement', 'ColumnCollection', 'ColumnElement', 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between', 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct', @@ -24,19 +24,19 @@ __all__ = [ 'literal', 'literal_column', 'not_', 'null', 'nullsfirst', 'nullslast', 'or_', 'outparam', 'outerjoin', 'over', 'select', 'subquery', 'table', 'text', - 'tuple_', 'type_coerce', 'union', 'union_all', 'update'] + 'tuple_', 'type_coerce', 'union', 'union_all', 'update', 'within_group'] from .visitors import Visitable from .functions import func, modifier, FunctionElement, Function from ..util.langhelpers import public_factory from .elements import ClauseElement, ColumnElement,\ - BindParameter, UnaryExpression, BooleanClauseList, \ + BindParameter, CollectionAggregate, UnaryExpression, BooleanClauseList, \ Label, Cast, Case, ColumnClause, TextClause, Over, Null, \ True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \ - Grouping, not_, \ + Grouping, WithinGroup, not_, \ collate, literal_column, between,\ - literal, outparam, type_coerce, ClauseList, FunctionFilter + literal, outparam, TypeCoerce, ClauseList, FunctionFilter from .elements import SavepointClause, RollbackToSavepointClause, \ ReleaseSavepointClause @@ -57,6 +57,8 @@ from .dml import Insert, Update, Delete, UpdateBase, ValuesBase # the functions to be available in the sqlalchemy.sql.* namespace and # to be auto-cross-documenting from the function to the class itself. +all_ = public_factory(CollectionAggregate._create_all, ".expression.all_") +any_ = public_factory(CollectionAggregate._create_any, ".expression.any_") and_ = public_factory(BooleanClauseList.and_, ".expression.and_") or_ = public_factory(BooleanClauseList.or_, ".expression.or_") bindparam = public_factory(BindParameter, ".expression.bindparam") @@ -65,6 +67,7 @@ text = public_factory(TextClause._create_text, ".expression.text") table = public_factory(TableClause, ".expression.table") column = public_factory(ColumnClause, ".expression.column") over = public_factory(Over, ".expression.over") +within_group = public_factory(WithinGroup, ".expression.within_group") label = public_factory(Label, ".expression.label") case = public_factory(Case, ".expression.case") cast = public_factory(Cast, ".expression.cast") @@ -89,6 +92,7 @@ asc = public_factory(UnaryExpression._create_asc, ".expression.asc") desc = public_factory(UnaryExpression._create_desc, ".expression.desc") distinct = public_factory( UnaryExpression._create_distinct, ".expression.distinct") +type_coerce = public_factory(TypeCoerce, ".expression.type_coerce") true = public_factory(True_._instance, ".expression.true") false = public_factory(False_._instance, ".expression.false") null = public_factory(Null._instance, ".expression.null") diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 538a2c549..6cfbd12b3 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -12,9 +12,9 @@ from . import sqltypes, schema from .base import Executable, ColumnCollection from .elements import ClauseList, Cast, Extract, _literal_as_binds, \ literal_column, _type_from_args, ColumnElement, _clone,\ - Over, BindParameter, FunctionFilter + Over, BindParameter, FunctionFilter, Grouping, WithinGroup from .selectable import FromClause, Select, Alias - +from . import util as sqlutil from . import operators from .visitors import VisitableType from .. import util @@ -116,6 +116,21 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ return Over(self, partition_by=partition_by, order_by=order_by) + def within_group(self, *order_by): + """Produce a WITHIN GROUP (ORDER BY expr) clause against this function. + + Used against so-called "ordered set aggregate" and "hypothetical + set aggregate" functions, including :class:`.percentile_cont`, + :class:`.rank`, :class:`.dense_rank`, etc. + + See :func:`~.expression.within_group` for a full description. + + .. versionadded:: 1.1 + + + """ + return WithinGroup(self, *order_by) + def filter(self, *criterion): """Produce a FILTER clause against this function. @@ -157,6 +172,18 @@ class FunctionElement(Executable, ColumnElement, FromClause): self._reset_exported() FunctionElement.clauses._reset(self) + def within_group_type(self, within_group): + """For types that define their return type as based on the criteria + within a WITHIN GROUP (ORDER BY) expression, called by the + :class:`.WithinGroup` construct. + + Returns None by default, in which case the function's normal ``.type`` + is used. + + """ + + return None + def alias(self, name=None, flat=False): """Produce a :class:`.Alias` construct against this :class:`.FunctionElement`. @@ -233,6 +260,16 @@ class FunctionElement(Executable, ColumnElement, FromClause): return BindParameter(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) + def self_group(self, against=None): + # for the moment, we are parenthesizing all array-returning + # expressions against getitem. This may need to be made + # more portable if in the future we support other DBs + # besides postgresql. + if against is operators.getitem: + return Grouping(self) + else: + return super(FunctionElement, self).self_group(against=against) + class _FunctionGenerator(object): """Generate :class:`.Function` objects based on getattr calls.""" @@ -483,7 +520,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): def __init__(self, *args, **kwargs): parsed_args = kwargs.pop('_parsed_args', None) if parsed_args is None: - parsed_args = [_literal_as_binds(c) for c in args] + parsed_args = [_literal_as_binds(c, self.name) for c in args] self.packagenames = [] self._bind = kwargs.get('bind', None) self.clause_expr = ClauseList( @@ -528,10 +565,10 @@ class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" def __init__(self, *args, **kwargs): - args = [_literal_as_binds(c) for c in args] + args = [_literal_as_binds(c, self.name) for c in args] kwargs.setdefault('type_', _type_from_args(args)) kwargs['_parsed_args'] = args - GenericFunction.__init__(self, *args, **kwargs) + super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) class coalesce(ReturnTypeFromArgs): @@ -579,7 +616,7 @@ class count(GenericFunction): def __init__(self, expression=None, **kwargs): if expression is None: expression = literal_column('*') - GenericFunction.__init__(self, expression, **kwargs) + super(count, self).__init__(expression, **kwargs) class current_date(AnsiFunction): @@ -616,3 +653,150 @@ class sysdate(AnsiFunction): class user(AnsiFunction): type = sqltypes.String + + +class array_agg(GenericFunction): + """support for the ARRAY_AGG function. + + The ``func.array_agg(expr)`` construct returns an expression of + type :class:`.Array`. + + e.g.:: + + stmt = select([func.array_agg(table.c.values)[2:5]]) + + .. versionadded:: 1.1 + + .. seealso:: + + :func:`.postgresql.array_agg` - PostgreSQL-specific version that + returns :class:`.ARRAY`, which has PG-specific operators added. + + """ + + type = sqltypes.Array + + def __init__(self, *args, **kwargs): + args = [_literal_as_binds(c) for c in args] + kwargs.setdefault('type_', self.type(_type_from_args(args))) + kwargs['_parsed_args'] = args + super(array_agg, self).__init__(*args, **kwargs) + + +class OrderedSetAgg(GenericFunction): + """Define a function where the return type is based on the sort + expression type as defined by the expression passed to the + :meth:`.FunctionElement.within_group` method.""" + + array_for_multi_clause = False + + def within_group_type(self, within_group): + func_clauses = self.clause_expr.element + order_by = sqlutil.unwrap_order_by(within_group.order_by) + if self.array_for_multi_clause and len(func_clauses.clauses) > 1: + return sqltypes.Array(order_by[0].type) + else: + return order_by[0].type + + +class mode(OrderedSetAgg): + """implement the ``mode`` ordered-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is the same as the sort expression. + + .. versionadded:: 1.1 + + """ + + +class percentile_cont(OrderedSetAgg): + """implement the ``percentile_cont`` ordered-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is the same as the sort expression, + or if the arguments are an array, an :class:`.Array` of the sort + expression's type. + + .. versionadded:: 1.1 + + """ + + array_for_multi_clause = True + + +class percentile_disc(OrderedSetAgg): + """implement the ``percentile_disc`` ordered-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is the same as the sort expression, + or if the arguments are an array, an :class:`.Array` of the sort + expression's type. + + .. versionadded:: 1.1 + + """ + + array_for_multi_clause = True + + +class rank(GenericFunction): + """Implement the ``rank`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Integer`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Integer() + + +class dense_rank(GenericFunction): + """Implement the ``dense_rank`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Integer`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Integer() + + +class percent_rank(GenericFunction): + """Implement the ``percent_rank`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Numeric`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Numeric() + + +class cume_dist(GenericFunction): + """Implement the ``cume_dist`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Numeric`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Numeric() diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 51f162c98..da3576466 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -214,10 +214,13 @@ class custom_op(object): """ __name__ = 'custom_op' - def __init__(self, opstring, precedence=0, is_comparison=False): + def __init__( + self, opstring, precedence=0, is_comparison=False, + natural_self_precedent=False): self.opstring = opstring self.precedence = precedence self.is_comparison = is_comparison + self.natural_self_precedent = natural_self_precedent def __eq__(self, other): return isinstance(other, custom_op) and \ @@ -597,6 +600,14 @@ class ColumnOperators(Operators): """ return self.reverse_operate(div, other) + def __rmod__(self, other): + """Implement the ``%`` operator in reverse. + + See :meth:`.ColumnOperators.__mod__`. + + """ + return self.reverse_operate(mod, other) + def between(self, cleft, cright, symmetric=False): """Produce a :func:`~.expression.between` clause against the parent object, given the lower and upper range. @@ -611,6 +622,24 @@ class ColumnOperators(Operators): """ return self.operate(distinct_op) + def any_(self): + """Produce a :func:`~.expression.any_` clause against the + parent object. + + .. versionadded:: 1.1 + + """ + return self.operate(any_op) + + def all_(self): + """Produce a :func:`~.expression.all_` clause against the + parent object. + + .. versionadded:: 1.1 + + """ + return self.operate(all_op) + def __add__(self, other): """Implement the ``+`` operator. @@ -744,6 +773,14 @@ def distinct_op(a): return a.distinct() +def any_op(a): + return a.any_() + + +def all_op(a): + return a.all_() + + def startswith_op(a, b, escape=None): return a.startswith(b, escape=escape) @@ -818,6 +855,28 @@ def is_ordering_modifier(op): return op in (asc_op, desc_op, nullsfirst_op, nullslast_op) + +def is_natural_self_precedent(op): + return op in _natural_self_precedent or \ + isinstance(op, custom_op) and op.natural_self_precedent + +_mirror = { + gt: lt, + ge: le, + lt: gt, + le: ge +} + + +def mirror(op): + """rotate a comparison operator 180 degrees. + + Note this is not the same as negation. + + """ + return _mirror.get(op, op) + + _associative = _commutative.union([concat_op, and_, or_]) _natural_self_precedent = _associative.union([getitem]) @@ -826,12 +885,15 @@ parenthesize (a op b). """ + _asbool = util.symbol('_asbool', canonical=-10) _smallest = util.symbol('_smallest', canonical=-100) _largest = util.symbol('_largest', canonical=100) _PRECEDENCE = { from_: 15, + any_op: 15, + all_op: 15, getitem: 15, mul: 8, truediv: 8, @@ -885,7 +947,7 @@ _PRECEDENCE = { def is_precedent(operator, against): - if operator is against and operator in _natural_self_precedent: + if operator is against and is_natural_self_precedent(operator): return False else: return (_PRECEDENCE.get(operator, diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index a8989627d..42dbe72b2 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -572,18 +572,9 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def _init_collections(self): pass - @util.memoized_property + @property def _autoincrement_column(self): - for col in self.primary_key: - if (col.autoincrement and col.type._type_affinity is not None and - issubclass(col.type._type_affinity, - type_api.INTEGERTYPE._type_affinity) and - (not col.foreign_keys or - col.autoincrement == 'ignore_fk') and - isinstance(col.default, (type(None), Sequence)) and - (col.server_default is None or - col.server_default.reflected)): - return col + return self.primary_key._autoincrement_column @property def key(self): @@ -913,17 +904,40 @@ class Column(SchemaItem, ColumnClause): argument is available such as ``server_default``, ``default`` and ``unique``. - :param autoincrement: This flag may be set to ``False`` to - indicate an integer primary key column that should not be - considered to be the "autoincrement" column, that is - the integer primary key column which generates values - implicitly upon INSERT and whose value is usually returned - via the DBAPI cursor.lastrowid attribute. It defaults - to ``True`` to satisfy the common use case of a table - with a single integer primary key column. If the table - has a composite primary key consisting of more than one - integer column, set this flag to True only on the - column that should be considered "autoincrement". + :param autoincrement: Set up "auto increment" semantics for an integer + primary key column. The default value is the string ``"auto"`` + which indicates that a single-column primary key that is of + an INTEGER type with no stated client-side or python-side defaults + should receive auto increment semantics automatically; + all other varieties of primary key columns will not. This + includes that :term:`DDL` such as Postgresql SERIAL or MySQL + AUTO_INCREMENT will be emitted for this column during a table + create, as well as that the column is assumed to generate new + integer primary key values when an INSERT statement invokes which + will be retrieved by the dialect. + + The flag may be set to ``True`` to indicate that a column which + is part of a composite (e.g. multi-column) primary key should + have autoincrement semantics, though note that only one column + within a primary key may have this setting. It can also + be set to ``True`` to indicate autoincrement semantics on a + column that has a client-side or server-side default configured, + however note that not all dialects can accommodate all styles + of default as an "autoincrement". It can also be + set to ``False`` on a single-column primary key that has a + datatype of INTEGER in order to disable auto increment semantics + for that column. + + .. versionchanged:: 1.1 The autoincrement flag now defaults to + ``"auto"`` which indicates autoincrement semantics by default + for single-column integer primary keys only; for composite + (multi-column) primary keys, autoincrement is never implicitly + enabled; as always, ``autoincrement=True`` will allow for + at most one of those columns to be an "autoincrement" column. + ``autoincrement=True`` may also be set on a :class:`.Column` + that has an explicit client-side or server-side default, + subject to limitations of the backend database and dialect. + The setting *only* has an effect for columns which are: @@ -940,11 +954,8 @@ class Column(SchemaItem, ColumnClause): primary_key=True, autoincrement='ignore_fk') It is typically not desirable to have "autoincrement" enabled - on such a column as its value intends to mirror that of a - primary key column elsewhere. - - * have no server side or client side defaults (with the exception - of Postgresql SERIAL). + on a column that refers to another via foreign key, as such a column + is required to refer to a value that originates from elsewhere. The setting has these two effects on columns that meet the above criteria: @@ -961,20 +972,15 @@ class Column(SchemaItem, ColumnClause): :ref:`sqlite_autoincrement` - * The column will be considered to be available as - cursor.lastrowid or equivalent, for those dialects which - "post fetch" newly inserted identifiers after a row has - been inserted (SQLite, MySQL, MS-SQL). It does not have - any effect in this regard for databases that use sequences - to generate primary key identifiers (i.e. Firebird, Postgresql, - Oracle). - - .. versionchanged:: 0.7.4 - ``autoincrement`` accepts a special value ``'ignore_fk'`` - to indicate that autoincrementing status regardless of foreign - key references. This applies to certain composite foreign key - setups, such as the one demonstrated in the ORM documentation - at :ref:`post_update`. + * The column will be considered to be available using an + "autoincrement" method specific to the backend database, such + as calling upon ``cursor.lastrowid``, using RETURNING in an + INSERT statement to get at a sequence-generated value, or using + special functions such as "SELECT scope_identity()". + These methods are highly specific to the DBAPIs and databases in + use and vary greatly, so care should be taken when associating + ``autoincrement=True`` with a custom default generation function. + :param default: A scalar, Python callable, or :class:`.ColumnElement` expression representing the @@ -984,8 +990,12 @@ class Column(SchemaItem, ColumnClause): a positional argument; see that class for full detail on the structure of the argument. - Contrast this argument to ``server_default`` which creates a - default generator on the database side. + Contrast this argument to :paramref:`.Column.server_default` + which creates a default generator on the database side. + + .. seealso:: + + :ref:`metadata_defaults_toplevel` :param doc: optional String that can be used by the ORM or similar to document attributes. This attribute does not render SQL @@ -1051,6 +1061,10 @@ class Column(SchemaItem, ColumnClause): construct does not specify any DDL and the implementation is left to the database, such as via a trigger. + .. seealso:: + + :ref:`server_defaults` + :param server_onupdate: A :class:`.FetchedValue` instance representing a database-side default generation function. This indicates to SQLAlchemy that a newly generated value will be @@ -1128,7 +1142,7 @@ class Column(SchemaItem, ColumnClause): self.system = kwargs.pop('system', False) self.doc = kwargs.pop('doc', None) self.onupdate = kwargs.pop('onupdate', None) - self.autoincrement = kwargs.pop('autoincrement', True) + self.autoincrement = kwargs.pop('autoincrement', "auto") self.constraints = set() self.foreign_keys = set() @@ -1263,12 +1277,12 @@ class Column(SchemaItem, ColumnClause): if self.primary_key: table.primary_key._replace(self) - Table._autoincrement_column._reset(table) elif self.key in table.primary_key: raise exc.ArgumentError( "Trying to redefine primary-key column '%s' as a " "non-primary-key column on table '%s'" % ( self.key, table.fullname)) + self.table = table if self.index: @@ -1981,13 +1995,14 @@ class ColumnDefault(DefaultGenerator): try: argspec = util.get_callable_argspec(fn, no_self=True) except TypeError: - return lambda ctx: fn() + return util.wrap_callable(lambda ctx: fn(), fn) defaulted = argspec[3] is not None and len(argspec[3]) or 0 positionals = len(argspec[0]) - defaulted if positionals == 0: - return lambda ctx: fn() + return util.wrap_callable(lambda ctx: fn(), fn) + elif positionals == 1: return fn else: @@ -2040,8 +2055,9 @@ class Sequence(DefaultGenerator): is_sequence = True - def __init__(self, name, start=None, increment=None, schema=None, - optional=False, quote=None, metadata=None, + def __init__(self, name, start=None, increment=None, minvalue=None, + maxvalue=None, nominvalue=None, nomaxvalue=None, cycle=None, + schema=None, optional=False, quote=None, metadata=None, quote_schema=None, for_update=False): """Construct a :class:`.Sequence` object. @@ -2057,6 +2073,53 @@ class Sequence(DefaultGenerator): the database as the value of the "INCREMENT BY" clause. If ``None``, the clause is omitted, which on most platforms indicates an increment of 1. + :param minvalue: the minimum value of the sequence. This + value is used when the CREATE SEQUENCE command is emitted to + the database as the value of the "MINVALUE" clause. If ``None``, + the clause is omitted, which on most platforms indicates a + minvalue of 1 and -2^63-1 for ascending and descending sequences, + respectively. + + .. versionadded:: 1.0.7 + + :param maxvalue: the maximum value of the sequence. This + value is used when the CREATE SEQUENCE command is emitted to + the database as the value of the "MAXVALUE" clause. If ``None``, + the clause is omitted, which on most platforms indicates a + maxvalue of 2^63-1 and -1 for ascending and descending sequences, + respectively. + + .. versionadded:: 1.0.7 + + :param nominvalue: no minimum value of the sequence. This + value is used when the CREATE SEQUENCE command is emitted to + the database as the value of the "NO MINVALUE" clause. If ``None``, + the clause is omitted, which on most platforms indicates a + minvalue of 1 and -2^63-1 for ascending and descending sequences, + respectively. + + .. versionadded:: 1.0.7 + + :param nomaxvalue: no maximum value of the sequence. This + value is used when the CREATE SEQUENCE command is emitted to + the database as the value of the "NO MAXVALUE" clause. If ``None``, + the clause is omitted, which on most platforms indicates a + maxvalue of 2^63-1 and -1 for ascending and descending sequences, + respectively. + + .. versionadded:: 1.0.7 + + :param cycle: allows the sequence to wrap around when the maxvalue + or minvalue has been reached by an ascending or descending sequence + respectively. This value is used when the CREATE SEQUENCE command + is emitted to the database as the "CYCLE" clause. If the limit is + reached, the next number generated will be the minvalue or maxvalue, + respectively. If cycle=False (the default) any calls to nextval + after the sequence has reached its maximum value will return an + error. + + .. versionadded:: 1.0.7 + :param schema: Optional schema name for the sequence, if located in a schema other than the default. :param optional: boolean value, when ``True``, indicates that this @@ -2101,6 +2164,11 @@ class Sequence(DefaultGenerator): self.name = quoted_name(name, quote) self.start = start self.increment = increment + self.minvalue = minvalue + self.maxvalue = maxvalue + self.nominvalue = nominvalue + self.nomaxvalue = nomaxvalue + self.cycle = cycle self.optional = optional if metadata is not None and schema is None and metadata.schema: self.schema = schema = metadata.schema @@ -2972,11 +3040,77 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): self.columns.extend(columns) + PrimaryKeyConstraint._autoincrement_column._reset(self) self._set_parent_with_dispatch(self.table) def _replace(self, col): + PrimaryKeyConstraint._autoincrement_column._reset(self) self.columns.replace(col) + @property + def columns_autoinc_first(self): + autoinc = self._autoincrement_column + + if autoinc is not None: + return [autoinc] + [c for c in self.columns if c is not autoinc] + else: + return list(self.columns) + + @util.memoized_property + def _autoincrement_column(self): + + def _validate_autoinc(col, autoinc_true): + if col.type._type_affinity is None or not issubclass( + col.type._type_affinity, + type_api.INTEGERTYPE._type_affinity): + if autoinc_true: + raise exc.ArgumentError( + "Column type %s on column '%s' is not " + "compatible with autoincrement=True" % ( + col.type, + col + )) + else: + return False + elif not isinstance(col.default, (type(None), Sequence)) and \ + not autoinc_true: + return False + elif col.server_default is not None and not autoinc_true: + return False + elif ( + col.foreign_keys and col.autoincrement + not in (True, 'ignore_fk')): + return False + return True + + if len(self.columns) == 1: + col = list(self.columns)[0] + + if col.autoincrement is True: + _validate_autoinc(col, True) + return col + elif ( + col.autoincrement in ('auto', 'ignore_fk') and + _validate_autoinc(col, False) + ): + return col + + else: + autoinc = None + for col in self.columns: + if col.autoincrement is True: + _validate_autoinc(col, True) + if autoinc is not None: + raise exc.ArgumentError( + "Only one Column may be marked " + "autoincrement=True, found both %s and %s." % + (col.name, autoinc.name) + ) + else: + autoinc = col + + return autoinc + class UniqueConstraint(ColumnCollectionConstraint): """A table-level UNIQUE constraint. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 245c54817..73341053d 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -224,7 +224,7 @@ class HasSuffixes(object): stmt = select([col1, col2]).cte().suffix_with( "cycle empno set y_cycle to 1 default 0", dialect="oracle") - Multiple prefixes can be specified by multiple calls + Multiple suffixes can be specified by multiple calls to :meth:`.suffix_with`. :param \*expr: textual or :class:`.ClauseElement` construct which @@ -1101,6 +1101,14 @@ class Alias(FromClause): or 'anon')) self.name = name + def self_group(self, target=None): + if isinstance(target, CompoundSelect) and \ + isinstance(self.original, Select) and \ + self.original._needs_parens_for_grouping(): + return FromGrouping(self) + + return super(Alias, self).self_group(target) + @property def description(self): if util.py3k: @@ -3208,6 +3216,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return None return None + def _needs_parens_for_grouping(self): + return ( + self._limit_clause is not None or + self._offset_clause is not None or + bool(self._order_by_clause.clauses) + ) + def self_group(self, against=None): """return a 'grouping' construct as per the ClauseElement specification. @@ -3217,7 +3232,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): expressions and should not require explicit use. """ - if isinstance(against, CompoundSelect): + if isinstance(against, CompoundSelect) and \ + not self._needs_parens_for_grouping(): return self return FromGrouping(self) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 7e2e601e2..4abb9b15a 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -13,10 +13,11 @@ import datetime as dt import codecs from .type_api import TypeEngine, TypeDecorator, to_instance -from .elements import quoted_name, type_coerce, _defer_name +from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name from .. import exc, util, processors from .base import _bind_or_error, SchemaEventTarget from . import operators +from .. import inspection from .. import event from ..util import pickle import decimal @@ -68,7 +69,39 @@ class Concatenable(object): )): return operators.concat_op, self.expr.type else: - return op, self.expr.type + return super(Concatenable.Comparator, self)._adapt_expression( + op, other_comparator) + + comparator_factory = Comparator + + +class Indexable(object): + """A mixin that marks a type as supporting indexing operations, + such as array or JSON structures. + + + .. versionadded:: 1.1.0 + + + """ + + zero_indexes = False + """if True, Python zero-based indexes should be interpreted as one-based + on the SQL expression side.""" + + class Comparator(TypeEngine.Comparator): + + def _setup_getitem(self, index): + raise NotImplementedError() + + def __getitem__(self, index): + operator, adjusted_right_expr, result_type = \ + self._setup_getitem(index) + return self.operate( + operator, + adjusted_right_expr, + result_type=result_type + ) comparator_factory = Comparator @@ -215,9 +248,6 @@ class String(Concatenable, TypeEngine): self.convert_unicode != 'force_nocheck' ) if needs_convert: - to_unicode = processors.to_unicode_processor_factory( - dialect.encoding, self.unicode_error) - if needs_isinstance: return processors.to_conditional_unicode_processor_factory( dialect.encoding, self.unicode_error) @@ -1466,6 +1496,246 @@ class Interval(_DateAffinity, TypeDecorator): return self.impl.coerce_compared_value(op, value) +class Array(Indexable, Concatenable, TypeEngine): + """Represent a SQL Array type. + + .. note:: This type serves as the basis for all ARRAY operations. + However, currently **only the Postgresql backend has support + for SQL arrays in SQLAlchemy**. It is recommended to use the + :class:`.postgresql.ARRAY` type directly when using ARRAY types + with PostgreSQL, as it provides additional operators specific + to that backend. + + :class:`.Array` is part of the Core in support of various SQL standard + functions such as :class:`.array_agg` which explicitly involve arrays; + however, with the exception of the PostgreSQL backend and possibly + some third-party dialects, no other SQLAlchemy built-in dialect has + support for this type. + + An :class:`.Array` type is constructed given the "type" + of element:: + + mytable = Table("mytable", metadata, + Column("data", Array(Integer)) + ) + + The above type represents an N-dimensional array, + meaning a supporting backend such as Postgresql will interpret values + with any number of dimensions automatically. To produce an INSERT + construct that passes in a 1-dimensional array of integers:: + + connection.execute( + mytable.insert(), + data=[1,2,3] + ) + + The :class:`.Array` type can be constructed given a fixed number + of dimensions:: + + mytable = Table("mytable", metadata, + Column("data", Array(Integer, dimensions=2)) + ) + + Sending a number of dimensions is optional, but recommended if the + datatype is to represent arrays of more than one dimension. This number + is used: + + * When emitting the type declaration itself to the database, e.g. + ``INTEGER[][]`` + + * When translating Python values to database values, and vice versa, e.g. + an ARRAY of :class:`.Unicode` objects uses this number to efficiently + access the string values inside of array structures without resorting + to per-row type inspection + + * When used with the Python ``getitem`` accessor, the number of dimensions + serves to define the kind of type that the ``[]`` operator should + return, e.g. for an ARRAY of INTEGER with two dimensions:: + + >>> expr = table.c.column[5] # returns ARRAY(Integer, dimensions=1) + >>> expr = expr[6] # returns Integer + + For 1-dimensional arrays, an :class:`.Array` instance with no + dimension parameter will generally assume single-dimensional behaviors. + + SQL expressions of type :class:`.Array` have support for "index" and + "slice" behavior. The Python ``[]`` operator works normally here, given + integer indexes or slices. Arrays default to 1-based indexing. + The operator produces binary expression + constructs which will produce the appropriate SQL, both for + SELECT statements:: + + select([mytable.c.data[5], mytable.c.data[2:7]]) + + as well as UPDATE statements when the :meth:`.Update.values` method + is used:: + + mytable.update().values({ + mytable.c.data[5]: 7, + mytable.c.data[2:7]: [1, 2, 3] + }) + + The :class:`.Array` type also provides for the operators + :meth:`.Array.Comparator.any` and :meth:`.Array.Comparator.all`. + The PostgreSQL-specific version of :class:`.Array` also provides additional + operators. + + .. versionadded:: 1.1.0 + + .. seealso:: + + :class:`.postgresql.ARRAY` + + """ + __visit_name__ = 'ARRAY' + + class Comparator(Indexable.Comparator, Concatenable.Comparator): + + """Define comparison operations for :class:`.Array`. + + More operators are available on the dialect-specific form + of this type. See :class:`.postgresql.ARRAY.Comparator`. + + """ + + def _setup_getitem(self, index): + if isinstance(index, slice): + return_type = self.type + elif self.type.dimensions is None or self.type.dimensions == 1: + return_type = self.type.item_type + else: + adapt_kw = {'dimensions': self.type.dimensions - 1} + return_type = self.type.adapt(self.type.__class__, **adapt_kw) + + return operators.getitem, index, return_type + + @util.dependencies("sqlalchemy.sql.elements") + def any(self, elements, other, operator=None): + """Return ``other operator ANY (array)`` clause. + + Argument places are switched, because ANY requires array + expression to be on the right hand-side. + + E.g.:: + + from sqlalchemy.sql import operators + + conn.execute( + select([table.c.data]).where( + table.c.data.any(7, operator=operators.lt) + ) + ) + + :param other: expression to be compared + :param operator: an operator object from the + :mod:`sqlalchemy.sql.operators` + package, defaults to :func:`.operators.eq`. + + .. seealso:: + + :func:`.sql.expression.any_` + + :meth:`.Array.Comparator.all` + + """ + operator = operator if operator else operators.eq + return operator( + elements._literal_as_binds(other), + elements.CollectionAggregate._create_any(self.expr) + ) + + @util.dependencies("sqlalchemy.sql.elements") + def all(self, elements, other, operator=None): + """Return ``other operator ALL (array)`` clause. + + Argument places are switched, because ALL requires array + expression to be on the right hand-side. + + E.g.:: + + from sqlalchemy.sql import operators + + conn.execute( + select([table.c.data]).where( + table.c.data.all(7, operator=operators.lt) + ) + ) + + :param other: expression to be compared + :param operator: an operator object from the + :mod:`sqlalchemy.sql.operators` + package, defaults to :func:`.operators.eq`. + + .. seealso:: + + :func:`.sql.expression.all_` + + :meth:`.Array.Comparator.any` + + """ + operator = operator if operator else operators.eq + return operator( + elements._literal_as_binds(other), + elements.CollectionAggregate._create_all(self.expr) + ) + + comparator_factory = Comparator + + def __init__(self, item_type, as_tuple=False, dimensions=None, + zero_indexes=False): + """Construct an :class:`.Array`. + + E.g.:: + + Column('myarray', Array(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that + dimensionality is irrelevant here, so multi-dimensional arrays like + ``INTEGER[][]``, are constructed as ``Array(Integer)``, not as + ``Array(Array(Integer))`` or such. + + :param as_tuple=False: Specify whether return results + should be converted to tuples from lists. This parameter is + not generally needed as a Python list corresponds well + to a SQL array. + + :param dimensions: if non-None, the ARRAY will assume a fixed + number of dimensions. This impacts how the array is declared + on the database, how it goes about interpreting Python and + result values, as well as how expression behavior in conjunction + with the "getitem" operator works. See the description at + :class:`.Array` for additional detail. + + :param zero_indexes=False: when True, index values will be converted + between Python zero-based and SQL one-based indexes, e.g. + a value of one will be added to all index values before passing + to the database. + + """ + if isinstance(item_type, Array): + raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype") + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.as_tuple = as_tuple + self.dimensions = dimensions + self.zero_indexes = zero_indexes + + @property + def hashable(self): + return self.as_tuple + + @property + def python_type(self): + return list + + def compare_values(self, x, y): + return x == y + + class REAL(Float): """The SQL REAL type.""" @@ -1648,6 +1918,8 @@ class NullType(TypeEngine): _isnull = True + hashable = False + def literal_processor(self, dialect): def process(value): return "NULL" @@ -1704,6 +1976,26 @@ else: _type_map[unicode] = Unicode() _type_map[str] = String() +_type_map_get = _type_map.get + + +def _resolve_value_to_type(value): + _result_type = _type_map_get(type(value), False) + if _result_type is False: + # use inspect() to detect SQLAlchemy built-in + # objects. + insp = inspection.inspect(value, False) + if ( + insp is not None and + # foil mock.Mock() and other impostors by ensuring + # the inspection target itself self-inspects + insp.__class__ in inspection._registrars + ): + raise exc.ArgumentError( + "Object %r is not legal as a SQL literal value" % value) + return NULLTYPE + else: + return _result_type # back-assign to type_api from . import type_api @@ -1712,6 +2004,6 @@ type_api.STRINGTYPE = STRINGTYPE type_api.INTEGERTYPE = INTEGERTYPE type_api.NULLTYPE = NULLTYPE type_api.MATCHTYPE = MATCHTYPE -type_api._type_map = _type_map - +type_api.INDEXABLE = Indexable +type_api._resolve_value_to_type = _resolve_value_to_type TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index a55eed981..c367bc73e 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -13,6 +13,7 @@ from .. import exc, util from . import operators from .visitors import Visitable, VisitableType +from .base import SchemaEventTarget # these are back-assigned by sqltypes. BOOLEANTYPE = None @@ -20,6 +21,8 @@ INTEGERTYPE = None NULLTYPE = None STRINGTYPE = None MATCHTYPE = None +INDEXABLE = None +_resolve_value_to_type = None class TypeEngine(Visitable): @@ -90,7 +93,7 @@ class TypeEngine(Visitable): boolean comparison or special SQL keywords like MATCH or BETWEEN. """ - return op, other_comparator.type + return op, self.type def __reduce__(self): return _reconstitute_comparator, (self.expr, ) @@ -128,6 +131,76 @@ class TypeEngine(Visitable): """ + should_evaluate_none = False + """If True, the Python constant ``None`` is considered to be handled + explicitly by this type. + + The ORM uses this flag to indicate that a positive value of ``None`` + is passed to the column in an INSERT statement, rather than omitting + the column from the INSERT statement which has the effect of firing + off column-level defaults. It also allows types which have special + behavior for Python None, such as a JSON type, to indicate that + they'd like to handle the None value explicitly. + + To set this flag on an existing type, use the + :meth:`.TypeEngine.evaluates_none` method. + + .. seealso:: + + :meth:`.TypeEngine.evaluates_none` + + .. versionadded:: 1.1 + + + """ + + def evaluates_none(self): + """Return a copy of this type which has the :attr:`.should_evaluate_none` + flag set to True. + + E.g.:: + + Table( + 'some_table', metadata, + Column( + String(50).evaluates_none(), + nullable=True, + server_default='no value') + ) + + The ORM uses this flag to indicate that a positive value of ``None`` + is passed to the column in an INSERT statement, rather than omitting + the column from the INSERT statement which has the effect of firing + off column-level defaults. It also allows for types which have + special behavior associated with the Python None value to indicate + that the value doesn't necessarily translate into SQL NULL; a + prime example of this is a JSON type which may wish to persist the + JSON value ``'null'``. + + In all cases, the actual NULL SQL value can be always be + persisted in any column by using + the :obj:`~.expression.null` SQL construct in an INSERT statement + or associated with an ORM-mapped attribute. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`session_forcing_null` - in the ORM documentation + + :paramref:`.postgresql.JSON.none_as_null` - Postgresql JSON + interaction with this flag. + + :attr:`.TypeEngine.should_evaluate_none` - class-level flag + + """ + typ = self.copy() + typ.should_evaluate_none = True + return typ + + def copy(self, **kw): + return self.adapt(self.__class__) + def compare_against_backend(self, dialect, conn_type): """Compare this type against the given backend type. @@ -440,7 +513,7 @@ class TypeEngine(Visitable): end-user customization of this behavior. """ - _coerced_type = _type_map.get(type(value), NULLTYPE) + _coerced_type = _resolve_value_to_type(value) if _coerced_type is NULLTYPE or _coerced_type._type_affinity \ is self._type_affinity: return self @@ -577,7 +650,7 @@ class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)): return self -class TypeDecorator(TypeEngine): +class TypeDecorator(SchemaEventTarget, TypeEngine): """Allows the creation of types which add additional functionality to an existing type. @@ -602,7 +675,7 @@ class TypeDecorator(TypeEngine): def process_result_value(self, value, dialect): return value[7:] - def copy(self): + def copy(self, **kw): return MyType(self.impl.length) The class-level "impl" attribute is required, and can reference any @@ -656,6 +729,26 @@ class TypeDecorator(TypeEngine): else: return self + .. warning:: + + Note that the **behavior of coerce_compared_value is not inherited + by default from that of the base type**. + If the :class:`.TypeDecorator` is augmenting a + type that requires special logic for certain types of operators, + this method **must** be overridden. A key example is when decorating + the :class:`.postgresql.JSON` and :class:`.postgresql.JSONB` types; + the default rules of :meth:`.TypeEngine.coerce_compared_value` should + be used in order to deal with operators like index operations:: + + class MyJsonType(TypeDecorator): + impl = postgresql.JSON + + def coerce_compared_value(self, op, value): + return self.impl.coerce_compared_value(op, value) + + Without the above step, index operations such as ``mycol['foo']`` + will cause the index value ``'foo'`` to be JSON encoded. + """ __visit_name__ = "type_decorator" @@ -757,6 +850,18 @@ class TypeDecorator(TypeEngine): """ return self.impl._type_affinity + def _set_parent(self, column): + """Support SchemaEentTarget""" + + if isinstance(self.impl, SchemaEventTarget): + self.impl._set_parent(column) + + def _set_parent_with_dispatch(self, parent): + """Support SchemaEentTarget""" + + if isinstance(self.impl, SchemaEventTarget): + self.impl._set_parent_with_dispatch(parent) + def type_engine(self, dialect): """Return a dialect-specific :class:`.TypeEngine` instance for this :class:`.TypeDecorator`. @@ -1031,7 +1136,7 @@ class TypeDecorator(TypeEngine): """ return self - def copy(self): + def copy(self, **kw): """Produce a copy of this :class:`.TypeDecorator` instance. This is a shallow copy and is provided to fulfill part of diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 8f502fc86..f5aa9f228 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -154,6 +154,7 @@ def unwrap_order_by(clause): without DESC/ASC/NULLS FIRST/NULLS LAST""" cols = util.column_set() + result = [] stack = deque([clause]) while stack: t = stack.popleft() @@ -166,11 +167,13 @@ def unwrap_order_by(clause): t = t.element if isinstance(t, (_textual_label_reference)): continue - cols.add(t) + if t not in cols: + cols.add(t) + result.append(t) else: for c in t.get_children(): stack.append(c) - return cols + return result def clause_is_present(clause, search): @@ -200,6 +203,21 @@ def surface_selectables(clause): stack.append(elem.element) +def surface_column_elements(clause): + """traverse and yield only outer-exposed column elements, such as would + be addressable in the WHERE clause of a SELECT if this element were + in the columns clause.""" + + stack = deque([clause]) + while stack: + elem = stack.popleft() + yield elem + for sub in elem.get_children(): + if isinstance(sub, FromGrouping): + continue + stack.append(sub) + + def selectables_overlap(left, right): """Return True if left/right have some overlapping selectable""" @@ -433,7 +451,6 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, return pairs - class ClauseAdapter(visitors.ReplacingCloningVisitor): """Clones and modifies clauses based on column correspondence. diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 7482e32a1..bd6377eb7 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -21,7 +21,8 @@ def against(*queries): from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \ assert_raises_message, AssertsCompiledSQL, ComparesTables, \ - AssertsExecutionResults, expect_deprecated, expect_warnings + AssertsExecutionResults, expect_deprecated, expect_warnings, \ + in_, not_in_ from .util import run_as_contextmanager, rowset, fail, \ provide_metadata, adict, force_drop_names, \ diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 01fa0b8a9..63667654d 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -121,7 +121,7 @@ def uses_deprecated(*messages): def _expect_warnings(exc_cls, messages, regex=True, assert_=True): if regex: - filters = [re.compile(msg, re.I) for msg in messages] + filters = [re.compile(msg, re.I | re.S) for msg in messages] else: filters = messages @@ -229,6 +229,16 @@ def is_not_(a, b, msg=None): assert a is not b, msg or "%r is %r" % (a, b) +def in_(a, b, msg=None): + """Assert a in b, with repr messaging on failure.""" + assert a in b, msg or "%r not in %r" % (a, b) + + +def not_in_(a, b, msg=None): + """Assert a in not b, with repr messaging on failure.""" + assert a not in b, msg or "%r is in %r" % (a, b) + + def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 243493607..39d078985 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -13,6 +13,7 @@ import contextlib from .. import event from sqlalchemy.schema import _DDLCompiles from sqlalchemy.engine.util import _distill_params +from sqlalchemy.engine import url class AssertRule(object): @@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params=None): + def __init__(self, statement, params=None, dialect='default'): self.statement = statement self.params = params + self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - return DefaultDialect() + if self.dialect == 'default': + return DefaultDialect() + else: + # ugh + if self.dialect == 'postgresql': + params = {'implicit_returning': True} + else: + params = {} + return url.URL(self.dialect).get_dialect()(**params) def _received_statement(self, execute_observed): """reconstruct the statement and params in terms @@ -159,7 +169,7 @@ class CompiledSQL(SQLMatchRule): 'Testing for compiled statement %r partial params %r, ' 'received %%(received_statement)r with params ' '%%(received_parameters)r' % ( - self.statement, expected_params + self.statement.replace('%', '%%'), expected_params ) ) @@ -170,6 +180,7 @@ class RegexSQL(CompiledSQL): self.regex = re.compile(regex) self.orig_regex = regex self.params = params + self.dialect = 'default' def _failure_message(self, expected_params): return ( diff --git a/lib/sqlalchemy/testing/distutils_run.py b/lib/sqlalchemy/testing/distutils_run.py deleted file mode 100644 index 38de8872c..000000000 --- a/lib/sqlalchemy/testing/distutils_run.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Quick and easy way to get setup.py test to run py.test without any -custom setuptools/distutils code. - -""" -import unittest -import pytest - - -class TestSuite(unittest.TestCase): - def test_sqlalchemy(self): - pytest.main(["-n", "4", "-q"]) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 972dec3a9..5d7baeb9c 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -12,6 +12,7 @@ from . import config from .. import util import inspect import contextlib +from sqlalchemy.util.compat import inspect_getargspec def skip_if(predicate, reason=None): @@ -295,7 +296,7 @@ class SpecPredicate(Predicate): class LambdaPredicate(Predicate): def __init__(self, lambda_, description=None, args=None, kw=None): - spec = inspect.getargspec(lambda_) + spec = inspect_getargspec(lambda_) if not spec[0]: self.lambda_ = lambda db: lambda_() else: @@ -397,8 +398,8 @@ def closed(): return skip_if(BooleanPredicate(True, "marked as skip")) -def fails(): - return fails_if(BooleanPredicate(True, "expected to fail")) +def fails(reason=None): + return fails_if(BooleanPredicate(True, reason or "expected to fail")) @decorator @@ -407,19 +408,19 @@ def future(fn, *arg): def fails_on(db, reason=None): - return fails_if(SpecPredicate(db), reason) + return fails_if(Predicate.as_predicate(db), reason) def fails_on_everything_except(*dbs): return succeeds_if( OrPredicate([ - SpecPredicate(db) for db in dbs + Predicate.as_predicate(db) for db in dbs ]) ) def skip(db, reason=None): - return skip_if(SpecPredicate(db), reason) + return skip_if(Predicate.as_predicate(db), reason) def only_on(dbs, reason=None): diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index e16bc77c0..5cd0244ef 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -275,12 +275,14 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): def setup(self): self._setup_each_tables() + self._setup_each_classes() self._setup_each_mappers() self._setup_each_inserts() def teardown(self): sa.orm.session.Session.close_all() self._teardown_each_mappers() + self._teardown_each_classes() self._teardown_each_tables() @classmethod @@ -302,6 +304,10 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): if self.run_setup_mappers == 'each': self._with_register_classes(self.setup_mappers) + def _setup_each_classes(self): + if self.run_setup_classes == 'each': + self._with_register_classes(self.setup_classes) + @classmethod def _with_register_classes(cls, fn): """Run a setup method, framing the operation with a Base class @@ -336,6 +342,10 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): if self.run_setup_mappers != 'once': sa.orm.clear_mappers() + def _teardown_each_classes(self): + if self.run_setup_classes != 'once': + self.classes.clear() + @classmethod def setup_classes(cls): pass diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index ef304afa6..6cdec05ad 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -40,7 +40,6 @@ file_config = None logging = None -db_opts = {} include_tags = set() exclude_tags = set() options = None @@ -115,7 +114,6 @@ def memoize_important_follower_config(dict_): """ dict_['memoized_config'] = { - 'db_opts': db_opts, 'include_tags': include_tags, 'exclude_tags': exclude_tags } @@ -127,8 +125,7 @@ def restore_important_follower_config(dict_): This invokes in the follower process. """ - global db_opts, include_tags, exclude_tags - db_opts.update(dict_['memoized_config']['db_opts']) + global include_tags, exclude_tags include_tags.update(dict_['memoized_config']['include_tags']) exclude_tags.update(dict_['memoized_config']['exclude_tags']) @@ -268,7 +265,7 @@ def _engine_uri(options, file_config): for db_url in db_urls: cfg = provision.setup_config( - db_url, db_opts, options, file_config, provision.FOLLOWER_IDENT) + db_url, options, file_config, provision.FOLLOWER_IDENT) if not config._current: cfg.set_as_current(cfg, testing) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 8469a0658..3f9ddae73 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -2,7 +2,7 @@ from sqlalchemy.engine import url as sa_url from sqlalchemy import text from sqlalchemy.util import compat from . import config, engines - +import os FOLLOWER_IDENT = None @@ -46,11 +46,13 @@ def configure_follower(follower_ident): _configure_follower(cfg, follower_ident) -def setup_config(db_url, db_opts, options, file_config, follower_ident): +def setup_config(db_url, options, file_config, follower_ident): if follower_ident: db_url = _follower_url_from_main(db_url, follower_ident) + db_opts = {} _update_db_opts(db_url, db_opts) eng = engines.testing_engine(db_url, db_opts) + _post_configure_engine(db_url, eng, follower_ident) eng.connect().close() cfg = config.Config.register(eng, db_opts, options, file_config) if follower_ident: @@ -105,6 +107,11 @@ def _configure_follower(cfg, ident): @register.init +def _post_configure_engine(url, engine, follower_ident): + pass + + +@register.init def _follower_url_from_main(url, ident): url = sa_url.make_url(url) url.database = ident @@ -125,6 +132,23 @@ def _sqlite_follower_url_from_main(url, ident): return sa_url.make_url("sqlite:///%s.db" % ident) +@_post_configure_engine.for_db("sqlite") +def _sqlite_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + # use file DBs in all cases, memory acts kind of strangely + # as an attached + if not follower_ident: + dbapi_connection.execute( + 'ATTACH DATABASE "test_schema.db" AS test_schema') + else: + dbapi_connection.execute( + 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' + % follower_ident) + + @_create_db.for_db("postgresql") def _pg_create_db(cfg, eng, ident): with eng.connect().execution_options( @@ -175,8 +199,10 @@ def _pg_drop_db(cfg, eng, ident): @_drop_db.for_db("sqlite") def _sqlite_drop_db(cfg, eng, ident): - pass - #os.remove("%s.db" % ident) + if ident: + os.remove("%s_test_schema.db" % ident) + else: + os.remove("%s.db" % ident) @_drop_db.for_db("mysql") diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index e8b3a995f..15bfad831 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -111,6 +111,32 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def parens_in_union_contained_select_w_limit_offset(self): + """Target database must support parenthesized SELECT in UNION + when LIMIT/OFFSET is specifically present. + + E.g. (SELECT ...) UNION (SELECT ..) + + This is known to fail on SQLite. + + """ + return exclusions.open() + + @property + def parens_in_union_contained_select_wo_limit_offset(self): + """Target database must support parenthesized SELECT in UNION + when OFFSET/LIMIT is specifically not present. + + E.g. (SELECT ... LIMIT ..) UNION (SELECT .. OFFSET ..) + + This is known to fail on SQLite. It also fails on Oracle + because without LIMIT/OFFSET, there is currently no step that + creates an additional subquery. + + """ + return exclusions.open() + + @property def boolean_col_expressions(self): """Target database must support boolean expressions as columns""" diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 93b52ad58..257578668 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -71,9 +71,12 @@ def Column(*args, **kw): args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] col = schema.Column(*args, **kw) - if 'test_needs_autoincrement' in test_opts and \ + if test_opts.get('test_needs_autoincrement', False) and \ kw.get('primary_key', False): + if col.default is None and col.server_default is None: + col.autoincrement = True + # allow any test suite to pick up on this col.info['test_needs_autoincrement'] = True diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 3edbdeb8c..288a85973 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -531,12 +531,20 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _test_get_unique_constraints(self, schema=None): + # SQLite dialect needs to parse the names of the constraints + # separately from what it gets from PRAGMA index_list(), and + # then matches them up. so same set of column_names in two + # constraints will confuse it. Perhaps we should no longer + # bother with index_list() here since we have the whole + # CREATE TABLE? uniques = sorted( [ {'name': 'unique_a', 'column_names': ['a']}, {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']}, {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']}, {'name': 'unique_asc_key', 'column_names': ['asc', 'key']}, + {'name': 'i.have.dots', 'column_names': ['b']}, + {'name': 'i have spaces', 'column_names': ['c']}, ], key=operator.itemgetter('name') ) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index d4bf63b55..e7de356b8 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -2,7 +2,7 @@ from .. import fixtures, config from ..assertions import eq_ from sqlalchemy import util -from sqlalchemy import Integer, String, select, func, bindparam +from sqlalchemy import Integer, String, select, func, bindparam, union from sqlalchemy import testing from ..schema import Table, Column @@ -146,7 +146,7 @@ class LimitOffsetTest(fixtures.TablesTest): select([table]).order_by(table.c.id).limit(2).offset(1), [(2, 2, 3), (3, 3, 4)] ) - + @testing.requires.offset def test_limit_offset_nobinds(self): """test that 'literal binds' mode works - no bound params.""" @@ -190,3 +190,123 @@ class LimitOffsetTest(fixtures.TablesTest): [(2, 2, 3), (3, 3, 4)], params={"l": 2, "o": 1} ) + + +class CompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table("some_table", metadata, + Column('id', Integer, primary_key=True), + Column('x', Integer), + Column('y', Integer)) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ] + ) + + def _assert_result(self, select, result, params=()): + eq_( + config.db.execute(select, params).fetchall(), + result + ) + + def test_plain_union(self): + table = self.tables.some_table + s1 = select([table]).where(table.c.id == 2) + s2 = select([table]).where(table.c.id == 3) + + u1 = union(s1, s2) + self._assert_result( + u1.order_by(u1.c.id), + [(2, 2, 3), (3, 3, 4)] + ) + + def test_select_from_plain_union(self): + table = self.tables.some_table + s1 = select([table]).where(table.c.id == 2) + s2 = select([table]).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + self._assert_result( + u1.order_by(u1.c.id), + [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self): + table = self.tables.some_table + s1 = select([table]).where(table.c.id == 2).\ + limit(1).order_by(table.c.id) + s2 = select([table]).where(table.c.id == 3).\ + limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.c.id), + [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self): + table = self.tables.some_table + s1 = select([table]).where(table.c.id == 2).\ + order_by(table.c.id) + s2 = select([table]).where(table.c.id == 3).\ + order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.c.id), + [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self): + table = self.tables.some_table + s1 = select([table]).where(table.c.id == 2).\ + distinct() + s2 = select([table]).where(table.c.id == 3).\ + distinct() + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.c.id), + [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_in_unions_from_alias(self): + table = self.tables.some_table + s1 = select([table]).where(table.c.id == 2).\ + limit(1).order_by(table.c.id) + s2 = select([table]).where(table.c.id == 3).\ + limit(1).order_by(table.c.id) + + # this necessarily has double parens + u1 = union(s1, s2).alias() + self._assert_result( + u1.select().limit(2).order_by(u1.c.id), + [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self): + table = self.tables.some_table + s1 = select([table]).where(table.c.id == 2).\ + limit(1).order_by(table.c.id).alias().select() + s2 = select([table]).where(table.c.id == 3).\ + limit(1).order_by(table.c.id).alias().select() + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.c.id), + [(2, 2, 3), (3, 3, 4)] + ) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 9ab92e90b..d82e683d9 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -16,7 +16,8 @@ __all__ = ['TypeEngine', 'TypeDecorator', 'UserDefinedType', 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer', 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', 'Unicode', - 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'Enum'] + 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'Enum', + 'Indexable', 'Array'] from .sql.type_api import ( adapt_type, @@ -27,6 +28,7 @@ from .sql.type_api import ( UserDefinedType ) from .sql.sqltypes import ( + Array, BIGINT, BINARY, BLOB, @@ -46,6 +48,7 @@ from .sql.sqltypes import ( Enum, FLOAT, Float, + Indexable, INT, INTEGER, Integer, @@ -74,5 +77,4 @@ from .sql.sqltypes import ( UnicodeText, VARBINARY, VARCHAR, - _type_map ) diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index ed968f168..a15ca8efa 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -6,7 +6,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from .compat import callable, cmp, reduce, \ - threading, py3k, py33, py2k, jython, pypy, cpython, win32, \ + threading, py3k, py33, py36, py2k, jython, pypy, cpython, win32, \ pickle, dottedgetter, parse_qsl, namedtuple, next, reraise, \ raise_from_cause, text_type, safe_kwarg, string_types, int_types, \ binary_type, nested, \ @@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \ generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \ safe_reraise,\ get_callable_argspec, only_once, attrsetter, ellipses_string, \ - warn_limited, map_bits, MemoizedSlots, EnsureKWArgType + warn_limited, map_bits, MemoizedSlots, EnsureKWArgType, wrap_callable from .deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation, inject_docstring_text diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 5b6f691f1..25c88c662 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -14,6 +14,7 @@ try: except ImportError: import dummy_threading as threading +py36 = sys.version_info >= (3, 6) py33 = sys.version_info >= (3, 3) py32 = sys.version_info >= (3, 2) py3k = sys.version_info >= (3, 0) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 499515142..11aa9384d 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -426,7 +426,7 @@ def getargspec_init(method): """ try: - return inspect.getargspec(method) + return compat.inspect_getargspec(method) except TypeError: if method is object.__init__: return (['self'], None, None, None) @@ -464,7 +464,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): for i, insp in enumerate(to_inspect): try: (_args, _vargs, vkw, defaults) = \ - inspect.getargspec(insp.__init__) + compat.inspect_getargspec(insp.__init__) except TypeError: continue else: @@ -625,7 +625,7 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, except AttributeError: continue try: - spec = inspect.getargspec(fn) + spec = compat.inspect_getargspec(fn) fn_args = inspect.formatargspec(spec[0]) d_args = inspect.formatargspec(spec[0][1:]) except TypeError: @@ -805,6 +805,8 @@ class MemoizedSlots(object): """ + __slots__ = () + def _fallback_getattr(self, key): raise AttributeError(key) @@ -1017,7 +1019,9 @@ def constructor_copy(obj, cls, *args, **kw): """ names = get_cls_kwargs(cls) - kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__) + kw.update( + (k, obj.__dict__[k]) for k in names.difference(kw) + if k in obj.__dict__) return cls(*args, **kw) @@ -1361,7 +1365,7 @@ class EnsureKWArgType(type): m = re.match(fn_reg, key) if m: fn = clsdict[key] - spec = inspect.getargspec(fn) + spec = compat.inspect_getargspec(fn) if not spec.keywords: clsdict[key] = wrapped = cls._wrap_w_kw(fn) setattr(cls, key, wrapped) @@ -1373,3 +1377,25 @@ class EnsureKWArgType(type): return fn(*arg) return update_wrapper(wrap, fn) + +def wrap_callable(wrapper, fn): + """Augment functools.update_wrapper() to work with objects with + a ``__call__()`` method. + + :param fn: + object with __call__ method + + """ + if hasattr(fn, '__name__'): + return update_wrapper(wrapper, fn) + else: + _f = wrapper + _f.__name__ = fn.__class__.__name__ + _f.__module__ = fn.__module__ + + if hasattr(fn.__call__, '__doc__') and fn.__call__.__doc__: + _f.__doc__ = fn.__call__.__doc__ + elif fn.__doc__: + _f.__doc__ = fn.__doc__ + + return _f |
