summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJan-Jaap Driessen <jdriessen@minddistrict.com>2020-09-25 11:29:29 +0200
committerJan-Jaap Driessen <jdriessen@minddistrict.com>2020-09-25 11:29:29 +0200
commit1025519cb68ad8fdae379daa6de8d9ec427456cd (patch)
tree0a0ed8306ca64b5c92b8051d09204a7b8d3dab83 /src
parent255db9d3c0ea7a5016e3c87644c23fca6eda9a7d (diff)
downloadzope-interface-1025519cb68ad8fdae379daa6de8d9ec427456cd.tar.gz
When an invariant is defined in an interface, it's found by
`validateInvariants` in all interfaces inheriting from that interface. Make sure to call each invariant only once when validating invariants.
Diffstat (limited to 'src')
-rw-r--r--src/zope/interface/interface.py9
-rw-r--r--src/zope/interface/tests/test_interface.py14
2 files changed, 21 insertions, 2 deletions
diff --git a/src/zope/interface/interface.py b/src/zope/interface/interface.py
index f819441..9b6b752 100644
--- a/src/zope/interface/interface.py
+++ b/src/zope/interface/interface.py
@@ -875,9 +875,14 @@ class InterfaceClass(_InterfaceClassBase):
def queryDescriptionFor(self, name, default=None):
return self.get(name, default)
- def validateInvariants(self, obj, errors=None):
+ def validateInvariants(self, obj, errors=None, seen=None):
"""validate object to defined invariants."""
+ if seen is None:
+ seen = set()
for call in self.queryTaggedValue('invariants', []):
+ if call in seen:
+ continue
+ seen.add(call)
try:
call(obj)
except Invalid as e:
@@ -886,7 +891,7 @@ class InterfaceClass(_InterfaceClassBase):
errors.append(e)
for base in self.__bases__:
try:
- base.validateInvariants(obj, errors)
+ base.validateInvariants(obj, errors, seen=seen)
except Invalid:
if errors is None:
raise
diff --git a/src/zope/interface/tests/test_interface.py b/src/zope/interface/tests/test_interface.py
index 036e858..9dc2aff 100644
--- a/src/zope/interface/tests/test_interface.py
+++ b/src/zope/interface/tests/test_interface.py
@@ -1014,6 +1014,20 @@ class InterfaceClassTests(unittest.TestCase):
self.assertEqual(len(_errors), 1)
self.assertTrue(isinstance(_errors[0], Invalid))
+ def test_validateInvariants_inherited_not_called_multiple_times(self):
+ _passable_called_with = []
+
+ def _passable(*args, **kw):
+ _passable_called_with.append((args, kw))
+ return True
+
+ obj = object()
+ base = self._makeOne('IBase')
+ base.setTaggedValue('invariants', [_passable])
+ derived = self._makeOne('IDerived', (base,))
+ derived.validateInvariants(obj)
+ self.assertEqual(1, len(_passable_called_with))
+
def test___reduce__(self):
iface = self._makeOne('PickleMe')
self.assertEqual(iface.__reduce__(), 'PickleMe')