summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_14/lmb_no_cache.rst6
-rw-r--r--lib/sqlalchemy/sql/lambdas.py99
-rw-r--r--test/sql/test_lambdas.py58
3 files changed, 134 insertions, 29 deletions
diff --git a/doc/build/changelog/unreleased_14/lmb_no_cache.rst b/doc/build/changelog/unreleased_14/lmb_no_cache.rst
new file mode 100644
index 000000000..4f6e19320
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/lmb_no_cache.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: bug, sql
+
+ Fixed issue in lambda caching system where an element of a query that
+ produces no cache key, like a custom option or clause element, would still
+ populate the expression in the "lambda cache" inappropriately.
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index d33e8ebfb..36e470ce7 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -182,28 +182,49 @@ class LambdaElement(elements.ClauseElement):
self._resolved_bindparams = bindparams = []
- anon_map = traversals.anon_map()
- cache_key = tuple(
- [
- getter(closure, opts, anon_map, bindparams)
- for getter in tracker.closure_trackers
- ]
- )
-
if self.parent_lambda is not None:
- cache_key = self.parent_lambda.closure_cache_key + cache_key
+ parent_closure_cache_key = self.parent_lambda.closure_cache_key
+ else:
+ parent_closure_cache_key = ()
+
+ if parent_closure_cache_key is not traversals.NO_CACHE:
+ anon_map = traversals.anon_map()
+ cache_key = tuple(
+ [
+ getter(closure, opts, anon_map, bindparams)
+ for getter in tracker.closure_trackers
+ ]
+ )
- self.closure_cache_key = cache_key
+ if traversals.NO_CACHE not in anon_map:
+ cache_key = parent_closure_cache_key + cache_key
- try:
- rec = lambda_cache[tracker_key + cache_key]
- except KeyError:
+ self.closure_cache_key = cache_key
+
+ try:
+ rec = lambda_cache[tracker_key + cache_key]
+ except KeyError:
+ rec = None
+ else:
+ cache_key = traversals.NO_CACHE
+ rec = None
+
+ else:
+ cache_key = traversals.NO_CACHE
rec = None
+ self.closure_cache_key = cache_key
+
if rec is None:
- rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn)
- rec.closure_bindparams = bindparams
- lambda_cache[tracker_key + cache_key] = rec
+ if cache_key is not traversals.NO_CACHE:
+ rec = AnalyzedFunction(
+ tracker, self, apply_propagate_attrs, fn
+ )
+ rec.closure_bindparams = bindparams
+ lambda_cache[tracker_key + cache_key] = rec
+ else:
+ rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
+
else:
bindparams[:] = [
orig_bind._with_value(new_bind.value, maintain_key=True)
@@ -212,21 +233,24 @@ class LambdaElement(elements.ClauseElement):
)
]
- if self.parent_lambda is not None:
- bindparams[:0] = self.parent_lambda._resolved_bindparams
-
self._rec = rec
- lambda_element = self
- while lambda_element is not None:
- rec = lambda_element._rec
- if rec.bindparam_trackers:
- tracker_instrumented_fn = rec.tracker_instrumented_fn
- for tracker in rec.bindparam_trackers:
- tracker(
- lambda_element.fn, tracker_instrumented_fn, bindparams
- )
- lambda_element = lambda_element.parent_lambda
+ if cache_key is not traversals.NO_CACHE:
+ if self.parent_lambda is not None:
+ bindparams[:0] = self.parent_lambda._resolved_bindparams
+
+ lambda_element = self
+ while lambda_element is not None:
+ rec = lambda_element._rec
+ if rec.bindparam_trackers:
+ tracker_instrumented_fn = rec.tracker_instrumented_fn
+ for tracker in rec.bindparam_trackers:
+ tracker(
+ lambda_element.fn,
+ tracker_instrumented_fn,
+ bindparams,
+ )
+ lambda_element = lambda_element.parent_lambda
return rec
@@ -304,6 +328,9 @@ class LambdaElement(elements.ClauseElement):
return expr
def _gen_cache_key(self, anon_map, bindparams):
+ if self.closure_cache_key is traversals.NO_CACHE:
+ anon_map[traversals.NO_CACHE] = True
+ return None
cache_key = (
self.fn.__code__,
@@ -914,6 +941,20 @@ class AnalyzedCode(object):
)
+class NonAnalyzedFunction(object):
+ __slots__ = ("expr",)
+
+ closure_bindparams = None
+ bindparam_trackers = None
+
+ def __init__(self, expr):
+ self.expr = expr
+
+ @property
+ def expected_expr(self):
+ return self.expr
+
+
class AnalyzedFunction(object):
__slots__ = (
"analyzed_code",
diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py
index 51530b079..2e794d7bc 100644
--- a/test/sql/test_lambdas.py
+++ b/test/sql/test_lambdas.py
@@ -17,6 +17,7 @@ from sqlalchemy.sql import roles
from sqlalchemy.sql import select
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql.traversals import HasCacheKey
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
@@ -953,6 +954,63 @@ class LambdaElementTest(
eq_(s1key.key, s2key.key)
ne_(s1key.key, s3key.key)
+ def test_stmt_lambda_opt_w_key(self):
+ """test issue related to #6887"""
+
+ def go(opts):
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+ stmt += lambda stmt: stmt.options(*opts)
+
+ return stmt
+
+ class SomeOpt(HasCacheKey):
+ def _gen_cache_key(self, anon_map, bindparams):
+ return ("fixed_key",)
+
+ # generates no key, will not be cached
+ eq_(SomeOpt()._generate_cache_key().key, ("fixed_key",))
+
+ s1o, s2o = SomeOpt(), SomeOpt()
+ s1 = go([s1o])
+ s2 = go([s2o])
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ eq_(s1key.key[-1], (("fixed_key",),))
+ eq_(s1key.key, s2key.key)
+
+ eq_(s1._resolved._with_options, (s1o,))
+ eq_(s2._resolved._with_options, (s1o,))
+ ne_(s2._resolved._with_options, (s2o,))
+
+ def test_stmt_lambda_opt_w_no_key(self):
+ """test issue related to #6887"""
+
+ def go(opts):
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+ stmt += lambda stmt: stmt.options(*opts)
+
+ return stmt
+
+ class SomeOpt(HasCacheKey):
+ pass
+
+ # generates no key, will not be cached
+ eq_(SomeOpt()._generate_cache_key(), None)
+
+ s1o, s2o = SomeOpt(), SomeOpt()
+ s1 = go([s1o])
+ s2 = go([s2o])
+
+ s1key = s1._generate_cache_key()
+
+ eq_(s1key, None)
+
+ eq_(s1._resolved._with_options, (s1o,))
+ eq_(s2._resolved._with_options, (s2o,))
+ ne_(s2._resolved._with_options, (s1o,))
+
def test_stmt_lambda_hey_theres_multiple_paths(self):
def go(x, y):
stmt = lambdas.lambda_stmt(lambda: select(column("x")))