diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 62 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/compiler.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 12 |
4 files changed, 111 insertions, 56 deletions
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index af5303658..733bd6f6a 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -321,40 +321,41 @@ class ResultMetaData(object): # dupe records with "None" for index which results in # ambiguous column exception when accessed. if len(by_key) != num_ctx_cols: - seen = set() + # new in 1.4: get the complete set of all possible keys, + # strings, objects, whatever, that are dupes across two + # different records, first. + index_by_key = {} + dupes = set() for metadata_entry in raw: - key = metadata_entry[MD_RENDERED_NAME] - if key in seen: - # this is an "ambiguous" element, replacing - # the full record in the map - key = key.lower() if not self.case_sensitive else key - by_key[key] = (None, (), key) - seen.add(key) - - # copy secondary elements from compiled columns - # into self._keymap, write in the potentially "ambiguous" - # element + for key in (metadata_entry[MD_RENDERED_NAME],) + ( + metadata_entry[MD_OBJECTS] or () + ): + if not self.case_sensitive and isinstance( + key, util.string_types + ): + key = key.lower() + idx = metadata_entry[MD_INDEX] + # if this key has been associated with more than one + # positional index, it's a dupe + if index_by_key.setdefault(key, idx) != idx: + dupes.add(key) + + # then put everything we have into the keymap excluding only + # those keys that are dupes. self._keymap.update( [ - (obj_elem, by_key[metadata_entry[MD_LOOKUP_KEY]]) + (obj_elem, metadata_entry) for metadata_entry in raw if metadata_entry[MD_OBJECTS] for obj_elem in metadata_entry[MD_OBJECTS] + if obj_elem not in dupes ] ) - # if we did a pure positional match, then reset the - # original "expression element" back to the "unambiguous" - # entry. This is a new behavior in 1.1 which impacts - # TextualSelect but also straight compiled SQL constructs. - if not self.matched_on_name: - self._keymap.update( - [ - (metadata_entry[MD_OBJECTS][0], metadata_entry) - for metadata_entry in raw - if metadata_entry[MD_OBJECTS] - ] - ) + # then for the dupe keys, put the "ambiguous column" + # record into by_key. + by_key.update({key: (None, (), key) for key in dupes}) + else: # no dupes - copy secondary elements from compiled # columns into self._keymap @@ -502,16 +503,16 @@ class ResultMetaData(object): ( idx, obj, - colname, - colname, + cursor_colname, + cursor_colname, context.get_result_processor( - mapped_type, colname, coltype + mapped_type, cursor_colname, coltype ), untranslated, ) for ( idx, - colname, + cursor_colname, mapped_type, coltype, obj, @@ -592,7 +593,6 @@ class ResultMetaData(object): else: mapped_type = sqltypes.NULLTYPE obj = None - yield idx, colname, mapped_type, coltype, obj, untranslated def _merge_cols_by_name( @@ -758,7 +758,7 @@ class ResultMetaData(object): if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj + "result set column descriptions" % rec[MD_LOOKUP_KEY] ) return operator.methodcaller("_get_by_key_impl", index) diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 572c62b8e..4a5a8ba9c 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -398,6 +398,7 @@ Example usage:: """ from .. import exc +from ..sql import sqltypes from ..sql import visitors @@ -475,4 +476,22 @@ class _dispatcher(object): "compilation handler." % type(element) ) - return fn(element, compiler, **kw) + # if compilation includes add_to_result_map, collect add_to_result_map + # arguments from the user-defined callable, which are probably none + # because this is not public API. if it wasn't called, then call it + # ourselves. + arm = kw.get("add_to_result_map", None) + if arm: + arm_collection = [] + kw["add_to_result_map"] = lambda *args: arm_collection.append(args) + + expr = fn(element, compiler, **kw) + + if arm: + if not arm_collection: + arm_collection.append( + (None, None, (element,), sqltypes.NULLTYPE) + ) + for tup in arm_collection: + arm(*tup) + return expr diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 320c7b782..453ff56d2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -871,12 +871,11 @@ class SQLCompiler(Compiled): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: - add_to_result_map( - name, - orig_name, - (column, name, column.key, column._label) + result_map_targets, - column.type, - ) + targets = (column, name, column.key) + result_map_targets + if column._label: + targets += (column._label,) + + add_to_result_map(name, orig_name, targets, column.type) if is_literal: # note we are not currently accommodating for @@ -925,7 +924,7 @@ class SQLCompiler(Compiled): text = text.replace("%", "%%") return text - def visit_textclause(self, textclause, **kw): + def visit_textclause(self, textclause, add_to_result_map=None, **kw): def do_bindparam(m): name = m.group(1) if name in textclause._bindparams: @@ -936,6 +935,12 @@ class SQLCompiler(Compiled): if not self.stack: self.isplaintext = True + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + # un-escape any \:params return BIND_PARAMS_ESC.sub( lambda m: m.group(1), @@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled): return " AS " + alias_name_text def _add_to_result_map(self, keyname, name, objects, type_): + if keyname is None: + self._ordered_columns = False + self._textual_ordered_columns = True self._result_columns.append((keyname, name, objects, type_)) def _label_select_column( @@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled): column_clause_args, name=None, within_columns_clause=True, + column_is_repeated=False, need_column_expressions=False, ): """produce labeled columns present in a select().""" @@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled): need_column_expressions or populate_result_map ): col_expr = impl.column_expression(column) + else: + col_expr = column - if populate_result_map: + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map def add_to_result_map(keyname, name, objects, type_): - self._add_to_result_map( + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( keyname, name, (column,) + objects, type_ ) - else: - add_to_result_map = None else: - col_expr = column - if populate_result_map: - add_to_result_map = self._add_to_result_map - else: - add_to_result_map = None + add_to_result_map = None if not within_columns_clause: result_expr = col_expr @@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled): ) and ( not hasattr(column, "name") - or isinstance(column, functions.Function) + or isinstance(column, functions.FunctionElement) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) @@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled): asfrom, column_clause_args, name=name, + column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column in select._columns_plus_names + for name, column, repeated in select._columns_plus_names ] if c is not None ] @@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled): translate = dict( zip( - [name for (key, name) in select._columns_plus_names], [ name - for (key, name) in select_wraps_for._columns_plus_names + for (key, name, repeated) in select._columns_plus_names + ], + [ + name + for ( + key, + name, + repeated, + ) in select_wraps_for._columns_plus_names ], ) ) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index ddbcdf91d..6282cf2ee 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4191,8 +4191,9 @@ class Select( def name_for_col(c): if c._label is None or not c._render_label_in_columns_clause: - return (None, c) + return (None, c, False) + repeated = False name = c._label if name in names: @@ -4218,19 +4219,22 @@ class Select( # subsequent occurrences of the column so that the # original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c elif anon_for_dupe_key: # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c - return name, c + return name, c, repeated return [name_for_col(c) for c in cols] else: - return [(None, c) for c in cols] + # repeated name logic only for use labels at the moment + return [(None, c, False) for c in cols] @_memoized_property def _columns_plus_names(self): @@ -4245,7 +4249,7 @@ class Select( keys_seen = set() prox = [] - for name, c in self._generate_columns_plus_names(False): + for name, c, repeated in self._generate_columns_plus_names(False): if not hasattr(c, "_make_proxy"): continue if name is None: |
