diff options
| author | Marc Mueller <30130371+cdce8p@users.noreply.github.com> | 2021-04-10 14:49:18 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-04-10 14:49:18 +0200 |
| commit | cd0f896672049493b2e3cfc87a327c871f8dd329 (patch) | |
| tree | 8e884f5a77f6335bc421dea202e08488b3b1aeda /tests | |
| parent | 31a731a7dc04507b6278dae66dd4ef1521881d96 (diff) | |
| download | astroid-git-cd0f896672049493b2e3cfc87a327c871f8dd329.tar.gz | |
Modify infernce tip for typing.Generic and typing.Annotated with ``__class_getitem__`` (#931)
* Modify infernce tip for typing.Generic and typing.Annotated with __class_getitem__
* Fix issue with slots caching
* Clean typing.Generic from mro
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/unittest_brain.py | 50 | ||||
| -rw-r--r-- | tests/unittest_scoped_nodes.py | 139 |
2 files changed, 189 insertions, 0 deletions
diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index f15c4a1c..a217f223 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1361,6 +1361,56 @@ class TypingBrain(unittest.TestCase): inferred = next(node.infer()) self.assertIsInstance(inferred, nodes.ClassDef, node.as_string()) + @test_utils.require_version(minver="3.7") + def test_typing_generic_subscriptable(self): + """Test typing.Generic is subscriptable with __class_getitem__ (added in PY37)""" + node = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + Generic[T] + """ + ) + inferred = next(node.infer()) + assert isinstance(inferred, nodes.ClassDef) + assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef) + + @test_utils.require_version(minver="3.9") + def test_typing_annotated_subscriptable(self): + """Test typing.Annotated is subscriptable with __class_getitem__""" + node = builder.extract_node( + """ + import typing + typing.Annotated[str, "data"] + """ + ) + inferred = next(node.infer()) + assert isinstance(inferred, nodes.ClassDef) + assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef) + + @test_utils.require_version(minver="3.7") + def test_typing_generic_slots(self): + """Test cache reset for slots if Generic subscript is inferred.""" + node = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(Generic[T]): + __slots__ = ['value'] + def __init__(self, value): + self.value = value + """ + ) + inferred = next(node.infer()) + assert len(inferred.slots()) == 0 + # Only after the subscript base is inferred and the inference tip applied, + # will slots contain the correct value + next(node.bases[0].infer()) + slots = inferred.slots() + assert len(slots) == 1 + assert isinstance(slots[0], nodes.Const) + assert slots[0].value == "value" + def test_has_dunder_args(self): ast_node = builder.extract_node( """ diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py index a0f882ae..b98cd8f8 100644 --- a/tests/unittest_scoped_nodes.py +++ b/tests/unittest_scoped_nodes.py @@ -1275,6 +1275,9 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase): def assertEqualMro(self, klass, expected_mro): self.assertEqual([member.name for member in klass.mro()], expected_mro) + def assertEqualMroQName(self, klass, expected_mro): + self.assertEqual([member.qname() for member in klass.mro()], expected_mro) + @unittest.skipUnless(HAS_SIX, "These tests require the six library") def test_with_metaclass_mro(self): astroid = builder.parse( @@ -1438,6 +1441,142 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase): ) self.assertEqualMro(cls, ["C", "A", "B", "object"]) + @test_utils.require_version(minver="3.7") + def test_mro_generic_1(self): + cls = builder.extract_node( + """ + import typing + T = typing.TypeVar('T') + class A(typing.Generic[T]): ... + class B: ... + class C(A[T], B): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_2(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(Generic[T]): ... + class C(Generic[T], A, B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_3(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(A, Generic[T]): ... + class C(Generic[T]): ... + class D(B[T], C[T], Generic[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_4(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(Generic[T]): ... + class C(A, Generic[T], B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_5(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T1 = TypeVar('T1') + T2 = TypeVar('T2') + class A(Generic[T1]): ... + class B(Generic[T2]): ... + class C(A[T1], B[T2]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_6(self): + cls = builder.extract_node( + """ + from typing import Generic as TGeneric, TypeVar + T = TypeVar('T') + class Generic: ... + class A(Generic): ... + class B(TGeneric[T]): ... + class C(A, B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_7(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(): ... + class B(Generic[T]): ... + class C(A, B[T]): ... + class D: ... + class E(C[str], D): ... + """ + ) + self.assertEqualMroQName( + cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_error_1(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T1 = TypeVar('T1') + T2 = TypeVar('T2') + class A(Generic[T1], Generic[T2]): ... + """ + ) + with self.assertRaises(DuplicateBasesError) as ex: + cls.mro() + + @test_utils.require_version(minver="3.7") + def test_mro_generic_error_2(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(Generic[T]): ... + class B(A[T], A[T]): ... + """ + ) + with self.assertRaises(DuplicateBasesError) as ex: + cls.mro() + def test_generator_from_infer_call_result_parent(self): func = builder.extract_node( """ |
