summaryrefslogtreecommitdiff
path: root/mox.py
diff options
context:
space:
mode:
Diffstat (limited to 'mox.py')
-rwxr-xr-xmox.py505
1 files changed, 459 insertions, 46 deletions
diff --git a/mox.py b/mox.py
index 8bb2099..2ffdf90 100755
--- a/mox.py
+++ b/mox.py
@@ -130,7 +130,7 @@ class UnexpectedMethodCallError(Error):
diff = differ.compare(str(unexpected_method).splitlines(True),
str(expected).splitlines(True))
self._str = ("Unexpected method call. unexpected:- expected:+\n%s"
- % ("\n".join(diff),))
+ % ("\n".join(line.rstrip() for line in diff),))
def __str__(self):
return self._str
@@ -156,13 +156,86 @@ class UnknownMethodCallError(Error):
self._unknown_method_name
+class PrivateAttributeError(Error):
+ """
+ Raised if a MockObject is passed a private additional attribute name.
+ """
+
+ def __init__(self, attr):
+ Error.__init__(self)
+ self._attr = attr
+
+ def __str__(self):
+ return ("Attribute '%s' is private and should not be available in a mock "
+ "object." % attr)
+
+
+class ExpectedMockCreationError(Error):
+ """Raised if mocks should have been created by StubOutClassWithMocks."""
+
+ def __init__(self, expected_mocks):
+ """Init exception.
+
+ Args:
+ # expected_mocks: A sequence of MockObjects that should have been
+ # created
+
+ Raises:
+ ValueError: if expected_mocks contains no methods.
+ """
+
+ if not expected_mocks:
+ raise ValueError("There must be at least one expected method")
+ Error.__init__(self)
+ self._expected_mocks = expected_mocks
+
+ def __str__(self):
+ mocks = "\n".join(["%3d. %s" % (i, m)
+ for i, m in enumerate(self._expected_mocks)])
+ return "Verify: Expected mocks never created:\n%s" % (mocks,)
+
+
+class UnexpectedMockCreationError(Error):
+ """Raised if too many mocks were created by StubOutClassWithMocks."""
+
+ def __init__(self, instance, *params, **named_params):
+ """Init exception.
+
+ Args:
+ # instance: the type of obejct that was created
+ # params: parameters given during instantiation
+ # named_params: named parameters given during instantiation
+ """
+
+ Error.__init__(self)
+ self._instance = instance
+ self._params = params
+ self._named_params = named_params
+
+ def __str__(self):
+ args = ", ".join(["%s" % v for i, v in enumerate(self._params)])
+ error = "Unexpected mock creation: %s(%s" % (self._instance, args)
+
+ if self._named_params:
+ error += ", " + ", ".join(["%s=%s" % (k, v) for k, v in
+ self._named_params.iteritems()])
+
+ error += ")"
+ return error
+
+
class Mox(object):
"""Mox: a factory for creating mock objects."""
# A list of types that should be stubbed out with MockObjects (as
# opposed to MockAnythings).
- _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
- types.ObjectType, types.TypeType]
+ _USE_MOCK_OBJECT = [types.ClassType, types.FunctionType, types.InstanceType,
+ types.ModuleType, types.ObjectType, types.TypeType,
+ types.MethodType, types.UnboundMethodType,
+ ]
+
+ # A list of types that may be stubbed out with a MockObjectFactory.
+ _USE_MOCK_FACTORY = [types.ClassType, types.ObjectType, types.TypeType]
def __init__(self):
"""Initialize a new Mox."""
@@ -170,18 +243,21 @@ class Mox(object):
self._mock_objects = []
self.stubs = stubout.StubOutForTesting()
- def CreateMock(self, class_to_mock):
+ def CreateMock(self, class_to_mock, attrs=None):
"""Create a new mock object.
Args:
# class_to_mock: the class to be mocked
class_to_mock: class
+ attrs: dict of attribute names to values that will be set on the mock
+ object. Only public attributes may be set.
Returns:
MockObject that can be used as the class_to_mock would be.
"""
-
- new_mock = MockObject(class_to_mock)
+ if attrs is None:
+ attrs = {}
+ new_mock = MockObject(class_to_mock, attrs=attrs)
self._mock_objects.append(new_mock)
return new_mock
@@ -232,19 +308,72 @@ class Mox(object):
"""
attr_to_replace = getattr(obj, attr_name)
+ attr_type = type(attr_to_replace)
- # Check for a MockAnything. This could cause confusing problems later on.
- if attr_to_replace == MockAnything():
+ if attr_type == MockAnything or attr_type == MockObject:
raise TypeError('Cannot mock a MockAnything! Did you remember to '
'call UnsetStubs in your previous test?')
- if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
+ if attr_type in self._USE_MOCK_OBJECT and not use_mock_anything:
stub = self.CreateMock(attr_to_replace)
else:
stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace)
+ stub.__name__ = attr_name
self.stubs.Set(obj, attr_name, stub)
+ def StubOutClassWithMocks(self, obj, attr_name):
+ """Replace a class with a "mock factory" that will create mock objects.
+
+ This is useful if the code-under-test directly instantiates
+ dependencies. Previously some boilder plate was necessary to
+ create a mock that would act as a factory. Using
+ StubOutClassWithMocks, once you've stubbed out the class you may
+ use the stubbed class as you would any other mock created by mox:
+ during the record phase, new mock instances will be created, and
+ during replay, the recorded mocks will be returned.
+
+ In replay mode
+
+ # Example using StubOutWithMock (the old, clunky way):
+
+ mock1 = mox.CreateMock(my_import.FooClass)
+ mock2 = mox.CreateMock(my_import.FooClass)
+ foo_factory = mox.StubOutWithMock(my_import, 'FooClass',
+ use_mock_anything=True)
+ foo_factory(1, 2).AndReturn(mock1)
+ foo_factory(9, 10).AndReturn(mock2)
+ mox.ReplayAll()
+
+ my_import.FooClass(1, 2) # Returns mock1 again.
+ my_import.FooClass(9, 10) # Returns mock2 again.
+ mox.VerifyAll()
+
+ # Example using StubOutClassWithMocks:
+
+ mox.StubOutClassWithMocks(my_import, 'FooClass')
+ mock1 = my_import.FooClass(1, 2) # Returns a new mock of FooClass
+ mock2 = my_import.FooClass(9, 10) # Returns another mock instance
+ mox.ReplayAll()
+
+ my_import.FooClass(1, 2) # Returns mock1 again.
+ my_import.FooClass(9, 10) # Returns mock2 again.
+ mox.VerifyAll()
+ """
+ attr_to_replace = getattr(obj, attr_name)
+ attr_type = type(attr_to_replace)
+
+ if attr_type == MockAnything or attr_type == MockObject:
+ raise TypeError('Cannot mock a MockAnything! Did you remember to '
+ 'call UnsetStubs in your previous test?')
+
+ if attr_type not in self._USE_MOCK_FACTORY:
+ raise TypeError('Given attr is not a Class. Use StubOutWithMock.')
+
+ factory = _MockObjectFactory(attr_to_replace, self)
+ self._mock_objects.append(factory)
+ self.stubs.Set(obj, attr_name, factory)
+
def UnsetStubs(self):
"""Restore stubs to their original state."""
@@ -299,11 +428,11 @@ class MockAnything:
self._description = description
self._Reset()
- def __str__(self):
- return "<MockAnything instance at %s>" % id(self)
-
def __repr__(self):
- return '<MockAnything instance>'
+ if self._description:
+ return '<MockAnything instance of %s>' % self._description
+ else:
+ return '<MockAnything instance>'
def __getattr__(self, method_name):
"""Intercept method calls on this object.
@@ -319,6 +448,8 @@ class MockAnything:
Returns:
A new MockMethod aware of MockAnything's state (record or replay).
"""
+ if method_name == '__dir__':
+ return self.__class__.__dir__.__get__(self, self.__class__)
return self._CreateMockMethod(method_name)
@@ -392,7 +523,7 @@ class MockAnything:
class MockObject(MockAnything, object):
"""A mock object that simulates the public/protected interface of a class."""
- def __init__(self, class_to_mock):
+ def __init__(self, class_to_mock, attrs=None):
"""Initialize a mock object.
This determines the methods and properties of the class and stores them.
@@ -400,7 +531,15 @@ class MockObject(MockAnything, object):
Args:
# class_to_mock: class to be mocked
class_to_mock: class
+ attrs: dict of attribute names to values that will be set on the mock
+ object. Only public attributes may be set.
+
+ Raises:
+ PrivateAttributeError: if a supplied attribute is not public.
+ ValueError: if an attribute would mask an existing method.
"""
+ if attrs is None:
+ attrs = {}
# This is used to hack around the mixin/inheritance of MockAnything, which
# is not a proper object (it can be anything. :-)
@@ -410,12 +549,34 @@ class MockObject(MockAnything, object):
self._known_methods = set()
self._known_vars = set()
self._class_to_mock = class_to_mock
+ try:
+ if inspect.isclass(self._class_to_mock):
+ self._description = class_to_mock.__name__
+ else:
+ self._description = type(class_to_mock).__name__
+ except Exception:
+ pass
+
for method in dir(class_to_mock):
- if callable(getattr(class_to_mock, method)):
+ attr = getattr(class_to_mock, method)
+ if callable(attr):
self._known_methods.add(method)
- else:
+ elif not (type(attr) is property):
+ # treating properties as class vars makes little sense.
self._known_vars.add(method)
+ # Set additional attributes at instantiation time; this is quicker
+ # than manually setting attributes that are normally created in
+ # __init__.
+ for attr, value in attrs.items():
+ if attr.startswith("_"):
+ raise PrivateAttributeError(attr)
+ elif attr in self._known_methods:
+ raise ValueError("'%s' is a method of '%s' objects." % (attr,
+ class_to_mock))
+ else:
+ setattr(self, attr, value)
+
def __getattr__(self, name):
"""Intercept attribute request on this object.
@@ -596,7 +757,16 @@ class MockObject(MockAnything, object):
# Because the call is happening directly on this object instead of a method,
# the call on the mock method is made right here
- mock_method = self._CreateMockMethod('__call__')
+
+ # If we are mocking a Function, then use the function, and not the
+ # __call__ method
+ method = None
+ if type(self._class_to_mock) in (types.FunctionType, types.MethodType):
+ method = self._class_to_mock;
+ else:
+ method = getattr(self._class_to_mock, '__call__')
+ mock_method = self._CreateMockMethod('__call__', method_to_mock=method)
+
return mock_method(*params, **named_params)
@property
@@ -605,8 +775,61 @@ class MockObject(MockAnything, object):
return self._class_to_mock
+ @property
+ def __name__(self):
+ """Return the name that is being mocked."""
+ return self._description
+
+
+class _MockObjectFactory(MockObject):
+ """A MockObjectFactory creates mocks and verifies __init__ params.
+
+ A MockObjectFactory removes the boiler plate code that was previously
+ necessary to stub out direction instantiation of a class.
+
+ The MockObjectFactory creates new MockObjects when called and verifies the
+ __init__ params are correct when in record mode. When replaying, existing
+ mocks are returned, and the __init__ params are verified.
+
+ See StubOutWithMock vs StubOutClassWithMocks for more detail.
+ """
+
+ def __init__(self, class_to_mock, mox_instance):
+ MockObject.__init__(self, class_to_mock)
+ self._mox = mox_instance
+ self._instance_queue = deque()
+
+ def __call__(self, *params, **named_params):
+ """Instantiate and record that a new mock has been created."""
+
+ method = getattr(self._class_to_mock, '__init__')
+ mock_method = self._CreateMockMethod('__init__', method_to_mock=method)
+ # Note: calling mock_method() is deferred in order to catch the
+ # empty instance_queue first.
+
+ if self._replay_mode:
+ if not self._instance_queue:
+ raise UnexpectedMockCreationError(self._class_to_mock, *params,
+ **named_params)
+
+ mock_method(*params, **named_params)
+
+ return self._instance_queue.pop()
+ else:
+ mock_method(*params, **named_params)
+
+ instance = self._mox.CreateMock(self._class_to_mock)
+ self._instance_queue.appendleft(instance)
+ return instance
+
+ def _Verify(self):
+ """Verify that all mocks have been created."""
+ if self._instance_queue:
+ raise ExpectedMockCreationError(self._instance_queue)
+ super(_MockObjectFactory, self)._Verify()
+
-class MethodCallChecker(object):
+class MethodSignatureChecker(object):
"""Ensures that methods are called correctly."""
_NEEDED, _DEFAULT, _GIVEN = range(3)
@@ -630,6 +853,7 @@ class MethodCallChecker(object):
if inspect.ismethod(method):
self._args = self._args[1:] # Skip 'self'.
self._method = method
+ self._instance = None # May contain the instance this is bound to.
self._has_varargs = varargs is not None
self._has_varkw = varkw is not None
@@ -652,9 +876,9 @@ class MethodCallChecker(object):
Raises:
AttributeError: arg_name is already marked as _GIVEN.
"""
- if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN:
+ if arg_status.get(arg_name, None) == MethodSignatureChecker._GIVEN:
raise AttributeError('%s provided more than once' % (arg_name,))
- arg_status[arg_name] = MethodCallChecker._GIVEN
+ arg_status[arg_name] = MethodSignatureChecker._GIVEN
def Check(self, params, named_params):
"""Ensures that the parameters used while recording a call are valid.
@@ -668,10 +892,45 @@ class MethodCallChecker(object):
Raises:
AttributeError: the given parameters don't work with the given method.
"""
- arg_status = dict((a, MethodCallChecker._NEEDED)
+ arg_status = dict((a, MethodSignatureChecker._NEEDED)
for a in self._required_args)
for arg in self._default_args:
- arg_status[arg] = MethodCallChecker._DEFAULT
+ arg_status[arg] = MethodSignatureChecker._DEFAULT
+
+ # WARNING: Suspect hack ahead.
+ #
+ # Check to see if this is an unbound method, where the instance
+ # should be bound as the first argument. We try to determine if
+ # the first argument (param[0]) is an instance of the class, or it
+ # is equivalent to the class (used to account for Comparators).
+ #
+ # NOTE: If a Func() comparator is used, and the signature is not
+ # correct, this will cause extra executions of the function.
+ if inspect.ismethod(self._method):
+ # The extra param accounts for the bound instance.
+ if len(params) > len(self._required_args):
+ expected = getattr(self._method, 'im_class', None)
+
+ # Check if the param is an instance of the expected class,
+ # or check equality (useful for checking Comparators).
+
+ # This is a hack to work around the fact that the first
+ # parameter can be a Comparator, and the comparison may raise
+ # an exception during this comparison, which is OK.
+ try:
+ param_equality = (params[0] == expected)
+ except:
+ param_equality = False;
+
+
+ if isinstance(params[0], expected) or param_equality:
+ params = params[1:]
+ # If the IsA() comparator is being used, we need to check the
+ # inverse of the usual case - that the given instance is a subclass
+ # of the expected class. For example, the code under test does
+ # late binding to a subclass.
+ elif isinstance(params[0], IsA) and params[0]._IsSubClass(expected):
+ params = params[1:]
# Check that each positional param is valid.
for i in range(len(params)):
@@ -693,9 +952,9 @@ class MethodCallChecker(object):
# Ensure all the required arguments have been given.
still_needed = [k for k, v in arg_status.iteritems()
- if v == MethodCallChecker._NEEDED]
+ if v == MethodSignatureChecker._NEEDED]
if still_needed:
- raise AttributeError('No values given for arguments %s'
+ raise AttributeError('No values given for arguments: %s'
% (' '.join(sorted(still_needed))))
@@ -729,6 +988,7 @@ class MockMethod(object):
"""
self._name = method_name
+ self.__name__ = method_name
self._call_queue = call_queue
if not isinstance(call_queue, deque):
self._call_queue = deque(self._call_queue)
@@ -742,7 +1002,7 @@ class MockMethod(object):
self._side_effects = None
try:
- self._checker = MethodCallChecker(method_to_mock)
+ self._checker = MethodSignatureChecker(method_to_mock)
except ValueError:
self._checker = None
@@ -771,7 +1031,9 @@ class MockMethod(object):
expected_method = self._VerifyMethodCall()
if expected_method._side_effects:
- expected_method._side_effects(*params, **named_params)
+ result = expected_method._side_effects(*params, **named_params)
+ if expected_method._return_value is None:
+ expected_method._return_value = result
if expected_method._exception:
raise expected_method._exception
@@ -923,7 +1185,7 @@ class MockMethod(object):
"""Move this method into group of calls which may be called multiple times.
A group of repeating calls must be defined together, and must be executed in
- full before the next expected mehtod can be called.
+ full before the next expected method can be called.
Args:
group_name: the name of the unordered group.
@@ -1004,6 +1266,17 @@ class Comparator:
def __ne__(self, rhs):
return not self.equals(rhs)
+class Is(Comparator):
+ """Comparison class used to check identity, instead of equality."""
+
+ def __init__(self, obj):
+ self._obj = obj
+
+ def equals(self, rhs):
+ return rhs is self._obj
+
+ def __repr__(self):
+ return "<is %r (%s)>" % (self._obj, id(self._obj))
class IsA(Comparator):
"""This class wraps a basic Python type or class. It is used to verify
@@ -1040,8 +1313,26 @@ class IsA(Comparator):
# things like cStringIO.StringIO.
return type(rhs) == type(self._class_name)
+ def _IsSubClass(self, clazz):
+ """Check to see if the IsA comparators class is a subclass of clazz.
+
+ Args:
+ # clazz: a class object
+
+ Returns:
+ bool
+ """
+
+ try:
+ return issubclass(self._class_name, clazz)
+ except TypeError:
+ # Check raw types if there was a type error. This is helpful for
+ # things like cStringIO.StringIO.
+ return type(clazz) == type(self._class_name)
+
def __repr__(self):
- return str(self._class_name)
+ return 'mox.IsA(%s) ' % str(self._class_name)
+
class IsAlmost(Comparator):
"""Comparison class used to check whether a parameter is nearly equal
@@ -1073,7 +1364,7 @@ class IsAlmost(Comparator):
try:
return round(rhs-self._float_value, self._places) == 0
- except TypeError:
+ except Exception:
# This is probably because either float_value or rhs is not a number.
return False
@@ -1144,7 +1435,10 @@ class Regex(Comparator):
bool
"""
- return self.regex.search(rhs) is not None
+ try:
+ return self.regex.search(rhs) is not None
+ except Exception:
+ return False
def __repr__(self):
s = '<regular expression \'%s\'' % self.regex.pattern
@@ -1180,10 +1474,13 @@ class In(Comparator):
bool
"""
- return self._key in rhs
+ try:
+ return self._key in rhs
+ except Exception:
+ return False
def __repr__(self):
- return '<sequence or map containing \'%s\'>' % self._key
+ return '<sequence or map containing \'%s\'>' % str(self._key)
class Not(Comparator):
@@ -1214,7 +1511,10 @@ class Not(Comparator):
bool
"""
- return not self._predicate.equals(rhs)
+ try:
+ return not self._predicate.equals(rhs)
+ except Exception:
+ return False
def __repr__(self):
return '<not \'%s\'>' % self._predicate
@@ -1251,11 +1551,43 @@ class ContainsKeyValue(Comparator):
return False
def __repr__(self):
- return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
+ return '<map containing the entry \'%s: %s\'>' % (str(self._key),
+ str(self._value))
+
+
+class ContainsAttributeValue(Comparator):
+ """Checks whether a passed parameter contains attributes with a given value.
+
+ Example:
+ mock_dao.UpdateSomething(ContainsAttribute('stevepm', stevepm_user_info))
+ """
+
+ def __init__(self, key, value):
+ """Initialize.
+
+ Args:
+ # key: an attribute name of an object
+ # value: the corresponding value
+ """
+
+ self._key = key
+ self._value = value
+
+ def equals(self, rhs):
+ """Check whether the given attribute has a matching value in the rhs object.
+
+ Returns:
+ bool
+ """
+
+ try:
+ return getattr(rhs, self._key) == self._value
+ except Exception:
+ return False
class SameElementsAs(Comparator):
- """Checks whether iterables contain the same elements (ignoring order).
+ """Checks whether sequences contain the same elements (ignoring order).
Example:
mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
@@ -1267,8 +1599,8 @@ class SameElementsAs(Comparator):
Args:
expected_seq: a sequence
"""
-
- self._expected_seq = expected_seq
+ # Store in case expected_seq is an iterator.
+ self._expected_list = list(expected_seq)
def equals(self, actual_seq):
"""Check to see whether actual_seq has same elements as expected_seq.
@@ -1279,20 +1611,30 @@ class SameElementsAs(Comparator):
Returns:
bool
"""
+ try:
+ # Store in case actual_seq is an iterator. We potentially iterate twice:
+ # once to make the dict, once in the list fallback.
+ actual_list = list(actual_seq)
+ except TypeError:
+ # actual_seq cannot be read as a sequence.
+ #
+ # This happens because Mox uses __eq__ both to check object equality (in
+ # MethodSignatureChecker) and to invoke Comparators.
+ return False
try:
- expected = dict([(element, None) for element in self._expected_seq])
- actual = dict([(element, None) for element in actual_seq])
+ expected = dict([(element, None) for element in self._expected_list])
+ actual = dict([(element, None) for element in actual_list])
except TypeError:
# Fall back to slower list-compare if any of the objects are unhashable.
- expected = list(self._expected_seq)
- actual = list(actual_seq)
+ expected = self._expected_list
+ actual = actual_list
expected.sort()
actual.sort()
return expected == actual
def __repr__(self):
- return '<sequence with same elements as \'%s\'>' % self._expected_seq
+ return '<sequence with same elements as \'%s\'>' % self._expected_list
class And(Comparator):
@@ -1431,6 +1773,61 @@ class IgnoreArg(Comparator):
return '<IgnoreArg>'
+class Value(Comparator):
+ """Compares argument against a remembered value.
+
+ To be used in conjunction with Remember comparator. See Remember()
+ for example.
+ """
+
+ def __init__(self):
+ self._value = None
+ self._has_value = False
+
+ def store_value(self, rhs):
+ self._value = rhs
+ self._has_value = True
+
+ def equals(self, rhs):
+ if not self._has_value:
+ return False
+ else:
+ return rhs == self._value
+
+ def __repr__(self):
+ if self._has_value:
+ return "<Value %r>" % self._value
+ else:
+ return "<Value>"
+
+
+class Remember(Comparator):
+ """Remembers the argument to a value store.
+
+ To be used in conjunction with Value comparator.
+
+ Example:
+ # Remember the argument for one method call.
+ users_list = Value()
+ mock_dao.ProcessUsers(Remember(users_list))
+
+ # Check argument against remembered value.
+ mock_dao.ReportUsers(users_list)
+ """
+
+ def __init__(self, value_store):
+ if not isinstance(value_store, Value):
+ raise TypeError("value_store is not an instance of the Value class")
+ self._value_store = value_store
+
+ def equals(self, rhs):
+ self._value_store.store_value(rhs)
+ return True
+
+ def __repr__(self):
+ return "<Remember %d>" % id(self._value_store)
+
+
class MethodGroup(object):
"""Base class containing common behaviour for MethodGroups."""
@@ -1463,6 +1860,12 @@ class UnorderedGroup(MethodGroup):
super(UnorderedGroup, self).__init__(group_name)
self._methods = []
+ def __str__(self):
+ return '%s "%s" pending calls:\n%s' % (
+ self.__class__.__name__,
+ self._group_name,
+ "\n".join(str(method) for method in self._methods))
+
def AddMethod(self, mock_method):
"""Add a method to this group.
@@ -1589,7 +1992,8 @@ class MoxMetaTestBase(type):
# for a case when test class is not the immediate child of MoxTestBase
for base in bases:
for attr_name in dir(base):
- d[attr_name] = getattr(base, attr_name)
+ if attr_name not in d:
+ d[attr_name] = getattr(base, attr_name)
for func_name, func in d.items():
if func_name.startswith('test') and callable(func):
@@ -1611,14 +2015,21 @@ class MoxMetaTestBase(type):
"""
def new_method(self, *args, **kwargs):
mox_obj = getattr(self, 'mox', None)
+ stubout_obj = getattr(self, 'stubs', None)
cleanup_mox = False
+ cleanup_stubout = False
if mox_obj and isinstance(mox_obj, Mox):
cleanup_mox = True
+ if stubout_obj and isinstance(stubout_obj, stubout.StubOutForTesting):
+ cleanup_stubout = True
try:
func(self, *args, **kwargs)
finally:
if cleanup_mox:
mox_obj.UnsetStubs()
+ if cleanup_stubout:
+ stubout_obj.UnsetAll()
+ stubout_obj.SmartUnsetAll()
if cleanup_mox:
mox_obj.VerifyAll()
new_method.__name__ = func.__name__
@@ -1630,9 +2041,10 @@ class MoxMetaTestBase(type):
class MoxTestBase(unittest.TestCase):
"""Convenience test class to make stubbing easier.
- Sets up a "mox" attribute which is an instance of Mox - any mox tests will
- want this. Also automatically unsets any stubs and verifies that all mock
- methods have been called at the end of each test, eliminating boilerplate
+ Sets up a "mox" attribute which is an instance of Mox (any mox tests will
+ want this), and a "stubs" attribute that is an instance of StubOutForTesting
+ (needed at times). Also automatically unsets any stubs and verifies that all
+ mock methods have been called at the end of each test, eliminating boilerplate
code.
"""
@@ -1641,3 +2053,4 @@ class MoxTestBase(unittest.TestCase):
def setUp(self):
super(MoxTestBase, self).setUp()
self.mox = Mox()
+ self.stubs = stubout.StubOutForTesting()