summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-07-03 16:25:15 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-07-03 22:33:48 -0400
commit148711cb8515a19b6177dc07655cc6e652de0553 (patch)
treeb75505c907d25395d77f45b94919b9a17e9432cf
parent4b3f204d07d53ae09b59ce8f33b534f26a605cd4 (diff)
downloadsqlalchemy-148711cb8515a19b6177dc07655cc6e652de0553.tar.gz
runtime annotation fixes for relationship
* derive uselist=False when fwd ref passed to relationship This case needs to work whether or not the class name is a forward ref. we dont allow the colleciton to be a forward ref so this will work. * fix issues with MappedCollection When using string annotations or __future__.annotations, we need to do more parsing in order to get the target collection properly Change-Id: I9e5a1358b62d060a8815826f98190801a9cc0b68
-rw-r--r--lib/sqlalchemy/orm/__init__.py4
-rw-r--r--lib/sqlalchemy/orm/clsregistry.py3
-rw-r--r--lib/sqlalchemy/orm/relationships.py9
-rw-r--r--lib/sqlalchemy/orm/util.py30
-rw-r--r--lib/sqlalchemy/util/typing.py6
-rw-r--r--test/orm/declarative/test_dc_transforms.py32
-rw-r--r--test/orm/declarative/test_tm_future_annotations.py154
-rw-r--r--test/orm/declarative/test_typed_mapping.py36
-rw-r--r--test/orm/test_deferred.py44
9 files changed, 308 insertions, 10 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 4f19ba946..539cf2600 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -87,6 +87,10 @@ from .interfaces import PropComparator as PropComparator
from .interfaces import UserDefinedOption as UserDefinedOption
from .loading import merge_frozen_result as merge_frozen_result
from .loading import merge_result as merge_result
+from .mapped_collection import attribute_mapped_collection
+from .mapped_collection import column_mapped_collection
+from .mapped_collection import mapped_collection
+from .mapped_collection import MappedCollection
from .mapper import configure_mappers as configure_mappers
from .mapper import Mapper as Mapper
from .mapper import reconstructor as reconstructor
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py
index b3fcd29ea..dd79eb1d0 100644
--- a/lib/sqlalchemy/orm/clsregistry.py
+++ b/lib/sqlalchemy/orm/clsregistry.py
@@ -463,6 +463,7 @@ class _class_resolver:
generic_match = re.match(r"(.+)\[(.+)\]", name)
if generic_match:
+ clsarg = generic_match.group(2).strip("'")
raise exc.InvalidRequestError(
f"When initializing mapper {self.prop.parent}, "
f'expression "relationship({self.arg!r})" seems to be '
@@ -470,7 +471,7 @@ class _class_resolver:
"please state the generic argument "
"using an annotation, e.g. "
f'"{self.prop.key}: Mapped[{generic_match.group(1)}'
- f'[{generic_match.group(2)}]] = relationship()"'
+ f"['{clsarg}']] = relationship()\""
) from err
else:
raise exc.InvalidRequestError(
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 630f6898f..77a95a195 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -1724,11 +1724,12 @@ class Relationship(
self.collection_class = collection_class
else:
self.uselist = False
+
if argument.__args__: # type: ignore
if issubclass(
argument.__origin__, typing.Mapping # type: ignore
):
- type_arg = argument.__args__[1] # type: ignore
+ type_arg = argument.__args__[-1] # type: ignore
else:
type_arg = argument.__args__[0] # type: ignore
if hasattr(type_arg, "__forward_arg__"):
@@ -1743,6 +1744,12 @@ class Relationship(
elif hasattr(argument, "__forward_arg__"):
argument = argument.__forward_arg__ # type: ignore
+ # we don't allow the collection class to be a
+ # __forward_arg__ right now, so if we see a forward arg here,
+ # we know there was no collection class either
+ if self.collection_class is None:
+ self.uselist = False
+
self.argument = argument
@util.preload_module("sqlalchemy.orm.mapper")
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 317abe2b4..02080a27f 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1958,8 +1958,12 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
def _is_mapped_annotation(
raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
- annotated = de_stringify_annotation(cls, raw_annotation)
- return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
+ try:
+ annotated = de_stringify_annotation(cls, raw_annotation)
+ except NameError:
+ return False
+ else:
+ return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
def _cleanup_mapped_str_annotation(annotation: str) -> str:
@@ -1984,7 +1988,10 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
# stack: ['Mapped', 'List', 'Address']
if not re.match(r"""^["'].*["']$""", stack[-1]):
- stack[-1] = f'"{stack[-1]}"'
+ stripchars = "\"' "
+ stack[-1] = ", ".join(
+ f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",")
+ )
# stack: ['Mapped', 'List', '"Address"']
annotation = "[".join(stack) + ("]" * (len(stack) - 1))
@@ -2007,6 +2014,7 @@ def _extract_mapped_subtype(
Includes error raise scenarios and other options.
"""
+
if raw_annotation is None:
if required:
@@ -2017,9 +2025,19 @@ def _extract_mapped_subtype(
)
return None
- annotated = de_stringify_annotation(
- cls, raw_annotation, _cleanup_mapped_str_annotation
- )
+ try:
+ annotated = de_stringify_annotation(
+ cls, raw_annotation, _cleanup_mapped_str_annotation
+ )
+ except NameError as ne:
+ if raiseerr and "Mapped[" in raw_annotation: # type: ignore
+ raise sa_exc.ArgumentError(
+ f"Could not interpret annotation {raw_annotation}. "
+ "Check that it's not using names that might not be imported "
+ "at the module level. See chained stack trace for more hints."
+ ) from ne
+
+ annotated = raw_annotation # type: ignore
if is_dataclass_field:
return annotated
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 653301f1f..45fe63765 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -113,8 +113,10 @@ def de_stringify_annotation(
try:
annotation = eval(annotation, base_globals, None)
- except NameError:
- pass
+ except NameError as err:
+ raise NameError(
+ f"Could not de-stringify annotation {annotation}"
+ ) from err
return annotation # type: ignore
diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py
index 44976b5d8..f5111bfc7 100644
--- a/test/orm/declarative/test_dc_transforms.py
+++ b/test/orm/declarative/test_dc_transforms.py
@@ -38,6 +38,7 @@ from sqlalchemy.testing import eq_regex
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import ne_
@@ -547,6 +548,37 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase):
):
A()
+ def test_one_to_one_example(self, dc_decl_base: Type[MappedAsDataclass]):
+ """test example in the relationship docs will derive uselist=False
+ correctly"""
+
+ class Parent(dc_decl_base):
+ __tablename__ = "parent"
+
+ id: Mapped[int] = mapped_column(init=False, primary_key=True)
+ child: Mapped["Child"] = relationship( # noqa: F821
+ back_populates="parent", default=None
+ )
+
+ class Child(dc_decl_base):
+ __tablename__ = "child"
+
+ id: Mapped[int] = mapped_column(init=False, primary_key=True)
+ parent_id: Mapped[int] = mapped_column(
+ ForeignKey("parent.id"), init=False
+ )
+ parent: Mapped["Parent"] = relationship(
+ back_populates="child", default=None
+ )
+
+ c1 = Child()
+ p1 = Parent(child=c1)
+ is_(p1.child, c1)
+ is_(c1.parent, p1)
+
+ p2 = Parent()
+ is_(p2.child, None)
+
def test_replace_operation_works_w_history_etc(
self, registry: _RegistryType
):
diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py
index f8abd686a..74cbebb4d 100644
--- a/test/orm/declarative/test_tm_future_annotations.py
+++ b/test/orm/declarative/test_tm_future_annotations.py
@@ -1,13 +1,21 @@
from __future__ import annotations
from typing import List
+from typing import Set
+from typing import TypeVar
+from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
+from sqlalchemy.orm import attribute_mapped_collection
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedCollection
from sqlalchemy.orm import relationship
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
from .test_typed_mapping import MappedColumnTest # noqa
from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
@@ -17,6 +25,13 @@ having ``from __future__ import annotations`` in effect.
"""
+_R = TypeVar("_R")
+
+
+class MappedOneArg(MappedCollection[str, _R]):
+ pass
+
+
class RelationshipLHSTest(_RelationshipLHSTest):
def test_bidirectional_literal_annotations(self, decl_base):
"""test the 'string cleanup' function in orm/util.py, where
@@ -54,3 +69,142 @@ class RelationshipLHSTest(_RelationshipLHSTest):
b1 = B()
a1.bs.append(b1)
is_(a1, b1.a)
+
+ def test_collection_class_uselist_implicit_fwd(self, decl_base):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column()
+ bs_list: Mapped[List[B]] = relationship( # noqa: F821
+ viewonly=True
+ )
+ bs_set: Mapped[Set[B]] = relationship(viewonly=True) # noqa: F821
+ bs_list_warg: Mapped[List[B]] = relationship( # noqa: F821
+ "B", viewonly=True
+ )
+ bs_set_warg: Mapped[Set[B]] = relationship( # noqa: F821
+ "B", viewonly=True
+ )
+
+ b_one_to_one: Mapped[B] = relationship(viewonly=True) # noqa: F821
+
+ b_one_to_one_warg: Mapped[B] = relationship( # noqa: F821
+ "B", viewonly=True
+ )
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+ a: Mapped[A] = relationship(viewonly=True)
+ a_warg: Mapped[A] = relationship("A", viewonly=True)
+
+ is_(A.__mapper__.attrs["bs_list"].collection_class, list)
+ is_(A.__mapper__.attrs["bs_set"].collection_class, set)
+ is_(A.__mapper__.attrs["bs_list_warg"].collection_class, list)
+ is_(A.__mapper__.attrs["bs_set_warg"].collection_class, set)
+ is_true(A.__mapper__.attrs["bs_list"].uselist)
+ is_true(A.__mapper__.attrs["bs_set"].uselist)
+ is_true(A.__mapper__.attrs["bs_list_warg"].uselist)
+ is_true(A.__mapper__.attrs["bs_set_warg"].uselist)
+
+ is_false(A.__mapper__.attrs["b_one_to_one"].uselist)
+ is_false(A.__mapper__.attrs["b_one_to_one_warg"].uselist)
+
+ is_false(B.__mapper__.attrs["a"].uselist)
+ is_false(B.__mapper__.attrs["a_warg"].uselist)
+
+ def test_collection_class_dict_attr_mapped_collection_literal_annotations(
+ self, decl_base
+ ):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column()
+
+ bs: Mapped[MappedCollection[str, B]] = relationship( # noqa: F821
+ collection_class=attribute_mapped_collection("name")
+ )
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+ name: Mapped[str] = mapped_column()
+
+ self._assert_dict(A, B)
+
+ def test_collection_cls_attr_mapped_collection_dbl_literal_annotations(
+ self, decl_base
+ ):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column()
+
+ bs: Mapped[
+ MappedCollection[str, "B"]
+ ] = relationship( # noqa: F821
+ collection_class=attribute_mapped_collection("name")
+ )
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+ name: Mapped[str] = mapped_column()
+
+ self._assert_dict(A, B)
+
+ def test_collection_cls_not_locatable(self, decl_base):
+ class MyCollection(MappedCollection):
+ pass
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"Could not interpret annotation Mapped\[MyCollection\['B'\]\].",
+ ):
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column()
+
+ bs: Mapped[MyCollection["B"]] = relationship( # noqa: F821
+ collection_class=attribute_mapped_collection("name")
+ )
+
+ def test_collection_cls_one_arg(self, decl_base):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column()
+
+ bs: Mapped[MappedOneArg["B"]] = relationship( # noqa: F821
+ collection_class=attribute_mapped_collection("name")
+ )
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+ name: Mapped[str] = mapped_column()
+
+ self._assert_dict(A, B)
+
+ def _assert_dict(self, A, B):
+ A.registry.configure()
+
+ a1 = A()
+ b1 = B(name="foo")
+
+ # collection appender on MappedCollection
+ a1.bs.set(b1)
+
+ is_(a1.bs["foo"], b1)
diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py
index beb5d783b..3bf3a0182 100644
--- a/test/orm/declarative/test_typed_mapping.py
+++ b/test/orm/declarative/test_typed_mapping.py
@@ -1077,6 +1077,15 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
"B", viewonly=True
)
+ # note this is string annotation
+ b_one_to_one: Mapped["B"] = relationship( # noqa: F821
+ viewonly=True
+ )
+
+ b_one_to_one_warg: Mapped["B"] = relationship( # noqa: F821
+ "B", viewonly=True
+ )
+
class B(decl_base):
__tablename__ = "b"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
@@ -1094,9 +1103,36 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
is_true(A.__mapper__.attrs["bs_list_warg"].uselist)
is_true(A.__mapper__.attrs["bs_set_warg"].uselist)
+ is_false(A.__mapper__.attrs["b_one_to_one"].uselist)
+ is_false(A.__mapper__.attrs["b_one_to_one_warg"].uselist)
+
is_false(B.__mapper__.attrs["a"].uselist)
is_false(B.__mapper__.attrs["a_warg"].uselist)
+ def test_one_to_one_example(self, decl_base: Type[DeclarativeBase]):
+ """test example in the relationship docs will derive uselist=False
+ correctly"""
+
+ class Parent(decl_base):
+ __tablename__ = "parent"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ child: Mapped["Child"] = relationship( # noqa: F821
+ back_populates="parent"
+ )
+
+ class Child(decl_base):
+ __tablename__ = "child"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id"))
+ parent: Mapped["Parent"] = relationship(back_populates="child")
+
+ c1 = Child()
+ p1 = Parent(child=c1)
+ is_(p1.child, c1)
+ is_(c1.parent, p1)
+
def test_collection_class_dict_no_collection(self, decl_base):
class A(decl_base):
__tablename__ = "a"
diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py
index 14c0e81ee..0dda9f52f 100644
--- a/test/orm/test_deferred.py
+++ b/test/orm/test_deferred.py
@@ -10,6 +10,7 @@ from sqlalchemy import util
from sqlalchemy.orm import aliased
from sqlalchemy.orm import attributes
from sqlalchemy.orm import contains_eager
+from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import defaultload
from sqlalchemy.orm import defer
from sqlalchemy.orm import deferred
@@ -18,6 +19,8 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.orm import lazyload
from sqlalchemy.orm import Load
from sqlalchemy.orm import load_only
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import query_expression
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
@@ -86,6 +89,47 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest):
],
)
+ def test_basic_w_new_style(self):
+ """sanity check that mapped_column(deferred=True) works"""
+
+ class Base(DeclarativeBase):
+ pass
+
+ class Order(Base):
+ __tablename__ = "orders"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ user_id: Mapped[int]
+ address_id: Mapped[int]
+ isopen: Mapped[bool]
+ description: Mapped[str] = mapped_column(deferred=True)
+
+ q = fixture_session().query(Order).order_by(Order.id)
+
+ def go():
+ result = q.all()
+ o2 = result[2]
+ o2.description
+
+ self.sql_eq_(
+ go,
+ [
+ (
+ "SELECT orders.id AS orders_id, "
+ "orders.user_id AS orders_user_id, "
+ "orders.address_id AS orders_address_id, "
+ "orders.isopen AS orders_isopen "
+ "FROM orders ORDER BY orders.id",
+ {},
+ ),
+ (
+ "SELECT orders.description AS orders_description "
+ "FROM orders WHERE orders.id = :pk_1",
+ {"pk_1": 3},
+ ),
+ ],
+ )
+
def test_defer_primary_key(self):
"""what happens when we try to defer the primary key?"""