summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2021-03-14 21:31:08 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2021-03-14 21:31:08 +0000
commit16b13ba06b661a6dfb13b1537ee873c9065b084e (patch)
treea15d91f80324fd21bc38093efb7a5a92e8080577 /lib/sqlalchemy/ext
parent827af37d681b61de1300d7dacc5c50ff23a4fbf9 (diff)
parent5f8ee3920066c0cbe5d6d6b0ceb987524f7542c4 (diff)
downloadsqlalchemy-16b13ba06b661a6dfb13b1537ee873c9065b084e.tar.gz
Merge "Implement Mypy plugin"
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/mypy/__init__.py0
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py989
-rw-r--r--lib/sqlalchemy/ext/mypy/names.py194
-rw-r--r--lib/sqlalchemy/ext/mypy/plugin.py215
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py80
5 files changed, 1478 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/mypy/__init__.py b/lib/sqlalchemy/ext/mypy/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/__init__.py
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
new file mode 100644
index 000000000..f5215ca1c
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -0,0 +1,989 @@
+# ext/mypy/decl_class.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import Union
+
+from mypy import nodes
+from mypy import types
+from mypy.messages import format_type
+from mypy.nodes import ARG_NAMED_OPT
+from mypy.nodes import Argument
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import Decorator
+from mypy.nodes import JsonDict
+from mypy.nodes import ListExpr
+from mypy.nodes import MDEF
+from mypy.nodes import NameExpr
+from mypy.nodes import PlaceholderNode
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import add_method_to_class
+from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.subtypes import is_subtype
+from mypy.types import AnyType
+from mypy.types import Instance
+from mypy.types import NoneTyp
+from mypy.types import NoneType
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import names
+from . import util
+
+
+class DeclClassApplied:
+ def __init__(
+ self,
+ is_mapped: bool,
+ has_table: bool,
+ mapped_attr_names: Sequence[Tuple[str, Type]],
+ mapped_mro: Sequence[Type],
+ ):
+ self.is_mapped = is_mapped
+ self.has_table = has_table
+ self.mapped_attr_names = mapped_attr_names
+ self.mapped_mro = mapped_mro
+
+ def serialize(self) -> JsonDict:
+ return {
+ "is_mapped": self.is_mapped,
+ "has_table": self.has_table,
+ "mapped_attr_names": [
+ (name, type_.serialize())
+ for name, type_ in self.mapped_attr_names
+ ],
+ "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
+ }
+
+ @classmethod
+ def deserialize(
+ cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
+ ) -> "DeclClassApplied":
+
+ return DeclClassApplied(
+ is_mapped=data["is_mapped"],
+ has_table=data["has_table"],
+ mapped_attr_names=[
+ (name, deserialize_and_fixup_type(type_, api))
+ for name, type_ in data["mapped_attr_names"]
+ ],
+ mapped_mro=[
+ deserialize_and_fixup_type(type_, api)
+ for type_ in data["mapped_mro"]
+ ],
+ )
+
+
+def _scan_declarative_assignments_and_apply_types(
+ cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False
+) -> Optional[DeclClassApplied]:
+
+ if cls.fullname.startswith("builtins"):
+ return None
+ elif "_sa_decl_class_applied" in cls.info.metadata:
+ cls_metadata = DeclClassApplied.deserialize(
+ cls.info.metadata["_sa_decl_class_applied"], api
+ )
+
+ # ensure that a class that's mapped is always picked up by
+ # its mapped() decorator or declarative metaclass before
+ # it would be detected as an unmapped mixin class
+ if not is_mixin_scan:
+ assert cls_metadata.is_mapped
+
+ # mypy can call us more than once. it then will have reset the
+ # left hand side of everything, but not the right that we removed,
+ # removing our ability to re-scan. but we have the types
+ # here, so lets re-apply them.
+
+ _re_apply_declarative_assignments(cls, api, cls_metadata)
+
+ return cls_metadata
+
+ cls_metadata = DeclClassApplied(not is_mixin_scan, False, [], [])
+
+ for stmt in util._flatten_typechecking(cls.defs.body):
+ if isinstance(stmt, AssignmentStmt):
+ _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata)
+ elif isinstance(stmt, Decorator):
+ _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata)
+ _scan_for_mapped_bases(cls, api, cls_metadata)
+ _add_additional_orm_attributes(cls, api, cls_metadata)
+
+ cls.info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize()
+
+ return cls_metadata
+
+
+def _scan_declarative_decorator_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: Decorator,
+ cls_metadata: DeclClassApplied,
+):
+ """Extract mapping information from a @declared_attr in a declarative
+ class.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ @declared_attr
+ def updated_at(cls) -> Column[DateTime]:
+ return Column(DateTime)
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ updated_at: Mapped[Optional[datetime.datetime]]
+
+ """
+ for dec in stmt.decorators:
+ if names._type_id_for_named_node(dec) is names.DECLARED_ATTR:
+ break
+ else:
+ return
+
+ dec_index = cls.defs.body.index(stmt)
+
+ left_hand_explicit_type = None
+
+ if stmt.func.type is not None:
+ func_type = stmt.func.type.ret_type
+ if isinstance(func_type, UnboundType):
+ type_id = names._type_id_for_unbound_type(func_type, cls, api)
+ else:
+ # this does not seem to occur unless the type argument is
+ # incorrect
+ return
+
+ if (
+ type_id
+ in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }
+ and func_type.args
+ ):
+ left_hand_explicit_type = func_type.args[0]
+ elif type_id is names.COLUMN and func_type.args:
+ typeengine_arg = func_type.args[0]
+ if isinstance(typeengine_arg, UnboundType):
+ sym = api.lookup(typeengine_arg.name, typeengine_arg)
+ if sym is not None and names._mro_has_id(
+ sym.node.mro, names.TYPEENGINE
+ ):
+
+ left_hand_explicit_type = UnionType(
+ [
+ _extract_python_type_from_typeengine(sym.node),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ func_type,
+ )
+
+ if left_hand_explicit_type is None:
+ # no type on the decorated function. our option here is to
+ # dig into the function body and get the return type, but they
+ # should just have an annotation.
+ msg = (
+ "Can't infer type from @declared_attr on function '{}'; "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(stmt.var.name), stmt)
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+ left_node = NameExpr(stmt.var.name)
+ left_node.node = stmt.var
+
+ # totally feeling around in the dark here as I don't totally understand
+ # the significance of UnboundType. It seems to be something that is
+ # not going to do what's expected when it is applied as the type of
+ # an AssignmentStatement. So do a feeling-around-in-the-dark version
+ # of converting it to the regular Instance/TypeInfo/UnionType structures
+ # we see everywhere else.
+ if isinstance(left_hand_explicit_type, UnboundType):
+ left_hand_explicit_type = util._unbound_to_instance(
+ api, left_hand_explicit_type
+ )
+
+ left_node.node.type = Instance(descriptor.node, [left_hand_explicit_type])
+
+ # this will ignore the rvalue entirely
+ # rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(lambda: <function body>)
+ # the function body is maintained so it gets type checked internally
+ api.add_symbol_table_node("_sa_Mapped", descriptor)
+ column_descriptor = nodes.NameExpr("_sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+ mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
+
+ arg = nodes.LambdaExpr(stmt.func.arguments, stmt.func.body)
+ rvalue = CallExpr(
+ mm,
+ [arg],
+ [nodes.ARG_POS],
+ ["arg1"],
+ )
+
+ new_stmt = AssignmentStmt([left_node], rvalue)
+ new_stmt.type = left_node.node.type
+
+ cls_metadata.mapped_attr_names.append(
+ (left_node.name, left_hand_explicit_type)
+ )
+ cls.defs.body[dec_index] = new_stmt
+
+
+def _scan_declarative_assignment_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ cls_metadata: DeclClassApplied,
+):
+ """Extract mapping information from an assignment statement in a
+ declarative class.
+
+ """
+ lvalue = stmt.lvalues[0]
+ if not isinstance(lvalue, NameExpr):
+ return
+
+ sym = cls.info.names.get(lvalue.name)
+
+ # this establishes that semantic analysis has taken place, which
+ # means the nodes are populated and we are called from an appropriate
+ # hook.
+ assert sym is not None
+ node = sym.node
+
+ if isinstance(node, PlaceholderNode):
+ return
+
+ assert node is lvalue.node
+ assert isinstance(node, Var)
+
+ if node.name == "__abstract__":
+ if stmt.rvalue.fullname == "builtins.True":
+ cls_metadata.is_mapped = False
+ return
+ elif node.name == "__tablename__":
+ cls_metadata.has_table = True
+ elif node.name.startswith("__"):
+ return
+ elif node.name == "_mypy_mapped_attrs":
+ if not isinstance(stmt.rvalue, ListExpr):
+ util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
+ else:
+ for item in stmt.rvalue.items:
+ if isinstance(item, (NameExpr, StrExpr)):
+ _apply_mypy_mapped_attr(cls, api, item, cls_metadata)
+
+ left_hand_mapped_type: Type = None
+
+ if node.is_inferred or node.type is None:
+ if isinstance(stmt.type, UnboundType):
+ # look for an explicit Mapped[] type annotation on the left
+ # side with nothing on the right
+
+ # print(stmt.type)
+ # Mapped?[Optional?[A?]]
+
+ left_hand_explicit_type = stmt.type
+
+ if stmt.type.name == "Mapped":
+ mapped_sym = api.lookup("Mapped", cls)
+ if (
+ mapped_sym is not None
+ and names._type_id_for_named_node(mapped_sym.node)
+ is names.MAPPED
+ ):
+ left_hand_explicit_type = stmt.type.args[0]
+ left_hand_mapped_type = stmt.type
+
+ # TODO: do we need to convert from unbound for this case?
+ # left_hand_explicit_type = util._unbound_to_instance(
+ # api, left_hand_explicit_type
+ # )
+
+ else:
+ left_hand_explicit_type = None
+ else:
+ if (
+ isinstance(node.type, Instance)
+ and names._type_id_for_named_node(node.type.type) is names.MAPPED
+ ):
+ # print(node.type)
+ # sqlalchemy.orm.attributes.Mapped[<python type>]
+ left_hand_explicit_type = node.type.args[0]
+ left_hand_mapped_type = node.type
+ else:
+ # print(node.type)
+ # <python type>
+ left_hand_explicit_type = node.type
+ left_hand_mapped_type = None
+
+ if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
+ # annotation without assignment and Mapped is present
+ # as type annotation
+ # equivalent to using _infer_type_from_left_hand_type_only.
+
+ python_type_for_type = left_hand_explicit_type
+ elif isinstance(stmt.rvalue, CallExpr) and isinstance(
+ stmt.rvalue.callee, RefExpr
+ ):
+
+ type_id = names._type_id_for_callee(stmt.rvalue.callee)
+
+ if type_id is None:
+ return
+ elif type_id is names.COLUMN:
+ python_type_for_type = _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type, stmt.rvalue
+ )
+ elif type_id is names.RELATIONSHIP:
+ python_type_for_type = _infer_type_from_relationship(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.COLUMN_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_column_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.SYNONYM_PROPERTY:
+ python_type_for_type = _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif type_id is names.COMPOSITE_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_composite_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ else:
+ return
+
+ else:
+ return
+
+ cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+
+ assert python_type_for_type is not None
+
+ _apply_type_to_mapped_statement(
+ api,
+ stmt,
+ lvalue,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
+
+
+def _apply_mypy_mapped_attr(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ item: Union[NameExpr, StrExpr],
+ cls_metadata: DeclClassApplied,
+):
+ if isinstance(item, NameExpr):
+ name = item.name
+ elif isinstance(item, StrExpr):
+ name = item.value
+ else:
+ return
+
+ for stmt in cls.defs.body:
+ if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
+ break
+ else:
+ util.fail(api, "Can't find mapped attribute {}".format(name), cls)
+ return
+
+ if stmt.type is None:
+ util.fail(
+ api,
+ "Statement linked from _mypy_mapped_attrs has no "
+ "typing information",
+ stmt,
+ )
+ return
+
+ left_hand_explicit_type = stmt.type
+
+ cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+
+ _apply_type_to_mapped_statement(
+ api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
+ )
+
+
+def _infer_type_from_relationship(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a relationship.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses = relationship(Address, uselist=True)
+
+ order: Mapped["Order"] = relationship("Order")
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses: Mapped[List[Address]]
+
+ order: Mapped["Order"]
+
+ """
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ # type
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+
+ # other cases not covered - an error message directs the user
+ # to set an explicit type annotation
+ #
+ # node.type == str, it's a string
+ # if isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, Var
+ # )
+ # points to a type
+ # isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, TypeAlias
+ # )
+ # string expression
+ # isinstance(target_cls_arg, StrExpr)
+
+ uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
+ collection_cls_arg = util._get_callexpr_kwarg(
+ stmt.rvalue, "collection_class"
+ )
+
+ # this can be used to determine Optional for a many-to-one
+ # in the same way nullable=False could be used, if we start supporting
+ # that.
+ # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+
+ if (
+ uselist_arg is not None
+ and uselist_arg.fullname == "builtins.True"
+ and collection_cls_arg is None
+ ):
+ if python_type_for_type is not None:
+ python_type_for_type = Instance(
+ api.lookup_fully_qualified("builtins.list").node,
+ [python_type_for_type],
+ )
+ elif (
+ uselist_arg is None or uselist_arg.fullname == "builtins.True"
+ ) and collection_cls_arg is not None:
+ if isinstance(collection_cls_arg.node, TypeInfo):
+ if python_type_for_type is not None:
+ python_type_for_type = Instance(
+ collection_cls_arg.node, [python_type_for_type]
+ )
+ else:
+ util.fail(
+ api,
+ "Expected Python collection type for "
+ "collection_class parameter",
+ stmt.rvalue,
+ )
+ python_type_for_type = None
+ elif uselist_arg is not None and uselist_arg.fullname == "builtins.False":
+ if collection_cls_arg is not None:
+ util.fail(
+ api,
+ "Sending uselist=False and collection_class at the same time "
+ "does not make sense",
+ stmt.rvalue,
+ )
+ if python_type_for_type is not None:
+ python_type_for_type = UnionType(
+ [python_type_for_type, NoneType()]
+ )
+
+ else:
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer scalar or collection for ORM mapped expression "
+ "assigned to attribute '{}' if both 'uselist' and "
+ "'collection_class' arguments are absent from the "
+ "relationship(); please specify a "
+ "type annotation on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ if python_type_for_type is None:
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_composite_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a CompositeProperty."""
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+ else:
+ python_type_for_type = None
+
+ if python_type_for_type is None:
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_column_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a ColumnProperty.
+
+ This includes mappings against ``column_property()`` as well as the
+ ``deferred()`` function.
+
+ """
+ assert isinstance(stmt.rvalue, CallExpr)
+ first_prop_arg = stmt.rvalue.args[0]
+
+ if isinstance(first_prop_arg, CallExpr):
+ type_id = names._type_id_for_callee(first_prop_arg.callee)
+ else:
+ type_id = None
+
+ print(stmt.lvalues[0].name)
+
+ # look for column_property() / deferred() etc with Column as first
+ # argument
+ if type_id is names.COLUMN:
+ return _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type, first_prop_arg
+ )
+ else:
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_decl_column(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+ right_hand_expression: CallExpr,
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a Column.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a = Column(Integer)
+
+ b = Column("b", String)
+
+ c: Mapped[int] = Column(Integer)
+
+ d: bool = Column(Boolean)
+
+ Will resolve in MyPy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a : Mapped[int]
+
+ b : Mapped[str]
+
+ c: Mapped[int]
+
+ d: Mapped[bool]
+
+ """
+ assert isinstance(node, Var)
+
+ callee = None
+
+ for column_arg in right_hand_expression.args[0:2]:
+ if isinstance(column_arg, nodes.CallExpr):
+ # x = Column(String(50))
+ callee = column_arg.callee
+ break
+ elif isinstance(column_arg, nodes.NameExpr):
+ if isinstance(column_arg.node, TypeInfo):
+ # x = Column(String)
+ callee = column_arg
+ break
+ else:
+ # x = Column(some_name, String), go to next argument
+ continue
+ elif isinstance(column_arg, (StrExpr,)):
+ # x = Column("name", String), go to next argument
+ continue
+ else:
+ assert False
+
+ if callee is None:
+ return None
+
+ if names._mro_has_id(callee.node.mro, names.TYPEENGINE):
+ python_type_for_type = _extract_python_type_from_typeengine(
+ callee.node
+ )
+
+ if left_hand_explicit_type is not None:
+
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+
+ else:
+ python_type_for_type = UnionType(
+ [python_type_for_type, NoneType()]
+ )
+ return python_type_for_type
+ else:
+ # it's not TypeEngine, it's typically implicitly typed
+ # like ForeignKey. we can't infer from the right side.
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+ python_type_for_type: Union[Instance, UnionType],
+) -> Optional[Union[Instance, UnionType]]:
+ """Validate type when a left hand annotation is present and we also
+ could infer the right hand side::
+
+ attrname: SomeType = Column(SomeDBType)
+
+ """
+ if not is_subtype(left_hand_explicit_type, python_type_for_type):
+ descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+ effective_type = Instance(descriptor.node, [python_type_for_type])
+
+ msg = (
+ "Left hand assignment '{}: {}' not compatible "
+ "with ORM mapped expression of type {}"
+ )
+ util.fail(
+ api,
+ msg.format(
+ node.name,
+ format_type(left_hand_explicit_type),
+ format_type(effective_type),
+ ),
+ node,
+ )
+
+ return left_hand_explicit_type
+
+
+def _infer_type_from_left_hand_type_only(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Optional[Union[Instance, UnionType]]:
+ """Determine the type based on explicit annotation only.
+
+ if no annotation were present, note that we need one there to know
+ the type.
+
+ """
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer type from ORM mapped expression "
+ "assigned to attribute '{}'; please specify a "
+ "Python type or "
+ "Mapped[<python type>] on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+ return Instance(descriptor.node, [AnyType(TypeOfAny.special_form)])
+
+ else:
+ # use type from the left hand side
+ return left_hand_explicit_type
+
+
+def _re_apply_declarative_assignments(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ cls_metadata: DeclClassApplied,
+):
+ """For multiple class passes, re-apply our left-hand side types as mypy
+ seems to reset them in place.
+
+ """
+ mapped_attr_lookup = {
+ name: typ for name, typ in cls_metadata.mapped_attr_names
+ }
+
+ descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+ for stmt in cls.defs.body:
+ # for a re-apply, all of our statements are AssignmentStmt;
+ # @declared_attr calls will have been converted and this
+ # currently seems to be preserved by mypy (but who knows if this
+ # will change).
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and stmt.lvalues[0].name in mapped_attr_lookup
+ ):
+ typ = mapped_attr_lookup[stmt.lvalues[0].name]
+ left_node = stmt.lvalues[0].node
+
+ inst = Instance(descriptor.node, [typ])
+ left_node.type = inst
+
+
+def _apply_type_to_mapped_statement(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ lvalue: NameExpr,
+ left_hand_explicit_type: Optional[Union[Instance, UnionType]],
+ python_type_for_type: Union[Instance, UnionType],
+) -> None:
+ """Apply the Mapped[<type>] annotation and right hand object to a
+ declarative assignment statement.
+
+ This converts a Python declarative class statement such as::
+
+ class User(Base):
+ # ...
+
+ attrname = Column(Integer)
+
+ To one that describes the final Python behavior to Mypy::
+
+ class User(Base):
+ # ...
+
+ attrname : Mapped[Optional[int]] = <meaningless temp node>
+
+ """
+ descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+ left_node = lvalue.node
+
+ inst = Instance(descriptor.node, [python_type_for_type])
+
+ if left_hand_explicit_type is not None:
+ left_node.type = Instance(descriptor.node, [left_hand_explicit_type])
+ else:
+ lvalue.is_inferred_def = False
+ left_node.type = inst
+
+ # so to have it skip the right side totally, we can do this:
+ # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # however, if we instead manufacture a new node that uses the old
+ # one, then we can still get type checking for the call itself,
+ # e.g. the Column, relationship() call, etc.
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
+ # the original right-hand side is maintained so it gets type checked
+ # internally
+ api.add_symbol_table_node("_sa_Mapped", descriptor)
+ column_descriptor = nodes.NameExpr("_sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+ mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
+ orig_call_expr = stmt.rvalue
+ stmt.rvalue = CallExpr(
+ mm,
+ [orig_call_expr],
+ [nodes.ARG_POS],
+ ["arg1"],
+ )
+
+
+def _scan_for_mapped_bases(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ cls_metadata: DeclClassApplied,
+) -> None:
+ """Given a class, iterate through its superclass hierarchy to find
+ all other classes that are considered as ORM-significant.
+
+ Locates non-mapped mixins and scans them for mapped attributes to be
+ applied to subclasses.
+
+ """
+
+ baseclasses = list(cls.info.bases)
+ while baseclasses:
+ base: Instance = baseclasses.pop(0)
+
+ # scan each base for mapped attributes. if they are not already
+ # scanned, that means they are unmapped mixins
+ base_decl_class_applied = (
+ _scan_declarative_assignments_and_apply_types(
+ base.type.defn, api, is_mixin_scan=True
+ )
+ )
+ if base_decl_class_applied is not None:
+ cls_metadata.mapped_mro.append(base)
+ baseclasses.extend(base.type.bases)
+
+
+def _add_additional_orm_attributes(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ cls_metadata: DeclClassApplied,
+) -> None:
+ """Apply __init__, __table__ and other attributes to the mapped class."""
+ if "__init__" not in cls.info.names and cls_metadata.is_mapped:
+ mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
+
+ for mapped_base in cls_metadata.mapped_mro:
+ base_cls_metadata = DeclClassApplied.deserialize(
+ mapped_base.type.metadata["_sa_decl_class_applied"], api
+ )
+ for n, t in base_cls_metadata.mapped_attr_names:
+ mapped_attr_names.setdefault(n, t)
+
+ arguments = []
+ for name, typ in mapped_attr_names.items():
+ if typ is None:
+ typ = AnyType(TypeOfAny.special_form)
+ arguments.append(
+ Argument(
+ variable=Var(name, typ),
+ type_annotation=typ,
+ initializer=TempNode(typ),
+ kind=ARG_NAMED_OPT,
+ )
+ )
+ add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
+
+ if "__table__" not in cls.info.names and cls_metadata.has_table:
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.sql.schema.Table", "__table__"
+ )
+ if cls_metadata.is_mapped:
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
+ )
+
+
+def _apply_placeholder_attr_to_class(
+ api: SemanticAnalyzerPluginInterface,
+ cls: ClassDef,
+ qualified_name: str,
+ attrname: str,
+):
+ sym = api.lookup_fully_qualified_or_none(qualified_name)
+ if sym:
+ assert isinstance(sym.node, TypeInfo)
+ type_ = Instance(sym.node, [])
+ else:
+ type_ = AnyType(TypeOfAny.special_form)
+ var = Var(attrname)
+ var.info = cls.info
+ var.type = type_
+ cls.info.names[attrname] = SymbolTableNode(MDEF, var)
+
+
+def _extract_python_type_from_typeengine(node: TypeInfo) -> Instance:
+ for mr in node.mro:
+ if (
+ mr.bases
+ and mr.bases[-1].type.fullname
+ == "sqlalchemy.sql.type_api.TypeEngine"
+ ):
+ return mr.bases[-1].args[-1]
+ else:
+ assert False, "could not extract Python type from node: %s" % node
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
new file mode 100644
index 000000000..c9d48fcd8
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/names.py
@@ -0,0 +1,194 @@
+# ext/mypy/names.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+
+from mypy.nodes import ClassDef
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import RefExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import TypeAlias
+from mypy.nodes import TypeInfo
+from mypy.nodes import Union
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import UnboundType
+
+from ... import util
+
+COLUMN = util.symbol("COLUMN")
+RELATIONSHIP = util.symbol("RELATIONSHIP")
+REGISTRY = util.symbol("REGISTRY")
+COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
+TYPEENGINE = util.symbol("TYPEENGNE")
+MAPPED = util.symbol("MAPPED")
+DECLARATIVE_BASE = util.symbol("DECLARATIVE_BASE")
+DECLARATIVE_META = util.symbol("DECLARATIVE_META")
+MAPPED_DECORATOR = util.symbol("MAPPED_DECORATOR")
+COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
+SYNONYM_PROPERTY = util.symbol("SYNONYM_PROPERTY")
+COMPOSITE_PROPERTY = util.symbol("COMPOSITE_PROPERTY")
+DECLARED_ATTR = util.symbol("DECLARED_ATTR")
+MAPPER_PROPERTY = util.symbol("MAPPER_PROPERTY")
+
+
+_lookup = {
+ "Column": (
+ COLUMN,
+ {
+ "sqlalchemy.sql.schema.Column",
+ "sqlalchemy.sql.Column",
+ },
+ ),
+ "RelationshipProperty": (
+ RELATIONSHIP,
+ {
+ "sqlalchemy.orm.relationships.RelationshipProperty",
+ "sqlalchemy.orm.RelationshipProperty",
+ },
+ ),
+ "registry": (
+ REGISTRY,
+ {
+ "sqlalchemy.orm.decl_api.registry",
+ "sqlalchemy.orm.registry",
+ },
+ ),
+ "ColumnProperty": (
+ COLUMN_PROPERTY,
+ {
+ "sqlalchemy.orm.properties.ColumnProperty",
+ "sqlalchemy.orm.ColumnProperty",
+ },
+ ),
+ "SynonymProperty": (
+ SYNONYM_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.SynonymProperty",
+ "sqlalchemy.orm.SynonymProperty",
+ },
+ ),
+ "CompositeProperty": (
+ COMPOSITE_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.CompositeProperty",
+ "sqlalchemy.orm.CompositeProperty",
+ },
+ ),
+ "MapperProperty": (
+ MAPPER_PROPERTY,
+ {
+ "sqlalchemy.orm.interfaces.MapperProperty",
+ "sqlalchemy.orm.MapperProperty",
+ },
+ ),
+ "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
+ "Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}),
+ "declarative_base": (
+ DECLARATIVE_BASE,
+ {
+ "sqlalchemy.ext.declarative.declarative_base",
+ "sqlalchemy.orm.declarative_base",
+ "sqlalchemy.orm.decl_api.declarative_base",
+ },
+ ),
+ "DeclarativeMeta": (
+ DECLARATIVE_META,
+ {
+ "sqlalchemy.ext.declarative.DeclarativeMeta",
+ "sqlalchemy.orm.DeclarativeMeta",
+ "sqlalchemy.orm.decl_api.DeclarativeMeta",
+ },
+ ),
+ "mapped": (
+ MAPPED_DECORATOR,
+ {
+ "sqlalchemy.orm.decl_api.registry.mapped",
+ "sqlalchemy.orm.registry.mapped",
+ },
+ ),
+ "declared_attr": (
+ DECLARED_ATTR,
+ {
+ "sqlalchemy.orm.decl_api.declared_attr",
+ "sqlalchemy.orm.declared_attr",
+ },
+ ),
+}
+
+
+def _mro_has_id(mro: List[TypeInfo], type_id: int):
+ for mr in mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def _type_id_for_unbound_type(
+ type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> int:
+ type_id = None
+
+ sym = api.lookup(type_.name, type_)
+ if sym is not None:
+ if isinstance(sym.node, TypeAlias):
+ type_id = _type_id_for_named_node(sym.node.target.type)
+ elif isinstance(sym.node, TypeInfo):
+ type_id = _type_id_for_named_node(sym.node)
+
+ return type_id
+
+
+def _type_id_for_callee(callee: Expression) -> int:
+ if isinstance(callee.node, FuncDef):
+ return _type_id_for_funcdef(callee.node)
+ elif isinstance(callee.node, TypeAlias):
+ type_id = _type_id_for_fullname(callee.node.target.type.fullname)
+ elif isinstance(callee.node, TypeInfo):
+ type_id = _type_id_for_named_node(callee)
+ else:
+ type_id = None
+ return type_id
+
+
+def _type_id_for_funcdef(node: FuncDef) -> int:
+ if hasattr(node.type.ret_type, "type"):
+ type_id = _type_id_for_fullname(node.type.ret_type.type.fullname)
+ else:
+ type_id = None
+ return type_id
+
+
+def _type_id_for_named_node(node: Union[RefExpr, SymbolNode]) -> int:
+ type_id, fullnames = _lookup.get(node.name, (None, None))
+
+ if type_id is None:
+ return None
+
+ elif node.fullname in fullnames:
+ return type_id
+ else:
+ return None
+
+
+def _type_id_for_fullname(fullname: str) -> int:
+ tokens = fullname.split(".")
+ immediate = tokens[-1]
+
+ type_id, fullnames = _lookup.get(immediate, (None, None))
+
+ if type_id is None:
+ return None
+
+ elif fullname in fullnames:
+ return type_id
+ else:
+ return None
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py
new file mode 100644
index 000000000..9fcd09b1e
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/plugin.py
@@ -0,0 +1,215 @@
+# ext/mypy/plugin.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""
+Mypy plugin for SQLAlchemy ORM.
+
+"""
+from typing import List
+from typing import Tuple
+from typing import Type
+
+from mypy import nodes
+from mypy.mro import calculate_mro
+from mypy.mro import MroError
+from mypy.nodes import Block
+from mypy.nodes import ClassDef
+from mypy.nodes import GDEF
+from mypy.nodes import MypyFile
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolTable
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import AttributeContext
+from mypy.plugin import Callable
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import Optional
+from mypy.plugin import Plugin
+from mypy.types import Instance
+
+from . import decl_class
+from . import names
+from . import util
+
+
+class CustomPlugin(Plugin):
+ def get_dynamic_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[DynamicClassDefContext], None]]:
+ if names._type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
+ return _dynamic_class_hook
+ return None
+
+ def get_base_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+
+ # kind of a strange relationship between get_metaclass_hook()
+ # and get_base_class_hook(). the former doesn't fire off for
+ # subclasses. but then you can just check it here from the "base"
+ # and get the same effect.
+ sym = self.lookup_fully_qualified(fullname)
+ if (
+ sym
+ and isinstance(sym.node, TypeInfo)
+ and sym.node.metaclass_type
+ and names._type_id_for_named_node(sym.node.metaclass_type.type)
+ is names.DECLARATIVE_META
+ ):
+ return _base_cls_hook
+ return None
+
+ def get_class_decorator_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+
+ sym = self.lookup_fully_qualified(fullname)
+
+ if (
+ sym is not None
+ and names._type_id_for_named_node(sym.node)
+ is names.MAPPED_DECORATOR
+ ):
+ return _cls_decorator_hook
+ return None
+
+ def get_customize_class_mro_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ return _fill_in_decorators
+
+ def get_attribute_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[AttributeContext], Type]]:
+ if fullname.startswith(
+ "sqlalchemy.orm.attributes.QueryableAttribute."
+ ):
+ return _queryable_getattr_hook
+ return None
+
+ def get_additional_deps(
+ self, file: MypyFile
+ ) -> List[Tuple[int, str, int]]:
+ return [
+ (10, "sqlalchemy.orm.attributes", -1),
+ (10, "sqlalchemy.orm.decl_api", -1),
+ ]
+
+
+def plugin(version: str):
+ return CustomPlugin
+
+
+def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
+ # how do I....tell it it has no attribute of a certain name?
+ # can't find any Type that seems to match that
+ return ctx.default_attr_type
+
+
+def _fill_in_decorators(ctx: ClassDefContext) -> None:
+ for decorator in ctx.cls.decorators:
+ # set the ".fullname" attribute of a class decorator
+ # that is a MemberExpr. This causes the logic in
+ # semanal.py->apply_class_plugin_hooks to invoke the
+ # get_class_decorator_hook for our "registry.map_class()" method.
+ # this seems like a bug in mypy that these decorators are otherwise
+ # skipped.
+ if (
+ isinstance(decorator, nodes.MemberExpr)
+ and decorator.name == "mapped"
+ ):
+
+ sym = ctx.api.lookup(
+ decorator.expr.name, decorator, suppress_errors=True
+ )
+ if sym:
+ if sym.node.type and hasattr(sym.node.type, "type"):
+ decorator.fullname = (
+ f"{sym.node.type.type.fullname}.{decorator.name}"
+ )
+ else:
+ # if the registry is in the same file as where the
+ # decorator is used, it might not have semantic
+ # symbols applied and we can't get a fully qualified
+ # name or an inferred type, so we are actually going to
+ # flag an error in this case that they need to annotate
+ # it. The "registry" is declared just
+ # once (or few times), so they have to just not use
+ # type inference for its assignment in this one case.
+ util.fail(
+ ctx.api,
+ "Class decorator called mapped(), but we can't "
+ "tell if it's from an ORM registry. Please "
+ "annotate the registry assignment, e.g. "
+ "my_registry: registry = registry()",
+ sym.node,
+ )
+
+
+def _cls_metadata_hook(ctx: ClassDefContext) -> None:
+ decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _base_cls_hook(ctx: ClassDefContext) -> None:
+ decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _cls_decorator_hook(ctx: ClassDefContext) -> None:
+ assert isinstance(ctx.reason, nodes.MemberExpr)
+ expr = ctx.reason.expr
+ assert names._type_id_for_named_node(expr.node.type.type) is names.REGISTRY
+
+ decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
+ """Generate a declarative Base class when the declarative_base() function
+ is encountered."""
+
+ cls = ClassDef(ctx.name, Block([]))
+ cls.fullname = ctx.api.qualified_name(ctx.name)
+
+ declarative_meta_sym: SymbolTableNode = ctx.api.modules[
+ "sqlalchemy.orm.decl_api"
+ ].names["DeclarativeMeta"]
+ declarative_meta_typeinfo: TypeInfo = declarative_meta_sym.node
+
+ declarative_meta_name: NameExpr = NameExpr("DeclarativeMeta")
+ declarative_meta_name.kind = GDEF
+ declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta"
+ declarative_meta_name.node = declarative_meta_typeinfo
+
+ cls.metaclass = declarative_meta_name
+
+ declarative_meta_instance = Instance(declarative_meta_typeinfo, [])
+
+ info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
+ info.declared_metaclass = info.metaclass_type = declarative_meta_instance
+ cls.info = info
+
+ cls_arg = util._get_callexpr_kwarg(ctx.call, "cls")
+ if cls_arg is not None:
+ decl_class._scan_declarative_assignments_and_apply_types(
+ cls_arg.node.defn, ctx.api, is_mixin_scan=True
+ )
+ info.bases = [Instance(cls_arg.node, [])]
+ else:
+ obj = ctx.api.builtin_type("builtins.object")
+
+ info.bases = [obj]
+
+ try:
+ calculate_mro(info)
+ except MroError:
+ util.fail(
+ ctx.api, "Not able to calculate MRO for declarative base", ctx.call
+ )
+ info.bases = [obj]
+ info.fallback_to_any = True
+
+ ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
new file mode 100644
index 000000000..e7178a885
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+from mypy.nodes import CallExpr
+from mypy.nodes import Context
+from mypy.nodes import IfStmt
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolTableNode
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import Type
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context):
+ msg = "[SQLAlchemy Mypy plugin] %s" % msg
+ return api.fail(msg, ctx)
+
+
+def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]:
+ try:
+ arg_idx = callexpr.arg_names.index(name)
+ except ValueError:
+ return None
+
+ return callexpr.args[arg_idx]
+
+
+def _flatten_typechecking(stmts):
+ for stmt in stmts:
+ if isinstance(stmt, IfStmt) and stmt.expr[0].name == "TYPE_CHECKING":
+ for substmt in stmt.body[0].body:
+ yield substmt
+ else:
+ yield stmt
+
+
+def _unbound_to_instance(
+ api: SemanticAnalyzerPluginInterface, typ: UnboundType
+) -> Type:
+ """Take the UnboundType that we seem to get as the ret_type from a FuncDef
+ and convert it into an Instance/TypeInfo kind of structure that seems
+ to work as the left-hand type of an AssignmentStatement.
+
+ """
+
+ if not isinstance(typ, UnboundType):
+ return typ
+
+ # TODO: figure out a more robust way to check this. The node is some
+ # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
+ # but I cant figure out how to get them to match up
+ if typ.name == "Optional":
+ # convert from "Optional?" to the more familiar
+ # UnionType[..., NoneType()]
+ return _unbound_to_instance(
+ api,
+ UnionType(
+ [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ + [NoneType()]
+ ),
+ )
+
+ node = api.lookup(typ.name, typ)
+
+ if node is not None and isinstance(node, SymbolTableNode):
+ bound_type = node.node
+
+ return Instance(
+ bound_type,
+ [
+ _unbound_to_instance(api, arg)
+ if isinstance(arg, UnboundType)
+ else arg
+ for arg in typ.args
+ ],
+ )
+ else:
+ return typ