diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2019-10-08 15:23:51 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2019-10-08 15:23:51 +0000 |
| commit | 89da96a99c12e9deeb583504563224c7156bbcc8 (patch) | |
| tree | ffc83416c0d2192fa0c06d9345263c408bdee857 /lib/sqlalchemy/sql | |
| parent | ed3818fef710d8ab12b261a23dd1f7f05b6bf7c5 (diff) | |
| parent | 65aee6cce57fd1cca3a95814feff3ed99a5a51ee (diff) | |
| download | sqlalchemy-89da96a99c12e9deeb583504563224c7156bbcc8.tar.gz | |
Merge "Add result map targeting for custom compiled, text objects"
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 12 |
2 files changed, 60 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 320c7b782..453ff56d2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -871,12 +871,11 @@ class SQLCompiler(Compiled): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: - add_to_result_map( - name, - orig_name, - (column, name, column.key, column._label) + result_map_targets, - column.type, - ) + targets = (column, name, column.key) + result_map_targets + if column._label: + targets += (column._label,) + + add_to_result_map(name, orig_name, targets, column.type) if is_literal: # note we are not currently accommodating for @@ -925,7 +924,7 @@ class SQLCompiler(Compiled): text = text.replace("%", "%%") return text - def visit_textclause(self, textclause, **kw): + def visit_textclause(self, textclause, add_to_result_map=None, **kw): def do_bindparam(m): name = m.group(1) if name in textclause._bindparams: @@ -936,6 +935,12 @@ class SQLCompiler(Compiled): if not self.stack: self.isplaintext = True + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + # un-escape any \:params return BIND_PARAMS_ESC.sub( lambda m: m.group(1), @@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled): return " AS " + alias_name_text def _add_to_result_map(self, keyname, name, objects, type_): + if keyname is None: + self._ordered_columns = False + self._textual_ordered_columns = True self._result_columns.append((keyname, name, objects, type_)) def _label_select_column( @@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled): column_clause_args, name=None, within_columns_clause=True, + column_is_repeated=False, need_column_expressions=False, ): """produce labeled columns present in a select().""" @@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled): need_column_expressions or populate_result_map ): col_expr = impl.column_expression(column) + else: + col_expr = column - if populate_result_map: + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map def add_to_result_map(keyname, name, objects, type_): - self._add_to_result_map( + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( keyname, name, (column,) + objects, type_ ) - else: - add_to_result_map = None else: - col_expr = column - if populate_result_map: - add_to_result_map = self._add_to_result_map - else: - add_to_result_map = None + add_to_result_map = None if not within_columns_clause: result_expr = col_expr @@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled): ) and ( not hasattr(column, "name") - or isinstance(column, functions.Function) + or isinstance(column, functions.FunctionElement) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) @@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled): asfrom, column_clause_args, name=name, + column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column in select._columns_plus_names + for name, column, repeated in select._columns_plus_names ] if c is not None ] @@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled): translate = dict( zip( - [name for (key, name) in select._columns_plus_names], [ name - for (key, name) in select_wraps_for._columns_plus_names + for (key, name, repeated) in select._columns_plus_names + ], + [ + name + for ( + key, + name, + repeated, + ) in select_wraps_for._columns_plus_names ], ) ) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index ddbcdf91d..6282cf2ee 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4191,8 +4191,9 @@ class Select( def name_for_col(c): if c._label is None or not c._render_label_in_columns_clause: - return (None, c) + return (None, c, False) + repeated = False name = c._label if name in names: @@ -4218,19 +4219,22 @@ class Select( # subsequent occurrences of the column so that the # original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c elif anon_for_dupe_key: # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c - return name, c + return name, c, repeated return [name_for_col(c) for c in cols] else: - return [(None, c) for c in cols] + # repeated name logic only for use labels at the moment + return [(None, c, False) for c in cols] @_memoized_property def _columns_plus_names(self): @@ -4245,7 +4249,7 @@ class Select( keys_seen = set() prox = [] - for name, c in self._generate_columns_plus_names(False): + for name, c, repeated in self._generate_columns_plus_names(False): if not hasattr(c, "_make_proxy"): continue if name is None: |
