diff options
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_compare.py | 22 | ||||
| -rw-r--r-- | test/sql/test_external_traversal.py | 84 | ||||
| -rw-r--r-- | test/sql/test_lambdas.py | 461 | ||||
| -rw-r--r-- | test/sql/test_selectable.py | 28 |
4 files changed, 585 insertions, 10 deletions
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 7ac716dbe..b573accbd 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -131,6 +131,11 @@ class MyEntity(HasCacheKey): ] +class Foo: + x = 10 + y = 15 + + dml.Insert.argument_for("sqlite", "foo", None) dml.Update.argument_for("sqlite", "foo", None) dml.Delete.argument_for("sqlite", "foo", None) @@ -790,7 +795,7 @@ class CoreFixtures(object): def two(): r = random.randint(1, 10) - q = 20 + q = 408 return LambdaElement( lambda: table_a.c.a + q == r, roles.WhereHavingRole ) @@ -803,10 +808,6 @@ class CoreFixtures(object): roles.WhereHavingRole, ) - class Foo: - x = 10 - y = 15 - def four(): return LambdaElement( lambda: and_(table_a.c.a == Foo.x), roles.WhereHavingRole @@ -833,6 +834,16 @@ class CoreFixtures(object): lambda s: s.where(table_a.c.a == value) ) + from sqlalchemy.sql import lambdas + + def eight(): + q = 5 + return lambdas.DeferredLambdaElement( + lambda t: t.c.a > q, + roles.WhereHavingRole, + lambda_args=(table_a,), + ) + return [ one(), two(), @@ -841,6 +852,7 @@ class CoreFixtures(object): five(), six(), seven(), + eight(), ] dont_compare_values_fixtures.append(_lambda_fixtures) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index aefcaf252..4918afc9c 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -791,6 +791,90 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): "JOIN table2 ON table1.col1 = table2.col2) AS anon_1", ) + def test_this_thing_using_setup_joins_three(self): + + j = t1.join(t2, t1.c.col1 == t2.c.col2) + + s1 = select(j) + + s2 = s1.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s2, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + vis = sql_util.ClauseAdapter(j) + + s3 = vis.traverse(s1) + + s4 = s3.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s4, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s5 = vis.traverse(s3) + + s6 = s5.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s6, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + def test_this_thing_using_setup_joins_four(self): + + j = t1.join(t2, t1.c.col1 == t2.c.col2) + + s1 = select(j) + + assert not s1._from_obj + + s2 = s1.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s2, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s3 = visitors.replacement_traverse(s1, {}, lambda elem: None) + + s4 = s3.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s4, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s5 = visitors.replacement_traverse(s3, {}, lambda elem: None) + + s6 = s5.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s6, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + def test_select_fromtwice_one(self): t1a = t1.alias() diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index 53f6a9544..a91242de5 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -5,6 +5,7 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKey from sqlalchemy.schema import Table from sqlalchemy.sql import and_ +from sqlalchemy.sql import bindparam from sqlalchemy.sql import coercions from sqlalchemy.sql import column from sqlalchemy.sql import join @@ -19,6 +20,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import ne_ from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.types import Integer from sqlalchemy.types import String @@ -46,10 +48,39 @@ class DeferredLambdaTest( go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :x_1 AND t1.p = :y_1" ) + def test_global_tracking(self): + t1 = table("t1", column("q"), column("p")) + + global global_x, global_y + + global_x = 10 + global_y = 17 + + def go(): + return select([t1]).where( + lambda: and_(t1.c.q == global_x, t1.c.p == global_y) + ) + + self.assert_compile( + go(), + "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 " + "AND t1.p = :global_y_1", + checkparams={"global_x_1": 10, "global_y_1": 17}, + ) + + global_y = 9 + + self.assert_compile( + go(), + "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 " + "AND t1.p = :global_y_1", + checkparams={"global_x_1": 10, "global_y_1": 9}, + ) + def test_stale_checker_embedded(self): def go(x): - stmt = select([lambda: x]) + stmt = select(lambda: x) return stmt c1 = column("x") @@ -67,7 +98,7 @@ class DeferredLambdaTest( def test_stale_checker_statement(self): def go(x): - stmt = lambdas.lambda_stmt(lambda: select([x])) + stmt = lambdas.lambda_stmt(lambda: select(x)) return stmt c1 = column("x") @@ -85,13 +116,13 @@ class DeferredLambdaTest( def test_stale_checker_linked(self): def go(x, y): - stmt = lambdas.lambda_stmt(lambda: select([x])) + ( + stmt = lambdas.lambda_stmt(lambda: select(x)) + ( lambda s: s.where(y > 5) ) return stmt - c1 = column("x") - c2 = column("y") + c1 = oldc1 = column("x") + c2 = oldc2 = column("y") s1 = go(c1, c2) s2 = go(c1, c2) @@ -104,6 +135,426 @@ class DeferredLambdaTest( s3 = go(c1, c2) self.assert_compile(s3, "SELECT q WHERE p > :p_1") + s4 = go(c1, c2) + self.assert_compile(s4, "SELECT q WHERE p > :p_1") + + s5 = go(oldc1, oldc2) + self.assert_compile(s5, "SELECT x WHERE y > :y_1") + + def test_stmt_lambda_w_additional_hascachekey_variants(self): + def go(col_expr, q): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(col_expr == q) + + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + self.assert_compile( + s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + self.assert_compile( + s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + ne_(s1key[0], s2key[0]) + + def test_stmt_lambda_w_atonce_whereclause_values_notrack(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause), enable_tracking=False + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5) + s2 = go(c1, c1 == 10) + + self.assert_compile( + s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + + # and as we see, this is wrong. Because whereclause + # is fixed for the lambda and we do not re-evaluate the closure + # for this value changing. this can't be passed unless + # enable_tracking=False. + self.assert_compile( + s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + + def test_stmt_lambda_w_atonce_whereclause_values(self): + c2 = column("y") + + def go(col_expr, whereclause, x): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by(c2 > x), + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 ORDER BY y > :x_2", + checkparams={"x_1": 5, "x_2": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 ORDER BY y > :x_2", + checkparams={"x_1": 10, "x_2": 15}, + ) + + def test_stmt_lambda_plain_customtrack(self): + c2 = column("y") + + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria(lambda stmt: stmt.where(whereclause)) + stmt = stmt.add_criteria( + lambda stmt: stmt.order_by(col_expr), track_on=(col_expr,) + ) + stmt = stmt.add_criteria(lambda stmt: stmt.where(col_expr == p)) + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + s3 = go(c2, c2 == 18, 12) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + eq_([b.value for b in s3key.bindparams], [18, 12]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x", + checkparams={"x_1": 5, "p_1": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x", + checkparams={"x_1": 10, "p_1": 15}, + ) + + self.assert_compile( + s3, + "SELECT y WHERE y = :y_1 AND y = :p_1 ORDER BY y", + checkparams={"y_1": 18, "p_1": 12}, + ) + + def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self): + c2 = column("y") + + # this pattern is *completely unnecessary*, and I would prefer + # if we can detect this and just raise, because when it is not done + # correctly, it is *extremely* difficult to catch it failing. + # however I also can't come up with a reliable way to catch it. + # so we will keep the use of "track_on" to be internal. + + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by(col_expr > p), + track_on=(whereclause, whereclause.right.value), + ) + + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + s3 = go(c2, c2 == 18, 12) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + eq_([b.value for b in s3key.bindparams], [18, 12]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 ORDER BY x > :p_1", + checkparams={"x_1": 5, "p_1": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 ORDER BY x > :p_1", + checkparams={"x_1": 10, "p_1": 15}, + ) + + self.assert_compile( + s3, + "SELECT y WHERE y = :y_1 ORDER BY y > :p_1", + checkparams={"y_1": 18, "p_1": 12}, + ) + + def test_stmt_lambda_track_closure_binds_one(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5) + s2 = go(c1, c1 == 10) + + self.assert_compile( + s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + self.assert_compile( + s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 10} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + eq_(s1key.key, s2key.key) + + eq_([b.value for b in s1key.bindparams], [5]) + eq_([b.value for b in s2key.bindparams], [10]) + + def test_stmt_lambda_track_closure_binds_two(self): + def go(col_expr, whereclause, x, y): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause).where( + and_(c1 == x, c1 < y) + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 8, 9) + s2 = go(c1, c1 == 10, 12, 14) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 5, "x_2": 8, "y_1": 9}, + ) + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 10, "x_2": 12, "y_1": 14}, + ) + + eq_([b.value for b in s1key.bindparams], [5, 8, 9]) + eq_([b.value for b in s2key.bindparams], [10, 12, 14]) + + s1_compiled_cached = s1.compile(cache_key=s1key) + + params = s1_compiled_cached.construct_params( + extracted_parameters=s2key[1] + ) + + eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14}) + + def test_stmt_lambda_track_closure_binds_three(self): + def go(col_expr, whereclause, x, y): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + stmt += lambda stmt: stmt.where(and_(c1 == x, c1 < y)) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 8, 9) + s2 = go(c1, c1 == 10, 12, 14) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 5, "x_2": 8, "y_1": 9}, + ) + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 10, "x_2": 12, "y_1": 14}, + ) + + eq_([b.value for b in s1key.bindparams], [5, 8, 9]) + eq_([b.value for b in s2key.bindparams], [10, 12, 14]) + + s1_compiled_cached = s1.compile(cache_key=s1key) + + params = s1_compiled_cached.construct_params( + extracted_parameters=s2key[1] + ) + + eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14}) + + def test_stmt_lambda_w_atonce_whereclause_novalue(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + + return stmt + + c1 = column("x") + + s1 = go(c1, bindparam("x")) + + self.assert_compile(s1, "SELECT x WHERE :x") + + def test_stmt_lambda_w_additional_hashable_variants(self): + # note a Python 2 old style class would fail here because it + # isn't hashable. right now we do a hard check for __hash__ which + # will raise if the attr isn't present + class Thing(object): + def __init__(self, col_expr): + self.col_expr = col_expr + + def go(thing, q): + stmt = lambdas.lambda_stmt(lambda: select(thing.col_expr)) + stmt += lambda stmt: stmt.where(thing.col_expr == q) + + return stmt + + c1 = Thing(column("x")) + c2 = Thing(column("y")) + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + self.assert_compile( + s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + self.assert_compile( + s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + ne_(s1key[0], s2key[0]) + + def test_stmt_lambda_w_set_of_opts(self): + + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + + opts = {column("x"), column("y")} + + assert_raises_message( + exc.ArgumentError, + 'Can\'t create a cache key for lambda closure variable "opts" ' + "because it's a set. try using a list", + stmt.__add__, + lambda stmt: stmt.options(*opts), + ) + + def test_stmt_lambda_w_list_of_opts(self): + def go(opts): + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + stmt += lambda stmt: stmt.options(*opts) + + return stmt + + s1 = go([column("a"), column("b")]) + + s2 = go([column("a"), column("b")]) + + s3 = go([column("q"), column("b")]) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_(s1key.key, s2key.key) + ne_(s1key.key, s3key.key) + + def test_stmt_lambda_hey_theres_multiple_paths(self): + def go(x, y): + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + + if x > 5: + stmt += lambda stmt: stmt.where(column("x") == x) + else: + stmt += lambda stmt: stmt.where(column("y") == y) + + stmt += lambda stmt: stmt.order_by(column("q")) + + # TODO: need more path variety here to exercise + # using a full path key + + return stmt + + s1 = go(2, 5) + s2 = go(8, 7) + s3 = go(4, 9) + s4 = go(10, 1) + + self.assert_compile(s1, "SELECT x WHERE y = :y_1 ORDER BY q") + self.assert_compile(s2, "SELECT x WHERE x = :x_1 ORDER BY q") + self.assert_compile(s3, "SELECT x WHERE y = :y_1 ORDER BY q") + self.assert_compile(s4, "SELECT x WHERE x = :x_1 ORDER BY q") + def test_coercion_cols_clause(self): assert_raises_message( exc.ArgumentError, diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 01c8d7ca6..58280bb67 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -29,6 +29,7 @@ from sqlalchemy import TypeDecorator from sqlalchemy import union from sqlalchemy import util from sqlalchemy.sql import Alias +from sqlalchemy.sql import annotation from sqlalchemy.sql import base from sqlalchemy.sql import column from sqlalchemy.sql import elements @@ -2352,6 +2353,33 @@ class AnnotationsTest(fixtures.TestBase): annot = obj._annotate({}) ne_(set([obj]), set([annot])) + def test_replacement_traverse_preserve(self): + """test that replacement traverse that hits an unannotated column + does not use it when replacing an annotated column. + + this requires that replacement traverse store elements in the + "seen" hash based on id(), not hash. + + """ + t = table("t", column("x")) + + stmt = select([t.c.x]) + + whereclause = annotation._deep_annotate(t.c.x == 5, {"foo": "bar"}) + + eq_(whereclause._annotations, {"foo": "bar"}) + eq_(whereclause.left._annotations, {"foo": "bar"}) + eq_(whereclause.right._annotations, {"foo": "bar"}) + + stmt = stmt.where(whereclause) + + s2 = visitors.replacement_traverse(stmt, {}, lambda elem: None) + + whereclause = s2._where_criteria[0] + eq_(whereclause._annotations, {"foo": "bar"}) + eq_(whereclause.left._annotations, {"foo": "bar"}) + eq_(whereclause.right._annotations, {"foo": "bar"}) + def test_proxy_set_iteration_includes_annotated(self): from sqlalchemy.schema import Column |
