summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMarc Mueller <30130371+cdce8p@users.noreply.github.com>2021-04-10 14:49:18 +0200
committerGitHub <noreply@github.com>2021-04-10 14:49:18 +0200
commitcd0f896672049493b2e3cfc87a327c871f8dd329 (patch)
tree8e884f5a77f6335bc421dea202e08488b3b1aeda /tests
parent31a731a7dc04507b6278dae66dd4ef1521881d96 (diff)
downloadastroid-git-cd0f896672049493b2e3cfc87a327c871f8dd329.tar.gz
Modify infernce tip for typing.Generic and typing.Annotated with ``__class_getitem__`` (#931)
* Modify infernce tip for typing.Generic and typing.Annotated with __class_getitem__ * Fix issue with slots caching * Clean typing.Generic from mro
Diffstat (limited to 'tests')
-rw-r--r--tests/unittest_brain.py50
-rw-r--r--tests/unittest_scoped_nodes.py139
2 files changed, 189 insertions, 0 deletions
diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py
index f15c4a1c..a217f223 100644
--- a/tests/unittest_brain.py
+++ b/tests/unittest_brain.py
@@ -1361,6 +1361,56 @@ class TypingBrain(unittest.TestCase):
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())
+ @test_utils.require_version(minver="3.7")
+ def test_typing_generic_subscriptable(self):
+ """Test typing.Generic is subscriptable with __class_getitem__ (added in PY37)"""
+ node = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T = TypeVar('T')
+ Generic[T]
+ """
+ )
+ inferred = next(node.infer())
+ assert isinstance(inferred, nodes.ClassDef)
+ assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)
+
+ @test_utils.require_version(minver="3.9")
+ def test_typing_annotated_subscriptable(self):
+ """Test typing.Annotated is subscriptable with __class_getitem__"""
+ node = builder.extract_node(
+ """
+ import typing
+ typing.Annotated[str, "data"]
+ """
+ )
+ inferred = next(node.infer())
+ assert isinstance(inferred, nodes.ClassDef)
+ assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)
+
+ @test_utils.require_version(minver="3.7")
+ def test_typing_generic_slots(self):
+ """Test cache reset for slots if Generic subscript is inferred."""
+ node = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T = TypeVar('T')
+ class A(Generic[T]):
+ __slots__ = ['value']
+ def __init__(self, value):
+ self.value = value
+ """
+ )
+ inferred = next(node.infer())
+ assert len(inferred.slots()) == 0
+ # Only after the subscript base is inferred and the inference tip applied,
+ # will slots contain the correct value
+ next(node.bases[0].infer())
+ slots = inferred.slots()
+ assert len(slots) == 1
+ assert isinstance(slots[0], nodes.Const)
+ assert slots[0].value == "value"
+
def test_has_dunder_args(self):
ast_node = builder.extract_node(
"""
diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py
index a0f882ae..b98cd8f8 100644
--- a/tests/unittest_scoped_nodes.py
+++ b/tests/unittest_scoped_nodes.py
@@ -1275,6 +1275,9 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase):
def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)
+ def assertEqualMroQName(self, klass, expected_mro):
+ self.assertEqual([member.qname() for member in klass.mro()], expected_mro)
+
@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
astroid = builder.parse(
@@ -1438,6 +1441,142 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_1(self):
+ cls = builder.extract_node(
+ """
+ import typing
+ T = typing.TypeVar('T')
+ class A(typing.Generic[T]): ...
+ class B: ...
+ class C(A[T], B): ...
+ """
+ )
+ self.assertEqualMroQName(
+ cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"]
+ )
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_2(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T = TypeVar('T')
+ class A: ...
+ class B(Generic[T]): ...
+ class C(Generic[T], A, B[T]): ...
+ """
+ )
+ self.assertEqualMroQName(
+ cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
+ )
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_3(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T = TypeVar('T')
+ class A: ...
+ class B(A, Generic[T]): ...
+ class C(Generic[T]): ...
+ class D(B[T], C[T], Generic[T]): ...
+ """
+ )
+ self.assertEqualMroQName(
+ cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"]
+ )
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_4(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T = TypeVar('T')
+ class A: ...
+ class B(Generic[T]): ...
+ class C(A, Generic[T], B[T]): ...
+ """
+ )
+ self.assertEqualMroQName(
+ cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
+ )
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_5(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T1 = TypeVar('T1')
+ T2 = TypeVar('T2')
+ class A(Generic[T1]): ...
+ class B(Generic[T2]): ...
+ class C(A[T1], B[T2]): ...
+ """
+ )
+ self.assertEqualMroQName(
+ cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
+ )
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_6(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic as TGeneric, TypeVar
+ T = TypeVar('T')
+ class Generic: ...
+ class A(Generic): ...
+ class B(TGeneric[T]): ...
+ class C(A, B[T]): ...
+ """
+ )
+ self.assertEqualMroQName(
+ cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"]
+ )
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_7(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T = TypeVar('T')
+ class A(): ...
+ class B(Generic[T]): ...
+ class C(A, B[T]): ...
+ class D: ...
+ class E(C[str], D): ...
+ """
+ )
+ self.assertEqualMroQName(
+ cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"]
+ )
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_error_1(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T1 = TypeVar('T1')
+ T2 = TypeVar('T2')
+ class A(Generic[T1], Generic[T2]): ...
+ """
+ )
+ with self.assertRaises(DuplicateBasesError) as ex:
+ cls.mro()
+
+ @test_utils.require_version(minver="3.7")
+ def test_mro_generic_error_2(self):
+ cls = builder.extract_node(
+ """
+ from typing import Generic, TypeVar
+ T = TypeVar('T')
+ class A(Generic[T]): ...
+ class B(A[T], A[T]): ...
+ """
+ )
+ with self.assertRaises(DuplicateBasesError) as ex:
+ cls.mro()
+
def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""