summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-01-25 22:45:31 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2023-01-25 22:45:31 +0000
commitf24c0521cee3b2c1fa224d5dbc6441b9d5f8c0bb (patch)
tree213c8a8719b43982797f173e45eebbfd3a76e4eb
parentec2151fd5915d35ba9a8b9f09b9b677a209a66ad (diff)
parent1526cf68af500141480cc51ec4de18c705fe0b0a (diff)
downloadsqlalchemy-f24c0521cee3b2c1fa224d5dbc6441b9d5f8c0bb.tar.gz
Merge "Add public protocol for mapped class" into main
-rw-r--r--doc/build/orm/mapping_api.rst3
-rw-r--r--doc/build/orm/mapping_styles.rst2
-rw-r--r--lib/sqlalchemy/orm/__init__.py1
-rw-r--r--lib/sqlalchemy/orm/decl_base.py79
-rw-r--r--test/ext/mypy/plain_files/declared_attr_one.py19
-rw-r--r--test/orm/declarative/test_basic.py18
6 files changed, 94 insertions, 28 deletions
diff --git a/doc/build/orm/mapping_api.rst b/doc/build/orm/mapping_api.rst
index 8eebe7c77..1a33f9566 100644
--- a/doc/build/orm/mapping_api.rst
+++ b/doc/build/orm/mapping_api.rst
@@ -141,3 +141,6 @@ Class Mapping API
.. autoclass:: MappedAsDataclass
:members:
+
+.. autoclass:: MappedClassProtocol
+ :no-members:
diff --git a/doc/build/orm/mapping_styles.rst b/doc/build/orm/mapping_styles.rst
index b26399393..b4c21a353 100644
--- a/doc/build/orm/mapping_styles.rst
+++ b/doc/build/orm/mapping_styles.rst
@@ -36,6 +36,8 @@ the class itself has been :term:`instrumented` to include behaviors linked to
relational operations both at the level of the class as well as on instances of
that class. As the process is basically the same in all cases, classes mapped
from different styles are always fully interoperable with each other.
+The protocol :class:`_orm.MappedClassProtocol` can be used to indicate a mapped
+class when using type checkers such as mypy.
The original mapping API is commonly referred to as "classical" style,
whereas the more automated style of mapping is known as "declarative" style.
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 6980db2e2..d54e1ccb9 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -65,6 +65,7 @@ from .decl_api import has_inherited_table as has_inherited_table
from .decl_api import MappedAsDataclass as MappedAsDataclass
from .decl_api import registry as registry
from .decl_api import synonym_for as synonym_for
+from .decl_base import MappedClassProtocol as MappedClassProtocol
from .descriptor_props import Composite as Composite
from .descriptor_props import CompositeProperty as CompositeProperty
from .descriptor_props import Synonym as Synonym
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index a379af2dd..9e8b02359 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -49,7 +49,6 @@ from .interfaces import _IntrospectsAnnotations
from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
-from .mapper import Mapper as mapper
from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
@@ -84,25 +83,38 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
_MapperKwArgs = Mapping[str, Any]
-
_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]]
-class _DeclMappedClassProtocol(Protocol[_O]):
- metadata: MetaData
+class MappedClassProtocol(Protocol[_O]):
+ """A protocol representing a SQLAlchemy mapped class.
+
+ The protocol is generic on the type of class, use
+ ``MappedClassProtocol[Any]`` to allow any mapped class.
+ """
+
+ __name__: str
__mapper__: Mapper[_O]
- __table__: Table
+ __table__: FromClause
+
+ def __call__(self, **kw: Any) -> _O:
+ ...
+
+
+class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol):
+ "Internal more detailed version of ``MappedClassProtocol``."
+ metadata: MetaData
__tablename__: str
- __mapper_args__: Mapping[str, Any]
+ __mapper_args__: _MapperKwArgs
__table_args__: Optional[_TableArgsType]
_sa_apply_dc_transforms: Optional[_DataclassArguments]
def __declare_first__(self) -> None:
- pass
+ ...
def __declare_last__(self) -> None:
- pass
+ ...
class _DataclassArguments(TypedDict):
@@ -241,7 +253,7 @@ def _mapper(
mapper_kw: _MapperKwArgs,
) -> Mapper[_O]:
_ImperativeMapperConfig(registry, cls, table, mapper_kw)
- return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__
+ return cast("MappedClassProtocol[_O]", cls).__mapper__
@util.preload_module("sqlalchemy.orm.decl_api")
@@ -297,7 +309,7 @@ class _MapperConfig:
manager = attributes.opt_manager_of_class(cls)
if manager and manager.class_ is cls_:
raise exc.InvalidRequestError(
- "Class %r already has been " "instrumented declaratively" % cls
+ f"Class {cls!r} already has been instrumented declaratively"
)
if cls_.__dict__.get("__abstract__", False):
@@ -382,7 +394,7 @@ class _ImperativeMapperConfig(_MapperConfig):
self._early_mapping(mapper_kw)
def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
- mapper_cls = mapper
+ mapper_cls = Mapper
return self.set_cls_attribute(
"__mapper__",
@@ -413,7 +425,7 @@ class _ImperativeMapperConfig(_MapperConfig):
% (cls, inherits_search)
)
inherits = inherits_search[0]
- elif isinstance(inherits, mapper):
+ elif isinstance(inherits, Mapper):
inherits = inherits.class_
self.inherits = inherits
@@ -567,7 +579,7 @@ class _ClassScanMapperConfig(_MapperConfig):
def _setup_declared_events(self) -> None:
if _get_immediate_cls_attr(self.cls, "__declare_last__"):
- @event.listens_for(mapper, "after_configured")
+ @event.listens_for(Mapper, "after_configured")
def after_configured() -> None:
cast(
"_DeclMappedClassProtocol[Any]", self.cls
@@ -575,7 +587,7 @@ class _ClassScanMapperConfig(_MapperConfig):
if _get_immediate_cls_attr(self.cls, "__declare_first__"):
- @event.listens_for(mapper, "before_configured")
+ @event.listens_for(Mapper, "before_configured")
def before_configured() -> None:
cast(
"_DeclMappedClassProtocol[Any]", self.cls
@@ -1507,7 +1519,7 @@ class _ClassScanMapperConfig(_MapperConfig):
def _setup_table(self, table: Optional[FromClause] = None) -> None:
cls = self.cls
- cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+ cls_as_Decl = cast("MappedClassProtocol[Any]", cls)
tablename = self.tablename
table_args = self.table_args
@@ -1570,8 +1582,9 @@ class _ClassScanMapperConfig(_MapperConfig):
self.local_table = table
def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData:
- if hasattr(self.cls, "metadata"):
- return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata
+ meta: Optional[MetaData] = getattr(self.cls, "metadata", None)
+ if meta is not None:
+ return meta
else:
return manager.registry.metadata
@@ -1599,7 +1612,7 @@ class _ClassScanMapperConfig(_MapperConfig):
% (cls, inherits_search)
)
inherits = inherits_search[0]
- elif isinstance(inherits, mapper):
+ elif isinstance(inherits, Mapper):
inherits = inherits.class_
self.inherits = inherits
@@ -1701,7 +1714,7 @@ class _ClassScanMapperConfig(_MapperConfig):
if "inherits" in mapper_args:
inherits_arg = mapper_args["inherits"]
- if isinstance(inherits_arg, mapper):
+ if isinstance(inherits_arg, Mapper):
inherits_arg = inherits_arg.class_
if inherits_arg is not self.inherits:
@@ -1762,7 +1775,7 @@ class _ClassScanMapperConfig(_MapperConfig):
),
)
else:
- mapper_cls = mapper
+ mapper_cls = Mapper
return self.set_cls_attribute(
"__mapper__",
@@ -1873,18 +1886,29 @@ def _add_attribute(
"""
if "__mapper__" in cls.__dict__:
- mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls)
+ mapped_cls = cast("MappedClassProtocol[Any]", cls)
+
+ def _table_or_raise(mc: MappedClassProtocol[Any]) -> Table:
+ if isinstance(mc.__table__, Table):
+ return mc.__table__
+ raise exc.InvalidRequestError(
+ f"Cannot add a new attribute to mapped class {mc.__name__!r} "
+ "because it's not mapped against a table."
+ )
+
if isinstance(value, Column):
_undefer_column_name(key, value)
- # TODO: raise for this is not a Table
- mapped_cls.__table__.append_column(value, replace_existing=True)
+ _table_or_raise(mapped_cls).append_column(
+ value, replace_existing=True
+ )
mapped_cls.__mapper__.add_property(key, value)
elif isinstance(value, _MapsColumns):
mp = value.mapper_property_to_assign
for col in value.columns_to_assign:
_undefer_column_name(key, col)
- # TODO: raise for this is not a Table
- mapped_cls.__table__.append_column(col, replace_existing=True)
+ _table_or_raise(mapped_cls).append_column(
+ col, replace_existing=True
+ )
if not mp:
mapped_cls.__mapper__.add_property(key, col)
if mp:
@@ -1904,12 +1928,11 @@ def _add_attribute(
def _del_attribute(cls: Type[Any], key: str) -> None:
-
if (
"__mapper__" in cls.__dict__
and key in cls.__dict__
and not cast(
- "_DeclMappedClassProtocol[Any]", cls
+ "MappedClassProtocol[Any]", cls
).__mapper__._dispose_called
):
value = cls.__dict__[key]
@@ -1922,7 +1945,7 @@ def _del_attribute(cls: Type[Any], key: str) -> None:
else:
type.__delattr__(cls, key)
cast(
- "_DeclMappedClassProtocol[Any]", cls
+ "MappedClassProtocol[Any]", cls
).__mapper__._expire_memoizations()
else:
type.__delattr__(cls, key)
diff --git a/test/ext/mypy/plain_files/declared_attr_one.py b/test/ext/mypy/plain_files/declared_attr_one.py
index a6d96f39e..d4f3c826e 100644
--- a/test/ext/mypy/plain_files/declared_attr_one.py
+++ b/test/ext/mypy/plain_files/declared_attr_one.py
@@ -10,6 +10,7 @@ from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedClassProtocol
from sqlalchemy.sql.schema import PrimaryKeyConstraint
@@ -70,6 +71,24 @@ class Manager(Employee):
)
+def do_something_with_mapped_class(
+ cls_: MappedClassProtocol[Employee],
+) -> None:
+
+ # EXPECTED_TYPE: Select[Any]
+ reveal_type(cls_.__table__.select())
+
+ # EXPECTED_TYPE: Mapper[Employee]
+ reveal_type(cls_.__mapper__)
+
+ # EXPECTED_TYPE: Employee
+ reveal_type(cls_())
+
+
+do_something_with_mapped_class(Manager)
+do_something_with_mapped_class(Engineer)
+
+
if typing.TYPE_CHECKING:
# EXPECTED_TYPE: InstrumentedAttribute[datetime]
diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py
index 83d103864..28fdc97f2 100644
--- a/test/orm/declarative/test_basic.py
+++ b/test/orm/declarative/test_basic.py
@@ -611,6 +611,24 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
sa.Column("id", Integer, primary_key=True),
)
+ def test_cannot_add_to_selectable(self):
+ class Base(DeclarativeBase):
+ pass
+
+ class Foo(Base):
+ __table__ = (
+ select(sa.Column("x", sa.Integer, primary_key=True))
+ .select_from(sa.table("foo"))
+ .subquery("foo")
+ )
+
+ with assertions.expect_raises_message(
+ exc.InvalidRequestError,
+ "Cannot add a new attribute to mapped class 'Foo' "
+ "because it's not mapped against a table",
+ ):
+ Foo.y = mapped_column(sa.Text)
+
@testing.combinations(
("declarative_base_nometa_superclass",),