summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-06-06 20:40:43 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-06-10 15:29:01 -0400
commitb0cfa7379cf8513a821a3dbe3028c4965d9f85bd (patch)
tree19a79632b4f159092d955765ff9f7e842808bce7 /lib
parent3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6 (diff)
downloadsqlalchemy-b0cfa7379cf8513a821a3dbe3028c4965d9f85bd.tar.gz
Turn on caching everywhere, add logging
A variety of caching issues found by running all tests with statement caching turned on. The cache system now has a more conservative approach where any subclass of a SQL element will by default invalidate the cache key unless it adds the flag inherit_cache=True at the class level, or if it implements its own caching. Add working caching to a few elements that were omitted previously; fix some caching implementations to suit lesser used edge cases such as json casts and array slices. Refine the way BaseCursorResult and CursorMetaData interact with caching; to suit cases like Alembic modifying table structures, don't cache the cursor metadata if it were created against a cursor.description using non-positional matching, e.g. "select *". if a table re-ordered its columns or added/removed, now that data is obsolete. Additionally we have to adapt the cursor metadata _keymap regardless of if we just processed cursor.description, because if we ran against a cached SQLCompiler we won't have the right columns in _keymap. Other refinements to how and when we do this adaption as some weird cases were exposed in the Postgresql dialect, a text() construct that names just one column that is not actually in the statement. Fixed that also as it looks like a cut-and-paste artifact that doesn't actually affect anything. Various issues with re-use of compiled result maps and cursor metadata in conjunction with tables being changed, such as change in order of columns. mappers can be cleared but the class remains, meaning a mapper has to use itself as the cache key not the class. lots of bound parameter / literal issues, due to Alembic creating a straight subclass of bindparam that renders inline directly. While we can update Alembic to not do this, we have to assume other people might be doing this, so bindparam() implements the inherit_cache=True logic as well that was a bit involved. turn on cache stats in logging. Includes a fix to subqueryloader which moves all setup to the create_row_processor() phase and elminates any storage within the compiled context. This includes some changes to create_row_processor() signature and a revising of the technique used to determine if the loader can participate in polymorphic queries, which is also applied to selectinloading. DML update.values() and ordered_values() now coerces the keys as we have tests that pass an arbitrary class here which only includes __clause_element__(), so the key can't be cached unless it is coerced. this in turn changed how composite attributes support bulk update to use the standard approach of ClauseElement with annotations that are parsed in the ORM context. memory profiling successfully caught that the Session from Query was getting passed into _statement_20() so that was a big win for that test suite. Apparently Compiler had .execute() and .scalar() methods stuck on it, these date back to version 0.4 and there was a single test in the PostgreSQL dialect tests that exercised it for no apparent reason. Removed these methods as well as the concept of a Compiler holding onto a "bind". Fixes: #5386 Change-Id: I990b43aab96b42665af1b2187ad6020bee778784
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2
-rw-r--r--lib/sqlalchemy/engine/base.py105
-rw-r--r--lib/sqlalchemy/engine/create.py13
-rw-r--r--lib/sqlalchemy/engine/cursor.py89
-rw-r--r--lib/sqlalchemy/engine/default.py15
-rw-r--r--lib/sqlalchemy/ext/baked.py1
-rw-r--r--lib/sqlalchemy/future/selectable.py1
-rw-r--r--lib/sqlalchemy/orm/attributes.py30
-rw-r--r--lib/sqlalchemy/orm/context.py37
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py1
-rw-r--r--lib/sqlalchemy/orm/interfaces.py32
-rw-r--r--lib/sqlalchemy/orm/loading.py11
-rw-r--r--lib/sqlalchemy/orm/mapper.py2
-rw-r--r--lib/sqlalchemy/orm/path_registry.py8
-rw-r--r--lib/sqlalchemy/orm/persistence.py9
-rw-r--r--lib/sqlalchemy/orm/properties.py1
-rw-r--r--lib/sqlalchemy/orm/query.py20
-rw-r--r--lib/sqlalchemy/orm/relationships.py1
-rw-r--r--lib/sqlalchemy/orm/strategies.py500
-rw-r--r--lib/sqlalchemy/orm/util.py8
-rw-r--r--lib/sqlalchemy/sql/__init__.py3
-rw-r--r--lib/sqlalchemy/sql/annotation.py9
-rw-r--r--lib/sqlalchemy/sql/base.py7
-rw-r--r--lib/sqlalchemy/sql/compiler.py54
-rw-r--r--lib/sqlalchemy/sql/ddl.py3
-rw-r--r--lib/sqlalchemy/sql/dml.py9
-rw-r--r--lib/sqlalchemy/sql/elements.py138
-rw-r--r--lib/sqlalchemy/sql/functions.py39
-rw-r--r--lib/sqlalchemy/sql/schema.py2
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py24
-rw-r--r--lib/sqlalchemy/sql/traversals.py129
-rw-r--r--lib/sqlalchemy/sql/type_api.py7
-rw-r--r--lib/sqlalchemy/sql/visitors.py19
-rw-r--r--lib/sqlalchemy/testing/assertsql.py6
35 files changed, 920 insertions, 447 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 441e77a37..24e2d13d8 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -3681,7 +3681,7 @@ class PGDialect(default.DefaultDialect):
WHERE t.typtype = 'd'
"""
- s = sql.text(SQL_DOMAINS).columns(attname=sqltypes.Unicode)
+ s = sql.text(SQL_DOMAINS)
c = connection.execution_options(future_result=True).execute(s)
domains = {}
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index a36f4eee2..3e02a29fe 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1175,46 +1175,17 @@ class Connection(Connectable):
)
compiled_cache = execution_options.get(
- "compiled_cache", self.dialect._compiled_cache
+ "compiled_cache", self.engine._compiled_cache
)
- if compiled_cache is not None:
- elem_cache_key = elem._generate_cache_key()
- else:
- elem_cache_key = None
-
- if elem_cache_key:
- cache_key, extracted_params = elem_cache_key
- key = (
- dialect,
- cache_key,
- tuple(keys),
- bool(schema_translate_map),
- inline,
- )
- compiled_sql = compiled_cache.get(key)
-
- if compiled_sql is None:
- compiled_sql = elem.compile(
- dialect=dialect,
- cache_key=elem_cache_key,
- column_keys=keys,
- inline=inline,
- schema_translate_map=schema_translate_map,
- linting=self.dialect.compiler_linting
- | compiler.WARN_LINTING,
- )
- compiled_cache[key] = compiled_sql
- else:
- extracted_params = None
- compiled_sql = elem.compile(
- dialect=dialect,
- column_keys=keys,
- inline=inline,
- schema_translate_map=schema_translate_map,
- linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
- )
-
+ compiled_sql, extracted_params, cache_hit = elem._compile_w_cache(
+ dialect=dialect,
+ compiled_cache=compiled_cache,
+ column_keys=keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+ )
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_compiled,
@@ -1225,6 +1196,7 @@ class Connection(Connectable):
distilled_params,
elem,
extracted_params,
+ cache_hit=cache_hit,
)
if has_events:
self.dispatch.after_execute(
@@ -1389,7 +1361,8 @@ class Connection(Connectable):
statement,
parameters,
execution_options,
- *args
+ *args,
+ **kw
):
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`_engine.CursorResult`."""
@@ -1407,7 +1380,7 @@ class Connection(Connectable):
conn = self._revalidate_connection()
context = constructor(
- dialect, self, conn, execution_options, *args
+ dialect, self, conn, execution_options, *args, **kw
)
except (exc.PendingRollbackError, exc.ResourceClosedError):
raise
@@ -1455,32 +1428,21 @@ class Connection(Connectable):
self.engine.logger.info(statement)
- # stats = context._get_cache_stats()
+ stats = context._get_cache_stats()
if not self.engine.hide_parameters:
- # TODO: I love the stats but a ton of tests that are hardcoded.
- # to certain log output are failing.
self.engine.logger.info(
- "%r",
+ "[%s] %r",
+ stats,
sql_util._repr_params(
parameters, batches=10, ismulti=context.executemany
),
)
- # self.engine.logger.info(
- # "[%s] %r",
- # stats,
- # sql_util._repr_params(
- # parameters, batches=10, ismulti=context.executemany
- # ),
- # )
else:
self.engine.logger.info(
- "[SQL parameters hidden due to hide_parameters=True]"
+ "[%s] [SQL parameters hidden due to hide_parameters=True]"
+ % (stats,)
)
- # self.engine.logger.info(
- # "[%s] [SQL parameters hidden due to hide_parameters=True]"
- # % (stats,)
- # )
evt_handled = False
try:
@@ -2369,6 +2331,7 @@ class Engine(Connectable, log.Identified):
url,
logging_name=None,
echo=None,
+ query_cache_size=500,
execution_options=None,
hide_parameters=False,
):
@@ -2379,14 +2342,43 @@ class Engine(Connectable, log.Identified):
self.logging_name = logging_name
self.echo = echo
self.hide_parameters = hide_parameters
+ if query_cache_size != 0:
+ self._compiled_cache = util.LRUCache(
+ query_cache_size, size_alert=self._lru_size_alert
+ )
+ else:
+ self._compiled_cache = None
log.instance_logger(self, echoflag=echo)
if execution_options:
self.update_execution_options(**execution_options)
+ def _lru_size_alert(self, cache):
+ if self._should_log_info:
+ self.logger.info(
+ "Compiled cache size pruning from %d items to %d. "
+ "Increase cache size to reduce the frequency of pruning.",
+ len(cache),
+ cache.capacity,
+ )
+
@property
def engine(self):
return self
+ def clear_compiled_cache(self):
+ """Clear the compiled cache associated with the dialect.
+
+ This applies **only** to the built-in cache that is established
+ via the :paramref:`.create_engine.query_cache_size` parameter.
+ It will not impact any dictionary caches that were passed via the
+ :paramref:`.Connection.execution_options.query_cache` parameter.
+
+ .. versionadded:: 1.4
+
+ """
+ if self._compiled_cache:
+ self._compiled_cache.clear()
+
def update_execution_options(self, **opt):
r"""Update the default execution_options dictionary
of this :class:`_engine.Engine`.
@@ -2874,6 +2866,7 @@ class OptionEngineMixin(object):
self.dialect = proxied.dialect
self.logging_name = proxied.logging_name
self.echo = proxied.echo
+ self._compiled_cache = proxied._compiled_cache
self.hide_parameters = proxied.hide_parameters
log.instance_logger(self, echoflag=self.echo)
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
index 4c912349e..9bf72eb06 100644
--- a/lib/sqlalchemy/engine/create.py
+++ b/lib/sqlalchemy/engine/create.py
@@ -436,7 +436,13 @@ def create_engine(url, **kwargs):
.. versionadded:: 1.2.3
:param query_cache_size: size of the cache used to cache the SQL string
- form of queries. Defaults to zero, which disables caching.
+ form of queries. Set to zero to disable caching.
+
+ The cache is pruned of its least recently used items when its size reaches
+ N * 1.5. Defaults to 500, meaning the cache will always store at least
+ 500 SQL statements when filled, and will grow up to 750 items at which
+ point it is pruned back down to 500 by removing the 250 least recently
+ used items.
Caching is accomplished on a per-statement basis by generating a
cache key that represents the statement's structure, then generating
@@ -446,6 +452,11 @@ def create_engine(url, **kwargs):
bypass the cache. SQL logging will indicate statistics for each
statement whether or not it were pull from the cache.
+ .. note:: some ORM functions related to unit-of-work persistence as well
+ as some attribute loading strategies will make use of individual
+ per-mapper caches outside of the main cache.
+
+
.. seealso::
``engine_caching`` - TODO: this will be an upcoming section describing
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index d03d79df7..abffe0d1f 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -51,6 +51,7 @@ class CursorResultMetaData(ResultMetaData):
"_keys",
"_tuplefilter",
"_translated_indexes",
+ "_safe_for_cache"
# don't need _unique_filters support here for now. Can be added
# if a need arises.
)
@@ -104,11 +105,11 @@ class CursorResultMetaData(ResultMetaData):
return new_metadata
def _adapt_to_context(self, context):
- """When using a cached result metadata against a new context,
- we need to rewrite the _keymap so that it has the specific
- Column objects in the new context inside of it. this accommodates
- for select() constructs that contain anonymized columns and
- are cached.
+ """When using a cached Compiled construct that has a _result_map,
+ for a new statement that used the cached Compiled, we need to ensure
+ the keymap has the Column objects from our new statement as keys.
+ So here we rewrite keymap with new entries for the new columns
+ as matched to those of the cached statement.
"""
if not context.compiled._result_columns:
@@ -124,14 +125,15 @@ class CursorResultMetaData(ResultMetaData):
# to the result map.
md = self.__class__.__new__(self.__class__)
- md._keymap = self._keymap.copy()
+ md._keymap = dict(self._keymap)
# match up new columns positionally to the result columns
for existing, new in zip(
context.compiled._result_columns,
invoked_statement._exported_columns_iterator(),
):
- md._keymap[new] = md._keymap[existing[RM_NAME]]
+ if existing[RM_NAME] in md._keymap:
+ md._keymap[new] = md._keymap[existing[RM_NAME]]
md.case_sensitive = self.case_sensitive
md._processors = self._processors
@@ -147,6 +149,7 @@ class CursorResultMetaData(ResultMetaData):
self._tuplefilter = None
self._translated_indexes = None
self.case_sensitive = dialect.case_sensitive
+ self._safe_for_cache = False
if context.result_column_struct:
(
@@ -341,6 +344,10 @@ class CursorResultMetaData(ResultMetaData):
self._keys = [elem[0] for elem in result_columns]
# pure positional 1-1 case; doesn't need to read
# the names from cursor.description
+
+ # this metadata is safe to cache because we are guaranteed
+ # to have the columns in the same order for new executions
+ self._safe_for_cache = True
return [
(
idx,
@@ -359,9 +366,12 @@ class CursorResultMetaData(ResultMetaData):
for idx, rmap_entry in enumerate(result_columns)
]
else:
+
# name-based or text-positional cases, where we need
# to read cursor.description names
+
if textual_ordered:
+ self._safe_for_cache = True
# textual positional case
raw_iterator = self._merge_textual_cols_by_position(
context, cursor_description, result_columns
@@ -369,6 +379,9 @@ class CursorResultMetaData(ResultMetaData):
elif num_ctx_cols:
# compiled SQL with a mismatch of description cols
# vs. compiled cols, or textual w/ unordered columns
+ # the order of columns can change if the query is
+ # against a "select *", so not safe to cache
+ self._safe_for_cache = False
raw_iterator = self._merge_cols_by_name(
context,
cursor_description,
@@ -376,7 +389,9 @@ class CursorResultMetaData(ResultMetaData):
loose_column_name_matching,
)
else:
- # no compiled SQL, just a raw string
+ # no compiled SQL, just a raw string, order of columns
+ # can change for "select *"
+ self._safe_for_cache = False
raw_iterator = self._merge_cols_by_none(
context, cursor_description
)
@@ -1152,7 +1167,6 @@ class BaseCursorResult(object):
out_parameters = None
_metadata = None
- _metadata_from_cache = False
_soft_closed = False
closed = False
@@ -1209,33 +1223,38 @@ class BaseCursorResult(object):
def _init_metadata(self, context, cursor_description):
if context.compiled:
if context.compiled._cached_metadata:
- cached_md = self.context.compiled._cached_metadata
- self._metadata_from_cache = True
-
- # result rewrite/ adapt step. two translations can occur here.
- # one is if we are invoked against a cached statement, we want
- # to rewrite the ResultMetaData to reflect the column objects
- # that are in our current selectable, not the cached one. the
- # other is, the CompileState can return an alternative Result
- # object. Finally, CompileState might want to tell us to not
- # actually do the ResultMetaData adapt step if it in fact has
- # changed the selected columns in any case.
- compiled = context.compiled
- if (
- compiled
- and not compiled._rewrites_selected_columns
- and compiled.statement is not context.invoked_statement
- ):
- cached_md = cached_md._adapt_to_context(context)
-
- self._metadata = metadata = cached_md
-
+ metadata = self.context.compiled._cached_metadata
else:
- self._metadata = (
- metadata
- ) = context.compiled._cached_metadata = self._cursor_metadata(
- self, cursor_description
- )
+ metadata = self._cursor_metadata(self, cursor_description)
+ if metadata._safe_for_cache:
+ context.compiled._cached_metadata = metadata
+
+ # result rewrite/ adapt step. this is to suit the case
+ # when we are invoked against a cached Compiled object, we want
+ # to rewrite the ResultMetaData to reflect the Column objects
+ # that are in our current SQL statement object, not the one
+ # that is associated with the cached Compiled object.
+ # the Compiled object may also tell us to not
+ # actually do this step; this is to support the ORM where
+ # it is to produce a new Result object in any case, and will
+ # be using the cached Column objects against this database result
+ # so we don't want to rewrite them.
+ #
+ # Basically this step suits the use case where the end user
+ # is using Core SQL expressions and is accessing columns in the
+ # result row using row._mapping[table.c.column].
+ compiled = context.compiled
+ if (
+ compiled
+ and compiled._result_columns
+ and context.cache_hit
+ and not compiled._rewrites_selected_columns
+ and compiled.statement is not context.invoked_statement
+ ):
+ metadata = metadata._adapt_to_context(context)
+
+ self._metadata = metadata
+
else:
self._metadata = metadata = self._cursor_metadata(
self, cursor_description
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index c682a8ee1..4d516e97c 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -230,7 +230,6 @@ class DefaultDialect(interfaces.Dialect):
supports_native_boolean=None,
max_identifier_length=None,
label_length=None,
- query_cache_size=0,
# int() is because the @deprecated_params decorator cannot accommodate
# the direct reference to the "NO_LINTING" object
compiler_linting=int(compiler.NO_LINTING),
@@ -262,10 +261,6 @@ class DefaultDialect(interfaces.Dialect):
if supports_native_boolean is not None:
self.supports_native_boolean = supports_native_boolean
self.case_sensitive = case_sensitive
- if query_cache_size != 0:
- self._compiled_cache = util.LRUCache(query_cache_size)
- else:
- self._compiled_cache = None
self._user_defined_max_identifier_length = max_identifier_length
if self._user_defined_max_identifier_length:
@@ -794,6 +789,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
parameters,
invoked_statement,
extracted_parameters,
+ cache_hit=False,
):
"""Initialize execution context for a Compiled construct."""
@@ -804,6 +800,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.extracted_parameters = extracted_parameters
self.invoked_statement = invoked_statement
self.compiled = compiled
+ self.cache_hit = cache_hit
self.execution_options = execution_options
@@ -1027,13 +1024,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
def _get_cache_stats(self):
if self.compiled is None:
- return "raw SQL"
+ return "raw sql"
now = time.time()
if self.compiled.cache_key is None:
- return "gen %.5fs" % (now - self.compiled._gen_time,)
+ return "no key %.5fs" % (now - self.compiled._gen_time,)
+ elif self.cache_hit:
+ return "cached for %.4gs" % (now - self.compiled._gen_time,)
else:
- return "cached %.5fs" % (now - self.compiled._gen_time,)
+ return "generated in %.5fs" % (now - self.compiled._gen_time,)
@util.memoized_property
def engine(self):
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
index f95a30fda..4f40637c5 100644
--- a/lib/sqlalchemy/ext/baked.py
+++ b/lib/sqlalchemy/ext/baked.py
@@ -412,7 +412,6 @@ class Result(object):
result = self.session.execute(
statement, params, execution_options=execution_options
)
-
if result._attributes.get("is_single_entity", False):
result = result.scalars()
diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py
index 407ec9633..53fc7c107 100644
--- a/lib/sqlalchemy/future/selectable.py
+++ b/lib/sqlalchemy/future/selectable.py
@@ -11,6 +11,7 @@ class Select(_LegacySelect):
_is_future = True
_setup_joins = ()
_legacy_setup_joins = ()
+ inherit_cache = True
@classmethod
def _create_select(cls, *entities):
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 262a1efc9..bf07061c6 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -85,16 +85,16 @@ class QueryableAttribute(
self,
class_,
key,
+ parententity,
impl=None,
comparator=None,
- parententity=None,
of_type=None,
):
self.class_ = class_
self.key = key
+ self._parententity = parententity
self.impl = impl
self.comparator = comparator
- self._parententity = parententity
self._of_type = of_type
manager = manager_of_class(class_)
@@ -197,10 +197,14 @@ class QueryableAttribute(
@util.memoized_property
def expression(self):
return self.comparator.__clause_element__()._annotate(
- {"orm_key": self.key}
+ {"orm_key": self.key, "entity_namespace": self._entity_namespace}
)
@property
+ def _entity_namespace(self):
+ return self._parententity
+
+ @property
def _annotations(self):
return self.__clause_element__()._annotations
@@ -230,9 +234,9 @@ class QueryableAttribute(
return QueryableAttribute(
self.class_,
self.key,
- self.impl,
- self.comparator.of_type(entity),
self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.of_type(entity),
of_type=inspection.inspect(entity),
)
@@ -301,6 +305,8 @@ class InstrumentedAttribute(QueryableAttribute):
"""
+ inherit_cache = True
+
def __set__(self, instance, value):
self.impl.set(
instance_state(instance), instance_dict(instance), value, None
@@ -320,6 +326,11 @@ class InstrumentedAttribute(QueryableAttribute):
return self.impl.get(instance_state(instance), dict_)
+HasEntityNamespace = util.namedtuple(
+ "HasEntityNamespace", ["entity_namespace"]
+)
+
+
def create_proxied_attribute(descriptor):
"""Create an QueryableAttribute / user descriptor hybrid.
@@ -365,6 +376,15 @@ def create_proxied_attribute(descriptor):
)
@property
+ def _entity_namespace(self):
+ if hasattr(self._comparator, "_parententity"):
+ return self._comparator._parententity
+ else:
+ # used by hybrid attributes which try to remain
+ # agnostic of any ORM concepts like mappers
+ return HasEntityNamespace(self.class_)
+
+ @property
def property(self):
return self.comparator.property
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index a16db66f6..588b83571 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -63,6 +63,8 @@ class QueryContext(object):
"post_load_paths",
"identity_token",
"yield_per",
+ "loaders_require_buffering",
+ "loaders_require_uniquing",
)
class default_load_options(Options):
@@ -80,21 +82,23 @@ class QueryContext(object):
def __init__(
self,
compile_state,
+ statement,
session,
load_options,
execution_options=None,
bind_arguments=None,
):
-
self.load_options = load_options
self.execution_options = execution_options or _EMPTY_DICT
self.bind_arguments = bind_arguments or _EMPTY_DICT
self.compile_state = compile_state
- self.query = query = compile_state.select_statement
+ self.query = statement
self.session = session
+ self.loaders_require_buffering = False
+ self.loaders_require_uniquing = False
self.propagated_loader_options = {
- o for o in query._with_options if o.propagate_to_loaders
+ o for o in statement._with_options if o.propagate_to_loaders
}
self.attributes = dict(compile_state.attributes)
@@ -237,6 +241,7 @@ class ORMCompileState(CompileState):
)
querycontext = QueryContext(
compile_state,
+ statement,
session,
load_options,
execution_options,
@@ -278,8 +283,6 @@ class ORMFromStatementCompileState(ORMCompileState):
_has_orm_entities = False
multi_row_eager_loaders = False
compound_eager_adapter = None
- loaders_require_buffering = False
- loaders_require_uniquing = False
@classmethod
def create_for_statement(cls, statement_container, compiler, **kw):
@@ -386,8 +389,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
_has_orm_entities = False
multi_row_eager_loaders = False
compound_eager_adapter = None
- loaders_require_buffering = False
- loaders_require_uniquing = False
correlate = None
_where_criteria = ()
@@ -416,7 +417,14 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self = cls.__new__(cls)
- self.select_statement = select_statement
+ if select_statement._execution_options:
+ # execution options should not impact the compilation of a
+ # query, and at the moment subqueryloader is putting some things
+ # in here that we explicitly don't want stuck in a cache.
+ self.select_statement = select_statement._clone()
+ self.select_statement._execution_options = util.immutabledict()
+ else:
+ self.select_statement = select_statement
# indicates this select() came from Query.statement
self.for_statement = (
@@ -654,6 +662,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
self._setup_with_polymorphics()
+ # entities will also set up polymorphic adapters for mappers
+ # that have with_polymorphic configured
_QueryEntity.to_compile_state(self, query._raw_columns)
return self
@@ -1810,10 +1820,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self._where_criteria += (single_crit,)
-def _column_descriptions(query_or_select_stmt):
- ctx = ORMSelectCompileState._create_entities_collection(
- query_or_select_stmt
- )
+def _column_descriptions(query_or_select_stmt, compile_state=None):
+ if compile_state is None:
+ compile_state = ORMSelectCompileState._create_entities_collection(
+ query_or_select_stmt
+ )
+ ctx = compile_state
return [
{
"name": ent._label_name,
@@ -2097,6 +2109,7 @@ class _MapperEntity(_QueryEntity):
only_load_props = refresh_state = None
_instance = loading._instance_processor(
+ self,
self.mapper,
context,
result,
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 027f2521b..39cf86e34 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -411,7 +411,6 @@ class CompositeProperty(DescriptorProperty):
def expression(self):
clauses = self.clauses._annotate(
{
- "bundle": True,
"parententity": self._parententity,
"parentmapper": self._parententity,
"orm_key": self.prop.key,
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 6c0f5d3ef..9782d92b7 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -158,7 +158,7 @@ class MapperProperty(
"""
def create_row_processor(
- self, context, path, mapper, result, adapter, populators
+ self, context, query_entity, path, mapper, result, adapter, populators
):
"""Produce row processing functions and append to the given
set of populators lists.
@@ -539,7 +539,7 @@ class StrategizedProperty(MapperProperty):
"_wildcard_token",
"_default_path_loader_key",
)
-
+ inherit_cache = True
strategy_wildcard_key = None
def _memoized_attr__wildcard_token(self):
@@ -600,7 +600,7 @@ class StrategizedProperty(MapperProperty):
)
def create_row_processor(
- self, context, path, mapper, result, adapter, populators
+ self, context, query_entity, path, mapper, result, adapter, populators
):
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
@@ -608,7 +608,14 @@ class StrategizedProperty(MapperProperty):
else:
strat = self.strategy
strat.create_row_processor(
- context, path, loader, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loader,
+ mapper,
+ result,
+ adapter,
+ populators,
)
def do_init(self):
@@ -668,7 +675,7 @@ class StrategizedProperty(MapperProperty):
)
-class ORMOption(object):
+class ORMOption(HasCacheKey):
"""Base class for option objects that are passed to ORM queries.
These options may be consumed by :meth:`.Query.options`,
@@ -696,7 +703,7 @@ class ORMOption(object):
_is_compile_state = False
-class LoaderOption(HasCacheKey, ORMOption):
+class LoaderOption(ORMOption):
"""Describe a loader modification to an ORM statement at compilation time.
.. versionadded:: 1.4
@@ -736,9 +743,6 @@ class UserDefinedOption(ORMOption):
def __init__(self, payload=None):
self.payload = payload
- def _gen_cache_key(self, *arg, **kw):
- return ()
-
@util.deprecated_cls(
"1.4",
@@ -855,7 +859,15 @@ class LoaderStrategy(object):
"""
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
"""Establish row processing functions for a given QueryContext.
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 424ed5dfe..a33e1b77d 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -72,8 +72,8 @@ def instances(cursor, context):
)
if context.yield_per and (
- context.compile_state.loaders_require_buffering
- or context.compile_state.loaders_require_uniquing
+ context.loaders_require_buffering
+ or context.loaders_require_uniquing
):
raise sa_exc.InvalidRequestError(
"Can't use yield_per with eager loaders that require uniquing "
@@ -545,6 +545,7 @@ def _warn_for_runid_changed(state):
def _instance_processor(
+ query_entity,
mapper,
context,
result,
@@ -648,6 +649,7 @@ def _instance_processor(
# to see if one fits
prop.create_row_processor(
context,
+ query_entity,
path,
mapper,
result,
@@ -667,7 +669,7 @@ def _instance_processor(
populators = {key: list(value) for key, value in cached_populators.items()}
for prop in getters["todo"]:
prop.create_row_processor(
- context, path, mapper, result, adapter, populators
+ context, query_entity, path, mapper, result, adapter, populators
)
propagated_loader_options = context.propagated_loader_options
@@ -925,6 +927,7 @@ def _instance_processor(
_instance = _decorate_polymorphic_switch(
_instance,
context,
+ query_entity,
mapper,
result,
path,
@@ -1081,6 +1084,7 @@ def _validate_version_id(mapper, state, dict_, row, getter):
def _decorate_polymorphic_switch(
instance_fn,
context,
+ query_entity,
mapper,
result,
path,
@@ -1112,6 +1116,7 @@ def _decorate_polymorphic_switch(
return False
return _instance_processor(
+ query_entity,
sub_mapper,
context,
result,
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index c4cb89c03..bec6da74d 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -720,7 +720,7 @@ class Mapper(
return self
_cache_key_traversal = [
- ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj)
+ ("mapper", visitors.ExtendedInternalTraversal.dp_plain_obj),
]
@property
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
index 2e5941713..ac7a64c30 100644
--- a/lib/sqlalchemy/orm/path_registry.py
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -216,6 +216,8 @@ class RootRegistry(PathRegistry):
"""
+ inherit_cache = True
+
path = natural_path = ()
has_entity = False
is_aliased_class = False
@@ -248,6 +250,8 @@ class PathToken(HasCacheKey, str):
class TokenRegistry(PathRegistry):
__slots__ = ("token", "parent", "path", "natural_path")
+ inherit_cache = True
+
def __init__(self, parent, token):
token = PathToken.intern(token)
@@ -280,6 +284,7 @@ class TokenRegistry(PathRegistry):
class PropRegistry(PathRegistry):
is_unnatural = False
+ inherit_cache = True
def __init__(self, parent, prop):
# restate this path in terms of the
@@ -439,6 +444,7 @@ class AbstractEntityRegistry(PathRegistry):
class SlotsEntityRegistry(AbstractEntityRegistry):
# for aliased class, return lightweight, no-cycles created
# version
+ inherit_cache = True
__slots__ = (
"key",
@@ -454,6 +460,8 @@ class CachingEntityRegistry(AbstractEntityRegistry, dict):
# for long lived mapper, return dict based caching
# version that creates reference cycles
+ inherit_cache = True
+
def __getitem__(self, entity):
if isinstance(entity, (int, slice)):
return self.path[entity]
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 19d43d354..8393eaf74 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -38,6 +38,7 @@ from ..sql.base import Options
from ..sql.dml import DeleteDMLState
from ..sql.dml import UpdateDMLState
from ..sql.elements import BooleanClauseList
+from ..sql.util import _entity_namespace_key
def _bulk_insert(
@@ -1820,8 +1821,12 @@ class BulkUDCompileState(CompileState):
if isinstance(k, util.string_types):
desc = sql.util._entity_namespace_key(mapper, k)
values.extend(desc._bulk_update_tuples(v))
- elif isinstance(k, attributes.QueryableAttribute):
- values.extend(k._bulk_update_tuples(v))
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["orm_key"]
+ )
+ values.extend(attr._bulk_update_tuples(v))
else:
values.append((k, v))
else:
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 02f0752a5..5fb3beca3 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -45,6 +45,7 @@ class ColumnProperty(StrategizedProperty):
"""
strategy_wildcard_key = "column"
+ inherit_cache = True
__slots__ = (
"_orig_columns",
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 284ea9d72..cdad55320 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -61,6 +61,7 @@ from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import SelectStatementGrouping
from ..sql.util import _entity_namespace_key
+from ..sql.visitors import InternalTraversal
from ..util import collections_abc
__all__ = ["Query", "QueryContext", "aliased"]
@@ -423,6 +424,7 @@ class Query(
_label_style=self._label_style,
compile_options=compile_options,
)
+ stmt.__dict__.pop("session", None)
stmt._propagate_attrs = self._propagate_attrs
return stmt
@@ -1725,7 +1727,6 @@ class Query(
"""
from_entity = self._filter_by_zero()
-
if from_entity is None:
raise sa_exc.InvalidRequestError(
"Can't use filter_by when the first entity '%s' of a query "
@@ -2900,7 +2901,10 @@ class Query(
compile_state = self._compile_state(for_statement=False)
context = QueryContext(
- compile_state, self.session, self.load_options
+ compile_state,
+ compile_state.statement,
+ self.session,
+ self.load_options,
)
result = loading.instances(result_proxy, context)
@@ -3376,7 +3380,12 @@ class Query(
def _compile_context(self, for_statement=False):
compile_state = self._compile_state(for_statement=for_statement)
- context = QueryContext(compile_state, self.session, self.load_options)
+ context = QueryContext(
+ compile_state,
+ compile_state.statement,
+ self.session,
+ self.load_options,
+ )
return context
@@ -3397,6 +3406,11 @@ class FromStatement(SelectStatementGrouping, Executable):
_for_update_arg = None
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("element", InternalTraversal.dp_clauseelement),
+ ] + Executable._executable_traverse_internals
+
def __init__(self, entities, element):
self._raw_columns = [
coercions.expect(
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 683f2b978..bedc54153 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -107,6 +107,7 @@ class RelationshipProperty(StrategizedProperty):
"""
strategy_wildcard_key = "relationship"
+ inherit_cache = True
_persistence_only = dict(
passive_deletes=False,
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index f67c23aab..5f039aff7 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -25,6 +25,7 @@ from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
from .context import _column_descriptions
+from .context import ORMCompileState
from .interfaces import LoaderStrategy
from .interfaces import StrategizedProperty
from .session import _state_session
@@ -156,7 +157,15 @@ class UninstrumentedColumnLoader(LoaderStrategy):
column_collection.append(c)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
pass
@@ -224,7 +233,15 @@ class ColumnLoader(LoaderStrategy):
)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
# look through list of columns represented here
# to see which, if any, is present in the row.
@@ -281,7 +298,15 @@ class ExpressionColumnLoader(ColumnLoader):
memoized_populators[self.parent_property] = fetch
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
# look through list of columns represented here
# to see which, if any, is present in the row.
@@ -332,7 +357,15 @@ class DeferredColumnLoader(LoaderStrategy):
self.group = self.parent_property.group
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
# for a DeferredColumnLoader, this method is only used during a
@@ -542,7 +575,15 @@ class NoLoader(AbstractRelationshipLoader):
)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
def invoke_no_load(state, dict_, row):
if self.uselist:
@@ -985,7 +1026,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
return None
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
key = self.key
@@ -1039,12 +1088,27 @@ class PostLoader(AbstractRelationshipLoader):
"""A relationship loader that emits a second SELECT statement."""
def _immediateload_create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
return self.parent_property._get_strategy(
(("lazy", "immediate"),)
).create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
@@ -1057,21 +1121,16 @@ class ImmediateLoader(PostLoader):
(("lazy", "select"),)
).init_class_attribute(mapper)
- def setup_query(
+ def create_row_processor(
self,
- compile_state,
- entity,
+ context,
+ query_entity,
path,
loadopt,
+ mapper,
+ result,
adapter,
- column_collection=None,
- parentmapper=None,
- **kwargs
- ):
- pass
-
- def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ populators,
):
def load_immediate(state, dict_, row):
state.get_impl(self.key).get(state, dict_)
@@ -1093,120 +1152,6 @@ class SubqueryLoader(PostLoader):
(("lazy", "select"),)
).init_class_attribute(mapper)
- def setup_query(
- self,
- compile_state,
- entity,
- path,
- loadopt,
- adapter,
- column_collection=None,
- parentmapper=None,
- **kwargs
- ):
- if (
- not compile_state.compile_options._enable_eagerloads
- or compile_state.compile_options._for_refresh_state
- ):
- return
-
- compile_state.loaders_require_buffering = True
-
- path = path[self.parent_property]
-
- # build up a path indicating the path from the leftmost
- # entity to the thing we're subquery loading.
- with_poly_entity = path.get(
- compile_state.attributes, "path_with_polymorphic", None
- )
- if with_poly_entity is not None:
- effective_entity = with_poly_entity
- else:
- effective_entity = self.entity
-
- subq_path = compile_state.attributes.get(
- ("subquery_path", None), orm_util.PathRegistry.root
- )
-
- subq_path = subq_path + path
-
- # if not via query option, check for
- # a cycle
- if not path.contains(compile_state.attributes, "loader"):
- if self.join_depth:
- if (
- (
- compile_state.current_path.length
- if compile_state.current_path
- else 0
- )
- + path.length
- ) / 2 > self.join_depth:
- return
- elif subq_path.contains_mapper(self.mapper):
- return
-
- (
- leftmost_mapper,
- leftmost_attr,
- leftmost_relationship,
- ) = self._get_leftmost(subq_path)
-
- orig_query = compile_state.attributes.get(
- ("orig_query", SubqueryLoader), compile_state.select_statement
- )
-
- # generate a new Query from the original, then
- # produce a subquery from it.
- left_alias = self._generate_from_original_query(
- compile_state,
- orig_query,
- leftmost_mapper,
- leftmost_attr,
- leftmost_relationship,
- entity.entity_zero,
- )
-
- # generate another Query that will join the
- # left alias to the target relationships.
- # basically doing a longhand
- # "from_self()". (from_self() itself not quite industrial
- # strength enough for all contingencies...but very close)
-
- q = query.Query(effective_entity)
-
- def set_state_options(compile_state):
- compile_state.attributes.update(
- {
- ("orig_query", SubqueryLoader): orig_query,
- ("subquery_path", None): subq_path,
- }
- )
-
- q = q._add_context_option(set_state_options, None)._disable_caching()
-
- q = q._set_enable_single_crit(False)
- to_join, local_attr, parent_alias = self._prep_for_joins(
- left_alias, subq_path
- )
-
- q = q.add_columns(*local_attr)
- q = self._apply_joins(
- q, to_join, left_alias, parent_alias, effective_entity
- )
-
- q = self._setup_options(q, subq_path, orig_query, effective_entity)
- q = self._setup_outermost_orderby(q)
-
- # add new query to attributes to be picked up
- # by create_row_processor
- # NOTE: be sure to consult baked.py for some hardcoded logic
- # about this structure as well
- assert q.session is None
- path.set(
- compile_state.attributes, "subqueryload_data", {"query": q},
- )
-
def _get_leftmost(self, subq_path):
subq_path = subq_path.path
subq_mapper = orm_util._class_to_mapper(subq_path[0])
@@ -1267,27 +1212,34 @@ class SubqueryLoader(PostLoader):
q,
*{
ent["entity"]
- for ent in _column_descriptions(orig_query)
+ for ent in _column_descriptions(
+ orig_query, compile_state=orig_compile_state
+ )
if ent["entity"] is not None
}
)
- # for column information, look to the compile state that is
- # already being passed through
- compile_state = orig_compile_state
-
# select from the identity columns of the outer (specifically, these
- # are the 'local_cols' of the property). This will remove
- # other columns from the query that might suggest the right entity
- # which is why we do _set_select_from above.
- target_cols = compile_state._adapt_col_list(
+ # are the 'local_cols' of the property). This will remove other
+ # columns from the query that might suggest the right entity which is
+ # why we do set select_from above. The attributes we have are
+ # coerced and adapted using the original query's adapter, which is
+ # needed only for the case of adapting a subclass column to
+ # that of a polymorphic selectable, e.g. we have
+ # Engineer.primary_language and the entity is Person. All other
+ # adaptations, e.g. from_self, select_entity_from(), will occur
+ # within the new query when it compiles, as the compile_state we are
+ # using here is only a partial one. If the subqueryload is from a
+ # with_polymorphic() or other aliased() object, left_attr will already
+ # be the correct attributes so no adaptation is needed.
+ target_cols = orig_compile_state._adapt_col_list(
[
- sql.coercions.expect(sql.roles.ByOfRole, o)
+ sql.coercions.expect(sql.roles.ColumnsClauseRole, o)
for o in leftmost_attr
],
- compile_state._get_current_adapter(),
+ orig_compile_state._get_current_adapter(),
)
- q._set_entities(target_cols)
+ q._raw_columns = target_cols
distinct_target_key = leftmost_relationship.distinct_target_key
@@ -1461,13 +1413,13 @@ class SubqueryLoader(PostLoader):
"_data",
)
- def __init__(self, context, subq_info):
+ def __init__(self, context, subq):
# avoid creating a cycle by storing context
# even though that's preferable
self.session = context.session
self.execution_options = context.execution_options
self.load_options = context.load_options
- self.subq = subq_info["query"]
+ self.subq = subq
self._data = None
def get(self, key, default):
@@ -1499,12 +1451,148 @@ class SubqueryLoader(PostLoader):
if self._data is None:
self._load()
+ def _setup_query_from_rowproc(
+ self, context, path, entity, loadopt, adapter,
+ ):
+ compile_state = context.compile_state
+ if (
+ not compile_state.compile_options._enable_eagerloads
+ or compile_state.compile_options._for_refresh_state
+ ):
+ return
+
+ context.loaders_require_buffering = True
+
+ path = path[self.parent_property]
+
+ # build up a path indicating the path from the leftmost
+ # entity to the thing we're subquery loading.
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ effective_entity = with_poly_entity
+ else:
+ effective_entity = self.entity
+
+ subq_path = context.query._execution_options.get(
+ ("subquery_path", None), orm_util.PathRegistry.root
+ )
+
+ subq_path = subq_path + path
+
+ # if not via query option, check for
+ # a cycle
+ if not path.contains(compile_state.attributes, "loader"):
+ if self.join_depth:
+ if (
+ (
+ compile_state.current_path.length
+ if compile_state.current_path
+ else 0
+ )
+ + path.length
+ ) / 2 > self.join_depth:
+ return
+ elif subq_path.contains_mapper(self.mapper):
+ return
+
+ (
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ ) = self._get_leftmost(subq_path)
+
+ # use the current query being invoked, not the compile state
+ # one. this is so that we get the current parameters. however,
+ # it means we can't use the existing compile state, we have to make
+ # a new one. other approaches include possibly using the
+ # compiled query but swapping the params, seems only marginally
+ # less time spent but more complicated
+ orig_query = context.query._execution_options.get(
+ ("orig_query", SubqueryLoader), context.query
+ )
+
+ # make a new compile_state for the query that's probably cached, but
+ # we're sort of undoing a bit of that caching :(
+ compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
+ orig_query, "orm"
+ )
+
+ # this would create the full blown compile state, which we don't
+ # need
+ # orig_compile_state = compile_state_cls.create_for_statement(
+ # orig_query, None)
+
+ # this is the more "quick" version, however it's not clear how
+ # much of this we need. in particular I can't get a test to
+ # fail if the "set_base_alias" is missing and not sure why that is.
+ orig_compile_state = compile_state_cls._create_entities_collection(
+ orig_query
+ )
+
+ # generate a new Query from the original, then
+ # produce a subquery from it.
+ left_alias = self._generate_from_original_query(
+ orig_compile_state,
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ entity,
+ )
+
+ # generate another Query that will join the
+ # left alias to the target relationships.
+ # basically doing a longhand
+ # "from_self()". (from_self() itself not quite industrial
+ # strength enough for all contingencies...but very close)
+
+ q = query.Query(effective_entity)
+
+ q._execution_options = q._execution_options.union(
+ {
+ ("orig_query", SubqueryLoader): orig_query,
+ ("subquery_path", None): subq_path,
+ }
+ )
+
+ q = q._set_enable_single_crit(False)
+ to_join, local_attr, parent_alias = self._prep_for_joins(
+ left_alias, subq_path
+ )
+
+ q = q.add_columns(*local_attr)
+ q = self._apply_joins(
+ q, to_join, left_alias, parent_alias, effective_entity
+ )
+
+ q = self._setup_options(q, subq_path, orig_query, effective_entity)
+ q = self._setup_outermost_orderby(q)
+
+ return q
+
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
if context.refresh_state:
return self._immediateload_create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
if not self.parent.class_manager[self.key].impl.supports_population:
@@ -1513,16 +1601,27 @@ class SubqueryLoader(PostLoader):
"population - eager loading cannot be applied." % self
)
- path = path[self.parent_property]
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
- subq_info = path.get(context.attributes, "subqueryload_data")
+ subq = self._setup_query_from_rowproc(
+ context, path, path[-1], loadopt, adapter,
+ )
- if subq_info is None:
+ if subq is None:
return
- subq = subq_info["query"]
-
assert subq.session is None
+
+ path = path[self.parent_property]
+
local_cols = self.parent_property.local_columns
# cache the loaded collections in the context
@@ -1530,7 +1629,7 @@ class SubqueryLoader(PostLoader):
# call upon create_row_processor again
collections = path.get(context.attributes, "collections")
if collections is None:
- collections = self._SubqCollections(context, subq_info)
+ collections = self._SubqCollections(context, subq)
path.set(context.attributes, "collections", collections)
if adapter:
@@ -1634,7 +1733,6 @@ class JoinedLoader(AbstractRelationshipLoader):
if not compile_state.compile_options._enable_eagerloads:
return
elif self.uselist:
- compile_state.loaders_require_uniquing = True
compile_state.multi_row_eager_loaders = True
path = path[self.parent_property]
@@ -2142,7 +2240,15 @@ class JoinedLoader(AbstractRelationshipLoader):
return False
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
if not self.parent.class_manager[self.key].impl.supports_population:
raise sa_exc.InvalidRequestError(
@@ -2150,6 +2256,9 @@ class JoinedLoader(AbstractRelationshipLoader):
"population - eager loading cannot be applied." % self
)
+ if self.uselist:
+ context.loaders_require_uniquing = True
+
our_path = path[self.parent_property]
eager_adapter = self._create_eager_adapter(
@@ -2160,6 +2269,7 @@ class JoinedLoader(AbstractRelationshipLoader):
key = self.key
_instance = loading._instance_processor(
+ query_entity,
self.mapper,
context,
result,
@@ -2177,7 +2287,14 @@ class JoinedLoader(AbstractRelationshipLoader):
self.parent_property._get_strategy(
(("lazy", "select"),)
).create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
def _create_collection_loader(self, context, key, _instance, populators):
@@ -2382,11 +2499,26 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
return util.preloaded.ext_baked.bakery(size=50)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
if context.refresh_state:
return self._immediateload_create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
if not self.parent.class_manager[self.key].impl.supports_population:
@@ -2395,13 +2527,20 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
"population - eager loading cannot be applied." % self
)
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
+
selectin_path = (
context.compile_state.current_path or orm_util.PathRegistry.root
) + path
- if not orm_util._entity_isa(path[-1], self.parent):
- return
-
if loading.PostLoad.path_exists(
context, selectin_path, self.parent_property
):
@@ -2427,7 +2566,6 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
return
elif selectin_path_w_prop.contains_mapper(self.mapper):
return
-
loading.PostLoad.callable_for_path(
context,
selectin_path,
@@ -2543,7 +2681,39 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
)
)
- orig_query = context.query
+ # a test which exercises what these comments talk about is
+ # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic
+ #
+ # effective_entity above is given to us in terms of the cached
+ # statement, namely this one:
+ orig_query = context.compile_state.select_statement
+
+ # the actual statement that was requested is this one:
+ # context_query = context.query
+ #
+ # that's not the cached one, however. So while it is of the identical
+ # structure, if it has entities like AliasedInsp, which we get from
+ # aliased() or with_polymorphic(), the AliasedInsp will likely be a
+ # different object identity each time, and will not match up
+ # hashing-wise to the corresponding AliasedInsp that's in the
+ # cached query, meaning it won't match on paths and loader lookups
+ # and loaders like this one will be skipped if it is used in options.
+ #
+ # Now we want to transfer loader options from the parent query to the
+ # "selectinload" query we're about to run. Which query do we transfer
+ # the options from? We use the cached query, because the options in
+ # that query will be in terms of the effective entity we were just
+ # handed.
+ #
+ # But now the selectinload/ baked query we are running is *also*
+ # cached. What if it's cached and running from some previous iteration
+ # of that AliasedInsp? Well in that case it will also use the previous
+ # iteration of the loader options. If the baked query expires and
+ # gets generated again, it will be handed the current effective_entity
+ # and the current _with_options, again in terms of whatever
+ # compile_state.select_statement happens to be right now, so the
+ # query will still be internally consistent and loader callables
+ # will be correctly invoked.
q._add_lazyload_options(
orig_query._with_options, path[self.parent_property]
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 85f4f85d1..f7a97bfe5 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1187,9 +1187,9 @@ class Bundle(ORMColumnsClauseRole, SupportsCloneAnnotations, InspectionAttr):
return cloned
def __clause_element__(self):
- annotations = self._annotations.union(
- {"bundle": self, "entity_namespace": self}
- )
+ # ensure existing entity_namespace remains
+ annotations = {"bundle": self, "entity_namespace": self}
+ annotations.update(self._annotations)
return expression.ClauseList(
_literal_as_text_role=roles.ColumnsClauseRole,
group=False,
@@ -1258,6 +1258,8 @@ class _ORMJoin(expression.Join):
__visit_name__ = expression.Join.__visit_name__
+ inherit_cache = True
+
def __init__(
self,
left,
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 78de80734..a25c1b083 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -100,6 +100,7 @@ def __go(lcls):
from .elements import AnnotatedColumnElement
from .elements import ClauseList # noqa
from .selectable import AnnotatedFromClause # noqa
+ from .traversals import _preconfigure_traversals
from . import base
from . import coercions
@@ -122,6 +123,8 @@ def __go(lcls):
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
+ _preconfigure_traversals(ClauseElement)
+
_sa_util.preloaded.import_prefix("sqlalchemy.sql")
from . import naming # noqa
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 08ed121d3..8a0d6ec28 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -338,6 +338,15 @@ def _new_annotation_type(cls, base_cls):
anno_cls._traverse_internals = list(cls._traverse_internals) + [
("_annotations", InternalTraversal.dp_annotations_key)
]
+ elif cls.__dict__.get("inherit_cache", False):
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+
+ # some classes include this even if they have traverse_internals
+ # e.g. BindParameter, add it if present.
+ if cls.__dict__.get("inherit_cache", False):
+ anno_cls.inherit_cache = True
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 5dd3b519a..5f2ce8f14 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -624,19 +624,14 @@ class Executable(Generative):
_bind = None
_with_options = ()
_with_context_options = ()
- _cache_enable = True
_executable_traverse_internals = [
("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list),
("_with_context_options", ExtendedInternalTraversal.dp_plain_obj),
- ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj),
+ ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
]
@_generative
- def _disable_caching(self):
- self._cache_enable = HasCacheKey()
-
- @_generative
def options(self, *options):
"""Apply options to this statement.
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2519438d1..61178291a 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -373,6 +373,8 @@ class Compiled(object):
_cached_metadata = None
+ _result_columns = None
+
schema_translate_map = None
execution_options = util.immutabledict()
@@ -433,7 +435,6 @@ class Compiled(object):
self,
dialect,
statement,
- bind=None,
schema_translate_map=None,
render_schema_translate=False,
compile_kwargs=util.immutabledict(),
@@ -463,7 +464,6 @@ class Compiled(object):
"""
self.dialect = dialect
- self.bind = bind
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.schema_translate_map = schema_translate_map
@@ -527,24 +527,6 @@ class Compiled(object):
"""Return the bind params for this compiled object."""
return self.construct_params()
- def execute(self, *multiparams, **params):
- """Execute this compiled object."""
-
- e = self.bind
- if e is None:
- raise exc.UnboundExecutionError(
- "This Compiled object is not bound to any Engine "
- "or Connection.",
- code="2afi",
- )
- return e._execute_compiled(self, multiparams, params)
-
- def scalar(self, *multiparams, **params):
- """Execute this compiled object and return the result's
- scalar value."""
-
- return self.execute(*multiparams, **params).scalar()
-
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
@@ -687,6 +669,13 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
+ _cache_key_bind_match = None
+ """a mapping that will relate the BindParameter object we compile
+ to those that are part of the extracted collection of parameters
+ in the cache key, if we were given a cache key.
+
+ """
+
def __init__(
self,
dialect,
@@ -717,6 +706,9 @@ class SQLCompiler(Compiled):
self.cache_key = cache_key
+ if cache_key:
+ self._cache_key_bind_match = {b: b for b in cache_key[1]}
+
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
self.inline = inline or getattr(statement, "_inline", False)
@@ -875,8 +867,9 @@ class SQLCompiler(Compiled):
replace_context=err,
)
+ ckbm = self._cache_key_bind_match
resolved_extracted = {
- b.key: extracted
+ ckbm[b]: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
}
else:
@@ -907,7 +900,7 @@ class SQLCompiler(Compiled):
else:
if resolved_extracted:
value_param = resolved_extracted.get(
- bindparam.key, bindparam
+ bindparam, bindparam
)
else:
value_param = bindparam
@@ -936,9 +929,7 @@ class SQLCompiler(Compiled):
)
if resolved_extracted:
- value_param = resolved_extracted.get(
- bindparam.key, bindparam
- )
+ value_param = resolved_extracted.get(bindparam, bindparam)
else:
value_param = bindparam
@@ -2021,6 +2012,19 @@ class SQLCompiler(Compiled):
)
self.binds[bindparam.key] = self.binds[name] = bindparam
+
+ # if we are given a cache key that we're going to match against,
+ # relate the bindparam here to one that is most likely present
+ # in the "extracted params" portion of the cache key. this is used
+ # to set up a positional mapping that is used to determine the
+ # correct parameters for a subsequent use of this compiled with
+ # a different set of parameter values. here, we accommodate for
+ # parameters that may have been cloned both before and after the cache
+ # key was been generated.
+ ckbm = self._cache_key_bind_match
+ if ckbm:
+ ckbm.update({bp: bindparam for bp in bindparam._cloned_set})
+
if bindparam.isoutparam:
self.has_out_parameters = True
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 569030651..d3730b124 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -28,6 +28,9 @@ class _DDLCompiles(ClauseElement):
return dialect.ddl_compiler(dialect, self, **kw)
+ def _compile_w_cache(self, *arg, **kw):
+ raise NotImplementedError()
+
class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
"""Base class for DDL expression constructs.
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index a82641d77..50b2a935a 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -641,7 +641,7 @@ class ValuesBase(UpdateBase):
if self._preserve_parameter_order:
arg = [
(
- k,
+ coercions.expect(roles.DMLColumnRole, k),
coercions.expect(
roles.ExpressionElementRole,
v,
@@ -654,7 +654,7 @@ class ValuesBase(UpdateBase):
self._ordered_values = arg
else:
arg = {
- k: coercions.expect(
+ coercions.expect(roles.DMLColumnRole, k): coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
@@ -772,6 +772,7 @@ class Insert(ValuesBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
@@ -997,6 +998,7 @@ class Update(DMLWhereBase, ValuesBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
@@ -1187,7 +1189,7 @@ class Update(DMLWhereBase, ValuesBase):
)
arg = [
(
- k,
+ coercions.expect(roles.DMLColumnRole, k),
coercions.expect(
roles.ExpressionElementRole,
v,
@@ -1238,6 +1240,7 @@ class Delete(DMLWhereBase, UpdateBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8e1b623a7..60c816ee6 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -381,6 +381,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return
for attrname, obj, meth in _copy_internals.run_generated_dispatch(
@@ -410,6 +411,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return []
return itertools.chain.from_iterable(
@@ -516,10 +518,62 @@ class ClauseElement(
dialect = bind.dialect
elif self.bind:
dialect = self.bind.dialect
- bind = self.bind
else:
dialect = default.StrCompileDialect()
- return self._compiler(dialect, bind=bind, **kw)
+
+ return self._compiler(dialect, **kw)
+
+ def _compile_w_cache(
+ self,
+ dialect,
+ compiled_cache=None,
+ column_keys=None,
+ inline=False,
+ schema_translate_map=None,
+ **kw
+ ):
+ if compiled_cache is not None:
+ elem_cache_key = self._generate_cache_key()
+ else:
+ elem_cache_key = None
+
+ cache_hit = False
+
+ if elem_cache_key:
+ cache_key, extracted_params = elem_cache_key
+ key = (
+ dialect,
+ cache_key,
+ tuple(column_keys),
+ bool(schema_translate_map),
+ inline,
+ )
+ compiled_sql = compiled_cache.get(key)
+
+ if compiled_sql is None:
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+ compiled_cache[key] = compiled_sql
+ else:
+ cache_hit = True
+ else:
+ extracted_params = None
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+
+ return compiled_sql, extracted_params, cache_hit
def _compiler(self, dialect, **kw):
"""Return a compiler appropriate for this ClauseElement, given a
@@ -1035,6 +1089,10 @@ class BindParameter(roles.InElementRole, ColumnElement):
_is_bind_parameter = True
_key_is_anon = False
+ # bindparam implements its own _gen_cache_key() method however
+ # we check subclasses for this flag, else no cache key is generated
+ inherit_cache = True
+
def __init__(
self,
key,
@@ -1396,6 +1454,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
return c
def _gen_cache_key(self, anon_map, bindparams):
+ _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False)
+
+ if not _gen_cache_ok:
+ if anon_map is not None:
+ anon_map[NO_CACHE] = True
+ return None
+
idself = id(self)
if idself in anon_map:
return (anon_map[idself], self.__class__)
@@ -2082,6 +2147,7 @@ class ClauseList(
roles.InElementRole,
roles.OrderByRole,
roles.ColumnsClauseRole,
+ roles.DMLColumnRole,
ClauseElement,
):
"""Describe a list of clauses, separated by an operator.
@@ -2174,6 +2240,7 @@ class ClauseList(
class BooleanClauseList(ClauseList, ColumnElement):
__visit_name__ = "clauselist"
+ inherit_cache = True
_tuple_values = False
@@ -3428,6 +3495,8 @@ class CollectionAggregate(UnaryExpression):
class AsBoolean(WrapsColumnExpression, UnaryExpression):
+ inherit_cache = True
+
def __init__(self, element, operator, negate):
self.element = element
self.type = type_api.BOOLEANTYPE
@@ -3474,6 +3543,7 @@ class BinaryExpression(ColumnElement):
("operator", InternalTraversal.dp_operator),
("negate", InternalTraversal.dp_operator),
("modifiers", InternalTraversal.dp_plain_dict),
+ ("type", InternalTraversal.dp_type,), # affects JSON CAST operators
]
_is_implicitly_boolean = True
@@ -3482,41 +3552,6 @@ class BinaryExpression(ColumnElement):
"""
- def _gen_cache_key(self, anon_map, bindparams):
- # inlined for performance
-
- idself = id(self)
-
- if idself in anon_map:
- return (anon_map[idself], self.__class__)
- else:
- # inline of
- # id_ = anon_map[idself]
- anon_map[idself] = id_ = str(anon_map.index)
- anon_map.index += 1
-
- if self._cache_key_traversal is NO_CACHE:
- anon_map[NO_CACHE] = True
- return None
-
- result = (id_, self.__class__)
-
- return result + (
- ("left", self.left._gen_cache_key(anon_map, bindparams)),
- ("right", self.right._gen_cache_key(anon_map, bindparams)),
- ("operator", self.operator),
- ("negate", self.negate),
- (
- "modifiers",
- tuple(
- (key, self.modifiers[key])
- for key in sorted(self.modifiers)
- )
- if self.modifiers
- else None,
- ),
- )
-
def __init__(
self, left, right, operator, type_=None, negate=None, modifiers=None
):
@@ -3587,15 +3622,30 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
_traverse_internals = [
- ("start", InternalTraversal.dp_plain_obj),
- ("stop", InternalTraversal.dp_plain_obj),
- ("step", InternalTraversal.dp_plain_obj),
+ ("start", InternalTraversal.dp_clauseelement),
+ ("stop", InternalTraversal.dp_clauseelement),
+ ("step", InternalTraversal.dp_clauseelement),
]
- def __init__(self, start, stop, step):
- self.start = start
- self.stop = stop
- self.step = step
+ def __init__(self, start, stop, step, _name=None):
+ self.start = coercions.expect(
+ roles.ExpressionElementRole,
+ start,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.stop = coercions.expect(
+ roles.ExpressionElementRole,
+ stop,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.step = coercions.expect(
+ roles.ExpressionElementRole,
+ step,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
self.type = type_api.NULLTYPE
def self_group(self, against=None):
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 6b1172eba..7b723f371 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -744,6 +744,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
coerce_arguments = True
_register = False
+ inherit_cache = True
def __init__(self, *args, **kwargs):
parsed_args = kwargs.pop("_parsed_args", None)
@@ -808,6 +809,8 @@ class next_value(GenericFunction):
class AnsiFunction(GenericFunction):
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
GenericFunction.__init__(self, *args, **kwargs)
@@ -815,6 +818,8 @@ class AnsiFunction(GenericFunction):
class ReturnTypeFromArgs(GenericFunction):
"""Define a function whose return type is the same as its arguments."""
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
args = [
coercions.expect(
@@ -832,30 +837,34 @@ class ReturnTypeFromArgs(GenericFunction):
class coalesce(ReturnTypeFromArgs):
_has_args = True
+ inherit_cache = True
class max(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class min(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class sum(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class now(GenericFunction): # noqa
type = sqltypes.DateTime
+ inherit_cache = True
class concat(GenericFunction):
type = sqltypes.String
+ inherit_cache = True
class char_length(GenericFunction):
type = sqltypes.Integer
+ inherit_cache = True
def __init__(self, arg, **kwargs):
GenericFunction.__init__(self, arg, **kwargs)
@@ -863,6 +872,7 @@ class char_length(GenericFunction):
class random(GenericFunction):
_has_args = True
+ inherit_cache = True
class count(GenericFunction):
@@ -887,6 +897,7 @@ class count(GenericFunction):
"""
type = sqltypes.Integer
+ inherit_cache = True
def __init__(self, expression=None, **kwargs):
if expression is None:
@@ -896,38 +907,47 @@ class count(GenericFunction):
class current_date(AnsiFunction):
type = sqltypes.Date
+ inherit_cache = True
class current_time(AnsiFunction):
type = sqltypes.Time
+ inherit_cache = True
class current_timestamp(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class current_user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class localtime(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class localtimestamp(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class session_user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class sysdate(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class array_agg(GenericFunction):
@@ -951,6 +971,7 @@ class array_agg(GenericFunction):
"""
type = sqltypes.ARRAY
+ inherit_cache = True
def __init__(self, *args, **kwargs):
args = [
@@ -978,6 +999,7 @@ class OrderedSetAgg(GenericFunction):
:meth:`.FunctionElement.within_group` method."""
array_for_multi_clause = False
+ inherit_cache = True
def within_group_type(self, within_group):
func_clauses = self.clause_expr.element
@@ -1000,6 +1022,8 @@ class mode(OrderedSetAgg):
"""
+ inherit_cache = True
+
class percentile_cont(OrderedSetAgg):
"""implement the ``percentile_cont`` ordered-set aggregate function.
@@ -1016,6 +1040,7 @@ class percentile_cont(OrderedSetAgg):
"""
array_for_multi_clause = True
+ inherit_cache = True
class percentile_disc(OrderedSetAgg):
@@ -1033,6 +1058,7 @@ class percentile_disc(OrderedSetAgg):
"""
array_for_multi_clause = True
+ inherit_cache = True
class rank(GenericFunction):
@@ -1048,6 +1074,7 @@ class rank(GenericFunction):
"""
type = sqltypes.Integer()
+ inherit_cache = True
class dense_rank(GenericFunction):
@@ -1063,6 +1090,7 @@ class dense_rank(GenericFunction):
"""
type = sqltypes.Integer()
+ inherit_cache = True
class percent_rank(GenericFunction):
@@ -1078,6 +1106,7 @@ class percent_rank(GenericFunction):
"""
type = sqltypes.Numeric()
+ inherit_cache = True
class cume_dist(GenericFunction):
@@ -1093,6 +1122,7 @@ class cume_dist(GenericFunction):
"""
type = sqltypes.Numeric()
+ inherit_cache = True
class cube(GenericFunction):
@@ -1109,6 +1139,7 @@ class cube(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
class rollup(GenericFunction):
@@ -1125,6 +1156,7 @@ class rollup(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
class grouping_sets(GenericFunction):
@@ -1158,3 +1190,4 @@ class grouping_sets(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index ee411174c..29ca81d26 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -1013,6 +1013,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
__visit_name__ = "column"
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
r"""
Construct a new ``Column`` object.
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index a95fc561a..54f293967 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -61,6 +61,8 @@ if util.TYPE_CHECKING:
class _OffsetLimitParam(BindParameter):
+ inherit_cache = True
+
@property
def _limit_offset_value(self):
return self.effective_value
@@ -1426,6 +1428,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
__visit_name__ = "alias"
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None, flat=False):
"""Return an :class:`_expression.Alias` object.
@@ -1500,6 +1504,8 @@ class Lateral(AliasedReturnsRows):
__visit_name__ = "lateral"
_is_lateral = True
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`_expression.Lateral` object.
@@ -1626,7 +1632,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
AliasedReturnsRows._traverse_internals
+ [
("_cte_alias", InternalTraversal.dp_clauseelement),
- ("_restates", InternalTraversal.dp_clauseelement_unordered_set),
+ ("_restates", InternalTraversal.dp_clauseelement_list),
("recursive", InternalTraversal.dp_boolean),
]
+ HasPrefixes._has_prefixes_traverse_internals
@@ -1651,7 +1657,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
name=None,
recursive=False,
_cte_alias=None,
- _restates=frozenset(),
+ _restates=(),
_prefixes=None,
_suffixes=None,
):
@@ -1692,7 +1698,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
self.element.union(other),
name=self.name,
recursive=self.recursive,
- _restates=self._restates.union([self]),
+ _restates=self._restates + (self,),
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
@@ -1702,7 +1708,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
self.element.union_all(other),
name=self.name,
recursive=self.recursive,
- _restates=self._restates.union([self]),
+ _restates=self._restates + (self,),
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
@@ -1918,6 +1924,8 @@ class Subquery(AliasedReturnsRows):
_is_subquery = True
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`.Subquery` object.
@@ -3783,15 +3791,15 @@ class Select(
("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
- ("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
- (
- "_correlate_except",
- InternalTraversal.dp_clauseelement_unordered_set,
- ),
+ ("_correlate", InternalTraversal.dp_clauseelement_list),
+ ("_correlate_except", InternalTraversal.dp_clauseelement_list,),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
("_for_update_arg", InternalTraversal.dp_clauseelement),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
("_label_style", InternalTraversal.dp_plain_obj),
+ ("_is_future", InternalTraversal.dp_boolean),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
@@ -4522,7 +4530,7 @@ class Select(
if fromclauses and fromclauses[0] is None:
self._correlate = ()
else:
- self._correlate = set(self._correlate).union(
+ self._correlate = self._correlate + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
@@ -4560,7 +4568,7 @@ class Select(
if fromclauses and fromclauses[0] is None:
self._correlate_except = ()
else:
- self._correlate_except = set(self._correlate_except or ()).union(
+ self._correlate_except = (self._correlate_except or ()) + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
@@ -4866,6 +4874,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
_from_objects = []
_is_from_container = True
_is_implicitly_boolean = False
+ inherit_cache = True
def __init__(self, element):
self.element = element
@@ -4899,6 +4908,7 @@ class Exists(UnaryExpression):
"""
_from_objects = []
+ inherit_cache = True
def __init__(self, *args, **kwargs):
"""Construct a new :class:`_expression.Exists` against an existing
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 732b775f6..9cd9d5058 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2616,26 +2616,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
return_type = self.type
if self.type.zero_indexes:
index = slice(index.start + 1, index.stop + 1, index.step)
- index = Slice(
- coercions.expect(
- roles.ExpressionElementRole,
- index.start,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
- coercions.expect(
- roles.ExpressionElementRole,
- index.stop,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
- coercions.expect(
- roles.ExpressionElementRole,
- index.step,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
+ slice_ = Slice(
+ index.start, index.stop, index.step, _name=self.expr.key
)
+ return operators.getitem, slice_, return_type
else:
if self.type.zero_indexes:
index += 1
@@ -2647,7 +2631,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
self.type.__class__, **adapt_kw
)
- return operators.getitem, index, return_type
+ return operators.getitem, index, return_type
def contains(self, *arg, **kw):
raise NotImplementedError(
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 68281f33d..ed0bfa27a 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -19,6 +19,7 @@ NO_CACHE = util.symbol("no_cache")
CACHE_IN_PLACE = util.symbol("cache_in_place")
CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
STATIC_CACHE_KEY = util.symbol("static_cache_key")
+PROPAGATE_ATTRS = util.symbol("propagate_attrs")
ANON_NAME = util.symbol("anon_name")
@@ -31,10 +32,74 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
+def _preconfigure_traversals(target_hierarchy):
+
+ stack = [target_hierarchy]
+ while stack:
+ cls = stack.pop()
+ stack.extend(cls.__subclasses__())
+
+ if hasattr(cls, "_traverse_internals"):
+ cls._generate_cache_attrs()
+ _copy_internals.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_copy_internals_traversal",
+ )
+ _get_children.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_get_children_traversal",
+ )
+
+
class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
__slots__ = ()
+ @classmethod
+ def _generate_cache_attrs(cls):
+ """generate cache key dispatcher for a new class.
+
+ This sets the _generated_cache_key_traversal attribute once called
+ so should only be called once per class.
+
+ """
+ inherit = cls.__dict__.get("inherit_cache", False)
+
+ if inherit:
+ _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
+ if _cache_key_traversal is None:
+ try:
+ _cache_key_traversal = cls._traverse_internals
+ except AttributeError:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ # TODO: wouldn't we instead get this from our superclass?
+ # also, our superclass may not have this yet, but in any case,
+ # we'd generate for the superclass that has it. this is a little
+ # more complicated, so for the moment this is a little less
+ # efficient on startup but simpler.
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+ else:
+ _cache_key_traversal = cls.__dict__.get(
+ "_cache_key_traversal", None
+ )
+ if _cache_key_traversal is None:
+ _cache_key_traversal = cls.__dict__.get(
+ "_traverse_internals", None
+ )
+ if _cache_key_traversal is None:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+
@util.preload_module("sqlalchemy.sql.elements")
def _gen_cache_key(self, anon_map, bindparams):
"""return an optional cache key.
@@ -72,14 +137,18 @@ class HasCacheKey(object):
else:
id_ = None
- _cache_key_traversal = self._cache_key_traversal
- if _cache_key_traversal is None:
- try:
- _cache_key_traversal = self._traverse_internals
- except AttributeError:
- _cache_key_traversal = NO_CACHE
+ try:
+ dispatcher = self.__class__.__dict__[
+ "_generated_cache_key_traversal"
+ ]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.__class__._generate_cache_attrs()
- if _cache_key_traversal is NO_CACHE:
+ if dispatcher is NO_CACHE:
if anon_map is not None:
anon_map[NO_CACHE] = True
return None
@@ -87,19 +156,13 @@ class HasCacheKey(object):
result = (id_, self.__class__)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
- try:
- dispatcher = self.__class__.__dict__[
- "_generated_cache_key_traversal"
- ]
- except KeyError:
- dispatcher = _cache_key_traversal_visitor.generate_dispatch(
- self, _cache_key_traversal, "_generated_cache_key_traversal"
- )
for attrname, obj, meth in dispatcher(
self, _cache_key_traversal_visitor
):
if obj is not None:
+ # TODO: see if C code can help here as Python lacks an
+ # efficient switch construct
if meth is CACHE_IN_PLACE:
# cache in place is always going to be a Python
# tuple, dict, list, etc. so we can do a boolean check
@@ -116,6 +179,15 @@ class HasCacheKey(object):
attrname,
obj._gen_cache_key(anon_map, bindparams),
)
+ elif meth is PROPAGATE_ATTRS:
+ if obj:
+ result += (
+ attrname,
+ obj["compile_state_plugin"],
+ obj["plugin_subject"]._gen_cache_key(
+ anon_map, bindparams
+ ),
+ )
elif meth is InternalTraversal.dp_annotations_key:
# obj is here is the _annotations dict. however,
# we want to use the memoized cache key version of it.
@@ -332,6 +404,8 @@ class _CacheKey(ExtendedInternalTraversal):
visit_type = STATIC_CACHE_KEY
visit_anon_name = ANON_NAME
+ visit_propagate_attrs = PROPAGATE_ATTRS
+
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
@@ -445,10 +519,16 @@ class _CacheKey(ExtendedInternalTraversal):
def visit_setup_join_tuple(
self, attrname, obj, parent, anon_map, bindparams
):
+ is_legacy = "legacy" in attrname
+
return tuple(
(
- target._gen_cache_key(anon_map, bindparams),
- onclause._gen_cache_key(anon_map, bindparams)
+ target
+ if is_legacy and isinstance(target, str)
+ else target._gen_cache_key(anon_map, bindparams),
+ onclause
+ if is_legacy and isinstance(onclause, str)
+ else onclause._gen_cache_key(anon_map, bindparams)
if onclause is not None
else None,
from_._gen_cache_key(anon_map, bindparams)
@@ -711,6 +791,11 @@ class _CopyInternals(InternalTraversal):
for sequence in element
]
+ def visit_propagate_attrs(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return element
+
_copy_internals = _CopyInternals()
@@ -782,6 +867,9 @@ class _GetChildren(InternalTraversal):
def visit_dml_multi_values(self, element, **kw):
return ()
+ def visit_propagate_attrs(self, element, **kw):
+ return ()
+
_get_children = _GetChildren()
@@ -916,6 +1004,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
):
return COMPARE_FAILED
+ def visit_propagate_attrs(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.compare_inner(
+ left.get("plugin_subject", None), right.get("plugin_subject", None)
+ )
+
def visit_has_cache_key_list(
self, attrname, left_parent, left, right_parent, right, **kw
):
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index ccda21e11..fe3634bad 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -555,7 +555,12 @@ class TypeEngine(Traversible):
def _static_cache_key(self):
names = util.get_cls_kwargs(self.__class__)
return (self.__class__,) + tuple(
- (k, self.__dict__[k])
+ (
+ k,
+ self.__dict__[k]._static_cache_key
+ if isinstance(self.__dict__[k], TypeEngine)
+ else self.__dict__[k],
+ )
for k in names
if k in self.__dict__ and not k.startswith("_")
)
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 5de68f504..904702003 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -217,18 +217,23 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
try:
dispatcher = target.__class__.__dict__[generate_dispatcher_name]
except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
dispatcher = self.generate_dispatch(
- target, internal_dispatch, generate_dispatcher_name
+ target.__class__, internal_dispatch, generate_dispatcher_name
)
return dispatcher(target, self)
def generate_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
+ self, target_cls, internal_dispatch, generate_dispatcher_name
):
dispatcher = _generate_dispatcher(
self, internal_dispatch, generate_dispatcher_name
)
- setattr(target.__class__, generate_dispatcher_name, dispatcher)
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
return dispatcher
dp_has_cache_key = symbol("HC")
@@ -263,10 +268,6 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
- dp_clauseelement_unordered_set = symbol("CU")
- """Visit an unordered set of :class:`_expression.ClauseElement`
- objects. """
-
dp_fromclause_ordered_set = symbol("CO")
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
@@ -414,6 +415,10 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_propagate_attrs = symbol("PA")
+ """Visit the propagate attrs dict. this hardcodes to the particular
+ elements we care about right now."""
+
class ExtendedInternalTraversal(InternalTraversal):
"""defines additional symbols that are useful in caching applications.
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 7988b4ec9..48cbb4694 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -97,13 +97,13 @@ class CompiledSQL(SQLMatchRule):
else:
map_ = None
- if isinstance(context.compiled.statement, _DDLCompiles):
+ if isinstance(execute_observed.clauseelement, _DDLCompiles):
- compiled = context.compiled.statement.compile(
+ compiled = execute_observed.clauseelement.compile(
dialect=compare_dialect, schema_translate_map=map_
)
else:
- compiled = context.compiled.statement.compile(
+ compiled = execute_observed.clauseelement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
inline=context.compiled.inline,