From 77c9534dcaf3723f7b2baf42442eda3e1d8c3332 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 12 Dec 2020 18:56:58 -0500 Subject: Major revisals to lambdas 1. Improve coercions._deep_is_literal to check sequences for clause elements, thus allowing a phrase like lambda: col.in_([literal("x"), literal("y")]) to be handled 2. revise closure variable caching completely. All variables entering must be part of a closure cache key or rejected. only objects that can be resolved to HasCacheKey or FunctionType are accepted; all other types are rejected. This adds a high degree of strictness to lambdas and will make them a little more awkward to use in some cases, however prevents several classes of critical issues: a. previously, a lambda that had an expression derived from some kind of state, like "self.x", or "execution_context.session.foo" would produce a closure cache key from "self" or "execution_context", objects that can very well be per-execution and would therefore cause a AnalyzedFunction objects to overflow. (memory won't leak as it looks like an LRUCache is already used for these) b. a lambda, such as one used within DeferredLamdaElement, that produces different SQL expressions based on the arguments (which is in fact what it's supposed to do), however it would through the use of conditionals produce different bound parameter combinations, leading to literal parameters not tracked properly. These are now rejected as uncacheable whereas previously they would again be part of the closure cache key, causing an overflow of AnalyizedFunction objects. 3. Ensure non-mapped mixins are handled correctly by with_loader_criteria(). 4. Fixed bug in lambda SQL system where we are not supposed to allow a Python function to be embedded in the lambda, since we can't predict a bound value from it. While there was an error condition added for this, it was not tested and wasn't working; an informative error is now raised. 5. new docs for lambdas 6. consolidated changelog for all of these Fixes: #5760 Fixes: #5765 Fixes: #5766 Fixes: #5768 Fixes: #5770 Change-Id: Iedaa636c3225fad496df23b612c516c8ab247ab7 --- test/sql/test_compare.py | 5 +- test/sql/test_lambdas.py | 572 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 565 insertions(+), 12 deletions(-) (limited to 'test/sql') diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 6fa961e4d..70281d4e8 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -57,6 +57,7 @@ from sqlalchemy.sql.functions import GenericFunction from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.sql.lambdas import lambda_stmt from sqlalchemy.sql.lambdas import LambdaElement +from sqlalchemy.sql.lambdas import LambdaOptions from sqlalchemy.sql.selectable import _OffsetLimitParam from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping @@ -859,7 +860,9 @@ class CoreFixtures(object): d = {"g": random.randint(40, 45)} return LambdaElement( - lambda: and_(table_a.c.b == d["g"]), roles.WhereHavingRole + lambda: and_(table_a.c.b == d["g"]), + roles.WhereHavingRole, + opts=LambdaOptions(track_closure_variables=False), ) def seven(): diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index a70dc0511..e8e4a8d2a 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -11,6 +11,8 @@ from sqlalchemy.sql import column from sqlalchemy.sql import join from sqlalchemy.sql import lambda_stmt from sqlalchemy.sql import lambdas +from sqlalchemy.sql import literal +from sqlalchemy.sql import null from sqlalchemy.sql import roles from sqlalchemy.sql import select from sqlalchemy.sql import table @@ -27,7 +29,7 @@ from sqlalchemy.types import Integer from sqlalchemy.types import String -class DeferredLambdaTest( +class LambdaElementTest( fixtures.TestBase, testing.AssertsExecutionResults, AssertsCompiledSQL ): __dialect__ = "default" @@ -274,6 +276,75 @@ class DeferredLambdaTest( checkparams={"x_1": 10, "x_2": 15}, ) + def test_conditional_must_be_tracked(self): + tab = table("foo", column("id"), column("col")) + + def run_my_statement(parameter, add_criteria=False): + stmt = lambda_stmt(lambda: select(tab)) + + stmt = stmt.add_criteria( + lambda s: s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter), + ) + + stmt += lambda s: s.order_by(tab.c.id) + + return stmt + + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'add_criteria' inside of lambda callable", + run_my_statement, + 5, + False, + ) + + def test_boolean_conditionals(self): + + tab = table("foo", column("id"), column("col")) + + def run_my_statement(parameter, add_criteria=False): + stmt = lambda_stmt(lambda: select(tab)) + + stmt = stmt.add_criteria( + lambda s: s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter), + track_on=[add_criteria], + ) + + stmt += lambda s: s.order_by(tab.c.id) + + return stmt + + c1 = run_my_statement(5, False) + c2 = run_my_statement(10, True) + c3 = run_my_statement(18, False) + + ck1 = c1._generate_cache_key() + ck2 = c2._generate_cache_key() + ck3 = c3._generate_cache_key() + + eq_(ck1[0], ck3[0]) + ne_(ck1[0], ck2[0]) + + self.assert_compile( + c1, + "SELECT foo.id, foo.col FROM foo WHERE " + "foo.col = :parameter_1 ORDER BY foo.id", + ) + self.assert_compile( + c2, + "SELECT foo.id, foo.col FROM foo " + "WHERE foo.col > :parameter_1 ORDER BY foo.id", + ) + self.assert_compile( + c3, + "SELECT foo.id, foo.col FROM foo WHERE " + "foo.col = :parameter_1 ORDER BY foo.id", + ) + def test_stmt_lambda_plain_customtrack(self): c2 = column("y") @@ -487,10 +558,11 @@ class DeferredLambdaTest( self.assert_compile(s1, "SELECT x WHERE :x") - def test_stmt_lambda_w_additional_hashable_variants(self): - # note a Python 2 old style class would fail here because it - # isn't hashable. right now we do a hard check for __hash__ which - # will raise if the attr isn't present + def test_reject_plain_object(self): + # with #5765 we move to no longer allow closure variables that + # refer to unknown types of objects inside the lambda. these have + # to be resolved outside of the lambda because we otherwise can't + # be sure they can be safely used as cache keys. class Thing(object): def __init__(self, col_expr): self.col_expr = col_expr @@ -501,6 +573,83 @@ class DeferredLambdaTest( return stmt + c1 = Thing(column("x")) + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'thing' inside of lambda callable", + go, + c1, + 5, + ) + + def test_plain_object_ok_w_tracking_disabled(self): + # with #5765 we move to no longer allow closure variables that + # refer to unknown types of objects inside the lambda. these have + # to be resolved outside of the lambda because we otherwise can't + # be sure they can be safely used as cache keys. + class Thing(object): + def __init__(self, col_expr): + self.col_expr = col_expr + + def go(thing, q): + stmt = lambdas.lambda_stmt( + lambda: select(thing.col_expr), track_closure_variables=False + ) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(thing.col_expr == q), + track_closure_variables=False, + ) + + return stmt + + c1 = Thing(column("x")) + c2 = Thing(column("y")) + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + # note this is wrong, because no tracking + self.assert_compile( + s2, "SELECT x WHERE x = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + # also wrong + self.assert_compile( + s4, "SELECT x WHERE x = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + # all one cache key + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + eq_(s1key[0], s2key[0]) + + def test_plain_object_used_outside_lambda(self): + # test the above 'test_reject_plain_object' with the expected + # workaround + + class Thing(object): + def __init__(self, col_expr): + self.col_expr = col_expr + + def go(thing, q): + col_expr = thing.col_expr + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(col_expr == q) + + return stmt + c1 = Thing(column("x")) c2 = Thing(column("y")) @@ -538,13 +687,92 @@ class DeferredLambdaTest( opts = {column("x"), column("y")} assert_raises_message( - exc.ArgumentError, - 'Can\'t create a cache key for lambda closure variable "opts" ' - "because it's a set. try using a list", + exc.InvalidRequestError, + "Closure variable named 'opts' inside of lambda callable ", stmt.__add__, lambda stmt: stmt.options(*opts), ) + def test_detect_embedded_callables_one(self): + t1 = table("t1", column("q")) + + x = 1 + + def go(): + def foo(): + return x + + stmt = select(t1).where(lambda: t1.c.q == foo()) + return stmt + + assert_raises_message( + exc.InvalidRequestError, + r"Can't invoke Python callable foo\(\) inside of lambda " + "expression ", + go, + ) + + def test_detect_embedded_callables_two(self): + t1 = table("t1", column("q"), column("y")) + + def go(): + def foo(): + return t1.c.y + + stmt = select(t1).where(lambda: t1.c.q == foo()) + return stmt + + self.assert_compile( + go(), "SELECT t1.q, t1.y FROM t1 WHERE t1.q = t1.y" + ) + + def test_detect_embedded_callables_three(self): + t1 = table("t1", column("q"), column("y")) + + def go(): + def foo(): + t1.c.y + + stmt = select(t1).where(lambda: t1.c.q == getattr(t1.c, "y")) + return stmt + + self.assert_compile( + go(), "SELECT t1.q, t1.y FROM t1 WHERE t1.q = t1.y" + ) + + def test_detect_embedded_callables_four(self): + t1 = table("t1", column("q")) + + x = 1 + + def go(): + def foo(): + return x + + stmt = select(t1).where( + lambdas.LambdaElement( + lambda: t1.c.q == foo(), + roles.WhereHavingRole, + lambdas.LambdaOptions(track_bound_values=False), + ) + ) + return stmt + + self.assert_compile( + go(), + "SELECT t1.q FROM t1 WHERE t1.q = :q_1", + checkparams={"q_1": 1}, + ) + + # we're not tracking it + x = 2 + + self.assert_compile( + go(), + "SELECT t1.q FROM t1 WHERE t1.q = :q_1", + checkparams={"q_1": 1}, + ) + def test_stmt_lambda_w_list_of_opts(self): def go(opts): stmt = lambdas.lambda_stmt(lambda: select(column("x"))) @@ -755,6 +983,23 @@ class DeferredLambdaTest( }, ) + def test_in_columnelement(self): + # test issue #5768 + + def go(): + v = [literal("a"), literal("b")] + expr1 = select(1).where(lambda: column("q").in_(v)) + return expr1 + + self.assert_compile(go(), "SELECT 1 WHERE q IN (:param_1, :param_2)") + + self.assert_compile( + go(), + "SELECT 1 WHERE q IN (:param_1, :param_2)", + render_postcompile=True, + checkparams={"param_1": "a", "param_2": "b"}, + ) + def test_select_columns_clause(self): t1 = table("t1", column("q"), column("p")) @@ -854,14 +1099,28 @@ class DeferredLambdaTest( expr, ) - def test_dict_literal_keys(self, user_address_fixture): + def test_reject_dict_literal_keys(self, user_address_fixture): users, addresses = user_address_fixture names = {"x": "some name"} lmb = lambda: users.c.name == names["x"] # noqa - expr = coercions.expect(roles.WhereHavingRole, lmb) + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'names' inside of lambda callable", + coercions.expect, + roles.WhereHavingRole, + lmb, + ) + + def test_dict_literal_keys_proper_use(self, user_address_fixture): + users, addresses = user_address_fixture + names = {"x": "some name"} + x = names["x"] + lmb = lambda: users.c.name == x # noqa + + expr = coercions.expect(roles.WhereHavingRole, lmb) self.assert_compile( expr, "users.name = :x_1", @@ -1158,7 +1417,7 @@ class DeferredLambdaTest( ), ) - def test_cache_key_thing(self): + def test_cache_key_bindparam_matches(self): t1 = table("t1", column("q"), column("p")) def go(x): @@ -1169,3 +1428,294 @@ class DeferredLambdaTest( is_(expr1._generate_cache_key().bindparams[0], expr1._resolved.right) is_(expr2._generate_cache_key().bindparams[0], expr2._resolved.right) + + def test_cache_key_instance_variable_issue_incorrect(self): + t1 = table("t1", column("q"), column("p")) + + class Foo(object): + def __init__(self, value): + self.value = value + + def go(foo): + return coercions.expect( + roles.WhereHavingRole, lambda: t1.c.q == foo.value + ) + + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'foo' inside of lambda callable", + go, + Foo(5), + ) + + def test_cache_key_instance_variable_issue_correct_one(self): + t1 = table("t1", column("q"), column("p")) + + class Foo(object): + def __init__(self, value): + self.value = value + + def go(foo): + value = foo.value + return coercions.expect( + roles.WhereHavingRole, lambda: t1.c.q == value + ) + + expr1 = go(Foo(5)) + expr2 = go(Foo(10)) + + c1 = expr1._generate_cache_key() + c2 = expr2._generate_cache_key() + eq_(c1, c2) + + def test_cache_key_instance_variable_issue_correct_two(self): + t1 = table("t1", column("q"), column("p")) + + class Foo(object): + def __init__(self, value): + self.value = value + + def go(foo): + return coercions.expect( + roles.WhereHavingRole, + lambda: t1.c.q == foo.value, + track_on=[self], + ) + + expr1 = go(Foo(5)) + expr2 = go(Foo(10)) + + c1 = expr1._generate_cache_key() + c2 = expr2._generate_cache_key() + eq_(c1, c2) + + def test_insert_statement(self, user_address_fixture): + users, addresses = user_address_fixture + + def ins(id_, name): + stmt = lambda_stmt(lambda: users.insert()) + stmt += lambda s: s.values(id=id_, name=name) + return stmt + + with testing.db.begin() as conn: + conn.execute(ins(12, "foo")) + + eq_( + conn.execute(select(users).where(users.c.id == 12)).first(), + (12, "foo"), + ) + + def test_update_statement(self, user_address_fixture): + users, addresses = user_address_fixture + + def upd(id_, newname): + stmt = lambda_stmt(lambda: users.update()) + stmt += lambda s: s.values(name=newname) + stmt += lambda s: s.where(users.c.id == id_) + return stmt + + with testing.db.begin() as conn: + conn.execute(users.insert().values(id=7, name="bar")) + conn.execute(upd(7, "foo")) + + eq_( + conn.execute(select(users).where(users.c.id == 7)).first(), + (7, "foo"), + ) + + +class DeferredLambdaElementTest( + fixtures.TestBase, testing.AssertsExecutionResults, AssertsCompiledSQL +): + __dialect__ = "default" + + @testing.fails("wontfix issue #5767") + def test_detect_change_in_binds_no_tracking(self): + t1 = table("t1", column("q"), column("p")) + t2 = table("t2", column("q"), column("p")) + + vv = [1, 2, 3] + # lambda produces either "t1 IN vv" or "NULL" based on the + # argument. will not produce a consistent cache key + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_(vv) if tab.name == "t2" else null(), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions(track_closure_variables=False), + ) + + self.assert_compile(elem.expr, "NULL") + + assert_raises_message( + exc.InvalidRequestError, + r"Lambda callable at %s produced " + "a different set of bound parameters " + "than its original run: vv" % (elem.fn.__code__), + elem._resolve_with_args, + t2, + ) + + def test_detect_change_in_binds_tracking_positive(self): + t1 = table("t1", column("q"), column("p")) + + vv = [1, 2, 3] + + # lambda produces either "t1 IN vv" or "NULL" based on the + # argument. will not produce a consistent cache key + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'vv' inside of lambda callable", + lambdas.DeferredLambdaElement, + lambda tab: tab.c.q.in_(vv) if tab.name == "t2" else None, + roles.WhereHavingRole, + opts=lambdas.LambdaOptions, + lambda_args=(t1,), + ) + + @testing.fails("wontfix issue #5767") + def test_detect_change_in_binds_tracking_negative(self): + t1 = table("t1", column("q"), column("p")) + t2 = table("t2", column("q"), column("p")) + + vv = [1, 2, 3] + qq = [3, 4, 5] + + # lambda produces either "t1 IN vv" or "t2 IN qq" based on the + # argument. will not produce a consistent cache key + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_(vv) + if tab.name == "t1" + else tab.c.q.in_(qq), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions(track_closure_variables=False), + ) + + self.assert_compile(elem.expr, "t1.q IN ([POSTCOMPILE_vv_1])") + + assert_raises_message( + exc.InvalidRequestError, + r"Lambda callable at %s produced " + "a different set of bound parameters " + "than its original run: qq" % (elem.fn.__code__), + elem._resolve_with_args, + t2, + ) + + def _fixture_one(self, t1): + vv = [1, 2, 3] + + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_(vv), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_two(self, t1): + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q == "x", + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_three(self, t1): + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q != "x", + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_four(self, t1): + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_([1, 2, 3]), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_five(self, t1): + def go(): + x = "x" + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q == x, + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_six(self, t1): + def go(): + x = "x" + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q != x, + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + @testing.combinations( + ("_fixture_one",), + ("_fixture_two",), + ("_fixture_three",), + ("_fixture_four",), + ("_fixture_five",), + ("_fixture_six",), + ) + def test_cache_key_many_different_args(self, fixture_name): + t1 = table("t1", column("q"), column("p")) + t2 = table("t2", column("q"), column("p")) + t3 = table("t3", column("q"), column("p")) + + go = getattr(self, fixture_name)(t1) + + g1 = go() + g2 = go() + + g1key = g1._generate_cache_key() + g2key = g2._generate_cache_key() + eq_(g1key[0], g2key[0]) + + e1 = go()._resolve_with_args(t1) + e2 = go()._resolve_with_args(t2) + e3 = go()._resolve_with_args(t3) + + e1key = e1._generate_cache_key() + e2key = e2._generate_cache_key() + e3key = e3._generate_cache_key() + + e12 = go()._resolve_with_args(t1) + e32 = go()._resolve_with_args(t3) + + e12key = e12._generate_cache_key() + e32key = e32._generate_cache_key() + + ne_(e1key[0], e2key[0]) + ne_(e2key[0], e3key[0]) + + eq_(e12key[0], e1key[0]) + eq_(e32key[0], e3key[0]) -- cgit v1.2.1