summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-04-01 18:31:16 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-04-01 19:25:23 -0400
commit49b6c50016c8a038a6df7104560bb3945debe064 (patch)
tree9b5b6b9ad6a6aba5374768afd52783fd8c2170f3
parenta9b62055bfa61c11e9fe0b2984437e2c3e32bf0e (diff)
downloadsqlalchemy-49b6c50016c8a038a6df7104560bb3945debe064.tar.gz
Repair caching / traversals for values
The test suite wasn't running the copy_internals most fixtures, enable that and try to get all cases working. Set up selectable.values to do tuple conversion within compilation step. at the same time, disable caching for selectable.values for the moment and make it equivalent to dml_multi_values. fix cache / compare / copy cases for dml_values and dml_multi_values which weren't fully tested or covered. Change-Id: I484ca6e9cb2b66c2e6a321698f2abc0838db1460
-rw-r--r--lib/sqlalchemy/sql/compiler.py9
-rw-r--r--lib/sqlalchemy/sql/selectable.py24
-rw-r--r--lib/sqlalchemy/sql/traversals.py77
-rw-r--r--test/sql/test_compare.py189
4 files changed, 165 insertions, 134 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 799fca2f5..b93ed8890 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -2324,9 +2324,14 @@ class SQLCompiler(Compiled):
return text
def visit_values(self, element, asfrom=False, from_linter=None, **kw):
+
v = "VALUES %s" % ", ".join(
- self.process(elem, literal_binds=element.literal_binds)
- for elem in element._data
+ self.process(
+ elements.Tuple(*elem).self_group(),
+ literal_binds=element.literal_binds,
+ )
+ for chunk in element._data
+ for elem in chunk
)
if isinstance(element.name, elements._truncated_label):
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index e39d61fdb..a0df45b52 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -47,7 +47,6 @@ from .elements import ColumnClause
from .elements import GroupedElement
from .elements import Grouping
from .elements import literal_column
-from .elements import Tuple
from .elements import UnaryExpression
from .visitors import InternalTraversal
from .. import exc
@@ -1264,14 +1263,16 @@ class AliasedReturnsRows(NoInit, FromClause):
self.element._generate_fromclause_column_proxies(self)
def _copy_internals(self, clone=_clone, **kw):
- element = clone(self.element, **kw)
+ existing_element = self.element
+
+ super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw)
# the element clone is usually against a Table that returns the
# same object. don't reset exported .c. collections and other
- # memoized details if nothing changed
- if element is not self.element:
+ # memoized details if it was not changed. this saves a lot on
+ # performance.
+ if existing_element is not self.element:
self._reset_column_collection()
- self.element = element
@property
def _from_objects(self):
@@ -1528,15 +1529,6 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
self._suffixes = _suffixes
super(CTE, self)._init(selectable, name=name)
- def _copy_internals(self, clone=_clone, **kw):
- super(CTE, self)._copy_internals(clone, **kw)
- # TODO: I don't like that we can't use the traversal data here
- if self._cte_alias is not None:
- self._cte_alias = clone(self._cte_alias, **kw)
- self._restates = frozenset(
- [clone(elem, **kw) for elem in self._restates]
- )
-
def alias(self, name=None, flat=False):
"""Return an :class:`.Alias` of this :class:`.CTE`.
@@ -2064,7 +2056,7 @@ class Values(Generative, FromClause):
_traverse_internals = [
("_column_args", InternalTraversal.dp_clauseelement_list,),
- ("_data", InternalTraversal.dp_clauseelement_list),
+ ("_data", InternalTraversal.dp_dml_multi_values),
("name", InternalTraversal.dp_string),
("literal_binds", InternalTraversal.dp_boolean),
]
@@ -2155,7 +2147,7 @@ class Values(Generative, FromClause):
"""
- self._data += tuple(Tuple(*row).self_group() for row in values)
+ self._data += (values,)
def _populate_column_collection(self):
for c in self._column_args:
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 9ac6cda97..032488826 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import util
from ..inspection import inspect
+from ..util import collections_abc
from ..util import HasMemoized
SKIP_TRAVERSE = util.symbol("skip_traverse")
@@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal):
]
def visit_dml_values(self, parent, element, clone=_clone, **kw):
- # sequence of dictionaries
- return [
- {
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key
- ): clone(value, **kw)
- for key, value in sub_element.items()
- }
- for sub_element in element
- ]
+ return {
+ (
+ clone(key, **kw) if hasattr(key, "__clause_element__") else key
+ ): clone(value, **kw)
+ for key, value in element.items()
+ }
def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
# sequence of sequences, each sequence contains a list/dict/tuple
@@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal):
def copy(elem):
if isinstance(elem, (list, tuple)):
return [
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key,
- clone(value, **kw)
- if hasattr(value, "__clause_element__")
- else value,
- )
- for key, value in elem
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ for value in elem
]
elif isinstance(elem, dict):
return {
@@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal):
if hasattr(value, "__clause_element__")
else value
)
- for key, value in elem
+ for key, value in elem.items()
}
else:
# TODO: use abc classes
@@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for (lk, lv), (rk, rv) in util.zip_longest(
left, right, fillvalue=(None, None)
):
- lkce = hasattr(lk, "__clause_element__")
- rkce = hasattr(rk, "__clause_element__")
- if lkce != rkce:
- return COMPARE_FAILED
- elif lkce and not self.compare_inner(lk, rk, **kw):
- return COMPARE_FAILED
- elif not lkce and lk != rk:
- return COMPARE_FAILED
- elif not self.compare_inner(lv, rv, **kw):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
return COMPARE_FAILED
+ def _compare_dml_values_or_ce(self, lv, rv, **kw):
+ lvce = hasattr(lv, "__clause_element__")
+ rvce = hasattr(rv, "__clause_element__")
+ if lvce != rvce:
+ return False
+ elif lvce and not self.compare_inner(lv, rv, **kw):
+ return False
+ elif not lvce and lv != rv:
+ return False
+ elif not self.compare_inner(lv, rv, **kw):
+ return False
+
+ return True
+
def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
if left is None or right is None or len(left) != len(right):
return COMPARE_FAILED
- for lk in left:
- lv = left[lk]
+ if isinstance(left, collections_abc.Sequence):
+ for lv, rv in zip(left, right):
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+ else:
+ for lk in left:
+ lv = left[lk]
- if lk not in right:
- return COMPARE_FAILED
- rv = right[lk]
+ if lk not in right:
+ return COMPARE_FAILED
+ rv = right[lk]
- if not self.compare_inner(lv, rv, **kw):
- return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
def visit_dml_multi_values(
self, left_parent, left, right_parent, right, **kw
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 2800f8248..3a6feac01 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -55,11 +55,13 @@ from sqlalchemy.sql.functions import ReturnTypeFromArgs
from sqlalchemy.sql.selectable import _OffsetLimitParam
from sqlalchemy.sql.selectable import AliasedReturnsRows
from sqlalchemy.sql.selectable import FromGrouping
+from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql.selectable import Selectable
from sqlalchemy.sql.selectable import SelectStatementGrouping
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_not_
from sqlalchemy.testing import is_true
@@ -372,6 +374,15 @@ class CoreFixtures(object):
table_b.insert()
.values([{"a": 5, "b": 10}, {"a": 8, "b": 12}])
._annotate({"nocache": True}),
+ table_b.insert()
+ .values([{"a": 9, "b": 10}, {"a": 8, "b": 7}])
+ ._annotate({"nocache": True}),
+ table_b.insert()
+ .values([(5, 10), (8, 12)])
+ ._annotate({"nocache": True}),
+ table_b.insert()
+ .values([(5, 9), (5, 12)])
+ ._annotate({"nocache": True}),
),
lambda: (
table_b.update(),
@@ -405,6 +416,51 @@ class CoreFixtures(object):
table_b.delete().where(table_b.c.b == 5),
),
lambda: (
+ values(
+ column("mykey", Integer),
+ column("mytext", String),
+ column("myint", Integer),
+ name="myvalues",
+ )
+ .data([(1, "textA", 99), (2, "textB", 88)])
+ ._annotate({"nocache": True}),
+ values(
+ column("mykey", Integer),
+ column("mytext", String),
+ column("myint", Integer),
+ name="myothervalues",
+ )
+ .data([(1, "textA", 99), (2, "textB", 88)])
+ ._annotate({"nocache": True}),
+ values(
+ column("mykey", Integer),
+ column("mytext", String),
+ column("myint", Integer),
+ name="myvalues",
+ )
+ .data([(1, "textA", 89), (2, "textG", 88)])
+ ._annotate({"nocache": True}),
+ values(
+ column("mykey", Integer),
+ column("mynottext", String),
+ column("myint", Integer),
+ name="myvalues",
+ )
+ .data([(1, "textA", 99), (2, "textB", 88)])
+ ._annotate({"nocache": True}),
+ # TODO: difference in type
+ # values(
+ # [
+ # column("mykey", Integer),
+ # column("mytext", Text),
+ # column("myint", Integer),
+ # ],
+ # (1, "textA", 99),
+ # (2, "textB", 88),
+ # alias_name="myvalues",
+ # ),
+ ),
+ lambda: (
select([table_a.c.a]),
select([table_a.c.a]).prefix_with("foo"),
select([table_a.c.a]).prefix_with("foo", dialect="mysql"),
@@ -482,43 +538,6 @@ class CoreFixtures(object):
table("a", column("q"), column("y", Integer)),
),
lambda: (table_a, table_b),
- lambda: (
- values(
- column("mykey", Integer),
- column("mytext", String),
- column("myint", Integer),
- name="myvalues",
- ).data([(1, "textA", 99), (2, "textB", 88)]),
- values(
- column("mykey", Integer),
- column("mytext", String),
- column("myint", Integer),
- name="myothervalues",
- ).data([(1, "textA", 99), (2, "textB", 88)]),
- values(
- column("mykey", Integer),
- column("mytext", String),
- column("myint", Integer),
- name="myvalues",
- ).data([(1, "textA", 89), (2, "textG", 88)]),
- values(
- column("mykey", Integer),
- column("mynottext", String),
- column("myint", Integer),
- name="myvalues",
- ).data([(1, "textA", 99), (2, "textB", 88)]),
- # TODO: difference in type
- # values(
- # [
- # column("mykey", Integer),
- # column("mytext", Text),
- # column("myint", Integer),
- # ],
- # (1, "textA", 99),
- # (2, "textB", 88),
- # alias_name="myvalues",
- # ),
- ),
]
dont_compare_values_fixtures = [
@@ -697,10 +716,36 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
index_elements=[table_a.c.a], set_={"name": "foo"}
),
mysql.insert(table_a).on_duplicate_key_update(updated_once=None),
+ table_a.insert().values( # multivalues doesn't cache
+ [
+ {"name": "some name"},
+ {"name": "some other name"},
+ {"name": "yet another name"},
+ ]
+ ),
)
def test_dml_not_cached_yet(self, dml_stmt):
eq_(dml_stmt._generate_cache_key(), None)
+ def test_values_doesnt_caches_right_now(self):
+ v1 = values(
+ column("mykey", Integer),
+ column("mytext", String),
+ column("myint", Integer),
+ name="myvalues",
+ ).data([(1, "textA", 99), (2, "textB", 88)])
+
+ is_(v1._generate_cache_key(), None)
+
+ large_v1 = values(
+ column("mykey", Integer),
+ column("mytext", String),
+ column("myint", Integer),
+ name="myvalues",
+ ).data([(i, "data %s" % i, i * 5) for i in range(500)])
+
+ is_(large_v1._generate_cache_key(), None)
+
def test_cache_key(self):
for fixtures_, compare_values in [
(self.fixtures, True),
@@ -912,50 +957,38 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
case_a = fixture()
case_b = fixture()
- assert case_a[0].compare(
- case_b[0], compare_values=compare_values
- )
+ for idx in range(len(case_a)):
+ assert case_a[idx].compare(
+ case_b[idx], compare_values=compare_values
+ )
- clone = visitors.replacement_traverse(
- case_a[0], {}, lambda elem: None
- )
+ clone = visitors.replacement_traverse(
+ case_a[idx], {}, lambda elem: None
+ )
- assert clone.compare(case_b[0], compare_values=compare_values)
-
- stack = [clone]
- seen = {clone}
- found_elements = False
- while stack:
- obj = stack.pop(0)
-
- items = [
- subelem
- for key, elem in clone.__dict__.items()
- if key != "_is_clone_of" and elem is not None
- for subelem in util.to_list(elem)
- if (
- isinstance(subelem, (ColumnElement, ClauseList))
- and subelem not in seen
- and not isinstance(subelem, Immutable)
- and subelem is not case_a[0]
- )
- ]
- stack.extend(items)
- seen.update(items)
-
- if obj is not clone:
- found_elements = True
- # ensure the element will not compare as true
- obj.compare = lambda other, **kw: False
- obj.__visit_name__ = "dont_match"
-
- if found_elements:
- assert not clone.compare(
- case_b[0], compare_values=compare_values
+ assert clone.compare(
+ case_b[idx], compare_values=compare_values
)
- assert case_a[0].compare(
- case_b[0], compare_values=compare_values
- )
+
+ assert case_a[idx].compare(
+ case_b[idx], compare_values=compare_values
+ )
+
+ # copy internals of Select is very different than other
+ # elements and additionally this is extremely well tested
+ # in test_selectable and test_external_traversal, so
+ # skip these
+ if isinstance(case_a[idx], Select):
+ continue
+
+ for elema, elemb in zip(
+ visitors.iterate(case_a[idx], {}),
+ visitors.iterate(clone, {}),
+ ):
+ if isinstance(elema, ClauseElement) and not isinstance(
+ elema, Immutable
+ ):
+ assert elema is not elemb
class CompareClausesTest(fixtures.TestBase):