diff options
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_compare.py | 79 | ||||
| -rw-r--r-- | test/sql/test_functions.py | 20 | ||||
| -rw-r--r-- | test/sql/test_labels.py | 2 | ||||
| -rw-r--r-- | test/sql/test_lambdas.py | 25 | ||||
| -rw-r--r-- | test/sql/test_operators.py | 8 | ||||
| -rw-r--r-- | test/sql/test_resultset.py | 2 | ||||
| -rw-r--r-- | test/sql/test_types.py | 2 |
7 files changed, 127 insertions, 11 deletions
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index af78ea19b..ca1eff62b 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -16,6 +16,7 @@ from sqlalchemy import Integer from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ +from sqlalchemy import PickleType from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table @@ -1264,13 +1265,20 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): # the None for cache key will prevent objects # which contain these elements from being cached. f1 = Foobar1() - eq_(f1._generate_cache_key(), None) + with expect_warnings( + "Class Foobar1 will not make use of SQL compilation caching" + ): + eq_(f1._generate_cache_key(), None) f2 = Foobar2() - eq_(f2._generate_cache_key(), None) + with expect_warnings( + "Class Foobar2 will not make use of SQL compilation caching" + ): + eq_(f2._generate_cache_key(), None) s1 = select(column("q"), Foobar2()) + # warning is memoized, won't happen the second time eq_(s1._generate_cache_key(), None) def test_get_children_no_method(self): @@ -1355,6 +1363,7 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): and ( "__init__" in cls.__dict__ or issubclass(cls, AliasedReturnsRows) + or "inherit_cache" not in cls.__dict__ ) and not issubclass(cls, (Annotated)) and "orm" not in cls.__module__ @@ -1819,3 +1828,69 @@ class TypesTest(fixtures.TestBase): eq_(c1, c2) ne_(c1, c3) eq_(c1, c4) + + def test_thirdparty_sub_subclass_no_cache(self): + class MyType(PickleType): + pass + + expr = column("q", MyType()) == 1 + + with expect_warnings( + r"TypeDecorator MyType\(\) will not produce a cache key" + ): + is_(expr._generate_cache_key(), None) + + def test_userdefined_sub_subclass_no_cache(self): + class MyType(UserDefinedType): + cache_ok = True + + class MySubType(MyType): + pass + + expr = column("q", MySubType()) == 1 + + with expect_warnings( + r"UserDefinedType MySubType\(\) will not produce a cache key" + ): + is_(expr._generate_cache_key(), None) + + def test_userdefined_sub_subclass_cache_ok(self): + class MyType(UserDefinedType): + cache_ok = True + + class MySubType(MyType): + cache_ok = True + + def go1(): + expr = column("q", MySubType()) == 1 + return expr + + def go2(): + expr = column("p", MySubType()) == 1 + return expr + + c1 = go1()._generate_cache_key()[0] + c2 = go1()._generate_cache_key()[0] + c3 = go2()._generate_cache_key()[0] + + eq_(c1, c2) + ne_(c1, c3) + + def test_thirdparty_sub_subclass_cache_ok(self): + class MyType(PickleType): + cache_ok = True + + def go1(): + expr = column("q", MyType()) == 1 + return expr + + def go2(): + expr = column("p", MyType()) == 1 + return expr + + c1 = go1()._generate_cache_key()[0] + c2 = go1()._generate_cache_key()[0] + c3 = go2()._generate_cache_key()[0] + + eq_(c1, c2) + ne_(c1, c3) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 9378cfc38..e08526419 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -81,6 +81,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): # test generic function compile class fake_func(GenericFunction): + inherit_cache = True __return_type__ = sqltypes.Integer def __init__(self, arg, **kwargs): @@ -107,6 +108,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): if use_custom: class MyFunc(FunctionElement): + inherit_cache = True name = "myfunc" type = Integer() @@ -135,6 +137,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_use_labels_function_element(self): class max_(FunctionElement): name = "max" + inherit_cache = True @compiles(max_) def visit_max(element, compiler, **kw): @@ -260,7 +263,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_default_namespace(self): class myfunc(GenericFunction): - pass + inherit_cache = True assert isinstance(func.myfunc(), myfunc) self.assert_compile(func.myfunc(), "myfunc()") @@ -268,6 +271,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_type(self): class myfunc(GenericFunction): type = DateTime + inherit_cache = True assert isinstance(func.myfunc().type, DateTime) self.assert_compile(func.myfunc(), "myfunc()") @@ -275,12 +279,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_legacy_type(self): # in case someone was using this system class myfunc(GenericFunction): + inherit_cache = True __return_type__ = DateTime assert isinstance(func.myfunc().type, DateTime) def test_case_sensitive(self): class MYFUNC(GenericFunction): + inherit_cache = True type = DateTime assert isinstance(func.MYFUNC().type, DateTime) @@ -336,6 +342,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_w_custom_name(self): class myfunc(GenericFunction): + inherit_cache = True name = "notmyfunc" assert isinstance(func.notmyfunc(), myfunc) @@ -343,6 +350,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_w_quoted_name(self): class myfunc(GenericFunction): + inherit_cache = True name = quoted_name("NotMyFunc", quote=True) identifier = "myfunc" @@ -350,6 +358,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_w_quoted_name_no_identifier(self): class myfunc(GenericFunction): + inherit_cache = True name = quoted_name("NotMyFunc", quote=True) # note this requires that the quoted name be lower cased for @@ -359,6 +368,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_package_namespace(self): def cls1(pk_name): class myfunc(GenericFunction): + inherit_cache = True package = pk_name return myfunc @@ -372,6 +382,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_name(self): class MyFunction(GenericFunction): name = "my_func" + inherit_cache = True def __init__(self, *args): args = args + (3,) @@ -387,20 +398,24 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): package = "geo" name = "BufferOne" identifier = "buf1" + inherit_cache = True class GeoBuffer2(GenericFunction): type = Integer name = "BufferTwo" identifier = "buf2" + inherit_cache = True class BufferThree(GenericFunction): type = Integer identifier = "buf3" + inherit_cache = True class GeoBufferFour(GenericFunction): type = Integer name = "BufferFour" identifier = "Buf4" + inherit_cache = True self.assert_compile(func.geo.buf1(), "BufferOne()") self.assert_compile(func.buf2(), "BufferTwo()") @@ -413,7 +428,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_args(self): class myfunc(GenericFunction): - pass + inherit_cache = True self.assert_compile( myfunc(1, 2, 3), "myfunc(:myfunc_1, :myfunc_2, :myfunc_3)" @@ -1010,6 +1025,7 @@ class ExecuteTest(fixtures.TestBase): from sqlalchemy.ext.compiler import compiles class myfunc(FunctionElement): + inherit_cache = True type = Date() @compiles(myfunc) diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index 535d4dd0b..8c8e9dbed 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -805,6 +805,8 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL): def _fixture(self): class SomeColThing(WrapsColumnExpression, ColumnElement): + inherit_cache = False + def __init__(self, expression): self.clause = coercions.expect( roles.ExpressionElementRole, expression diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index fd6b1eb41..bbf9716f5 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -17,6 +17,7 @@ from sqlalchemy.sql import roles from sqlalchemy.sql import select from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql.base import ExecutableOption from sqlalchemy.sql.traversals import HasCacheKey from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -810,7 +811,10 @@ class LambdaElementTest( stmt = lambdas.lambda_stmt(lambda: select(column("x"))) - opts = {column("x"), column("y")} + class MyUncacheable(ExecutableOption): + pass + + opts = {MyUncacheable()} assert_raises_message( exc.InvalidRequestError, @@ -942,11 +946,18 @@ class LambdaElementTest( return stmt - s1 = go([column("a"), column("b")]) + class SomeOpt(HasCacheKey, ExecutableOption): + def __init__(self, x): + self.x = x + + def _gen_cache_key(self, anon_map, bindparams): + return (SomeOpt, self.x) + + s1 = go([SomeOpt("a"), SomeOpt("b")]) - s2 = go([column("a"), column("b")]) + s2 = go([SomeOpt("a"), SomeOpt("b")]) - s3 = go([column("q"), column("b")]) + s3 = go([SomeOpt("q"), SomeOpt("b")]) s1key = s1._generate_cache_key() s2key = s2._generate_cache_key() @@ -964,7 +975,7 @@ class LambdaElementTest( return stmt - class SomeOpt(HasCacheKey): + class SomeOpt(HasCacheKey, ExecutableOption): def _gen_cache_key(self, anon_map, bindparams): return ("fixed_key",) @@ -994,8 +1005,8 @@ class LambdaElementTest( return stmt - class SomeOpt(HasCacheKey): - pass + class SomeOpt(HasCacheKey, ExecutableOption): + inherit_cache = False # generates no key, will not be cached eq_(SomeOpt()._generate_cache_key(), None) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 831b75e7e..6e943d236 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -656,6 +656,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_contains(self): class MyType(UserDefinedType): + cache_ok = True + class comparator_factory(UserDefinedType.Comparator): def contains(self, other, **kw): return self.op("->")(other) @@ -664,6 +666,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_getitem(self): class MyType(UserDefinedType): + cache_ok = True + class comparator_factory(UserDefinedType.Comparator): def __getitem__(self, index): return self.op("->")(index) @@ -682,6 +686,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_lshift(self): class MyType(UserDefinedType): + cache_ok = True + class comparator_factory(UserDefinedType.Comparator): def __lshift__(self, other): return self.op("->")(other) @@ -690,6 +696,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_rshift(self): class MyType(UserDefinedType): + cache_ok = True + class comparator_factory(UserDefinedType.Comparator): def __rshift__(self, other): return self.op("->")(other) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 4aa932b47..e4f07a758 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1797,6 +1797,7 @@ class KeyTargetingTest(fixtures.TablesTest): def test_keyed_targeting_no_label_at_all_one(self, connection): class not_named_max(expression.ColumnElement): name = "not_named_max" + inherit_cache = True @compiles(not_named_max) def visit_max(element, compiler, **kw): @@ -1814,6 +1815,7 @@ class KeyTargetingTest(fixtures.TablesTest): def test_keyed_targeting_no_label_at_all_two(self, connection): class not_named_max(expression.ColumnElement): name = "not_named_max" + inherit_cache = True @compiles(not_named_max) def visit_max(element, compiler, **kw): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 35467de94..dc47cca46 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1497,6 +1497,8 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): return process class UTypeThree(types.UserDefinedType): + cache_ok = True + def get_col_spec(self): return "UTYPETHREE" |
