diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 77 |
3 files changed, 54 insertions, 56 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 799fca2f5..b93ed8890 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2324,9 +2324,14 @@ class SQLCompiler(Compiled): return text def visit_values(self, element, asfrom=False, from_linter=None, **kw): + v = "VALUES %s" % ", ".join( - self.process(elem, literal_binds=element.literal_binds) - for elem in element._data + self.process( + elements.Tuple(*elem).self_group(), + literal_binds=element.literal_binds, + ) + for chunk in element._data + for elem in chunk ) if isinstance(element.name, elements._truncated_label): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e39d61fdb..a0df45b52 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -47,7 +47,6 @@ from .elements import ColumnClause from .elements import GroupedElement from .elements import Grouping from .elements import literal_column -from .elements import Tuple from .elements import UnaryExpression from .visitors import InternalTraversal from .. import exc @@ -1264,14 +1263,16 @@ class AliasedReturnsRows(NoInit, FromClause): self.element._generate_fromclause_column_proxies(self) def _copy_internals(self, clone=_clone, **kw): - element = clone(self.element, **kw) + existing_element = self.element + + super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) # the element clone is usually against a Table that returns the # same object. don't reset exported .c. collections and other - # memoized details if nothing changed - if element is not self.element: + # memoized details if it was not changed. this saves a lot on + # performance. + if existing_element is not self.element: self._reset_column_collection() - self.element = element @property def _from_objects(self): @@ -1528,15 +1529,6 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): self._suffixes = _suffixes super(CTE, self)._init(selectable, name=name) - def _copy_internals(self, clone=_clone, **kw): - super(CTE, self)._copy_internals(clone, **kw) - # TODO: I don't like that we can't use the traversal data here - if self._cte_alias is not None: - self._cte_alias = clone(self._cte_alias, **kw) - self._restates = frozenset( - [clone(elem, **kw) for elem in self._restates] - ) - def alias(self, name=None, flat=False): """Return an :class:`.Alias` of this :class:`.CTE`. @@ -2064,7 +2056,7 @@ class Values(Generative, FromClause): _traverse_internals = [ ("_column_args", InternalTraversal.dp_clauseelement_list,), - ("_data", InternalTraversal.dp_clauseelement_list), + ("_data", InternalTraversal.dp_dml_multi_values), ("name", InternalTraversal.dp_string), ("literal_binds", InternalTraversal.dp_boolean), ] @@ -2155,7 +2147,7 @@ class Values(Generative, FromClause): """ - self._data += tuple(Tuple(*row).self_group() for row in values) + self._data += (values,) def _populate_column_collection(self): for c in self._column_args: diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 9ac6cda97..032488826 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import util from ..inspection import inspect +from ..util import collections_abc from ..util import HasMemoized SKIP_TRAVERSE = util.symbol("skip_traverse") @@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal): ] def visit_dml_values(self, parent, element, clone=_clone, **kw): - # sequence of dictionaries - return [ - { - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key - ): clone(value, **kw) - for key, value in sub_element.items() - } - for sub_element in element - ] + return { + ( + clone(key, **kw) if hasattr(key, "__clause_element__") else key + ): clone(value, **kw) + for key, value in element.items() + } def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): # sequence of sequences, each sequence contains a list/dict/tuple @@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal): def copy(elem): if isinstance(elem, (list, tuple)): return [ - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key, - clone(value, **kw) - if hasattr(value, "__clause_element__") - else value, - ) - for key, value in elem + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value + for value in elem ] elif isinstance(elem, dict): return { @@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal): if hasattr(value, "__clause_element__") else value ) - for key, value in elem + for key, value in elem.items() } else: # TODO: use abc classes @@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for (lk, lv), (rk, rv) in util.zip_longest( left, right, fillvalue=(None, None) ): - lkce = hasattr(lk, "__clause_element__") - rkce = hasattr(rk, "__clause_element__") - if lkce != rkce: - return COMPARE_FAILED - elif lkce and not self.compare_inner(lk, rk, **kw): - return COMPARE_FAILED - elif not lkce and lk != rk: - return COMPARE_FAILED - elif not self.compare_inner(lv, rv, **kw): + if not self._compare_dml_values_or_ce(lk, rk, **kw): return COMPARE_FAILED + def _compare_dml_values_or_ce(self, lv, rv, **kw): + lvce = hasattr(lv, "__clause_element__") + rvce = hasattr(rv, "__clause_element__") + if lvce != rvce: + return False + elif lvce and not self.compare_inner(lv, rv, **kw): + return False + elif not lvce and lv != rv: + return False + elif not self.compare_inner(lv, rv, **kw): + return False + + return True + def visit_dml_values(self, left_parent, left, right_parent, right, **kw): if left is None or right is None or len(left) != len(right): return COMPARE_FAILED - for lk in left: - lv = left[lk] + if isinstance(left, collections_abc.Sequence): + for lv, rv in zip(left, right): + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED + else: + for lk in left: + lv = left[lk] - if lk not in right: - return COMPARE_FAILED - rv = right[lk] + if lk not in right: + return COMPARE_FAILED + rv = right[lk] - if not self.compare_inner(lv, rv, **kw): - return COMPARE_FAILED + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED def visit_dml_multi_values( self, left_parent, left, right_parent, right, **kw |
