diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-09-25 21:40:48 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-09-25 21:41:14 -0400 |
| commit | 74f6e38ec717979bcde76244c94b1d8c519a5b63 (patch) | |
| tree | e36e3c8a6fc1c7d04b1d34007ea4a71fef47a453 | |
| parent | e708cfea0bdaae82ac30dd7d33f9442115b9af6d (diff) | |
| download | sqlalchemy-74f6e38ec717979bcde76244c94b1d8c519a5b63.tar.gz | |
add typing for sqlalchemy.orm.validates
Fixes: #8577
Change-Id: Iede1c956078960fb866da45f1ac6aa43842516bc
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 14 | ||||
| -rw-r--r-- | test/ext/mypy/plain_files/orm_config_constructs.py | 20 |
2 files changed, 28 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 98c0eba0c..553f7b35b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -118,6 +118,8 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _MP = TypeVar("_MP", bound="MapperProperty[Any]") +_Fn = TypeVar("_Fn", bound="Callable[..., Any]") + _WithPolymorphicArg = Union[ Literal["*"], @@ -3895,7 +3897,9 @@ def reconstructor(fn): return fn -def validates(*names, **kw): +def validates( + *names: str, include_removes: bool = False, include_backrefs: bool = False +) -> Callable[[_Fn], _Fn]: r"""Decorate a method as a 'validator' for one or more named properties. Designates a method as a validator, a method which receives the @@ -3930,12 +3934,10 @@ def validates(*names, **kw): :ref:`simple_validators` - usage examples for :func:`.validates` """ - include_removes = kw.pop("include_removes", False) - include_backrefs = kw.pop("include_backrefs", True) - def wrap(fn): - fn.__sa_validators__ = names - fn.__sa_validation_opts__ = { + def wrap(fn: _Fn) -> _Fn: + fn.__sa_validators__ = names # type: ignore[attr-defined] + fn.__sa_validation_opts__ = { # type: ignore[attr-defined] "include_removes": include_removes, "include_backrefs": include_backrefs, } diff --git a/test/ext/mypy/plain_files/orm_config_constructs.py b/test/ext/mypy/plain_files/orm_config_constructs.py new file mode 100644 index 000000000..008e16f24 --- /dev/null +++ b/test/ext/mypy/plain_files/orm_config_constructs.py @@ -0,0 +1,20 @@ +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import validates + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "User" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + @validates("name", include_removes=True) + def validate_name(self, name: str) -> str: + """test #8577""" + return name + "hi" |
