summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py3
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py6
-rw-r--r--lib/sqlalchemy/engine/default.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py78
-rw-r--r--test/dialect/mssql/test_compiler.py18
-rw-r--r--test/dialect/test_oracle.py14
6 files changed, 79 insertions, 42 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 92d7e4ab3..a35ab80d3 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -1031,6 +1031,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
_order_by_clauses = select._order_by_clause.clauses
limit_clause = select._limit_clause
offset_clause = select._offset_clause
+ kwargs['_select_wraps'] = select
select = select._generate()
select._mssql_visit = True
select = select.column(
@@ -1048,7 +1049,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
else:
limitselect.append_whereclause(
mssql_rn <= (limit_clause))
- return self.process(limitselect, iswrapper=True, **kwargs)
+ return self.process(limitselect, **kwargs)
else:
return compiler.SQLCompiler.visit_select(self, select, **kwargs)
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index a5e071148..9ec84d268 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -665,8 +665,8 @@ class OracleCompiler(compiler.SQLCompiler):
else:
return sql.and_(*clauses)
- def visit_outer_join_column(self, vc):
- return self.process(vc.column) + "(+)"
+ def visit_outer_join_column(self, vc, **kw):
+ return self.process(vc.column, **kw) + "(+)"
def visit_sequence(self, seq):
return (self.dialect.identifier_preparer.format_sequence(seq) +
@@ -738,6 +738,7 @@ class OracleCompiler(compiler.SQLCompiler):
# limit=0
# TODO: use annotations instead of clone + attr set ?
+ kwargs['_select_wraps'] = select
select = select._generate()
select._oracle_visit = True
@@ -794,7 +795,6 @@ class OracleCompiler(compiler.SQLCompiler):
offsetselect._for_update_arg = select._for_update_arg
select = offsetselect
- kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
return compiler.SQLCompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select, **kw):
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 9d2bbfb15..62469d720 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -663,7 +663,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.cursor = self.create_cursor()
return self
- @property
+ @util.memoized_property
def result_map(self):
if self._result_columns:
return self.compiled.result_map
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 61b6d22d0..e37fa646c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -683,20 +683,18 @@ class SQLCompiler(Compiled):
self.post_process_text(textclause.text))
)
- def visit_text_as_from(self, taf, iswrapper=False,
- compound_index=0, force_result_map=False,
+ def visit_text_as_from(self, taf,
+ compound_index=None, force_result_map=False,
asfrom=False,
parens=True, **kw):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = force_result_map or (
- compound_index == 0 and (
- toplevel or
- entry['iswrapper']
- )
- )
+ populate_result_map = force_result_map or \
+ toplevel or \
+ (compound_index == 0 and entry.get(
+ 'need_result_map_for_compound', False))
if populate_result_map:
self._ordered_columns = False
@@ -812,13 +810,16 @@ class SQLCompiler(Compiled):
parens=True, compound_index=0, **kwargs):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
+ need_result_map = toplevel or \
+ (compound_index == 0
+ and entry.get('need_result_map_for_compound', False))
self.stack.append(
{
'correlate_froms': entry['correlate_froms'],
- 'iswrapper': toplevel,
'asfrom_froms': entry['asfrom_froms'],
- 'selectable': cs
+ 'selectable': cs,
+ 'need_result_map_for_compound': need_result_map
})
keyword = self.compound_keywords.get(cs.keyword)
@@ -840,8 +841,7 @@ class SQLCompiler(Compiled):
or cs._offset_clause is not None) and \
self.limit_clause(cs, **kwargs) or ""
- if self.ctes and \
- compound_index == 0 and toplevel:
+ if self.ctes and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -1460,7 +1460,6 @@ class SQLCompiler(Compiled):
]
_default_stack_entry = util.immutabledict([
- ('iswrapper', False),
('correlate_froms', frozenset()),
('asfrom_froms', frozenset())
])
@@ -1488,10 +1487,11 @@ class SQLCompiler(Compiled):
return froms
def visit_select(self, select, asfrom=False, parens=True,
- iswrapper=False, fromhints=None,
+ fromhints=None,
compound_index=0,
force_result_map=False,
nested_join_translation=False,
+ _select_wraps=None,
**kwargs):
needs_nested_translation = \
@@ -1505,7 +1505,7 @@ class SQLCompiler(Compiled):
select)
text = self.visit_select(
transformed_select, asfrom=asfrom, parens=parens,
- iswrapper=iswrapper, fromhints=fromhints,
+ fromhints=fromhints,
compound_index=compound_index,
force_result_map=force_result_map,
nested_join_translation=True, **kwargs
@@ -1514,12 +1514,11 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = force_result_map or (
- compound_index == 0 and (
- toplevel or
- entry['iswrapper']
+ populate_result_map = force_result_map or \
+ toplevel or (
+ compound_index == 0 and entry.get(
+ 'need_result_map_for_compound', False)
)
- )
if needs_nested_translation:
if populate_result_map:
@@ -1527,7 +1526,7 @@ class SQLCompiler(Compiled):
select, transformed_select)
return text
- froms = self._setup_select_stack(select, entry, asfrom, iswrapper)
+ froms = self._setup_select_stack(select, entry, asfrom)
column_clause_args = kwargs.copy()
column_clause_args.update({
@@ -1553,16 +1552,34 @@ class SQLCompiler(Compiled):
# the actual list of columns to print in the SELECT column list.
inner_columns = [
c for c in [
- self._label_select_column(select,
- column,
- populate_result_map, asfrom,
- column_clause_args,
- name=name)
+ self._label_select_column(
+ select,
+ column,
+ populate_result_map, asfrom,
+ column_clause_args,
+ name=name)
for name, column in select._columns_plus_names
]
if c is not None
]
+ if populate_result_map and _select_wraps is not None:
+ # if this select is a compiler-generated wrapper,
+ # rewrite the targeted columns in the result map
+ wrapped_inner_columns = set(_select_wraps.inner_columns)
+ translate = dict(
+ (outer, inner.pop()) for outer, inner in [
+ (
+ outer,
+ outer.proxy_set.intersection(wrapped_inner_columns))
+ for outer in select.inner_columns
+ ] if inner
+ )
+ self._result_columns = [
+ (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ for key, name, obj, type_ in self._result_columns
+ ]
+
text = self._compose_select_body(
text, select, inner_columns, froms, byfrom, kwargs)
@@ -1575,8 +1592,7 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes and \
- compound_index == 0 and toplevel:
+ if self.ctes and toplevel:
text = self._render_cte_clause() + text
if select._suffixes:
@@ -1603,7 +1619,7 @@ class SQLCompiler(Compiled):
hint_text = self.get_select_hint_text(byfrom)
return hint_text, byfrom
- def _setup_select_stack(self, select, entry, asfrom, iswrapper):
+ def _setup_select_stack(self, select, entry, asfrom):
correlate_froms = entry['correlate_froms']
asfrom_froms = entry['asfrom_froms']
@@ -1622,7 +1638,6 @@ class SQLCompiler(Compiled):
new_entry = {
'asfrom_froms': new_correlate_froms,
- 'iswrapper': iswrapper,
'correlate_froms': all_correlate_froms,
'selectable': select,
}
@@ -1765,7 +1780,6 @@ class SQLCompiler(Compiled):
def visit_insert(self, insert_stmt, **kw):
self.stack.append(
{'correlate_froms': set(),
- "iswrapper": False,
"asfrom_froms": set(),
"selectable": insert_stmt})
@@ -1889,7 +1903,6 @@ class SQLCompiler(Compiled):
def visit_update(self, update_stmt, **kw):
self.stack.append(
{'correlate_froms': set([update_stmt.table]),
- "iswrapper": False,
"asfrom_froms": set([update_stmt.table]),
"selectable": update_stmt})
@@ -1975,7 +1988,6 @@ class SQLCompiler(Compiled):
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'correlate_froms': set([delete_stmt.table]),
- "iswrapper": False,
"asfrom_froms": set([delete_stmt.table]),
"selectable": delete_stmt})
self.isdelete = True
diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py
index 3de8ea5c9..54a23ee6e 100644
--- a/test/dialect/mssql/test_compiler.py
+++ b/test/dialect/mssql/test_compiler.py
@@ -416,11 +416,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT TOP 0 t.x, t.y FROM t WHERE t.x = :x_1 ORDER BY t.y",
checkparams={'x_1': 5}
)
+ c = s.compile(dialect=mssql.MSDialect())
+ eq_(len(c._result_columns), 2)
+ assert t.c.x in set(c.result_map['x'][1])
def test_offset_using_window(self):
t = table('t', column('x', Integer), column('y', Integer))
- s = select([t]).where(t.c.x==5).order_by(t.c.y).offset(20)
+ s = select([t]).where(t.c.x == 5).order_by(t.c.y).offset(20)
# test that the select is not altered with subsequent compile
# calls
@@ -434,6 +437,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
checkparams={'param_1': 20, 'x_1': 5}
)
+ c = s.compile(dialect=mssql.MSDialect())
+ eq_(len(c._result_columns), 2)
+ assert t.c.x in set(c.result_map['x'][1])
+
def test_limit_offset_using_window(self):
t = table('t', column('x', Integer), column('y', Integer))
@@ -449,6 +456,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1",
checkparams={'param_1': 20, 'param_2': 10, 'x_1': 5}
)
+ c = s.compile(dialect=mssql.MSDialect())
+ eq_(len(c._result_columns), 2)
+ assert t.c.x in set(c.result_map['x'][1])
+ assert t.c.y in set(c.result_map['y'][1])
def test_limit_offset_with_correlated_order_by(self):
t1 = table('t1', column('x', Integer), column('y', Integer))
@@ -471,6 +482,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
checkparams={'param_1': 20, 'param_2': 10, 'x_1': 5}
)
+ c = s.compile(dialect=mssql.MSDialect())
+ eq_(len(c._result_columns), 2)
+ assert t1.c.x in set(c.result_map['x'][1])
+ assert t1.c.y in set(c.result_map['y'][1])
+
def test_limit_zero_offset_using_window(self):
t = table('t', column('x', Integer), column('y', Integer))
diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py
index 3c67f1590..58ea058c2 100644
--- a/test/dialect/test_oracle.py
+++ b/test/dialect/test_oracle.py
@@ -240,9 +240,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
checkparams={'param_1': 10, 'param_2': 20})
c = s.compile(dialect=oracle.OracleDialect())
+ eq_(len(c._result_columns), 2)
assert t.c.col1 in set(c.result_map['col1'][1])
- s = select([s.c.col1, s.c.col2])
- self.assert_compile(s,
+
+ s2 = select([s.c.col1, s.c.col2])
+ self.assert_compile(s2,
'SELECT col1, col2 FROM (SELECT col1, col2 '
'FROM (SELECT col1, col2, ROWNUM AS ora_rn '
'FROM (SELECT sometable.col1 AS col1, '
@@ -251,13 +253,16 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
':param_2)',
checkparams={'param_1': 10, 'param_2': 20})
- self.assert_compile(s,
+ self.assert_compile(s2,
'SELECT col1, col2 FROM (SELECT col1, col2 '
'FROM (SELECT col1, col2, ROWNUM AS ora_rn '
'FROM (SELECT sometable.col1 AS col1, '
'sometable.col2 AS col2 FROM sometable) '
'WHERE ROWNUM <= :param_1 + :param_2) WHERE ora_rn > '
':param_2)')
+ c = s2.compile(dialect=oracle.OracleDialect())
+ eq_(len(c._result_columns), 2)
+ assert s.c.col1 in set(c.result_map['col1'][1])
s = select([t]).limit(10).offset(20).order_by(t.c.col2)
self.assert_compile(s,
@@ -269,6 +274,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
':param_1 + :param_2) WHERE ora_rn > :param_2',
checkparams={'param_1': 10, 'param_2': 20}
)
+ c = s.compile(dialect=oracle.OracleDialect())
+ eq_(len(c._result_columns), 2)
+ assert t.c.col1 in set(c.result_map['col1'][1])
s = select([t], for_update=True).limit(10).order_by(t.c.col2)
self.assert_compile(s,