summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-07-03 16:25:15 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-07-03 22:33:48 -0400
commit148711cb8515a19b6177dc07655cc6e652de0553 (patch)
treeb75505c907d25395d77f45b94919b9a17e9432cf /lib/sqlalchemy/orm/util.py
parent4b3f204d07d53ae09b59ce8f33b534f26a605cd4 (diff)
downloadsqlalchemy-148711cb8515a19b6177dc07655cc6e652de0553.tar.gz
runtime annotation fixes for relationship
* derive uselist=False when fwd ref passed to relationship This case needs to work whether or not the class name is a forward ref. we dont allow the colleciton to be a forward ref so this will work. * fix issues with MappedCollection When using string annotations or __future__.annotations, we need to do more parsing in order to get the target collection properly Change-Id: I9e5a1358b62d060a8815826f98190801a9cc0b68
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r--lib/sqlalchemy/orm/util.py30
1 files changed, 24 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 317abe2b4..02080a27f 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1958,8 +1958,12 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
def _is_mapped_annotation(
raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
- annotated = de_stringify_annotation(cls, raw_annotation)
- return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
+ try:
+ annotated = de_stringify_annotation(cls, raw_annotation)
+ except NameError:
+ return False
+ else:
+ return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
def _cleanup_mapped_str_annotation(annotation: str) -> str:
@@ -1984,7 +1988,10 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
# stack: ['Mapped', 'List', 'Address']
if not re.match(r"""^["'].*["']$""", stack[-1]):
- stack[-1] = f'"{stack[-1]}"'
+ stripchars = "\"' "
+ stack[-1] = ", ".join(
+ f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",")
+ )
# stack: ['Mapped', 'List', '"Address"']
annotation = "[".join(stack) + ("]" * (len(stack) - 1))
@@ -2007,6 +2014,7 @@ def _extract_mapped_subtype(
Includes error raise scenarios and other options.
"""
+
if raw_annotation is None:
if required:
@@ -2017,9 +2025,19 @@ def _extract_mapped_subtype(
)
return None
- annotated = de_stringify_annotation(
- cls, raw_annotation, _cleanup_mapped_str_annotation
- )
+ try:
+ annotated = de_stringify_annotation(
+ cls, raw_annotation, _cleanup_mapped_str_annotation
+ )
+ except NameError as ne:
+ if raiseerr and "Mapped[" in raw_annotation: # type: ignore
+ raise sa_exc.ArgumentError(
+ f"Could not interpret annotation {raw_annotation}. "
+ "Check that it's not using names that might not be imported "
+ "at the module level. See chained stack trace for more hints."
+ ) from ne
+
+ annotated = raw_annotation # type: ignore
if is_dataclass_field:
return annotated