diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-05 16:42:26 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-05 16:42:26 -0400 |
| commit | cc57ea495f6460dd56daa6de57e40047ed999369 (patch) | |
| tree | 837f5a84363c387d7f8fdeabc06928cd078028e1 /test/sql | |
| parent | 2a946254023135eddd222974cf300ffaa5583f02 (diff) | |
| download | sqlalchemy-cc57ea495f6460dd56daa6de57e40047ed999369.tar.gz | |
Robustness for lambdas, lambda statements
in order to accommodate relationship loaders
with lambda caching, a lot more is needed. This is
a full refactor of the lambda system such that it
now has two levels of caching; the first level caches what
can be known from the __code__ element, then the next level
of caching is against the lambda itself and the contents
of __closure__. This allows for the elements inside
the lambdas, like columns and entities, to change and
then be part of the cache key. Lazy/selectinloads' use of
baked queries had to add distinct cache key elements,
which was attempted here but overall things needed to be
more robust than that.
This commit is broken out from the very long and sprawling
commit at Id6b5c03b1ce9ddb7b280f66792212a0ef0a1c541 .
Change-Id: I29a513c98917b1d503abfdd61e6b6e8800851aa8
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_compare.py | 22 | ||||
| -rw-r--r-- | test/sql/test_external_traversal.py | 84 | ||||
| -rw-r--r-- | test/sql/test_lambdas.py | 461 | ||||
| -rw-r--r-- | test/sql/test_selectable.py | 28 |
4 files changed, 585 insertions, 10 deletions
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 7ac716dbe..b573accbd 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -131,6 +131,11 @@ class MyEntity(HasCacheKey): ] +class Foo: + x = 10 + y = 15 + + dml.Insert.argument_for("sqlite", "foo", None) dml.Update.argument_for("sqlite", "foo", None) dml.Delete.argument_for("sqlite", "foo", None) @@ -790,7 +795,7 @@ class CoreFixtures(object): def two(): r = random.randint(1, 10) - q = 20 + q = 408 return LambdaElement( lambda: table_a.c.a + q == r, roles.WhereHavingRole ) @@ -803,10 +808,6 @@ class CoreFixtures(object): roles.WhereHavingRole, ) - class Foo: - x = 10 - y = 15 - def four(): return LambdaElement( lambda: and_(table_a.c.a == Foo.x), roles.WhereHavingRole @@ -833,6 +834,16 @@ class CoreFixtures(object): lambda s: s.where(table_a.c.a == value) ) + from sqlalchemy.sql import lambdas + + def eight(): + q = 5 + return lambdas.DeferredLambdaElement( + lambda t: t.c.a > q, + roles.WhereHavingRole, + lambda_args=(table_a,), + ) + return [ one(), two(), @@ -841,6 +852,7 @@ class CoreFixtures(object): five(), six(), seven(), + eight(), ] dont_compare_values_fixtures.append(_lambda_fixtures) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index aefcaf252..4918afc9c 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -791,6 +791,90 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): "JOIN table2 ON table1.col1 = table2.col2) AS anon_1", ) + def test_this_thing_using_setup_joins_three(self): + + j = t1.join(t2, t1.c.col1 == t2.c.col2) + + s1 = select(j) + + s2 = s1.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s2, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + vis = sql_util.ClauseAdapter(j) + + s3 = vis.traverse(s1) + + s4 = s3.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s4, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s5 = vis.traverse(s3) + + s6 = s5.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s6, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + def test_this_thing_using_setup_joins_four(self): + + j = t1.join(t2, t1.c.col1 == t2.c.col2) + + s1 = select(j) + + assert not s1._from_obj + + s2 = s1.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s2, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s3 = visitors.replacement_traverse(s1, {}, lambda elem: None) + + s4 = s3.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s4, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s5 = visitors.replacement_traverse(s3, {}, lambda elem: None) + + s6 = s5.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s6, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + def test_select_fromtwice_one(self): t1a = t1.alias() diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index 53f6a9544..a91242de5 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -5,6 +5,7 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKey from sqlalchemy.schema import Table from sqlalchemy.sql import and_ +from sqlalchemy.sql import bindparam from sqlalchemy.sql import coercions from sqlalchemy.sql import column from sqlalchemy.sql import join @@ -19,6 +20,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import ne_ from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.types import Integer from sqlalchemy.types import String @@ -46,10 +48,39 @@ class DeferredLambdaTest( go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :x_1 AND t1.p = :y_1" ) + def test_global_tracking(self): + t1 = table("t1", column("q"), column("p")) + + global global_x, global_y + + global_x = 10 + global_y = 17 + + def go(): + return select([t1]).where( + lambda: and_(t1.c.q == global_x, t1.c.p == global_y) + ) + + self.assert_compile( + go(), + "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 " + "AND t1.p = :global_y_1", + checkparams={"global_x_1": 10, "global_y_1": 17}, + ) + + global_y = 9 + + self.assert_compile( + go(), + "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 " + "AND t1.p = :global_y_1", + checkparams={"global_x_1": 10, "global_y_1": 9}, + ) + def test_stale_checker_embedded(self): def go(x): - stmt = select([lambda: x]) + stmt = select(lambda: x) return stmt c1 = column("x") @@ -67,7 +98,7 @@ class DeferredLambdaTest( def test_stale_checker_statement(self): def go(x): - stmt = lambdas.lambda_stmt(lambda: select([x])) + stmt = lambdas.lambda_stmt(lambda: select(x)) return stmt c1 = column("x") @@ -85,13 +116,13 @@ class DeferredLambdaTest( def test_stale_checker_linked(self): def go(x, y): - stmt = lambdas.lambda_stmt(lambda: select([x])) + ( + stmt = lambdas.lambda_stmt(lambda: select(x)) + ( lambda s: s.where(y > 5) ) return stmt - c1 = column("x") - c2 = column("y") + c1 = oldc1 = column("x") + c2 = oldc2 = column("y") s1 = go(c1, c2) s2 = go(c1, c2) @@ -104,6 +135,426 @@ class DeferredLambdaTest( s3 = go(c1, c2) self.assert_compile(s3, "SELECT q WHERE p > :p_1") + s4 = go(c1, c2) + self.assert_compile(s4, "SELECT q WHERE p > :p_1") + + s5 = go(oldc1, oldc2) + self.assert_compile(s5, "SELECT x WHERE y > :y_1") + + def test_stmt_lambda_w_additional_hascachekey_variants(self): + def go(col_expr, q): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(col_expr == q) + + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + self.assert_compile( + s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + self.assert_compile( + s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + ne_(s1key[0], s2key[0]) + + def test_stmt_lambda_w_atonce_whereclause_values_notrack(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause), enable_tracking=False + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5) + s2 = go(c1, c1 == 10) + + self.assert_compile( + s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + + # and as we see, this is wrong. Because whereclause + # is fixed for the lambda and we do not re-evaluate the closure + # for this value changing. this can't be passed unless + # enable_tracking=False. + self.assert_compile( + s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + + def test_stmt_lambda_w_atonce_whereclause_values(self): + c2 = column("y") + + def go(col_expr, whereclause, x): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by(c2 > x), + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 ORDER BY y > :x_2", + checkparams={"x_1": 5, "x_2": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 ORDER BY y > :x_2", + checkparams={"x_1": 10, "x_2": 15}, + ) + + def test_stmt_lambda_plain_customtrack(self): + c2 = column("y") + + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria(lambda stmt: stmt.where(whereclause)) + stmt = stmt.add_criteria( + lambda stmt: stmt.order_by(col_expr), track_on=(col_expr,) + ) + stmt = stmt.add_criteria(lambda stmt: stmt.where(col_expr == p)) + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + s3 = go(c2, c2 == 18, 12) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + eq_([b.value for b in s3key.bindparams], [18, 12]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x", + checkparams={"x_1": 5, "p_1": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x", + checkparams={"x_1": 10, "p_1": 15}, + ) + + self.assert_compile( + s3, + "SELECT y WHERE y = :y_1 AND y = :p_1 ORDER BY y", + checkparams={"y_1": 18, "p_1": 12}, + ) + + def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self): + c2 = column("y") + + # this pattern is *completely unnecessary*, and I would prefer + # if we can detect this and just raise, because when it is not done + # correctly, it is *extremely* difficult to catch it failing. + # however I also can't come up with a reliable way to catch it. + # so we will keep the use of "track_on" to be internal. + + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by(col_expr > p), + track_on=(whereclause, whereclause.right.value), + ) + + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + s3 = go(c2, c2 == 18, 12) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + eq_([b.value for b in s3key.bindparams], [18, 12]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 ORDER BY x > :p_1", + checkparams={"x_1": 5, "p_1": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 ORDER BY x > :p_1", + checkparams={"x_1": 10, "p_1": 15}, + ) + + self.assert_compile( + s3, + "SELECT y WHERE y = :y_1 ORDER BY y > :p_1", + checkparams={"y_1": 18, "p_1": 12}, + ) + + def test_stmt_lambda_track_closure_binds_one(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5) + s2 = go(c1, c1 == 10) + + self.assert_compile( + s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + self.assert_compile( + s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 10} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + eq_(s1key.key, s2key.key) + + eq_([b.value for b in s1key.bindparams], [5]) + eq_([b.value for b in s2key.bindparams], [10]) + + def test_stmt_lambda_track_closure_binds_two(self): + def go(col_expr, whereclause, x, y): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause).where( + and_(c1 == x, c1 < y) + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 8, 9) + s2 = go(c1, c1 == 10, 12, 14) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 5, "x_2": 8, "y_1": 9}, + ) + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 10, "x_2": 12, "y_1": 14}, + ) + + eq_([b.value for b in s1key.bindparams], [5, 8, 9]) + eq_([b.value for b in s2key.bindparams], [10, 12, 14]) + + s1_compiled_cached = s1.compile(cache_key=s1key) + + params = s1_compiled_cached.construct_params( + extracted_parameters=s2key[1] + ) + + eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14}) + + def test_stmt_lambda_track_closure_binds_three(self): + def go(col_expr, whereclause, x, y): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + stmt += lambda stmt: stmt.where(and_(c1 == x, c1 < y)) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 8, 9) + s2 = go(c1, c1 == 10, 12, 14) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 5, "x_2": 8, "y_1": 9}, + ) + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 10, "x_2": 12, "y_1": 14}, + ) + + eq_([b.value for b in s1key.bindparams], [5, 8, 9]) + eq_([b.value for b in s2key.bindparams], [10, 12, 14]) + + s1_compiled_cached = s1.compile(cache_key=s1key) + + params = s1_compiled_cached.construct_params( + extracted_parameters=s2key[1] + ) + + eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14}) + + def test_stmt_lambda_w_atonce_whereclause_novalue(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + + return stmt + + c1 = column("x") + + s1 = go(c1, bindparam("x")) + + self.assert_compile(s1, "SELECT x WHERE :x") + + def test_stmt_lambda_w_additional_hashable_variants(self): + # note a Python 2 old style class would fail here because it + # isn't hashable. right now we do a hard check for __hash__ which + # will raise if the attr isn't present + class Thing(object): + def __init__(self, col_expr): + self.col_expr = col_expr + + def go(thing, q): + stmt = lambdas.lambda_stmt(lambda: select(thing.col_expr)) + stmt += lambda stmt: stmt.where(thing.col_expr == q) + + return stmt + + c1 = Thing(column("x")) + c2 = Thing(column("y")) + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + self.assert_compile( + s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + self.assert_compile( + s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + ne_(s1key[0], s2key[0]) + + def test_stmt_lambda_w_set_of_opts(self): + + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + + opts = {column("x"), column("y")} + + assert_raises_message( + exc.ArgumentError, + 'Can\'t create a cache key for lambda closure variable "opts" ' + "because it's a set. try using a list", + stmt.__add__, + lambda stmt: stmt.options(*opts), + ) + + def test_stmt_lambda_w_list_of_opts(self): + def go(opts): + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + stmt += lambda stmt: stmt.options(*opts) + + return stmt + + s1 = go([column("a"), column("b")]) + + s2 = go([column("a"), column("b")]) + + s3 = go([column("q"), column("b")]) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_(s1key.key, s2key.key) + ne_(s1key.key, s3key.key) + + def test_stmt_lambda_hey_theres_multiple_paths(self): + def go(x, y): + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + + if x > 5: + stmt += lambda stmt: stmt.where(column("x") == x) + else: + stmt += lambda stmt: stmt.where(column("y") == y) + + stmt += lambda stmt: stmt.order_by(column("q")) + + # TODO: need more path variety here to exercise + # using a full path key + + return stmt + + s1 = go(2, 5) + s2 = go(8, 7) + s3 = go(4, 9) + s4 = go(10, 1) + + self.assert_compile(s1, "SELECT x WHERE y = :y_1 ORDER BY q") + self.assert_compile(s2, "SELECT x WHERE x = :x_1 ORDER BY q") + self.assert_compile(s3, "SELECT x WHERE y = :y_1 ORDER BY q") + self.assert_compile(s4, "SELECT x WHERE x = :x_1 ORDER BY q") + def test_coercion_cols_clause(self): assert_raises_message( exc.ArgumentError, diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 01c8d7ca6..58280bb67 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -29,6 +29,7 @@ from sqlalchemy import TypeDecorator from sqlalchemy import union from sqlalchemy import util from sqlalchemy.sql import Alias +from sqlalchemy.sql import annotation from sqlalchemy.sql import base from sqlalchemy.sql import column from sqlalchemy.sql import elements @@ -2352,6 +2353,33 @@ class AnnotationsTest(fixtures.TestBase): annot = obj._annotate({}) ne_(set([obj]), set([annot])) + def test_replacement_traverse_preserve(self): + """test that replacement traverse that hits an unannotated column + does not use it when replacing an annotated column. + + this requires that replacement traverse store elements in the + "seen" hash based on id(), not hash. + + """ + t = table("t", column("x")) + + stmt = select([t.c.x]) + + whereclause = annotation._deep_annotate(t.c.x == 5, {"foo": "bar"}) + + eq_(whereclause._annotations, {"foo": "bar"}) + eq_(whereclause.left._annotations, {"foo": "bar"}) + eq_(whereclause.right._annotations, {"foo": "bar"}) + + stmt = stmt.where(whereclause) + + s2 = visitors.replacement_traverse(stmt, {}, lambda elem: None) + + whereclause = s2._where_criteria[0] + eq_(whereclause._annotations, {"foo": "bar"}) + eq_(whereclause.left._annotations, {"foo": "bar"}) + eq_(whereclause.right._annotations, {"foo": "bar"}) + def test_proxy_set_iteration_includes_annotated(self): from sqlalchemy.schema import Column |
