diff options
| -rw-r--r-- | doc/build/changelog/unreleased_14/6124.rst | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 2 | ||||
| -rw-r--r-- | test/orm/test_core_compilation.py | 48 | ||||
| -rw-r--r-- | test/orm/test_mapper.py | 2 |
5 files changed, 63 insertions, 1 deletions
diff --git a/doc/build/changelog/unreleased_14/6124.rst b/doc/build/changelog/unreleased_14/6124.rst new file mode 100644 index 000000000..ac08eaf9f --- /dev/null +++ b/doc/build/changelog/unreleased_14/6124.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 6124 + + Repaired support so that the :meth:`_sql.Select.params` method can work + correctly with a :class:`_sql.Select` object that includes joins across ORM + relationship structures, which is a new feature in 1.4. + diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 2e48695f5..610ee2726 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -85,6 +85,10 @@ class QueryableAttribute( is_attribute = True + # PropComparator has a __visit_name__ to participate within + # traversals. Disambiguate the attribute vs. a comparator. + __visit_name__ = "orm_instrumented_attribute" + def __init__( self, class_, diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index a660d7e1a..e2cc36999 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -393,6 +393,8 @@ class PropComparator(operators.ColumnOperators): __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" + __visit_name__ = "orm_prop_comparator" + def __init__( self, prop, # type: MapperProperty diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 1a58356e3..a53b15bcb 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -1,3 +1,4 @@ +from sqlalchemy import bindparam from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import insert @@ -24,6 +25,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.testing.util import resolve_lambda from .inheritance import _poly_fixtures from .test_query import QueryTest @@ -257,6 +259,52 @@ class JoinTest(QueryTest, AssertsCompiledSQL): }, ) + @testing.combinations( + ( + lambda User: select(User).where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ( + lambda User, Address: select(User) + .join_from(User, Address) + .where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ( + lambda User, Address: select(User) + .join_from(User, Address, User.addresses) + .where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ( + lambda User, Address: select(User) + .join(User.addresses) + .where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ) + def test_params_with_join( + self, test_case, expected, bindparams, expected_params + ): + User, Address = self.classes("User", "Address") + + stmt = resolve_lambda(test_case, **locals()) + + stmt = stmt.params(**bindparams) + + self.assert_compile(stmt, expected, checkparams=expected_params) + class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): """The Query object calls eanble_eagerloads(False) when you call diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index e4c89c7e8..3c8f83f91 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -86,7 +86,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert um.attrs.addresses.primaryjoin.compare( users.c.id == addresses.c.user_id ) - assert um.attrs.addresses.order_by[0].compare(Address.id) + assert um.attrs.addresses.order_by[0].compare(Address.id.expression) configure_mappers() |
