summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2015-01-10 17:27:39 +0200
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2015-01-10 17:27:39 +0200
commit978e1de2a85a449230cbfc6340fd5b4bba294132 (patch)
tree05ce5657c831ee9e170388a78de2081ae2b63bc5
parente55961e43d3b933b325a5e8ad98baf6611396634 (diff)
downloadastroid-978e1de2a85a449230cbfc6340fd5b4bba294132.tar.gz
Improve the mro resolution.
First, validate that a class has duplicate entries in the mro or not. If it has, there's no way to determine a stable hierarchy, so raise an error when this situation occurs. Second, use a better algorithm for our purposes than .ancestors to infer the base classes, since .ancestors will return all the inferred objects for a base, which will also include base classes in the case of superclasses, which will result in inconsistent hierarchy again from the point of view of _c3_merge.
-rw-r--r--astroid/scoped_nodes.py49
-rw-r--r--astroid/tests/unittest_scoped_nodes.py17
2 files changed, 64 insertions, 2 deletions
diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py
index 372cd9c..e342ea4 100644
--- a/astroid/scoped_nodes.py
+++ b/astroid/scoped_nodes.py
@@ -78,6 +78,13 @@ def _c3_merge(sequences):
del seq[0]
+def _verify_duplicates_mro(sequences):
+ for sequence in sequences:
+ names = [node.qname() for node in sequence]
+ if len(names) != len(set(names)):
+ raise ResolveError('Duplicates found in the mro.')
+
+
def remove_nodes(func, cls):
def wrapper(*args, **kwargs):
nodes = [n for n in func(*args, **kwargs) if not isinstance(n, cls)]
@@ -1393,6 +1400,41 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
return None
return [first] + list(slots)
+ def _inferred_bases(self, recurs=True, context=None):
+ # TODO(cpopa): really similar with .ancestors,
+ # but the difference is when one base is inferred,
+ # only the first object is wanted. That's because
+ # we aren't interested in superclasses, as in the following
+ # example:
+ #
+ # class SomeSuperClass(object): pass
+ # class SomeClass(SomeSuperClass): pass
+ # class Test(SomeClass): pass
+ #
+ # Inferring SomeClass from the Test's bases will give
+ # us both SomeClass and SomeSuperClass, but we are interested
+ # only in SomeClass.
+
+ if context is None:
+ context = InferenceContext()
+ if sys.version_info[0] >= 3:
+ if not self.bases and self.qname() != 'builtins.object':
+ yield builtin_lookup("object")[1][0]
+ return
+
+ for stmt in self.bases:
+ try:
+ baseobj = next(stmt.infer(context=context))
+ except InferenceError:
+ # XXX log error ?
+ continue
+ if isinstance(baseobj, Instance):
+ baseobj = baseobj._proxied
+ if not isinstance(baseobj, Class):
+ continue
+ if not baseobj.hide:
+ yield baseobj
+
def mro(self, context=None):
"""Get the method resolution order, using C3 linearization.
@@ -1404,5 +1446,8 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
raise NotImplementedError(
"Could not obtain mro for old-style classes.")
- bases = list(self.ancestors(recurs=False, context=context))
- return _c3_merge([[self]] + [base.mro() for base in bases] + [bases])
+ bases = list(self._inferred_bases(context=context))
+ unmerged_mro = [[self]] + [base.mro() for base in bases] + [bases]
+
+ _verify_duplicates_mro(unmerged_mro)
+ return _c3_merge(unmerged_mro)
diff --git a/astroid/tests/unittest_scoped_nodes.py b/astroid/tests/unittest_scoped_nodes.py
index 9b62213..e4befa9 100644
--- a/astroid/tests/unittest_scoped_nodes.py
+++ b/astroid/tests/unittest_scoped_nodes.py
@@ -1131,6 +1131,19 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase):
class PedalWheelBoat(EngineLess, WheelBoat): pass
class SmallCatamaran(SmallMultihull): pass
class Pedalo(PedalWheelBoat, SmallCatamaran): pass
+
+ class OuterA(object):
+ class Inner(object):
+ pass
+ class OuterB(OuterA):
+ class Inner(OuterA.Inner):
+ pass
+ class OuterC(OuterA):
+ class Inner(OuterA.Inner):
+ pass
+ class OuterD(OuterC):
+ class Inner(OuterC.Inner, OuterB.Inner):
+ pass
""")
self.assertEqualMro(astroid['D'], ['D', 'dict', 'C', 'object'])
@@ -1154,6 +1167,10 @@ class ClassNodeTest(ModuleLoader, unittest.TestCase):
["Pedalo", "PedalWheelBoat", "EngineLess", "SmallCatamaran",
"SmallMultihull", "DayBoat", "WheelBoat", "Boat", "object"])
+ self.assertEqualMro(
+ astroid['OuterD']['Inner'],
+ ['Inner', 'Inner', 'Inner', 'Inner', 'object'])
+
if __name__ == '__main__':
unittest.main()