summaryrefslogtreecommitdiff
path: root/test/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-08-05 16:42:26 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-08-05 16:42:26 -0400
commitcc57ea495f6460dd56daa6de57e40047ed999369 (patch)
tree837f5a84363c387d7f8fdeabc06928cd078028e1 /test/sql
parent2a946254023135eddd222974cf300ffaa5583f02 (diff)
downloadsqlalchemy-cc57ea495f6460dd56daa6de57e40047ed999369.tar.gz
Robustness for lambdas, lambda statements
in order to accommodate relationship loaders with lambda caching, a lot more is needed. This is a full refactor of the lambda system such that it now has two levels of caching; the first level caches what can be known from the __code__ element, then the next level of caching is against the lambda itself and the contents of __closure__. This allows for the elements inside the lambdas, like columns and entities, to change and then be part of the cache key. Lazy/selectinloads' use of baked queries had to add distinct cache key elements, which was attempted here but overall things needed to be more robust than that. This commit is broken out from the very long and sprawling commit at Id6b5c03b1ce9ddb7b280f66792212a0ef0a1c541 . Change-Id: I29a513c98917b1d503abfdd61e6b6e8800851aa8
Diffstat (limited to 'test/sql')
-rw-r--r--test/sql/test_compare.py22
-rw-r--r--test/sql/test_external_traversal.py84
-rw-r--r--test/sql/test_lambdas.py461
-rw-r--r--test/sql/test_selectable.py28
4 files changed, 585 insertions, 10 deletions
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 7ac716dbe..b573accbd 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -131,6 +131,11 @@ class MyEntity(HasCacheKey):
]
+class Foo:
+ x = 10
+ y = 15
+
+
dml.Insert.argument_for("sqlite", "foo", None)
dml.Update.argument_for("sqlite", "foo", None)
dml.Delete.argument_for("sqlite", "foo", None)
@@ -790,7 +795,7 @@ class CoreFixtures(object):
def two():
r = random.randint(1, 10)
- q = 20
+ q = 408
return LambdaElement(
lambda: table_a.c.a + q == r, roles.WhereHavingRole
)
@@ -803,10 +808,6 @@ class CoreFixtures(object):
roles.WhereHavingRole,
)
- class Foo:
- x = 10
- y = 15
-
def four():
return LambdaElement(
lambda: and_(table_a.c.a == Foo.x), roles.WhereHavingRole
@@ -833,6 +834,16 @@ class CoreFixtures(object):
lambda s: s.where(table_a.c.a == value)
)
+ from sqlalchemy.sql import lambdas
+
+ def eight():
+ q = 5
+ return lambdas.DeferredLambdaElement(
+ lambda t: t.c.a > q,
+ roles.WhereHavingRole,
+ lambda_args=(table_a,),
+ )
+
return [
one(),
two(),
@@ -841,6 +852,7 @@ class CoreFixtures(object):
five(),
six(),
seven(),
+ eight(),
]
dont_compare_values_fixtures.append(_lambda_fixtures)
diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py
index aefcaf252..4918afc9c 100644
--- a/test/sql/test_external_traversal.py
+++ b/test/sql/test_external_traversal.py
@@ -791,6 +791,90 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
"JOIN table2 ON table1.col1 = table2.col2) AS anon_1",
)
+ def test_this_thing_using_setup_joins_three(self):
+
+ j = t1.join(t2, t1.c.col1 == t2.c.col2)
+
+ s1 = select(j)
+
+ s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s2,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ vis = sql_util.ClauseAdapter(j)
+
+ s3 = vis.traverse(s1)
+
+ s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s4,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ s5 = vis.traverse(s3)
+
+ s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s6,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ def test_this_thing_using_setup_joins_four(self):
+
+ j = t1.join(t2, t1.c.col1 == t2.c.col2)
+
+ s1 = select(j)
+
+ assert not s1._from_obj
+
+ s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s2,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ s3 = visitors.replacement_traverse(s1, {}, lambda elem: None)
+
+ s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s4,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ s5 = visitors.replacement_traverse(s3, {}, lambda elem: None)
+
+ s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s6,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
def test_select_fromtwice_one(self):
t1a = t1.alias()
diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py
index 53f6a9544..a91242de5 100644
--- a/test/sql/test_lambdas.py
+++ b/test/sql/test_lambdas.py
@@ -5,6 +5,7 @@ from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKey
from sqlalchemy.schema import Table
from sqlalchemy.sql import and_
+from sqlalchemy.sql import bindparam
from sqlalchemy.sql import coercions
from sqlalchemy.sql import column
from sqlalchemy.sql import join
@@ -19,6 +20,7 @@ from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import ne_
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.types import Integer
from sqlalchemy.types import String
@@ -46,10 +48,39 @@ class DeferredLambdaTest(
go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :x_1 AND t1.p = :y_1"
)
+ def test_global_tracking(self):
+ t1 = table("t1", column("q"), column("p"))
+
+ global global_x, global_y
+
+ global_x = 10
+ global_y = 17
+
+ def go():
+ return select([t1]).where(
+ lambda: and_(t1.c.q == global_x, t1.c.p == global_y)
+ )
+
+ self.assert_compile(
+ go(),
+ "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 "
+ "AND t1.p = :global_y_1",
+ checkparams={"global_x_1": 10, "global_y_1": 17},
+ )
+
+ global_y = 9
+
+ self.assert_compile(
+ go(),
+ "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 "
+ "AND t1.p = :global_y_1",
+ checkparams={"global_x_1": 10, "global_y_1": 9},
+ )
+
def test_stale_checker_embedded(self):
def go(x):
- stmt = select([lambda: x])
+ stmt = select(lambda: x)
return stmt
c1 = column("x")
@@ -67,7 +98,7 @@ class DeferredLambdaTest(
def test_stale_checker_statement(self):
def go(x):
- stmt = lambdas.lambda_stmt(lambda: select([x]))
+ stmt = lambdas.lambda_stmt(lambda: select(x))
return stmt
c1 = column("x")
@@ -85,13 +116,13 @@ class DeferredLambdaTest(
def test_stale_checker_linked(self):
def go(x, y):
- stmt = lambdas.lambda_stmt(lambda: select([x])) + (
+ stmt = lambdas.lambda_stmt(lambda: select(x)) + (
lambda s: s.where(y > 5)
)
return stmt
- c1 = column("x")
- c2 = column("y")
+ c1 = oldc1 = column("x")
+ c2 = oldc2 = column("y")
s1 = go(c1, c2)
s2 = go(c1, c2)
@@ -104,6 +135,426 @@ class DeferredLambdaTest(
s3 = go(c1, c2)
self.assert_compile(s3, "SELECT q WHERE p > :p_1")
+ s4 = go(c1, c2)
+ self.assert_compile(s4, "SELECT q WHERE p > :p_1")
+
+ s5 = go(oldc1, oldc2)
+ self.assert_compile(s5, "SELECT x WHERE y > :y_1")
+
+ def test_stmt_lambda_w_additional_hascachekey_variants(self):
+ def go(col_expr, q):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(col_expr == q)
+
+ return stmt
+
+ c1 = column("x")
+ c2 = column("y")
+
+ s1 = go(c1, 5)
+ s2 = go(c2, 10)
+ s3 = go(c1, 8)
+ s4 = go(c2, 12)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5}
+ )
+ self.assert_compile(
+ s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10}
+ )
+ self.assert_compile(
+ s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8}
+ )
+ self.assert_compile(
+ s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12}
+ )
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+ s4key = s4._generate_cache_key()
+
+ eq_(s1key[0], s3key[0])
+ eq_(s2key[0], s4key[0])
+ ne_(s1key[0], s2key[0])
+
+ def test_stmt_lambda_w_atonce_whereclause_values_notrack(self):
+ def go(col_expr, whereclause):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.where(whereclause), enable_tracking=False
+ )
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5)
+ s2 = go(c1, c1 == 10)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+ )
+
+ # and as we see, this is wrong. Because whereclause
+ # is fixed for the lambda and we do not re-evaluate the closure
+ # for this value changing. this can't be passed unless
+ # enable_tracking=False.
+ self.assert_compile(
+ s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+ )
+
+ def test_stmt_lambda_w_atonce_whereclause_values(self):
+ c2 = column("y")
+
+ def go(col_expr, whereclause, x):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.where(whereclause).order_by(c2 > x),
+ )
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5, 9)
+ s2 = go(c1, c1 == 10, 15)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ eq_([b.value for b in s1key.bindparams], [5, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 15])
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 ORDER BY y > :x_2",
+ checkparams={"x_1": 5, "x_2": 9},
+ )
+
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 ORDER BY y > :x_2",
+ checkparams={"x_1": 10, "x_2": 15},
+ )
+
+ def test_stmt_lambda_plain_customtrack(self):
+ c2 = column("y")
+
+ def go(col_expr, whereclause, p):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(lambda stmt: stmt.where(whereclause))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.order_by(col_expr), track_on=(col_expr,)
+ )
+ stmt = stmt.add_criteria(lambda stmt: stmt.where(col_expr == p))
+ return stmt
+
+ c1 = column("x")
+ c2 = column("y")
+
+ s1 = go(c1, c1 == 5, 9)
+ s2 = go(c1, c1 == 10, 15)
+ s3 = go(c2, c2 == 18, 12)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+
+ eq_([b.value for b in s1key.bindparams], [5, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 15])
+ eq_([b.value for b in s3key.bindparams], [18, 12])
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x",
+ checkparams={"x_1": 5, "p_1": 9},
+ )
+
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x",
+ checkparams={"x_1": 10, "p_1": 15},
+ )
+
+ self.assert_compile(
+ s3,
+ "SELECT y WHERE y = :y_1 AND y = :p_1 ORDER BY y",
+ checkparams={"y_1": 18, "p_1": 12},
+ )
+
+ def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self):
+ c2 = column("y")
+
+ # this pattern is *completely unnecessary*, and I would prefer
+ # if we can detect this and just raise, because when it is not done
+ # correctly, it is *extremely* difficult to catch it failing.
+ # however I also can't come up with a reliable way to catch it.
+ # so we will keep the use of "track_on" to be internal.
+
+ def go(col_expr, whereclause, p):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.where(whereclause).order_by(col_expr > p),
+ track_on=(whereclause, whereclause.right.value),
+ )
+
+ return stmt
+
+ c1 = column("x")
+ c2 = column("y")
+
+ s1 = go(c1, c1 == 5, 9)
+ s2 = go(c1, c1 == 10, 15)
+ s3 = go(c2, c2 == 18, 12)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+
+ eq_([b.value for b in s1key.bindparams], [5, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 15])
+ eq_([b.value for b in s3key.bindparams], [18, 12])
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 ORDER BY x > :p_1",
+ checkparams={"x_1": 5, "p_1": 9},
+ )
+
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 ORDER BY x > :p_1",
+ checkparams={"x_1": 10, "p_1": 15},
+ )
+
+ self.assert_compile(
+ s3,
+ "SELECT y WHERE y = :y_1 ORDER BY y > :p_1",
+ checkparams={"y_1": 18, "p_1": 12},
+ )
+
+ def test_stmt_lambda_track_closure_binds_one(self):
+ def go(col_expr, whereclause):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause)
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5)
+ s2 = go(c1, c1 == 10)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+ )
+ self.assert_compile(
+ s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 10}
+ )
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ eq_(s1key.key, s2key.key)
+
+ eq_([b.value for b in s1key.bindparams], [5])
+ eq_([b.value for b in s2key.bindparams], [10])
+
+ def test_stmt_lambda_track_closure_binds_two(self):
+ def go(col_expr, whereclause, x, y):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause).where(
+ and_(c1 == x, c1 < y)
+ )
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5, 8, 9)
+ s2 = go(c1, c1 == 10, 12, 14)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 5, "x_2": 8, "y_1": 9},
+ )
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 10, "x_2": 12, "y_1": 14},
+ )
+
+ eq_([b.value for b in s1key.bindparams], [5, 8, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 12, 14])
+
+ s1_compiled_cached = s1.compile(cache_key=s1key)
+
+ params = s1_compiled_cached.construct_params(
+ extracted_parameters=s2key[1]
+ )
+
+ eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14})
+
+ def test_stmt_lambda_track_closure_binds_three(self):
+ def go(col_expr, whereclause, x, y):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause)
+ stmt += lambda stmt: stmt.where(and_(c1 == x, c1 < y))
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5, 8, 9)
+ s2 = go(c1, c1 == 10, 12, 14)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 5, "x_2": 8, "y_1": 9},
+ )
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 10, "x_2": 12, "y_1": 14},
+ )
+
+ eq_([b.value for b in s1key.bindparams], [5, 8, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 12, 14])
+
+ s1_compiled_cached = s1.compile(cache_key=s1key)
+
+ params = s1_compiled_cached.construct_params(
+ extracted_parameters=s2key[1]
+ )
+
+ eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14})
+
+ def test_stmt_lambda_w_atonce_whereclause_novalue(self):
+ def go(col_expr, whereclause):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause)
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, bindparam("x"))
+
+ self.assert_compile(s1, "SELECT x WHERE :x")
+
+ def test_stmt_lambda_w_additional_hashable_variants(self):
+ # note a Python 2 old style class would fail here because it
+ # isn't hashable. right now we do a hard check for __hash__ which
+ # will raise if the attr isn't present
+ class Thing(object):
+ def __init__(self, col_expr):
+ self.col_expr = col_expr
+
+ def go(thing, q):
+ stmt = lambdas.lambda_stmt(lambda: select(thing.col_expr))
+ stmt += lambda stmt: stmt.where(thing.col_expr == q)
+
+ return stmt
+
+ c1 = Thing(column("x"))
+ c2 = Thing(column("y"))
+
+ s1 = go(c1, 5)
+ s2 = go(c2, 10)
+ s3 = go(c1, 8)
+ s4 = go(c2, 12)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5}
+ )
+ self.assert_compile(
+ s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10}
+ )
+ self.assert_compile(
+ s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8}
+ )
+ self.assert_compile(
+ s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12}
+ )
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+ s4key = s4._generate_cache_key()
+
+ eq_(s1key[0], s3key[0])
+ eq_(s2key[0], s4key[0])
+ ne_(s1key[0], s2key[0])
+
+ def test_stmt_lambda_w_set_of_opts(self):
+
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+
+ opts = {column("x"), column("y")}
+
+ assert_raises_message(
+ exc.ArgumentError,
+ 'Can\'t create a cache key for lambda closure variable "opts" '
+ "because it's a set. try using a list",
+ stmt.__add__,
+ lambda stmt: stmt.options(*opts),
+ )
+
+ def test_stmt_lambda_w_list_of_opts(self):
+ def go(opts):
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+ stmt += lambda stmt: stmt.options(*opts)
+
+ return stmt
+
+ s1 = go([column("a"), column("b")])
+
+ s2 = go([column("a"), column("b")])
+
+ s3 = go([column("q"), column("b")])
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+
+ eq_(s1key.key, s2key.key)
+ ne_(s1key.key, s3key.key)
+
+ def test_stmt_lambda_hey_theres_multiple_paths(self):
+ def go(x, y):
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+
+ if x > 5:
+ stmt += lambda stmt: stmt.where(column("x") == x)
+ else:
+ stmt += lambda stmt: stmt.where(column("y") == y)
+
+ stmt += lambda stmt: stmt.order_by(column("q"))
+
+ # TODO: need more path variety here to exercise
+ # using a full path key
+
+ return stmt
+
+ s1 = go(2, 5)
+ s2 = go(8, 7)
+ s3 = go(4, 9)
+ s4 = go(10, 1)
+
+ self.assert_compile(s1, "SELECT x WHERE y = :y_1 ORDER BY q")
+ self.assert_compile(s2, "SELECT x WHERE x = :x_1 ORDER BY q")
+ self.assert_compile(s3, "SELECT x WHERE y = :y_1 ORDER BY q")
+ self.assert_compile(s4, "SELECT x WHERE x = :x_1 ORDER BY q")
+
def test_coercion_cols_clause(self):
assert_raises_message(
exc.ArgumentError,
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 01c8d7ca6..58280bb67 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -29,6 +29,7 @@ from sqlalchemy import TypeDecorator
from sqlalchemy import union
from sqlalchemy import util
from sqlalchemy.sql import Alias
+from sqlalchemy.sql import annotation
from sqlalchemy.sql import base
from sqlalchemy.sql import column
from sqlalchemy.sql import elements
@@ -2352,6 +2353,33 @@ class AnnotationsTest(fixtures.TestBase):
annot = obj._annotate({})
ne_(set([obj]), set([annot]))
+ def test_replacement_traverse_preserve(self):
+ """test that replacement traverse that hits an unannotated column
+ does not use it when replacing an annotated column.
+
+ this requires that replacement traverse store elements in the
+ "seen" hash based on id(), not hash.
+
+ """
+ t = table("t", column("x"))
+
+ stmt = select([t.c.x])
+
+ whereclause = annotation._deep_annotate(t.c.x == 5, {"foo": "bar"})
+
+ eq_(whereclause._annotations, {"foo": "bar"})
+ eq_(whereclause.left._annotations, {"foo": "bar"})
+ eq_(whereclause.right._annotations, {"foo": "bar"})
+
+ stmt = stmt.where(whereclause)
+
+ s2 = visitors.replacement_traverse(stmt, {}, lambda elem: None)
+
+ whereclause = s2._where_criteria[0]
+ eq_(whereclause._annotations, {"foo": "bar"})
+ eq_(whereclause.left._annotations, {"foo": "bar"})
+ eq_(whereclause.right._annotations, {"foo": "bar"})
+
def test_proxy_set_iteration_includes_annotated(self):
from sqlalchemy.schema import Column