diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 8 |
5 files changed, 49 insertions, 9 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 674d54179..a0aa67c69 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1526,7 +1526,7 @@ class MSExecutionContext(default.DefaultExecutionContext): """Activate IDENTITY_INSERT if needed.""" if self.isinsert: - tbl = self.compiled.statement.table + tbl = self.compiled.compile_state.dml_table id_column = tbl._autoincrement_column insert_has_identity = (id_column is not None) and ( not isinstance(id_column.default, Sequence) @@ -1607,7 +1607,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self._opt_encode( "SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table( - self.compiled.statement.table + self.compiled.compile_state.dml_table ) ), (), @@ -1631,7 +1631,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self._opt_encode( "SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table( - self.compiled.statement.table + self.compiled.compile_state.dml_table ) ) ) diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index fa192a17e..23bae5cc0 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -357,6 +357,9 @@ class ORMFromStatementCompileState(ORMCompileState): self.statement_container = self.select_statement = statement_container self.requested_statement = statement = statement_container.element + if statement.is_dml: + self.dml_table = statement.table + self._entities = [] self._polymorphic_adapters = {} self._no_yield_pers = set() @@ -367,6 +370,7 @@ class ORMFromStatementCompileState(ORMCompileState): self.use_legacy_query_style and isinstance(statement, expression.SelectBase) and not statement._is_textual + and not statement.is_dml and statement._label_style is LABEL_STYLE_NONE ): self.statement = statement.set_label_style( @@ -377,7 +381,7 @@ class ORMFromStatementCompileState(ORMCompileState): self._label_convention = self._column_naming_convention( statement._label_style - if not statement._is_textual + if not statement._is_textual and not statement.is_dml else LABEL_STYLE_NONE, self.use_legacy_query_style, ) @@ -409,7 +413,9 @@ class ORMFromStatementCompileState(ORMCompileState): self.order_by = None - if isinstance(self.statement, expression.TextClause): + if isinstance( + self.statement, (expression.TextClause, expression.UpdateBase) + ): # setup for all entities. Currently, this is not useful # for eager loaders, as the eager loaders that work are able # to do their work entirely in row_processor. @@ -790,12 +796,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState): query = util.preloaded.orm_query from_statement = coercions.expect( - roles.SelectStatementRole, + roles.ReturnsRowsRole, from_statement, apply_propagate_attrs=statement, ) stmt = query.FromStatement(statement._raw_columns, from_statement) + stmt.__dict__.update( _with_options=statement._with_options, _with_context_options=statement._with_context_options, diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index f19f29daa..7ab9eeda7 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -2179,6 +2179,11 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): compiler._annotations.get("synchronize_session", None) == "fetch" and compiler.dialect.full_returning ): + if new_stmt._returning: + raise sa_exc.InvalidRequestError( + "Can't use synchronize_session='fetch' " + "with explicit returning()" + ) new_stmt = new_stmt.returning(*mapper.primary_key) UpdateDMLState.__init__(self, new_stmt, compiler, **kw) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 30cb9e730..c444b557b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -57,10 +57,12 @@ from ..sql.base import _generative from ..sql.base import Executable from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg +from ..sql.selectable import GroupedElement from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectBase from ..sql.selectable import SelectStatementGrouping from ..sql.visitors import InternalTraversal from ..util import collections_abc @@ -3178,7 +3180,7 @@ class Query( return context -class FromStatement(SelectStatementGrouping, Executable): +class FromStatement(GroupedElement, SelectBase, Executable): """Core construct that represents a load of ORM objects from a finished select or text construct. @@ -3210,7 +3212,19 @@ class FromStatement(SelectStatementGrouping, Executable): ) for ent in util.to_list(entities) ] - super(FromStatement, self).__init__(element) + self.element = element + + def get_label_style(self): + return self._label_style + + def set_label_style(self, label_style): + return SelectStatementGrouping( + self.element.set_label_style(label_style) + ) + + @property + def _label_style(self): + return self.element._label_style def _compiler_dispatch(self, compiler, **kw): @@ -3241,6 +3255,14 @@ class FromStatement(SelectStatementGrouping, Executable): for elem in super(FromStatement, self).get_children(**kw): yield elem + @property + def _returning(self): + return self.element._returning if self.element.is_dml else None + + @property + def _inline(self): + return self.element._inline if self.element.is_dml else None + class AliasOption(interfaces.LoaderOption): @util.deprecated( diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 3f492a490..ea10bfc27 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -47,6 +47,10 @@ class DMLState(CompileState): def __init__(self, statement, compiler, **kw): raise NotImplementedError() + @property + def dml_table(self): + return self.statement.table + def _make_extra_froms(self, statement): froms = [] @@ -407,7 +411,9 @@ class UpdateBase( raise exc.InvalidRequestError( "return_defaults() is already configured on this statement" ) - self._returning += cols + self._returning += tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) def _exported_columns_iterator(self): """Return the RETURNING columns as a sequence for this statement. |
