summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2019-10-08 15:23:51 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2019-10-08 15:23:51 +0000
commit89da96a99c12e9deeb583504563224c7156bbcc8 (patch)
treeffc83416c0d2192fa0c06d9345263c408bdee857 /lib/sqlalchemy/sql
parented3818fef710d8ab12b261a23dd1f7f05b6bf7c5 (diff)
parent65aee6cce57fd1cca3a95814feff3ed99a5a51ee (diff)
downloadsqlalchemy-89da96a99c12e9deeb583504563224c7156bbcc8.tar.gz
Merge "Add result map targeting for custom compiled, text objects"
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py72
-rw-r--r--lib/sqlalchemy/sql/selectable.py12
2 files changed, 60 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 320c7b782..453ff56d2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -871,12 +871,11 @@ class SQLCompiler(Compiled):
name = self._truncated_identifier("colident", name)
if add_to_result_map is not None:
- add_to_result_map(
- name,
- orig_name,
- (column, name, column.key, column._label) + result_map_targets,
- column.type,
- )
+ targets = (column, name, column.key) + result_map_targets
+ if column._label:
+ targets += (column._label,)
+
+ add_to_result_map(name, orig_name, targets, column.type)
if is_literal:
# note we are not currently accommodating for
@@ -925,7 +924,7 @@ class SQLCompiler(Compiled):
text = text.replace("%", "%%")
return text
- def visit_textclause(self, textclause, **kw):
+ def visit_textclause(self, textclause, add_to_result_map=None, **kw):
def do_bindparam(m):
name = m.group(1)
if name in textclause._bindparams:
@@ -936,6 +935,12 @@ class SQLCompiler(Compiled):
if not self.stack:
self.isplaintext = True
+ if add_to_result_map:
+ # text() object is present in the columns clause of a
+ # select(). Add a no-name entry to the result map so that
+ # row[text()] produces a result
+ add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
+
# un-escape any \:params
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
@@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled):
return " AS " + alias_name_text
def _add_to_result_map(self, keyname, name, objects, type_):
+ if keyname is None:
+ self._ordered_columns = False
+ self._textual_ordered_columns = True
self._result_columns.append((keyname, name, objects, type_))
def _label_select_column(
@@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled):
column_clause_args,
name=None,
within_columns_clause=True,
+ column_is_repeated=False,
need_column_expressions=False,
):
"""produce labeled columns present in a select()."""
@@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled):
need_column_expressions or populate_result_map
):
col_expr = impl.column_expression(column)
+ else:
+ col_expr = column
- if populate_result_map:
+ if populate_result_map:
+ # pass an "add_to_result_map" callable into the compilation
+ # of embedded columns. this collects information about the
+ # column as it will be fetched in the result and is coordinated
+ # with cursor.description when the query is executed.
+ add_to_result_map = self._add_to_result_map
+
+ # if the SELECT statement told us this column is a repeat,
+ # wrap the callable with one that prevents the addition of the
+ # targets
+ if column_is_repeated:
+ _add_to_result_map = add_to_result_map
def add_to_result_map(keyname, name, objects, type_):
- self._add_to_result_map(
+ _add_to_result_map(keyname, name, (), type_)
+
+ # if we redefined col_expr for type expressions, wrap the
+ # callable with one that adds the original column to the targets
+ elif col_expr is not column:
+ _add_to_result_map = add_to_result_map
+
+ def add_to_result_map(keyname, name, objects, type_):
+ _add_to_result_map(
keyname, name, (column,) + objects, type_
)
- else:
- add_to_result_map = None
else:
- col_expr = column
- if populate_result_map:
- add_to_result_map = self._add_to_result_map
- else:
- add_to_result_map = None
+ add_to_result_map = None
if not within_columns_clause:
result_expr = col_expr
@@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled):
)
and (
not hasattr(column, "name")
- or isinstance(column, functions.Function)
+ or isinstance(column, functions.FunctionElement)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
@@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled):
asfrom,
column_clause_args,
name=name,
+ column_is_repeated=repeated,
need_column_expressions=need_column_expressions,
)
- for name, column in select._columns_plus_names
+ for name, column, repeated in select._columns_plus_names
]
if c is not None
]
@@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled):
translate = dict(
zip(
- [name for (key, name) in select._columns_plus_names],
[
name
- for (key, name) in select_wraps_for._columns_plus_names
+ for (key, name, repeated) in select._columns_plus_names
+ ],
+ [
+ name
+ for (
+ key,
+ name,
+ repeated,
+ ) in select_wraps_for._columns_plus_names
],
)
)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index ddbcdf91d..6282cf2ee 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -4191,8 +4191,9 @@ class Select(
def name_for_col(c):
if c._label is None or not c._render_label_in_columns_clause:
- return (None, c)
+ return (None, c, False)
+ repeated = False
name = c._label
if name in names:
@@ -4218,19 +4219,22 @@ class Select(
# subsequent occurrences of the column so that the
# original stays non-ambiguous
name = c._dedupe_label_anon_label
+ repeated = True
else:
names[name] = c
elif anon_for_dupe_key:
# same column under the same name. apply the "dedupe"
# label so that the original stays non-ambiguous
name = c._dedupe_label_anon_label
+ repeated = True
else:
names[name] = c
- return name, c
+ return name, c, repeated
return [name_for_col(c) for c in cols]
else:
- return [(None, c) for c in cols]
+ # repeated name logic only for use labels at the moment
+ return [(None, c, False) for c in cols]
@_memoized_property
def _columns_plus_names(self):
@@ -4245,7 +4249,7 @@ class Select(
keys_seen = set()
prox = []
- for name, c in self._generate_columns_plus_names(False):
+ for name, c, repeated in self._generate_columns_plus_names(False):
if not hasattr(c, "_make_proxy"):
continue
if name is None: