summaryrefslogtreecommitdiff
path: root/test/sql
diff options
context:
space:
mode:
Diffstat (limited to 'test/sql')
-rw-r--r--test/sql/test_compare.py79
-rw-r--r--test/sql/test_functions.py20
-rw-r--r--test/sql/test_labels.py2
-rw-r--r--test/sql/test_lambdas.py25
-rw-r--r--test/sql/test_operators.py8
-rw-r--r--test/sql/test_resultset.py2
-rw-r--r--test/sql/test_types.py2
7 files changed, 127 insertions, 11 deletions
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index af78ea19b..ca1eff62b 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -16,6 +16,7 @@ from sqlalchemy import Integer
from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import or_
+from sqlalchemy import PickleType
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
@@ -1264,13 +1265,20 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
# the None for cache key will prevent objects
# which contain these elements from being cached.
f1 = Foobar1()
- eq_(f1._generate_cache_key(), None)
+ with expect_warnings(
+ "Class Foobar1 will not make use of SQL compilation caching"
+ ):
+ eq_(f1._generate_cache_key(), None)
f2 = Foobar2()
- eq_(f2._generate_cache_key(), None)
+ with expect_warnings(
+ "Class Foobar2 will not make use of SQL compilation caching"
+ ):
+ eq_(f2._generate_cache_key(), None)
s1 = select(column("q"), Foobar2())
+ # warning is memoized, won't happen the second time
eq_(s1._generate_cache_key(), None)
def test_get_children_no_method(self):
@@ -1355,6 +1363,7 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
and (
"__init__" in cls.__dict__
or issubclass(cls, AliasedReturnsRows)
+ or "inherit_cache" not in cls.__dict__
)
and not issubclass(cls, (Annotated))
and "orm" not in cls.__module__
@@ -1819,3 +1828,69 @@ class TypesTest(fixtures.TestBase):
eq_(c1, c2)
ne_(c1, c3)
eq_(c1, c4)
+
+ def test_thirdparty_sub_subclass_no_cache(self):
+ class MyType(PickleType):
+ pass
+
+ expr = column("q", MyType()) == 1
+
+ with expect_warnings(
+ r"TypeDecorator MyType\(\) will not produce a cache key"
+ ):
+ is_(expr._generate_cache_key(), None)
+
+ def test_userdefined_sub_subclass_no_cache(self):
+ class MyType(UserDefinedType):
+ cache_ok = True
+
+ class MySubType(MyType):
+ pass
+
+ expr = column("q", MySubType()) == 1
+
+ with expect_warnings(
+ r"UserDefinedType MySubType\(\) will not produce a cache key"
+ ):
+ is_(expr._generate_cache_key(), None)
+
+ def test_userdefined_sub_subclass_cache_ok(self):
+ class MyType(UserDefinedType):
+ cache_ok = True
+
+ class MySubType(MyType):
+ cache_ok = True
+
+ def go1():
+ expr = column("q", MySubType()) == 1
+ return expr
+
+ def go2():
+ expr = column("p", MySubType()) == 1
+ return expr
+
+ c1 = go1()._generate_cache_key()[0]
+ c2 = go1()._generate_cache_key()[0]
+ c3 = go2()._generate_cache_key()[0]
+
+ eq_(c1, c2)
+ ne_(c1, c3)
+
+ def test_thirdparty_sub_subclass_cache_ok(self):
+ class MyType(PickleType):
+ cache_ok = True
+
+ def go1():
+ expr = column("q", MyType()) == 1
+ return expr
+
+ def go2():
+ expr = column("p", MyType()) == 1
+ return expr
+
+ c1 = go1()._generate_cache_key()[0]
+ c2 = go1()._generate_cache_key()[0]
+ c3 = go2()._generate_cache_key()[0]
+
+ eq_(c1, c2)
+ ne_(c1, c3)
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index 9378cfc38..e08526419 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -81,6 +81,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
# test generic function compile
class fake_func(GenericFunction):
+ inherit_cache = True
__return_type__ = sqltypes.Integer
def __init__(self, arg, **kwargs):
@@ -107,6 +108,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
if use_custom:
class MyFunc(FunctionElement):
+ inherit_cache = True
name = "myfunc"
type = Integer()
@@ -135,6 +137,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_use_labels_function_element(self):
class max_(FunctionElement):
name = "max"
+ inherit_cache = True
@compiles(max_)
def visit_max(element, compiler, **kw):
@@ -260,7 +263,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_default_namespace(self):
class myfunc(GenericFunction):
- pass
+ inherit_cache = True
assert isinstance(func.myfunc(), myfunc)
self.assert_compile(func.myfunc(), "myfunc()")
@@ -268,6 +271,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_type(self):
class myfunc(GenericFunction):
type = DateTime
+ inherit_cache = True
assert isinstance(func.myfunc().type, DateTime)
self.assert_compile(func.myfunc(), "myfunc()")
@@ -275,12 +279,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_legacy_type(self):
# in case someone was using this system
class myfunc(GenericFunction):
+ inherit_cache = True
__return_type__ = DateTime
assert isinstance(func.myfunc().type, DateTime)
def test_case_sensitive(self):
class MYFUNC(GenericFunction):
+ inherit_cache = True
type = DateTime
assert isinstance(func.MYFUNC().type, DateTime)
@@ -336,6 +342,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_w_custom_name(self):
class myfunc(GenericFunction):
+ inherit_cache = True
name = "notmyfunc"
assert isinstance(func.notmyfunc(), myfunc)
@@ -343,6 +350,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_w_quoted_name(self):
class myfunc(GenericFunction):
+ inherit_cache = True
name = quoted_name("NotMyFunc", quote=True)
identifier = "myfunc"
@@ -350,6 +358,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_w_quoted_name_no_identifier(self):
class myfunc(GenericFunction):
+ inherit_cache = True
name = quoted_name("NotMyFunc", quote=True)
# note this requires that the quoted name be lower cased for
@@ -359,6 +368,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_package_namespace(self):
def cls1(pk_name):
class myfunc(GenericFunction):
+ inherit_cache = True
package = pk_name
return myfunc
@@ -372,6 +382,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_name(self):
class MyFunction(GenericFunction):
name = "my_func"
+ inherit_cache = True
def __init__(self, *args):
args = args + (3,)
@@ -387,20 +398,24 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
package = "geo"
name = "BufferOne"
identifier = "buf1"
+ inherit_cache = True
class GeoBuffer2(GenericFunction):
type = Integer
name = "BufferTwo"
identifier = "buf2"
+ inherit_cache = True
class BufferThree(GenericFunction):
type = Integer
identifier = "buf3"
+ inherit_cache = True
class GeoBufferFour(GenericFunction):
type = Integer
name = "BufferFour"
identifier = "Buf4"
+ inherit_cache = True
self.assert_compile(func.geo.buf1(), "BufferOne()")
self.assert_compile(func.buf2(), "BufferTwo()")
@@ -413,7 +428,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_custom_args(self):
class myfunc(GenericFunction):
- pass
+ inherit_cache = True
self.assert_compile(
myfunc(1, 2, 3), "myfunc(:myfunc_1, :myfunc_2, :myfunc_3)"
@@ -1010,6 +1025,7 @@ class ExecuteTest(fixtures.TestBase):
from sqlalchemy.ext.compiler import compiles
class myfunc(FunctionElement):
+ inherit_cache = True
type = Date()
@compiles(myfunc)
diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py
index 535d4dd0b..8c8e9dbed 100644
--- a/test/sql/test_labels.py
+++ b/test/sql/test_labels.py
@@ -805,6 +805,8 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
def _fixture(self):
class SomeColThing(WrapsColumnExpression, ColumnElement):
+ inherit_cache = False
+
def __init__(self, expression):
self.clause = coercions.expect(
roles.ExpressionElementRole, expression
diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py
index fd6b1eb41..bbf9716f5 100644
--- a/test/sql/test_lambdas.py
+++ b/test/sql/test_lambdas.py
@@ -17,6 +17,7 @@ from sqlalchemy.sql import roles
from sqlalchemy.sql import select
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.traversals import HasCacheKey
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
@@ -810,7 +811,10 @@ class LambdaElementTest(
stmt = lambdas.lambda_stmt(lambda: select(column("x")))
- opts = {column("x"), column("y")}
+ class MyUncacheable(ExecutableOption):
+ pass
+
+ opts = {MyUncacheable()}
assert_raises_message(
exc.InvalidRequestError,
@@ -942,11 +946,18 @@ class LambdaElementTest(
return stmt
- s1 = go([column("a"), column("b")])
+ class SomeOpt(HasCacheKey, ExecutableOption):
+ def __init__(self, x):
+ self.x = x
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (SomeOpt, self.x)
+
+ s1 = go([SomeOpt("a"), SomeOpt("b")])
- s2 = go([column("a"), column("b")])
+ s2 = go([SomeOpt("a"), SomeOpt("b")])
- s3 = go([column("q"), column("b")])
+ s3 = go([SomeOpt("q"), SomeOpt("b")])
s1key = s1._generate_cache_key()
s2key = s2._generate_cache_key()
@@ -964,7 +975,7 @@ class LambdaElementTest(
return stmt
- class SomeOpt(HasCacheKey):
+ class SomeOpt(HasCacheKey, ExecutableOption):
def _gen_cache_key(self, anon_map, bindparams):
return ("fixed_key",)
@@ -994,8 +1005,8 @@ class LambdaElementTest(
return stmt
- class SomeOpt(HasCacheKey):
- pass
+ class SomeOpt(HasCacheKey, ExecutableOption):
+ inherit_cache = False
# generates no key, will not be cached
eq_(SomeOpt()._generate_cache_key(), None)
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index 831b75e7e..6e943d236 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -656,6 +656,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
def test_contains(self):
class MyType(UserDefinedType):
+ cache_ok = True
+
class comparator_factory(UserDefinedType.Comparator):
def contains(self, other, **kw):
return self.op("->")(other)
@@ -664,6 +666,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
def test_getitem(self):
class MyType(UserDefinedType):
+ cache_ok = True
+
class comparator_factory(UserDefinedType.Comparator):
def __getitem__(self, index):
return self.op("->")(index)
@@ -682,6 +686,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
def test_lshift(self):
class MyType(UserDefinedType):
+ cache_ok = True
+
class comparator_factory(UserDefinedType.Comparator):
def __lshift__(self, other):
return self.op("->")(other)
@@ -690,6 +696,8 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
def test_rshift(self):
class MyType(UserDefinedType):
+ cache_ok = True
+
class comparator_factory(UserDefinedType.Comparator):
def __rshift__(self, other):
return self.op("->")(other)
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 4aa932b47..e4f07a758 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -1797,6 +1797,7 @@ class KeyTargetingTest(fixtures.TablesTest):
def test_keyed_targeting_no_label_at_all_one(self, connection):
class not_named_max(expression.ColumnElement):
name = "not_named_max"
+ inherit_cache = True
@compiles(not_named_max)
def visit_max(element, compiler, **kw):
@@ -1814,6 +1815,7 @@ class KeyTargetingTest(fixtures.TablesTest):
def test_keyed_targeting_no_label_at_all_two(self, connection):
class not_named_max(expression.ColumnElement):
name = "not_named_max"
+ inherit_cache = True
@compiles(not_named_max)
def visit_max(element, compiler, **kw):
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 35467de94..dc47cca46 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1497,6 +1497,8 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
return process
class UTypeThree(types.UserDefinedType):
+ cache_ok = True
+
def get_col_spec(self):
return "UTYPETHREE"