diff options
author | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-01-10 17:27:39 +0200 |
---|---|---|
committer | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-01-10 17:27:39 +0200 |
commit | 978e1de2a85a449230cbfc6340fd5b4bba294132 (patch) | |
tree | 05ce5657c831ee9e170388a78de2081ae2b63bc5 | |
parent | e55961e43d3b933b325a5e8ad98baf6611396634 (diff) | |
download | astroid-978e1de2a85a449230cbfc6340fd5b4bba294132.tar.gz |
Improve the mro resolution.
First, validate that a class has duplicate entries in the mro or not. If it has,
there's no way to determine a stable hierarchy, so raise an error when this situation
occurs. Second, use a better algorithm for our purposes than .ancestors to infer
the base classes, since .ancestors will return all the inferred objects for a base,
which will also include base classes in the case of superclasses, which will result
in inconsistent hierarchy again from the point of view of _c3_merge.
-rw-r--r-- | astroid/scoped_nodes.py | 49 | ||||
-rw-r--r-- | astroid/tests/unittest_scoped_nodes.py | 17 |
2 files changed, 64 insertions, 2 deletions
diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index 372cd9c..e342ea4 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -78,6 +78,13 @@ def _c3_merge(sequences): del seq[0] +def _verify_duplicates_mro(sequences): + for sequence in sequences: + names = [node.qname() for node in sequence] + if len(names) != len(set(names)): + raise ResolveError('Duplicates found in the mro.') + + def remove_nodes(func, cls): def wrapper(*args, **kwargs): nodes = [n for n in func(*args, **kwargs) if not isinstance(n, cls)] @@ -1393,6 +1400,41 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): return None return [first] + list(slots) + def _inferred_bases(self, recurs=True, context=None): + # TODO(cpopa): really similar with .ancestors, + # but the difference is when one base is inferred, + # only the first object is wanted. That's because + # we aren't interested in superclasses, as in the following + # example: + # + # class SomeSuperClass(object): pass + # class SomeClass(SomeSuperClass): pass + # class Test(SomeClass): pass + # + # Inferring SomeClass from the Test's bases will give + # us both SomeClass and SomeSuperClass, but we are interested + # only in SomeClass. + + if context is None: + context = InferenceContext() + if sys.version_info[0] >= 3: + if not self.bases and self.qname() != 'builtins.object': + yield builtin_lookup("object")[1][0] + return + + for stmt in self.bases: + try: + baseobj = next(stmt.infer(context=context)) + except InferenceError: + # XXX log error ? + continue + if isinstance(baseobj, Instance): + baseobj = baseobj._proxied + if not isinstance(baseobj, Class): + continue + if not baseobj.hide: + yield baseobj + def mro(self, context=None): """Get the method resolution order, using C3 linearization. @@ -1404,5 +1446,8 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): raise NotImplementedError( "Could not obtain mro for old-style classes.") - bases = list(self.ancestors(recurs=False, context=context)) - return _c3_merge([[self]] + [base.mro() for base in bases] + [bases]) + bases = list(self._inferred_bases(context=context)) + unmerged_mro = [[self]] + [base.mro() for base in bases] + [bases] + + _verify_duplicates_mro(unmerged_mro) + return _c3_merge(unmerged_mro) diff --git a/astroid/tests/unittest_scoped_nodes.py b/astroid/tests/unittest_scoped_nodes.py index 9b62213..e4befa9 100644 --- a/astroid/tests/unittest_scoped_nodes.py +++ b/astroid/tests/unittest_scoped_nodes.py @@ -1131,6 +1131,19 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase): class PedalWheelBoat(EngineLess, WheelBoat): pass class SmallCatamaran(SmallMultihull): pass class Pedalo(PedalWheelBoat, SmallCatamaran): pass + + class OuterA(object): + class Inner(object): + pass + class OuterB(OuterA): + class Inner(OuterA.Inner): + pass + class OuterC(OuterA): + class Inner(OuterA.Inner): + pass + class OuterD(OuterC): + class Inner(OuterC.Inner, OuterB.Inner): + pass """) self.assertEqualMro(astroid['D'], ['D', 'dict', 'C', 'object']) @@ -1154,6 +1167,10 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase): ["Pedalo", "PedalWheelBoat", "EngineLess", "SmallCatamaran", "SmallMultihull", "DayBoat", "WheelBoat", "Boat", "object"]) + self.assertEqualMro( + astroid['OuterD']['Inner'], + ['Inner', 'Inner', 'Inner', 'Inner', 'object']) + if __name__ == '__main__': unittest.main() |