diff options
| author | Peter Schutt <peter.github@proton.me> | 2022-09-01 19:11:40 -0400 |
|---|---|---|
| committer | sqla-tester <sqla-tester@sqlalchemy.org> | 2022-09-01 19:11:40 -0400 |
| commit | c3cfee5b00a40790c18d444a1ea1454aafc27889 (patch) | |
| tree | f6a557b289d5a4c567b70c207ec887c1bd18a08d /lib/sqlalchemy | |
| parent | d3e0b8e750d864766148cdf1a658a601079eed46 (diff) | |
| download | sqlalchemy-c3cfee5b00a40790c18d444a1ea1454aafc27889.tar.gz | |
Detection of PEP 604 union syntax.
### Description
Fixes #8478
Handle `UnionType` as arguments to `Mapped`, e.g., `Mapped[str | None]`:
- adds `utils.typing.is_optional_union()` used to detect if a column should be nullable.
- adds `"UnionType"` to `utils.typing.is_optional()` names.
- uses `get_origin()` in `utils.typing.is_origin_of()` as `UnionType` has no `__origin__` attribute.
- tests with runtime type and postponed annotations and guard the tests running with `compat.py310`.
### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)
-->
This pull request is:
- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
must include a complete example of the issue. one line code fixes without an
issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests. one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.
**Have a nice day!**
Closes: #8479
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8479
Pull-request-sha: 12417654822272c5847e684c53677f665553ef0e
Change-Id: Ib3248043dd4a97324ac592c048385006536b2d49
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 9 |
2 files changed, 11 insertions, 8 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 6213cfef8..7d7175678 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -52,8 +52,8 @@ from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref +from ..util.typing import is_optional_union from ..util.typing import is_pep593 -from ..util.typing import NoneType from ..util.typing import Self from ..util.typing import typing_get_args @@ -652,17 +652,15 @@ class MappedColumn( ) -> None: sqltype = self.column.type - nullable = False + if is_fwd_ref(argument): + argument = de_stringify_annotation(cls, argument) - if hasattr(argument, "__origin__"): - nullable = NoneType in argument.__args__ # type: ignore + nullable = is_optional_union(argument) if not self._has_nullable: self.column.nullable = nullable our_type = de_optionalize_union_types(argument) - if is_fwd_ref(our_type): - our_type = de_stringify_annotation(cls, our_type) use_args_from = None if is_pep593(our_type): diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 45fe63765..85c1bae72 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -169,7 +169,7 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]: def expand_unions( type_: Type[Any], include_union: bool = False, discard_none: bool = False ) -> Tuple[Type[Any], ...]: - """Return a type as as a tuple of individual types, expanding for + """Return a type as a tuple of individual types, expanding for ``Union`` types.""" if is_union(type_): @@ -191,9 +191,14 @@ def is_optional(type_): type_, "Optional", "Union", + "UnionType", ) +def is_optional_union(type_: Any) -> bool: + return is_optional(type_) and NoneType in typing_get_args(type_) + + def is_union(type_): return is_origin_of(type_, "Union") @@ -204,7 +209,7 @@ def is_origin_of( """return True if the given type has an __origin__ with the given name and optional module.""" - origin = getattr(type_, "__origin__", None) + origin = typing_get_origin(type_) if origin is None: return False |
