diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-09-08 16:31:11 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-09-08 16:31:11 -0400 |
commit | 7904ebc62e0a75d1ea31e1a4ae67654c7681a737 (patch) | |
tree | a0e162ea74d3bb25390643b7db84bb288ca4e841 | |
parent | e4996d4f5432657639798c1b286ee811a36e2a10 (diff) | |
download | sqlalchemy-7904ebc62e0a75d1ea31e1a4ae67654c7681a737.tar.gz |
- rework the previous "order by" system in terms of the new one,
unify everything.
- create a new layer of separation between the "from order bys" and "column order bys",
so that an OVER doesn't ORDER BY a label in the same columns clause
- identify another issue with polymorphic for ref #3148, match on label
keys rather than the objects
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 76 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 38 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 4 | ||||
-rw-r--r-- | test/orm/test_query.py | 20 | ||||
-rw-r--r-- | test/sql/test_compiler.py | 21 |
6 files changed, 123 insertions, 46 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 72dd11eaf..5149fa4fe 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -503,7 +503,35 @@ class SQLCompiler(Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" - def visit_label_reference(self, element, **kwargs): + def visit_label_reference( + self, element, within_columns_clause=False, **kwargs): + if self.stack and self.dialect.supports_simple_order_by_label: + selectable = self.stack[-1]['selectable'] + + with_cols, only_froms = selectable._label_resolve_dict + if within_columns_clause: + resolve_dict = only_froms + else: + resolve_dict = with_cols + + # this can be None in the case that a _label_reference() + # were subject to a replacement operation, in which case + # the replacement of the Label element may have changed + # to something else like a ColumnClause expression. + order_by_elem = element.element._order_by_label_element + + if order_by_elem is not None and order_by_elem.name in \ + resolve_dict: + + kwargs['render_label_as_label'] = \ + element.element._order_by_label_element + + return self.process( + element.element, within_columns_clause=within_columns_clause, + **kwargs) + + def visit_textual_label_reference( + self, element, within_columns_clause=False, **kwargs): if not self.stack: # compiling the element outside of the context of a SELECT return self.process( @@ -511,19 +539,25 @@ class SQLCompiler(Compiled): ) selectable = self.stack[-1]['selectable'] + with_cols, only_froms = selectable._label_resolve_dict + try: - col = selectable._label_resolve_dict[element.text] + if within_columns_clause: + col = only_froms[element.element] + else: + col = with_cols[element.element] except KeyError: # treat it like text() util.warn_limited( "Can't resolve label reference %r; converting to text()", - util.ellipses_string(element.text)) + util.ellipses_string(element.element)) return self.process( element._text_clause ) else: kwargs['render_label_as_label'] = col - return self.process(col, **kwargs) + return self.process( + col, within_columns_clause=within_columns_clause, **kwargs) def visit_label(self, label, add_to_result_map=None, @@ -678,11 +712,7 @@ class SQLCompiler(Compiled): else: return "0" - def visit_clauselist(self, clauselist, order_by_select=None, **kw): - if order_by_select is not None: - return self._order_by_clauselist( - clauselist, order_by_select, **kw) - + def visit_clauselist(self, clauselist, **kw): sep = clauselist.operator if sep is None: sep = " " @@ -695,26 +725,6 @@ class SQLCompiler(Compiled): for c in clauselist.clauses) if s) - def _order_by_clauselist(self, clauselist, order_by_select, **kw): - # look through raw columns collection for labels. - # note that its OK we aren't expanding tables and other selectables - # here; we can only add a label in the ORDER BY for an individual - # label expression in the columns clause. - - raw_col = set(order_by_select._label_resolve_dict.keys()) - - return ", ".join( - s for s in - ( - c._compiler_dispatch( - self, - render_label_as_label=c._order_by_label_element if - c._order_by_label_element is not None and - c._order_by_label_element._label in raw_col - else None, - **kw) - for c in clauselist.clauses) - if s) def visit_case(self, clause, **kwargs): x = "CASE " @@ -1590,13 +1600,7 @@ class SQLCompiler(Compiled): text += " \nHAVING " + t if select._order_by_clause.clauses: - if self.dialect.supports_simple_order_by_label: - order_by_select = select - else: - order_by_select = None - - text += self.order_by_clause( - select, order_by_select=order_by_select, **kwargs) + text += self.order_by_clause(select, **kwargs) if (select._limit_clause is not None or select._offset_clause is not None): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index cf8de936d..8ec0aa700 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2356,14 +2356,39 @@ class Extract(ColumnElement): class _label_reference(ColumnElement): + """Wrap a column expression as it appears in a 'reference' context. + + This expression is any that inclues an _order_by_label_element, + which is a Label, or a DESC / ASC construct wrapping a Label. + + The production of _label_reference() should occur when an expression + is added to this context; this includes the ORDER BY or GROUP BY of a + SELECT statement, as well as a few other places, such as the ORDER BY + within an OVER clause. + + """ __visit_name__ = 'label_reference' - def __init__(self, text): - self.text = self.key = text + def __init__(self, element): + self.element = element + + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) + + @property + def _from_objects(self): + return () + + +class _textual_label_reference(ColumnElement): + __visit_name__ = 'textual_label_reference' + + def __init__(self, element): + self.element = element @util.memoized_property def _text_clause(self): - return TextClause._create_text(self.text) + return TextClause._create_text(self.element) class UnaryExpression(ColumnElement): @@ -3556,6 +3581,13 @@ def _clause_element_as_expr(element): def _literal_as_label_reference(element): if isinstance(element, util.string_types): + return _textual_label_reference(element) + + elif hasattr(element, '__clause_element__'): + element = element.__clause_element__() + + if isinstance(element, ColumnElement) and \ + element._order_by_label_element is not None: return _label_reference(element) else: return _literal_as_text(element) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 57b16f45f..0f2926350 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1885,9 +1885,10 @@ class CompoundSelect(GenerativeSelect): @property def _label_resolve_dict(self): - return dict( + d = dict( (c.key, c) for c in self.c ) + return d, d @classmethod def _create_union(cls, *selects, **kwargs): @@ -2499,15 +2500,16 @@ class Select(HasPrefixes, GenerativeSelect): @_memoized_property def _label_resolve_dict(self): - d = dict( + with_cols = dict( (c._resolve_label or c._label or c.key, c) for c in _select_iterables(self._raw_columns) if c._allow_label_resolve) - d.update( + only_froms = dict( (c.key, c) for c in _select_iterables(self.froms) if c._allow_label_resolve) + with_cols.update(only_froms) - return d + return with_cols, only_froms def is_derived_from(self, fromclause): if self in fromclause._cloned_set: diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d6f3b5915..fbbe15da3 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -16,7 +16,7 @@ from itertools import chain from collections import deque from .elements import BindParameter, ColumnClause, ColumnElement, \ - Null, UnaryExpression, literal_column, Label + Null, UnaryExpression, literal_column, Label, _label_reference from .selectable import ScalarSelect, Join, FromClause, FromGrouping from .schema import Column @@ -161,6 +161,8 @@ def unwrap_order_by(clause): not isinstance(t, UnaryExpression) or not operators.is_ordering_modifier(t.modifier) ): + if isinstance(t, _label_reference): + t = t.element cols.add(t) else: for c in t.get_children(): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 3f6813138..c9f0a5db0 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -1236,7 +1236,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): __dialect__ = 'default' run_setup_mappers = 'each' - def _fixture(self, label=True): + def _fixture(self, label=True, polymorphic=False): User, Address = self.classes("User", "Address") users, addresses = self.tables("users", "addresses") stmt = select([func.max(addresses.c.email_address)]).\ @@ -1247,7 +1247,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): mapper(User, users, properties={ "ead": column_property(stmt) - }) + }, with_polymorphic="*" if polymorphic else None) mapper(Address, addresses) def test_order_by_column_prop_string(self): @@ -1355,6 +1355,22 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "users AS users_1 ORDER BY email_ad, anon_1" ) + def test_order_by_column_labeled_prop_attr_aliased_four(self): + User = self.classes.User + self._fixture(label=True, polymorphic=True) + + ua = aliased(User) + s = Session() + q = s.query(ua, User.id).order_by(ua.ead) + self.assert_compile( + q, + "SELECT (SELECT max(addresses.email_address) AS max_1 FROM " + "addresses WHERE addresses.user_id = users_1.id) AS anon_1, " + "users_1.id AS users_1_id, users_1.name AS users_1_name, " + "users.id AS users_id FROM users AS users_1, users ORDER BY anon_1" + ) + + def test_order_by_column_unlabeled_prop_attr_aliased_one(self): User = self.classes.User self._fixture(label=False) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 4f8ced72c..d47b58f1f 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2169,6 +2169,27 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT x + foo() OVER () AS anon_1" ) + # test a reference to a label that in the referecned selectable; + # this resolves + expr = (table1.c.myid + 5).label('sum') + stmt = select([expr]).alias() + self.assert_compile( + select([stmt.c.sum, func.row_number().over(order_by=stmt.c.sum)]), + "SELECT anon_1.sum, row_number() OVER (ORDER BY anon_1.sum) " + "AS anon_2 FROM (SELECT mytable.myid + :myid_1 AS sum " + "FROM mytable) AS anon_1" + ) + + # test a reference to a label that's at the same level as the OVER + # in the columns clause; doesn't resolve + expr = (table1.c.myid + 5).label('sum') + self.assert_compile( + select([expr, func.row_number().over(order_by=expr)]), + "SELECT mytable.myid + :myid_1 AS sum, " + "row_number() OVER " + "(ORDER BY mytable.myid + :myid_1) AS anon_1 FROM mytable" + ) + def test_date_between(self): import datetime table = Table('dt', metadata, |