summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-09-08 16:31:11 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-09-08 16:31:11 -0400
commit7904ebc62e0a75d1ea31e1a4ae67654c7681a737 (patch)
treea0e162ea74d3bb25390643b7db84bb288ca4e841
parente4996d4f5432657639798c1b286ee811a36e2a10 (diff)
downloadsqlalchemy-7904ebc62e0a75d1ea31e1a4ae67654c7681a737.tar.gz
- rework the previous "order by" system in terms of the new one,
unify everything. - create a new layer of separation between the "from order bys" and "column order bys", so that an OVER doesn't ORDER BY a label in the same columns clause - identify another issue with polymorphic for ref #3148, match on label keys rather than the objects
-rw-r--r--lib/sqlalchemy/sql/compiler.py76
-rw-r--r--lib/sqlalchemy/sql/elements.py38
-rw-r--r--lib/sqlalchemy/sql/selectable.py10
-rw-r--r--lib/sqlalchemy/sql/util.py4
-rw-r--r--test/orm/test_query.py20
-rw-r--r--test/sql/test_compiler.py21
6 files changed, 123 insertions, 46 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 72dd11eaf..5149fa4fe 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -503,7 +503,35 @@ class SQLCompiler(Compiled):
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
- def visit_label_reference(self, element, **kwargs):
+ def visit_label_reference(
+ self, element, within_columns_clause=False, **kwargs):
+ if self.stack and self.dialect.supports_simple_order_by_label:
+ selectable = self.stack[-1]['selectable']
+
+ with_cols, only_froms = selectable._label_resolve_dict
+ if within_columns_clause:
+ resolve_dict = only_froms
+ else:
+ resolve_dict = with_cols
+
+ # this can be None in the case that a _label_reference()
+ # were subject to a replacement operation, in which case
+ # the replacement of the Label element may have changed
+ # to something else like a ColumnClause expression.
+ order_by_elem = element.element._order_by_label_element
+
+ if order_by_elem is not None and order_by_elem.name in \
+ resolve_dict:
+
+ kwargs['render_label_as_label'] = \
+ element.element._order_by_label_element
+
+ return self.process(
+ element.element, within_columns_clause=within_columns_clause,
+ **kwargs)
+
+ def visit_textual_label_reference(
+ self, element, within_columns_clause=False, **kwargs):
if not self.stack:
# compiling the element outside of the context of a SELECT
return self.process(
@@ -511,19 +539,25 @@ class SQLCompiler(Compiled):
)
selectable = self.stack[-1]['selectable']
+ with_cols, only_froms = selectable._label_resolve_dict
+
try:
- col = selectable._label_resolve_dict[element.text]
+ if within_columns_clause:
+ col = only_froms[element.element]
+ else:
+ col = with_cols[element.element]
except KeyError:
# treat it like text()
util.warn_limited(
"Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.text))
+ util.ellipses_string(element.element))
return self.process(
element._text_clause
)
else:
kwargs['render_label_as_label'] = col
- return self.process(col, **kwargs)
+ return self.process(
+ col, within_columns_clause=within_columns_clause, **kwargs)
def visit_label(self, label,
add_to_result_map=None,
@@ -678,11 +712,7 @@ class SQLCompiler(Compiled):
else:
return "0"
- def visit_clauselist(self, clauselist, order_by_select=None, **kw):
- if order_by_select is not None:
- return self._order_by_clauselist(
- clauselist, order_by_select, **kw)
-
+ def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
sep = " "
@@ -695,26 +725,6 @@ class SQLCompiler(Compiled):
for c in clauselist.clauses)
if s)
- def _order_by_clauselist(self, clauselist, order_by_select, **kw):
- # look through raw columns collection for labels.
- # note that its OK we aren't expanding tables and other selectables
- # here; we can only add a label in the ORDER BY for an individual
- # label expression in the columns clause.
-
- raw_col = set(order_by_select._label_resolve_dict.keys())
-
- return ", ".join(
- s for s in
- (
- c._compiler_dispatch(
- self,
- render_label_as_label=c._order_by_label_element if
- c._order_by_label_element is not None and
- c._order_by_label_element._label in raw_col
- else None,
- **kw)
- for c in clauselist.clauses)
- if s)
def visit_case(self, clause, **kwargs):
x = "CASE "
@@ -1590,13 +1600,7 @@ class SQLCompiler(Compiled):
text += " \nHAVING " + t
if select._order_by_clause.clauses:
- if self.dialect.supports_simple_order_by_label:
- order_by_select = select
- else:
- order_by_select = None
-
- text += self.order_by_clause(
- select, order_by_select=order_by_select, **kwargs)
+ text += self.order_by_clause(select, **kwargs)
if (select._limit_clause is not None or
select._offset_clause is not None):
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index cf8de936d..8ec0aa700 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -2356,14 +2356,39 @@ class Extract(ColumnElement):
class _label_reference(ColumnElement):
+ """Wrap a column expression as it appears in a 'reference' context.
+
+ This expression is any that inclues an _order_by_label_element,
+ which is a Label, or a DESC / ASC construct wrapping a Label.
+
+ The production of _label_reference() should occur when an expression
+ is added to this context; this includes the ORDER BY or GROUP BY of a
+ SELECT statement, as well as a few other places, such as the ORDER BY
+ within an OVER clause.
+
+ """
__visit_name__ = 'label_reference'
- def __init__(self, text):
- self.text = self.key = text
+ def __init__(self, element):
+ self.element = element
+
+ def _copy_internals(self, clone=_clone, **kw):
+ self.element = clone(self.element, **kw)
+
+ @property
+ def _from_objects(self):
+ return ()
+
+
+class _textual_label_reference(ColumnElement):
+ __visit_name__ = 'textual_label_reference'
+
+ def __init__(self, element):
+ self.element = element
@util.memoized_property
def _text_clause(self):
- return TextClause._create_text(self.text)
+ return TextClause._create_text(self.element)
class UnaryExpression(ColumnElement):
@@ -3556,6 +3581,13 @@ def _clause_element_as_expr(element):
def _literal_as_label_reference(element):
if isinstance(element, util.string_types):
+ return _textual_label_reference(element)
+
+ elif hasattr(element, '__clause_element__'):
+ element = element.__clause_element__()
+
+ if isinstance(element, ColumnElement) and \
+ element._order_by_label_element is not None:
return _label_reference(element)
else:
return _literal_as_text(element)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 57b16f45f..0f2926350 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -1885,9 +1885,10 @@ class CompoundSelect(GenerativeSelect):
@property
def _label_resolve_dict(self):
- return dict(
+ d = dict(
(c.key, c) for c in self.c
)
+ return d, d
@classmethod
def _create_union(cls, *selects, **kwargs):
@@ -2499,15 +2500,16 @@ class Select(HasPrefixes, GenerativeSelect):
@_memoized_property
def _label_resolve_dict(self):
- d = dict(
+ with_cols = dict(
(c._resolve_label or c._label or c.key, c)
for c in _select_iterables(self._raw_columns)
if c._allow_label_resolve)
- d.update(
+ only_froms = dict(
(c.key, c) for c in
_select_iterables(self.froms) if c._allow_label_resolve)
+ with_cols.update(only_froms)
- return d
+ return with_cols, only_froms
def is_derived_from(self, fromclause):
if self in fromclause._cloned_set:
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d6f3b5915..fbbe15da3 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -16,7 +16,7 @@ from itertools import chain
from collections import deque
from .elements import BindParameter, ColumnClause, ColumnElement, \
- Null, UnaryExpression, literal_column, Label
+ Null, UnaryExpression, literal_column, Label, _label_reference
from .selectable import ScalarSelect, Join, FromClause, FromGrouping
from .schema import Column
@@ -161,6 +161,8 @@ def unwrap_order_by(clause):
not isinstance(t, UnaryExpression) or
not operators.is_ordering_modifier(t.modifier)
):
+ if isinstance(t, _label_reference):
+ t = t.element
cols.add(t)
else:
for c in t.get_children():
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index 3f6813138..c9f0a5db0 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -1236,7 +1236,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
__dialect__ = 'default'
run_setup_mappers = 'each'
- def _fixture(self, label=True):
+ def _fixture(self, label=True, polymorphic=False):
User, Address = self.classes("User", "Address")
users, addresses = self.tables("users", "addresses")
stmt = select([func.max(addresses.c.email_address)]).\
@@ -1247,7 +1247,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
mapper(User, users, properties={
"ead": column_property(stmt)
- })
+ }, with_polymorphic="*" if polymorphic else None)
mapper(Address, addresses)
def test_order_by_column_prop_string(self):
@@ -1355,6 +1355,22 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
"users AS users_1 ORDER BY email_ad, anon_1"
)
+ def test_order_by_column_labeled_prop_attr_aliased_four(self):
+ User = self.classes.User
+ self._fixture(label=True, polymorphic=True)
+
+ ua = aliased(User)
+ s = Session()
+ q = s.query(ua, User.id).order_by(ua.ead)
+ self.assert_compile(
+ q,
+ "SELECT (SELECT max(addresses.email_address) AS max_1 FROM "
+ "addresses WHERE addresses.user_id = users_1.id) AS anon_1, "
+ "users_1.id AS users_1_id, users_1.name AS users_1_name, "
+ "users.id AS users_id FROM users AS users_1, users ORDER BY anon_1"
+ )
+
+
def test_order_by_column_unlabeled_prop_attr_aliased_one(self):
User = self.classes.User
self._fixture(label=False)
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 4f8ced72c..d47b58f1f 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -2169,6 +2169,27 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT x + foo() OVER () AS anon_1"
)
+ # test a reference to a label that in the referecned selectable;
+ # this resolves
+ expr = (table1.c.myid + 5).label('sum')
+ stmt = select([expr]).alias()
+ self.assert_compile(
+ select([stmt.c.sum, func.row_number().over(order_by=stmt.c.sum)]),
+ "SELECT anon_1.sum, row_number() OVER (ORDER BY anon_1.sum) "
+ "AS anon_2 FROM (SELECT mytable.myid + :myid_1 AS sum "
+ "FROM mytable) AS anon_1"
+ )
+
+ # test a reference to a label that's at the same level as the OVER
+ # in the columns clause; doesn't resolve
+ expr = (table1.c.myid + 5).label('sum')
+ self.assert_compile(
+ select([expr, func.row_number().over(order_by=expr)]),
+ "SELECT mytable.myid + :myid_1 AS sum, "
+ "row_number() OVER "
+ "(ORDER BY mytable.myid + :myid_1) AS anon_1 FROM mytable"
+ )
+
def test_date_between(self):
import datetime
table = Table('dt', metadata,