diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-01-14 23:00:59 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-01-14 23:00:59 +0000 |
| commit | 3625ac21798a2a1ff082e9bcdfde7263ca51ab49 (patch) | |
| tree | 12331656ffbda07009d0e28e465d783afdb02407 /lib/sqlalchemy | |
| parent | f67f93db3cc5bb1980f0836f4ecbb6aada8b4618 (diff) | |
| parent | 06f83c26ea3636eaec0b85fc9d733ab4bfb827ec (diff) | |
| download | sqlalchemy-3625ac21798a2a1ff082e9bcdfde7263ca51ab49.tar.gz | |
Merge "track item schema names to identify name collisions w/ default schema" into main
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 99 |
6 files changed, 156 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index e8ac892ce..8e9cf66e2 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -684,7 +684,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._setup_for_generate() SelectState.__init__(self, self.statement, compiler, **kw) - return self def _dump_option_struct(self): diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 7841ce88a..74469b035 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -500,7 +500,7 @@ class CompileState: """ - __slots__ = ("statement",) + __slots__ = ("statement", "_ambiguous_table_name_map") plugins = {} diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cb10811c6..af39f0672 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1466,6 +1466,7 @@ class SQLCompiler(Compiled): add_to_result_map=None, include_table=True, result_map_targets=(), + ambiguous_table_name_map=None, **kwargs, ): name = orig_name = column.name @@ -1502,6 +1503,14 @@ class SQLCompiler(Compiled): else: schema_prefix = "" tablename = table.name + + if ( + not effective_schema + and ambiguous_table_name_map + and tablename in ambiguous_table_name_map + ): + tablename = ambiguous_table_name_map[tablename] + if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) @@ -3252,6 +3261,10 @@ class SQLCompiler(Compiled): compile_state = select_stmt._compile_state_factory( select_stmt, self, **kwargs ) + kwargs[ + "ambiguous_table_name_map" + ] = compile_state._ambiguous_table_name_map + select_stmt = compile_state.statement toplevel = not self.stack @@ -3732,6 +3745,7 @@ class SQLCompiler(Compiled): fromhints=None, use_schema=True, from_linter=None, + ambiguous_table_name_map=None, **kwargs, ): if from_linter: @@ -3748,6 +3762,20 @@ class SQLCompiler(Compiled): ) else: ret = self.preparer.quote(table.name) + + if ( + not effective_schema + and ambiguous_table_name_map + and table.name in ambiguous_table_name_map + ): + anon_name = self._truncated_identifier( + "alias", ambiguous_table_name_map[table.name] + ) + + ret = ret + self.get_render_as_alias_suffix( + self.preparer.format_alias(None, anon_name) + ) + if fromhints and table in fromhints: ret = self.format_from_hint_text( ret, table, fromhints[table], iscrud diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 65f345fb3..43979b4ae 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -246,6 +246,7 @@ class ClauseElement( is_clause_element = True is_selectable = False + _is_table = False _is_textual = False _is_from_clause = False _is_returns_rows = False diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 00e20e3fb..e1bbcffec 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2287,6 +2287,8 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause): named_with_column = True + _is_table = True + implicit_returning = False """:class:`_expression.TableClause` doesn't support having a primary key or column @@ -3660,6 +3662,8 @@ class SelectState(util.MemoizedSlots, CompileState): return go def _get_froms(self, statement): + self._ambiguous_table_name_map = ambiguous_table_name_map = {} + return self._normalize_froms( itertools.chain( itertools.chain.from_iterable( @@ -3677,10 +3681,16 @@ class SelectState(util.MemoizedSlots, CompileState): self.from_clauses, ), check_statement=statement, + ambiguous_table_name_map=ambiguous_table_name_map, ) @classmethod - def _normalize_froms(cls, iterable_of_froms, check_statement=None): + def _normalize_froms( + cls, + iterable_of_froms, + check_statement=None, + ambiguous_table_name_map=None, + ): """given an iterable of things to select FROM, reduce them to what would actually render in the FROM clause of a SELECT. @@ -3693,6 +3703,7 @@ class SelectState(util.MemoizedSlots, CompileState): froms = [] for item in iterable_of_froms: + if item._is_subquery and item.element is check_statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" @@ -3713,6 +3724,21 @@ class SelectState(util.MemoizedSlots, CompileState): # using a list to maintain ordering froms = [f for f in froms if f not in toremove] + if ambiguous_table_name_map is not None: + ambiguous_table_name_map.update( + ( + fr.name, + _anonymous_label.safe_construct( + hash(fr.name), fr.name + ), + ) + for item in froms + for fr in item._from_objects + if fr._is_table + and fr.schema + and fr.name not in ambiguous_table_name_map + ) + return froms def _get_display_froms( diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index c1228f5df..92fd29503 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -624,6 +624,105 @@ class FetchLimitOffsetTest(fixtures.TablesTest): eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) +class SameNamedSchemaTableTest(fixtures.TablesTest): + """tests for #7471""" + + __backend__ = True + + __requires__ = ("schemas",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + schema=config.test_schema, + ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column( + "some_table_id", + Integer, + # ForeignKey("%s.some_table.id" % config.test_schema), + nullable=False, + ), + ) + + @classmethod + def insert_data(cls, connection): + some_table, some_table_schema = cls.tables( + "some_table", "%s.some_table" % config.test_schema + ) + connection.execute(some_table_schema.insert(), {"id": 1}) + connection.execute(some_table.insert(), {"id": 1, "some_table_id": 1}) + + def test_simple_join_both_tables(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table, some_table_schema).join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + ).first(), + (1, 1, 1), + ) + + def test_simple_join_whereclause_only(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1), + ) + + def test_subquery(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + subq = ( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + .subquery() + ) + + eq_( + connection.execute( + select(some_table, subq.c.id) + .join_from( + some_table, + subq, + some_table.c.some_table_id == subq.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1, 1), + ) + + class JoinTest(fixtures.TablesTest): __backend__ = True |
