diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 69 |
1 files changed, 40 insertions, 29 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 87bc8ea1d..d14f6c27b 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,7 +23,6 @@ from . import evaluator from . import exc as orm_exc from . import loading from . import sync -from .base import _entity_descriptor from .base import state_str from .. import exc as sa_exc from .. import sql @@ -1653,15 +1652,14 @@ class BulkUD(object): def __init__(self, query): self.query = query.enable_eagerloads(False) - self.mapper = self.query._bind_mapper() self._validate_query_state() def _validate_query_state(self): for attr, methname, notset, op in ( - ("_limit", "limit()", None, operator.is_), - ("_offset", "offset()", None, operator.is_), - ("_order_by", "order_by()", False, operator.is_), - ("_group_by", "group_by()", False, operator.is_), + ("_limit_clause", "limit()", None, operator.is_), + ("_offset_clause", "offset()", None, operator.is_), + ("_order_by_clauses", "order_by()", (), operator.eq), + ("_group_by_clauses", "group_by()", (), operator.eq), ("_distinct", "distinct()", False, operator.is_), ( "_from_obj", @@ -1669,6 +1667,12 @@ class BulkUD(object): (), operator.eq, ), + ( + "_legacy_setup_joins", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), ): if not op(getattr(self.query, attr), notset): raise sa_exc.InvalidRequestError( @@ -1710,18 +1714,24 @@ class BulkUD(object): def _do_before_compile(self): raise NotImplementedError() - @util.preload_module("sqlalchemy.orm.query") + @util.preload_module("sqlalchemy.orm.context") def _do_pre(self): - querylib = util.preloaded.orm_query + query_context = util.preloaded.orm_context query = self.query - self.context = querylib.QueryContext(query) + self.compile_state = ( + self.context + ) = compile_state = query._compile_state() + + self.mapper = compile_state._bind_mapper() - if isinstance(query._entities[0], querylib._ColumnEntity): + if isinstance( + compile_state._entities[0], query_context._RawColumnEntity, + ): # check for special case of query(table) tables = set() - for ent in query._entities: - if not isinstance(ent, querylib._ColumnEntity): + for ent in compile_state._entities: + if not isinstance(ent, query_context._RawColumnEntity,): tables.clear() break else: @@ -1736,14 +1746,14 @@ class BulkUD(object): self.primary_table = tables.pop() else: - self.primary_table = query._only_entity_zero( + self.primary_table = compile_state._only_entity_zero( "This operation requires only one Table or " "entity be specified as the target." ).mapper.local_table session = query.session - if query._autoflush: + if query.load_options._autoflush: session._autoflush() def _do_pre_synchronize(self): @@ -1761,12 +1771,14 @@ class BulkEvaluate(BulkUD): def _do_pre_synchronize(self): query = self.query - target_cls = query._mapper_zero().class_ + target_cls = self.compile_state._mapper_zero().class_ try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - if query.whereclause is not None: - eval_condition = evaluator_compiler.process(query.whereclause) + if query._where_criteria: + eval_condition = evaluator_compiler.process( + *query._where_criteria + ) else: def eval_condition(obj): @@ -1802,12 +1814,11 @@ class BulkFetch(BulkUD): def _do_pre_synchronize(self): query = self.query session = query.session - context = query._compile_context() - select_stmt = context.statement.with_only_columns( + select_stmt = self.compile_state.statement.with_only_columns( self.primary_table.primary_key ) self.matched_rows = session.execute( - select_stmt, mapper=self.mapper, params=query._params + select_stmt, mapper=self.mapper, params=query.load_options._params ).fetchall() @@ -1850,7 +1861,7 @@ class BulkUpdate(BulkUD): ): if self.mapper: if isinstance(k, util.string_types): - desc = _entity_descriptor(self.mapper, k) + desc = sql.util._entity_namespace_key(self.mapper, k) values.extend(desc._bulk_update_tuples(v)) elif isinstance(k, attributes.QueryableAttribute): values.extend(k._bulk_update_tuples(v)) @@ -1890,11 +1901,10 @@ class BulkUpdate(BulkUD): values = dict(values) update_stmt = sql.update( - self.primary_table, - self.context.whereclause, - values, - **self.update_kwargs - ) + self.primary_table, **self.update_kwargs + ).values(values) + + update_stmt._where_criteria = self.compile_state._where_criteria self._execute_stmt(update_stmt) @@ -1929,7 +1939,8 @@ class BulkDelete(BulkUD): self.query = new_query def _do_exec(self): - delete_stmt = sql.delete(self.primary_table, self.context.whereclause) + delete_stmt = sql.delete(self.primary_table,) + delete_stmt._where_criteria = self.compile_state._where_criteria self._execute_stmt(delete_stmt) @@ -1994,7 +2005,7 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): def _do_post_synchronize(self): session = self.query.session - target_mapper = self.query._mapper_zero() + target_mapper = self.compile_state._mapper_zero() states = set( [ @@ -2024,7 +2035,7 @@ class BulkDeleteFetch(BulkFetch, BulkDelete): def _do_post_synchronize(self): session = self.query.session - target_mapper = self.query._mapper_zero() + target_mapper = self.compile_state._mapper_zero() for primary_key in self.matched_rows: # TODO: inline this and call remove_newly_deleted # once |
