summaryrefslogtreecommitdiff
path: root/mox3/mox.py
diff options
context:
space:
mode:
Diffstat (limited to 'mox3/mox.py')
-rw-r--r--mox3/mox.py2135
1 files changed, 2135 insertions, 0 deletions
diff --git a/mox3/mox.py b/mox3/mox.py
new file mode 100644
index 0000000..da1d89b
--- /dev/null
+++ b/mox3/mox.py
@@ -0,0 +1,2135 @@
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+#
+# This is a fork of the pymox library intended to work with Python 3.
+# The file was modified by quermit@gmail.com and dawid.fatyga@gmail.com
+
+"""Mox, an object-mocking framework for Python.
+
+Mox works in the record-replay-verify paradigm. When you first create
+a mock object, it is in record mode. You then programmatically set
+the expected behavior of the mock object (what methods are to be
+called on it, with what parameters, what they should return, and in
+what order).
+
+Once you have set up the expected mock behavior, you put it in replay
+mode. Now the mock responds to method calls just as you told it to.
+If an unexpected method (or an expected method with unexpected
+parameters) is called, then an exception will be raised.
+
+Once you are done interacting with the mock, you need to verify that
+all the expected interactions occured. (Maybe your code exited
+prematurely without calling some cleanup method!) The verify phase
+ensures that every expected method was called; otherwise, an exception
+will be raised.
+
+WARNING! Mock objects created by Mox are not thread-safe. If you are
+call a mock in multiple threads, it should be guarded by a mutex.
+
+TODO(stevepm): Add the option to make mocks thread-safe!
+
+Suggested usage / workflow:
+
+ # Create Mox factory
+ my_mox = Mox()
+
+ # Create a mock data access object
+ mock_dao = my_mox.CreateMock(DAOClass)
+
+ # Set up expected behavior
+ mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
+ mock_dao.DeletePerson(person)
+
+ # Put mocks in replay mode
+ my_mox.ReplayAll()
+
+ # Inject mock object and run test
+ controller.SetDao(mock_dao)
+ controller.DeletePersonById('1')
+
+ # Verify all methods were called as expected
+ my_mox.VerifyAll()
+"""
+
+from collections import deque
+import difflib
+import inspect
+import re
+import types
+import unittest
+
+from . import stubout
+
+
+class Error(AssertionError):
+ """Base exception for this module."""
+
+ pass
+
+
+class ExpectedMethodCallsError(Error):
+ """Raised when an expected method wasn't called.
+
+ This can occur if Verify() is called before all expected methods have been
+ called.
+ """
+
+ def __init__(self, expected_methods):
+ """Init exception.
+
+ Args:
+ # expected_methods: A sequence of MockMethod objects that should have
+ # been called.
+ expected_methods: [MockMethod]
+
+ Raises:
+ ValueError: if expected_methods contains no methods.
+ """
+
+ if not expected_methods:
+ raise ValueError("There must be at least one expected method")
+ Error.__init__(self)
+ self._expected_methods = expected_methods
+
+ def __str__(self):
+ calls = "\n".join(["%3d. %s" % (i, m)
+ for i, m in enumerate(self._expected_methods)])
+ return "Verify: Expected methods never called:\n%s" % (calls,)
+
+
+class UnexpectedMethodCallError(Error):
+ """Raised when an unexpected method is called.
+
+ This can occur if a method is called with incorrect parameters, or out of the
+ specified order.
+ """
+
+ def __init__(self, unexpected_method, expected):
+ """Init exception.
+
+ Args:
+ # unexpected_method: MockMethod that was called but was not at the head
+ # of the expected_method queue.
+ # expected: MockMethod or UnorderedGroup the method should have
+ # been in.
+ unexpected_method: MockMethod
+ expected: MockMethod or UnorderedGroup
+ """
+
+ Error.__init__(self)
+ if expected is None:
+ self._str = "Unexpected method call %s" % (unexpected_method,)
+ else:
+ differ = difflib.Differ()
+ diff = differ.compare(str(unexpected_method).splitlines(True),
+ str(expected).splitlines(True))
+ self._str = ("Unexpected method call. unexpected:- expected:+\n%s"
+ % ("\n".join(line.rstrip() for line in diff),))
+
+ def __str__(self):
+ return self._str
+
+
+class UnknownMethodCallError(Error):
+ """Raised if an unknown method is requested of the mock object."""
+
+ def __init__(self, unknown_method_name):
+ """Init exception.
+
+ Args:
+ # unknown_method_name: Method call that is not part of the mocked class's
+ # public interface.
+ unknown_method_name: str
+ """
+
+ Error.__init__(self)
+ self._unknown_method_name = unknown_method_name
+
+ def __str__(self):
+ return "Method called is not a member of the object: %s" % \
+ self._unknown_method_name
+
+
+class PrivateAttributeError(Error):
+ """
+ Raised if a MockObject is passed a private additional attribute name.
+ """
+
+ def __init__(self, attr):
+ Error.__init__(self)
+ self._attr = attr
+
+ def __str__(self):
+ return ("Attribute '%s' is private and should not be available in a mock "
+ "object." % self._attr)
+
+
+class ExpectedMockCreationError(Error):
+ """Raised if mocks should have been created by StubOutClassWithMocks."""
+
+ def __init__(self, expected_mocks):
+ """Init exception.
+
+ Args:
+ # expected_mocks: A sequence of MockObjects that should have been
+ # created
+
+ Raises:
+ ValueError: if expected_mocks contains no methods.
+ """
+
+ if not expected_mocks:
+ raise ValueError("There must be at least one expected method")
+ Error.__init__(self)
+ self._expected_mocks = expected_mocks
+
+ def __str__(self):
+ mocks = "\n".join(["%3d. %s" % (i, m)
+ for i, m in enumerate(self._expected_mocks)])
+ return "Verify: Expected mocks never created:\n%s" % (mocks,)
+
+
+class UnexpectedMockCreationError(Error):
+ """Raised if too many mocks were created by StubOutClassWithMocks."""
+
+ def __init__(self, instance, *params, **named_params):
+ """Init exception.
+
+ Args:
+ # instance: the type of obejct that was created
+ # params: parameters given during instantiation
+ # named_params: named parameters given during instantiation
+ """
+
+ Error.__init__(self)
+ self._instance = instance
+ self._params = params
+ self._named_params = named_params
+
+ def __str__(self):
+ args = ", ".join(["%s" % v for i, v in enumerate(self._params)])
+ error = "Unexpected mock creation: %s(%s" % (self._instance, args)
+
+ if self._named_params:
+ error += ", " + ", ".join(["%s=%s" % (k, v) for k, v in
+ self._named_params.items()])
+
+ error += ")"
+ return error
+
+
+class Mox(object):
+ """Mox: a factory for creating mock objects."""
+
+ # A list of types that should be stubbed out with MockObjects (as
+ # opposed to MockAnythings).
+ _USE_MOCK_OBJECT = [types.FunctionType, types.ModuleType, types.MethodType]
+
+ def __init__(self):
+ """Initialize a new Mox."""
+
+ self._mock_objects = []
+ self.stubs = stubout.StubOutForTesting()
+
+ def CreateMock(self, class_to_mock, attrs=None, bounded_to=None):
+ """Create a new mock object.
+
+ Args:
+ # class_to_mock: the class to be mocked
+ class_to_mock: class
+ attrs: dict of attribute names to values that will be set on the mock
+ object. Only public attributes may be set.
+ bounded_to: optionally, when class_to_mock is not a class, it points
+ to a real class object, to which attribute is bound
+
+ Returns:
+ MockObject that can be used as the class_to_mock would be.
+ """
+ if attrs is None:
+ attrs = {}
+ new_mock = MockObject(class_to_mock, attrs=attrs, class_to_bind=bounded_to)
+ self._mock_objects.append(new_mock)
+ return new_mock
+
+ def CreateMockAnything(self, description=None):
+ """Create a mock that will accept any method calls.
+
+ This does not enforce an interface.
+
+ Args:
+ description: str. Optionally, a descriptive name for the mock object
+ being created, for debugging output purposes.
+ """
+ new_mock = MockAnything(description=description)
+ self._mock_objects.append(new_mock)
+ return new_mock
+
+ def ReplayAll(self):
+ """Set all mock objects to replay mode."""
+
+ for mock_obj in self._mock_objects:
+ mock_obj._Replay()
+
+ def VerifyAll(self):
+ """Call verify on all mock objects created."""
+
+ for mock_obj in self._mock_objects:
+ mock_obj._Verify()
+
+ def ResetAll(self):
+ """Call reset on all mock objects. This does not unset stubs."""
+
+ for mock_obj in self._mock_objects:
+ mock_obj._Reset()
+
+ def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
+ """Replace a method, attribute, etc. with a Mock.
+
+ This will replace a class or module with a MockObject, and everything else
+ (method, function, etc) with a MockAnything. This can be overridden to
+ always use a MockAnything by setting use_mock_anything to True.
+
+ Args:
+ obj: A Python object (class, module, instance, callable).
+ attr_name: str. The name of the attribute to replace with a mock.
+ use_mock_anything: bool. True if a MockAnything should be used regardless
+ of the type of attribute.
+ """
+
+ if inspect.isclass(obj):
+ class_to_bind = obj
+ else:
+ class_to_bind = None
+
+ attr_to_replace = getattr(obj, attr_name)
+ attr_type = type(attr_to_replace)
+
+ if attr_type == MockAnything or attr_type == MockObject:
+ raise TypeError('Cannot mock a MockAnything! Did you remember to '
+ 'call UnsetStubs in your previous test?')
+
+ type_check = (
+ attr_type in self._USE_MOCK_OBJECT or inspect.isclass(attr_to_replace)
+ or isinstance(attr_to_replace, object))
+ if type_check and not use_mock_anything:
+ stub = self.CreateMock(attr_to_replace, bounded_to=class_to_bind)
+ else:
+ stub = self.CreateMockAnything(
+ description='Stub for %s' % attr_to_replace)
+ stub.__name__ = attr_name
+
+ self.stubs.Set(obj, attr_name, stub)
+
+ def StubOutClassWithMocks(self, obj, attr_name):
+ """Replace a class with a "mock factory" that will create mock objects.
+
+ This is useful if the code-under-test directly instantiates
+ dependencies. Previously some boilder plate was necessary to
+ create a mock that would act as a factory. Using
+ StubOutClassWithMocks, once you've stubbed out the class you may
+ use the stubbed class as you would any other mock created by mox:
+ during the record phase, new mock instances will be created, and
+ during replay, the recorded mocks will be returned.
+
+ In replay mode
+
+ # Example using StubOutWithMock (the old, clunky way):
+
+ mock1 = mox.CreateMock(my_import.FooClass)
+ mock2 = mox.CreateMock(my_import.FooClass)
+ foo_factory = mox.StubOutWithMock(my_import, 'FooClass',
+ use_mock_anything=True)
+ foo_factory(1, 2).AndReturn(mock1)
+ foo_factory(9, 10).AndReturn(mock2)
+ mox.ReplayAll()
+
+ my_import.FooClass(1, 2) # Returns mock1 again.
+ my_import.FooClass(9, 10) # Returns mock2 again.
+ mox.VerifyAll()
+
+ # Example using StubOutClassWithMocks:
+
+ mox.StubOutClassWithMocks(my_import, 'FooClass')
+ mock1 = my_import.FooClass(1, 2) # Returns a new mock of FooClass
+ mock2 = my_import.FooClass(9, 10) # Returns another mock instance
+ mox.ReplayAll()
+
+ my_import.FooClass(1, 2) # Returns mock1 again.
+ my_import.FooClass(9, 10) # Returns mock2 again.
+ mox.VerifyAll()
+ """
+ attr_to_replace = getattr(obj, attr_name)
+ attr_type = type(attr_to_replace)
+
+ if attr_type == MockAnything or attr_type == MockObject:
+ raise TypeError('Cannot mock a MockAnything! Did you remember to '
+ 'call UnsetStubs in your previous test?')
+
+ if not inspect.isclass(attr_to_replace):
+ raise TypeError('Given attr is not a Class. Use StubOutWithMock.')
+
+ factory = _MockObjectFactory(attr_to_replace, self)
+ self._mock_objects.append(factory)
+ self.stubs.Set(obj, attr_name, factory)
+
+ def UnsetStubs(self):
+ """Restore stubs to their original state."""
+
+ self.stubs.UnsetAll()
+
+
+def Replay(*args):
+ """Put mocks into Replay mode.
+
+ Args:
+ # args is any number of mocks to put into replay mode.
+ """
+
+ for mock in args:
+ mock._Replay()
+
+
+def Verify(*args):
+ """Verify mocks.
+
+ Args:
+ # args is any number of mocks to be verified.
+ """
+
+ for mock in args:
+ mock._Verify()
+
+
+def Reset(*args):
+ """Reset mocks.
+
+ Args:
+ # args is any number of mocks to be reset.
+ """
+
+ for mock in args:
+ mock._Reset()
+
+
+class MockAnything(object):
+ """A mock that can be used to mock anything.
+
+ This is helpful for mocking classes that do not provide a public interface.
+ """
+
+ def __init__(self, description=None):
+ """Initialize a new MockAnything.
+
+ Args:
+ description: str. Optionally, a descriptive name for the mock object
+ being created, for debugging output purposes.
+ """
+ self._description = description
+ self._Reset()
+
+ def __repr__(self):
+ if self._description:
+ return '<MockAnything instance of %s>' % self._description
+ else:
+ return '<MockAnything instance>'
+
+ def __getattr__(self, method_name):
+ """Intercept method calls on this object.
+
+ A new MockMethod is returned that is aware of the MockAnything's
+ state (record or replay). The call will be recorded or replayed
+ by the MockMethod's __call__.
+
+ Args:
+ # method name: the name of the method being called.
+ method_name: str
+
+ Returns:
+ A new MockMethod aware of MockAnything's state (record or replay).
+ """
+ if method_name == '__dir__':
+ return self.__class__.__dir__.__get__(self, self.__class__)
+
+ return self._CreateMockMethod(method_name)
+
+ def __str__(self):
+ return self._CreateMockMethod('__str__')()
+
+ def __call__(self, *args, **kwargs):
+ return self._CreateMockMethod('__call__')(*args, **kwargs)
+
+ def __getitem__(self, i):
+ return self._CreateMockMethod('__getitem__')(i)
+
+ def _CreateMockMethod(self, method_name, method_to_mock=None,
+ class_to_bind=object):
+ """Create a new mock method call and return it.
+
+ Args:
+ # method_name: the name of the method being called.
+ # method_to_mock: The actual method being mocked, used for introspection.
+ # class_to_bind: Class to which method is bounded (object by default)
+ method_name: str
+ method_to_mock: a method object
+
+ Returns:
+ A new MockMethod aware of MockAnything's state (record or replay).
+ """
+
+ return MockMethod(method_name, self._expected_calls_queue,
+ self._replay_mode, method_to_mock=method_to_mock,
+ description=self._description,
+ class_to_bind=class_to_bind)
+
+ def __nonzero__(self):
+ """Return 1 for nonzero so the mock can be used as a conditional."""
+
+ return 1
+
+ def __bool__(self):
+ """Return True for nonzero so the mock can be used as a conditional."""
+ return True
+
+ def __eq__(self, rhs):
+ """Provide custom logic to compare objects."""
+
+ return (isinstance(rhs, MockAnything) and
+ self._replay_mode == rhs._replay_mode and
+ self._expected_calls_queue == rhs._expected_calls_queue)
+
+ def __ne__(self, rhs):
+ """Provide custom logic to compare objects."""
+
+ return not self == rhs
+
+ def _Replay(self):
+ """Start replaying expected method calls."""
+
+ self._replay_mode = True
+
+ def _Verify(self):
+ """Verify that all of the expected calls have been made.
+
+ Raises:
+ ExpectedMethodCallsError: if there are still more method calls in the
+ expected queue.
+ """
+
+ # If the list of expected calls is not empty, raise an exception
+ if self._expected_calls_queue:
+ # The last MultipleTimesGroup is not popped from the queue.
+ if (len(self._expected_calls_queue) == 1 and
+ isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
+ self._expected_calls_queue[0].IsSatisfied()):
+ pass
+ else:
+ raise ExpectedMethodCallsError(self._expected_calls_queue)
+
+ def _Reset(self):
+ """Reset the state of this mock to record mode with an empty queue."""
+
+ # Maintain a list of method calls we are expecting
+ self._expected_calls_queue = deque()
+
+ # Make sure we are in setup mode, not replay mode
+ self._replay_mode = False
+
+
+class MockObject(MockAnything):
+ """A mock object that simulates the public/protected interface of a class."""
+
+ def __init__(self, class_to_mock, attrs=None, class_to_bind=None):
+ """Initialize a mock object.
+
+ This determines the methods and properties of the class and stores them.
+
+ Args:
+ # class_to_mock: class to be mocked
+ class_to_mock: class
+ attrs: dict of attribute names to values that will be set on the mock
+ object. Only public attributes may be set.
+ class_to_bind: optionally, when class_to_mock is not a class at all, it
+ points to a real class
+
+ Raises:
+ PrivateAttributeError: if a supplied attribute is not public.
+ ValueError: if an attribute would mask an existing method.
+ """
+ if attrs is None:
+ attrs = {}
+
+ # This is used to hack around the mixin/inheritance of MockAnything, which
+ # is not a proper object (it can be anything. :-)
+ MockAnything.__dict__['__init__'](self)
+
+ # Get a list of all the public and special methods we should mock.
+ self._known_methods = set()
+ self._known_vars = set()
+ self._class_to_mock = class_to_mock
+
+ if inspect.isclass(class_to_mock):
+ self._class_to_bind = self._class_to_mock
+ else:
+ self._class_to_bind = class_to_bind
+
+ try:
+ if inspect.isclass(self._class_to_mock):
+ self._description = class_to_mock.__name__
+ else:
+ self._description = type(class_to_mock).__name__
+ except Exception:
+ pass
+
+ for method in dir(class_to_mock):
+ attr = getattr(class_to_mock, method)
+ if callable(attr):
+ self._known_methods.add(method)
+ elif not (type(attr) is property):
+ # treating properties as class vars makes little sense.
+ self._known_vars.add(method)
+
+ # Set additional attributes at instantiation time; this is quicker
+ # than manually setting attributes that are normally created in
+ # __init__.
+ for attr, value in attrs.items():
+ if attr.startswith("_"):
+ raise PrivateAttributeError(attr)
+ elif attr in self._known_methods:
+ raise ValueError("'%s' is a method of '%s' objects." % (attr,
+ class_to_mock))
+ else:
+ setattr(self, attr, value)
+
+ def _CreateMockMethod(self, *args, **kwargs):
+ """Overridden to provide self._class_to_mock to class_to_bind parameter."""
+ kwargs.setdefault("class_to_bind", self._class_to_bind)
+ return super(MockObject, self)._CreateMockMethod(*args, **kwargs)
+
+ def __getattr__(self, name):
+ """Intercept attribute request on this object.
+
+ If the attribute is a public class variable, it will be returned and not
+ recorded as a call.
+
+ If the attribute is not a variable, it is handled like a method
+ call. The method name is checked against the set of mockable
+ methods, and a new MockMethod is returned that is aware of the
+ MockObject's state (record or replay). The call will be recorded
+ or replayed by the MockMethod's __call__.
+
+ Args:
+ # name: the name of the attribute being requested.
+ name: str
+
+ Returns:
+ Either a class variable or a new MockMethod that is aware of the state
+ of the mock (record or replay).
+
+ Raises:
+ UnknownMethodCallError if the MockObject does not mock the requested
+ method.
+ """
+
+ if name in self._known_vars:
+ return getattr(self._class_to_mock, name)
+
+ if name in self._known_methods:
+ return self._CreateMockMethod(
+ name,
+ method_to_mock=getattr(self._class_to_mock, name))
+
+ raise UnknownMethodCallError(name)
+
+ def __eq__(self, rhs):
+ """Provide custom logic to compare objects."""
+
+ return (isinstance(rhs, MockObject) and
+ self._class_to_mock == rhs._class_to_mock and
+ self._replay_mode == rhs._replay_mode and
+ self._expected_calls_queue == rhs._expected_calls_queue)
+
+ def __setitem__(self, key, value):
+ """Provide custom logic for mocking classes that support item assignment.
+
+ Args:
+ key: Key to set the value for.
+ value: Value to set.
+
+ Returns:
+ Expected return value in replay mode. A MockMethod object for the
+ __setitem__ method that has already been called if not in replay mode.
+
+ Raises:
+ TypeError if the underlying class does not support item assignment.
+ UnexpectedMethodCallError if the object does not expect the call to
+ __setitem__.
+
+ """
+ # Verify the class supports item assignment.
+ if '__setitem__' not in dir(self._class_to_mock):
+ raise TypeError('object does not support item assignment')
+
+ # If we are in replay mode then simply call the mock __setitem__ method.
+ if self._replay_mode:
+ return MockMethod('__setitem__', self._expected_calls_queue,
+ self._replay_mode)(key, value)
+
+ # Otherwise, create a mock method __setitem__.
+ return self._CreateMockMethod('__setitem__')(key, value)
+
+ def __getitem__(self, key):
+ """Provide custom logic for mocking classes that are subscriptable.
+
+ Args:
+ key: Key to return the value for.
+
+ Returns:
+ Expected return value in replay mode. A MockMethod object for the
+ __getitem__ method that has already been called if not in replay mode.
+
+ Raises:
+ TypeError if the underlying class is not subscriptable.
+ UnexpectedMethodCallError if the object does not expect the call to
+ __getitem__.
+
+ """
+ # Verify the class supports item assignment.
+ if '__getitem__' not in dir(self._class_to_mock):
+ raise TypeError('unsubscriptable object')
+
+ # If we are in replay mode then simply call the mock __getitem__ method.
+ if self._replay_mode:
+ return MockMethod('__getitem__', self._expected_calls_queue,
+ self._replay_mode)(key)
+
+ # Otherwise, create a mock method __getitem__.
+ return self._CreateMockMethod('__getitem__')(key)
+
+ def __iter__(self):
+ """Provide custom logic for mocking classes that are iterable.
+
+ Returns:
+ Expected return value in replay mode. A MockMethod object for the
+ __iter__ method that has already been called if not in replay mode.
+
+ Raises:
+ TypeError if the underlying class is not iterable.
+ UnexpectedMethodCallError if the object does not expect the call to
+ __iter__.
+
+ """
+ methods = dir(self._class_to_mock)
+
+ # Verify the class supports iteration.
+ if '__iter__' not in methods:
+ # If it doesn't have iter method and we are in replay method, then try to
+ # iterate using subscripts.
+ if '__getitem__' not in methods or not self._replay_mode:
+ raise TypeError('not iterable object')
+ else:
+ results = []
+ index = 0
+ try:
+ while True:
+ results.append(self[index])
+ index += 1
+ except IndexError:
+ return iter(results)
+
+ # If we are in replay mode then simply call the mock __iter__ method.
+ if self._replay_mode:
+ return MockMethod('__iter__', self._expected_calls_queue,
+ self._replay_mode)()
+
+ # Otherwise, create a mock method __iter__.
+ return self._CreateMockMethod('__iter__')()
+
+ def __contains__(self, key):
+ """Provide custom logic for mocking classes that contain items.
+
+ Args:
+ key: Key to look in container for.
+
+ Returns:
+ Expected return value in replay mode. A MockMethod object for the
+ __contains__ method that has already been called if not in replay mode.
+
+ Raises:
+ TypeError if the underlying class does not implement __contains__
+ UnexpectedMethodCaller if the object does not expect the call to
+ __contains__.
+
+ """
+ contains = self._class_to_mock.__dict__.get('__contains__', None)
+
+ if contains is None:
+ raise TypeError('unsubscriptable object')
+
+ if self._replay_mode:
+ return MockMethod('__contains__', self._expected_calls_queue,
+ self._replay_mode)(key)
+
+ return self._CreateMockMethod('__contains__')(key)
+
+ def __call__(self, *params, **named_params):
+ """Provide custom logic for mocking classes that are callable."""
+
+ # Verify the class we are mocking is callable.
+ is_callable = hasattr(self._class_to_mock, '__call__')
+ if not is_callable:
+ raise TypeError('Not callable')
+
+ # Because the call is happening directly on this object instead of
+ # a method, the call on the mock method is made right here
+
+ # If we are mocking a Function, then use the function, and not the
+ # __call__ method
+ method = None
+ if type(self._class_to_mock) in (types.FunctionType, types.MethodType):
+ method = self._class_to_mock
+ else:
+ method = getattr(self._class_to_mock, '__call__')
+ mock_method = self._CreateMockMethod('__call__', method_to_mock=method)
+
+ return mock_method(*params, **named_params)
+
+ @property
+ def __name__(self):
+ """Return the name that is being mocked."""
+ return self._description
+
+ # TODO(dejw): this property stopped to work after I introduced changes with
+ # binding classes. Fortunately I found a solution in the form of
+ # __getattribute__ method below, but this issue should be investigated
+ @property
+ def __class__(self):
+ return self._class_to_mock
+
+ def __dir__(self):
+ """Return only attributes of a class to mock """
+ return dir(self._class_to_mock)
+
+ def __getattribute__(self, name):
+ """Return _class_to_mock on __class__ attribute. """
+ if name == "__class__":
+ return super(MockObject, self).__getattribute__("_class_to_mock")
+
+ return super(MockObject, self).__getattribute__(name)
+
+
+class _MockObjectFactory(MockObject):
+ """A MockObjectFactory creates mocks and verifies __init__ params.
+
+ A MockObjectFactory removes the boiler plate code that was previously
+ necessary to stub out direction instantiation of a class.
+
+ The MockObjectFactory creates new MockObjects when called and verifies the
+ __init__ params are correct when in record mode. When replaying, existing
+ mocks are returned, and the __init__ params are verified.
+
+ See StubOutWithMock vs StubOutClassWithMocks for more detail.
+ """
+
+ def __init__(self, class_to_mock, mox_instance):
+ MockObject.__init__(self, class_to_mock)
+ self._mox = mox_instance
+ self._instance_queue = deque()
+
+ def __call__(self, *params, **named_params):
+ """Instantiate and record that a new mock has been created."""
+
+ method = getattr(self._class_to_mock, '__init__')
+ mock_method = self._CreateMockMethod('__init__', method_to_mock=method)
+ # Note: calling mock_method() is deferred in order to catch the
+ # empty instance_queue first.
+
+ if self._replay_mode:
+ if not self._instance_queue:
+ raise UnexpectedMockCreationError(self._class_to_mock, *params,
+ **named_params)
+
+ mock_method(*params, **named_params)
+
+ return self._instance_queue.pop()
+ else:
+ mock_method(*params, **named_params)
+
+ instance = self._mox.CreateMock(self._class_to_mock)
+ self._instance_queue.appendleft(instance)
+ return instance
+
+ def _Verify(self):
+ """Verify that all mocks have been created."""
+ if self._instance_queue:
+ raise ExpectedMockCreationError(self._instance_queue)
+ super(_MockObjectFactory, self)._Verify()
+
+
+class MethodSignatureChecker(object):
+ """Ensures that methods are called correctly."""
+
+ _NEEDED, _DEFAULT, _GIVEN = range(3)
+
+ def __init__(self, method, class_to_bind=None):
+ """Creates a checker.
+
+ Args:
+ # method: A method to check.
+ # class_to_bind: optionally, a class used to type check first method
+ # parameter, only used with unbound methods
+ method: function
+ class_to_bind: type or None
+
+ Raises:
+ ValueError: method could not be inspected, so checks aren't possible.
+ Some methods and functions like built-ins can't be inspected.
+ """
+ try:
+ self._args, varargs, varkw, defaults = inspect.getargspec(method)
+ except TypeError:
+ raise ValueError('Could not get argument specification for %r'
+ % (method,))
+ if inspect.ismethod(method) or class_to_bind:
+ self._args = self._args[1:] # Skip 'self'.
+ self._method = method
+ self._instance = None # May contain the instance this is bound to.
+ self._instance = getattr(method, "__self__", None)
+
+ # _bounded_to determines whether the method is bound or not
+ if self._instance:
+ self._bounded_to = self._instance.__class__
+ else:
+ self._bounded_to = class_to_bind or getattr(method, "im_class", None)
+
+ self._has_varargs = varargs is not None
+ self._has_varkw = varkw is not None
+ if defaults is None:
+ self._required_args = self._args
+ self._default_args = []
+ else:
+ self._required_args = self._args[:-len(defaults)]
+ self._default_args = self._args[-len(defaults):]
+
+ def _RecordArgumentGiven(self, arg_name, arg_status):
+ """Mark an argument as being given.
+
+ Args:
+ # arg_name: The name of the argument to mark in arg_status.
+ # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN.
+ arg_name: string
+ arg_status: dict
+
+ Raises:
+ AttributeError: arg_name is already marked as _GIVEN.
+ """
+ if arg_status.get(arg_name, None) == MethodSignatureChecker._GIVEN:
+ raise AttributeError('%s provided more than once' % (arg_name,))
+ arg_status[arg_name] = MethodSignatureChecker._GIVEN
+
+ def Check(self, params, named_params):
+ """Ensures that the parameters used while recording a call are valid.
+
+ Args:
+ # params: A list of positional parameters.
+ # named_params: A dict of named parameters.
+ params: list
+ named_params: dict
+
+ Raises:
+ AttributeError: the given parameters don't work with the given method.
+ """
+ arg_status = dict((a, MethodSignatureChecker._NEEDED)
+ for a in self._required_args)
+ for arg in self._default_args:
+ arg_status[arg] = MethodSignatureChecker._DEFAULT
+
+ # WARNING: Suspect hack ahead.
+ #
+ # Check to see if this is an unbound method, where the instance
+ # should be bound as the first argument. We try to determine if
+ # the first argument (param[0]) is an instance of the class, or it
+ # is equivalent to the class (used to account for Comparators).
+ #
+ # NOTE: If a Func() comparator is used, and the signature is not
+ # correct, this will cause extra executions of the function.
+ if inspect.ismethod(self._method) or self._bounded_to:
+ # The extra param accounts for the bound instance.
+ if len(params) > len(self._required_args):
+ expected = self._bounded_to
+
+ # Check if the param is an instance of the expected class,
+ # or check equality (useful for checking Comparators).
+
+ # This is a hack to work around the fact that the first
+ # parameter can be a Comparator, and the comparison may raise
+ # an exception during this comparison, which is OK.
+ try:
+ param_equality = (params[0] == expected)
+ except:
+ param_equality = False
+
+ if isinstance(params[0], expected) or param_equality:
+ params = params[1:]
+ # If the IsA() comparator is being used, we need to check the
+ # inverse of the usual case - that the given instance is a subclass
+ # of the expected class. For example, the code under test does
+ # late binding to a subclass.
+ elif isinstance(params[0], IsA) and params[0]._IsSubClass(expected):
+ params = params[1:]
+
+ # Check that each positional param is valid.
+ for i in range(len(params)):
+ try:
+ arg_name = self._args[i]
+ except IndexError:
+ if not self._has_varargs:
+ raise AttributeError('%s does not take %d or more positional '
+ 'arguments' % (self._method.__name__, i))
+ else:
+ self._RecordArgumentGiven(arg_name, arg_status)
+
+ # Check each keyword argument.
+ for arg_name in named_params:
+ if arg_name not in arg_status and not self._has_varkw:
+ raise AttributeError('%s is not expecting keyword argument %s'
+ % (self._method.__name__, arg_name))
+ self._RecordArgumentGiven(arg_name, arg_status)
+
+ # Ensure all the required arguments have been given.
+ still_needed = [k for k, v in arg_status.items()
+ if v == MethodSignatureChecker._NEEDED]
+ if still_needed:
+ raise AttributeError('No values given for arguments: %s'
+ % (' '.join(sorted(still_needed))))
+
+
+class MockMethod(object):
+ """Callable mock method.
+
+ A MockMethod should act exactly like the method it mocks, accepting
+ parameters and returning a value, or throwing an exception (as specified).
+ When this method is called, it can optionally verify whether the called
+ method (name and signature) matches the expected method.
+ """
+
+ def __init__(self, method_name, call_queue, replay_mode,
+ method_to_mock=None, description=None, class_to_bind=None):
+ """Construct a new mock method.
+
+ Args:
+ # method_name: the name of the method
+ # call_queue: deque of calls, verify this call against the head, or add
+ # this call to the queue.
+ # replay_mode: False if we are recording, True if we are verifying calls
+ # against the call queue.
+ # method_to_mock: The actual method being mocked, used for introspection.
+ # description: optionally, a descriptive name for this method. Typically
+ # this is equal to the descriptive name of the method's class.
+ # class_to_bind: optionally, a class that is used for unbound methods
+ # (or functions in Python3) to which method is bound, in order not
+ # to loose binding information. If given, it will be used for
+ # checking the type of first method parameter
+ method_name: str
+ call_queue: list or deque
+ replay_mode: bool
+ method_to_mock: a method object
+ description: str or None
+ class_to_bind: type or None
+ """
+
+ self._name = method_name
+ self.__name__ = method_name
+ self._call_queue = call_queue
+ if not isinstance(call_queue, deque):
+ self._call_queue = deque(self._call_queue)
+ self._replay_mode = replay_mode
+ self._description = description
+
+ self._params = None
+ self._named_params = None
+ self._return_value = None
+ self._exception = None
+ self._side_effects = None
+
+ try:
+ self._checker = MethodSignatureChecker(method_to_mock,
+ class_to_bind=class_to_bind)
+ except ValueError:
+ self._checker = None
+
+ def __call__(self, *params, **named_params):
+ """Log parameters and return the specified return value.
+
+ If the Mock(Anything/Object) associated with this call is in record mode,
+ this MockMethod will be pushed onto the expected call queue. If the mock
+ is in replay mode, this will pop a MockMethod off the top of the queue and
+ verify this call is equal to the expected call.
+
+ Raises:
+ UnexpectedMethodCall if this call is supposed to match an expected method
+ call and it does not.
+ """
+
+ self._params = params
+ self._named_params = named_params
+
+ if not self._replay_mode:
+ if self._checker is not None:
+ self._checker.Check(params, named_params)
+ self._call_queue.append(self)
+ return self
+
+ expected_method = self._VerifyMethodCall()
+
+ if expected_method._side_effects:
+ result = expected_method._side_effects(*params, **named_params)
+ if expected_method._return_value is None:
+ expected_method._return_value = result
+
+ if expected_method._exception:
+ raise expected_method._exception
+
+ return expected_method._return_value
+
+ def __getattr__(self, name):
+ """Raise an AttributeError with a helpful message."""
+
+ raise AttributeError('MockMethod has no attribute "%s". '
+ 'Did you remember to put your mocks in replay mode?' % name)
+
+ def __iter__(self):
+ """Raise a TypeError with a helpful message."""
+ raise TypeError('MockMethod cannot be iterated. '
+ 'Did you remember to put your mocks in replay mode?')
+
+ def next(self):
+ """Raise a TypeError with a helpful message."""
+ raise TypeError('MockMethod cannot be iterated. '
+ 'Did you remember to put your mocks in replay mode?')
+
+ def __next__(self):
+ """Raise a TypeError with a helpful message."""
+ raise TypeError('MockMethod cannot be iterated. '
+ 'Did you remember to put your mocks in replay mode?')
+
+ def _PopNextMethod(self):
+ """Pop the next method from our call queue."""
+ try:
+ return self._call_queue.popleft()
+ except IndexError:
+ raise UnexpectedMethodCallError(self, None)
+
+ def _VerifyMethodCall(self):
+ """Verify the called method is expected.
+
+ This can be an ordered method, or part of an unordered set.
+
+ Returns:
+ The expected mock method.
+
+ Raises:
+ UnexpectedMethodCall if the method called was not expected.
+ """
+
+ expected = self._PopNextMethod()
+
+ # Loop here, because we might have a MethodGroup followed by another
+ # group.
+ while isinstance(expected, MethodGroup):
+ expected, method = expected.MethodCalled(self)
+ if method is not None:
+ return method
+
+ # This is a mock method, so just check equality.
+ if expected != self:
+ raise UnexpectedMethodCallError(self, expected)
+
+ return expected
+
+ def __str__(self):
+ params = ', '.join(
+ [repr(p) for p in self._params or []] +
+ ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
+ full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
+ if self._description:
+ full_desc = "%s.%s" % (self._description, full_desc)
+ return full_desc
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, rhs):
+ """Test whether this MockMethod is equivalent to another MockMethod.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: MockMethod
+ """
+
+ return (isinstance(rhs, MockMethod) and
+ self._name == rhs._name and
+ self._params == rhs._params and
+ self._named_params == rhs._named_params)
+
+ def __ne__(self, rhs):
+ """Test whether this MockMethod is not equivalent to another MockMethod.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: MockMethod
+ """
+
+ return not self == rhs
+
+ def GetPossibleGroup(self):
+ """Returns a possible group from the end of the call queue or None if no
+ other methods are on the stack.
+ """
+
+ # Remove this method from the tail of the queue so we can add it
+ # to a group.
+ this_method = self._call_queue.pop()
+ assert this_method == self
+
+ # Determine if the tail of the queue is a group, or just a regular ordered
+ # mock method.
+ group = None
+ try:
+ group = self._call_queue[-1]
+ except IndexError:
+ pass
+
+ return group
+
+ def _CheckAndCreateNewGroup(self, group_name, group_class):
+ """Checks if the last method (a possible group) is an instance of our
+ group_class. Adds the current method to this group or creates a new one.
+
+ Args:
+
+ group_name: the name of the group.
+ group_class: the class used to create instance of this new group
+ """
+ group = self.GetPossibleGroup()
+
+ # If this is a group, and it is the correct group, add the method.
+ if isinstance(group, group_class) and group.group_name() == group_name:
+ group.AddMethod(self)
+ return self
+
+ # Create a new group and add the method.
+ new_group = group_class(group_name)
+ new_group.AddMethod(self)
+ self._call_queue.append(new_group)
+ return self
+
+ def InAnyOrder(self, group_name="default"):
+ """Move this method into a group of unordered calls.
+
+ A group of unordered calls must be defined together, and must be executed
+ in full before the next expected method can be called. There can be
+ multiple groups that are expected serially, if they are given
+ different group names. The same group name can be reused if there is a
+ standard method call, or a group with a different name, spliced between
+ usages.
+
+ Args:
+ group_name: the name of the unordered group.
+
+ Returns:
+ self
+ """
+ return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
+
+ def MultipleTimes(self, group_name="default"):
+ """Move this method into group of calls which may be called multiple times.
+
+ A group of repeating calls must be defined together, and must be executed
+ in full before the next expected method can be called.
+
+ Args:
+ group_name: the name of the unordered group.
+
+ Returns:
+ self
+ """
+ return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
+
+ def AndReturn(self, return_value):
+ """Set the value to return when this method is called.
+
+ Args:
+ # return_value can be anything.
+ """
+
+ self._return_value = return_value
+ return return_value
+
+ def AndRaise(self, exception):
+ """Set the exception to raise when this method is called.
+
+ Args:
+ # exception: the exception to raise when this method is called.
+ exception: Exception
+ """
+
+ self._exception = exception
+
+ def WithSideEffects(self, side_effects):
+ """Set the side effects that are simulated when this method is called.
+
+ Args:
+ side_effects: A callable which modifies the parameters or other relevant
+ state which a given test case depends on.
+
+ Returns:
+ Self for chaining with AndReturn and AndRaise.
+ """
+ self._side_effects = side_effects
+ return self
+
+
+class Comparator:
+ """Base class for all Mox comparators.
+
+ A Comparator can be used as a parameter to a mocked method when the exact
+ value is not known. For example, the code you are testing might build up a
+ long SQL string that is passed to your mock DAO. You're only interested that
+ the IN clause contains the proper primary keys, so you can set your mock
+ up as follows:
+
+ mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
+
+ Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
+
+ A Comparator may replace one or more parameters, for example:
+ # return at most 10 rows
+ mock_dao.RunQuery(StrContains('SELECT'), 10)
+
+ or
+
+ # Return some non-deterministic number of rows
+ mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
+ """
+
+ def equals(self, rhs):
+ """Special equals method that all comparators must implement.
+
+ Args:
+ rhs: any python object
+ """
+
+ raise NotImplementedError('method must be implemented by a subclass.')
+
+ def __eq__(self, rhs):
+ return self.equals(rhs)
+
+ def __ne__(self, rhs):
+ return not self.equals(rhs)
+
+
+class Is(Comparator):
+ """Comparison class used to check identity, instead of equality."""
+
+ def __init__(self, obj):
+ self._obj = obj
+
+ def equals(self, rhs):
+ return rhs is self._obj
+
+ def __repr__(self):
+ return "<is %r (%s)>" % (self._obj, id(self._obj))
+
+
+class IsA(Comparator):
+ """This class wraps a basic Python type or class. It is used to verify
+ that a parameter is of the given type or class.
+
+ Example:
+ mock_dao.Connect(IsA(DbConnectInfo))
+ """
+
+ def __init__(self, class_name):
+ """Initialize IsA
+
+ Args:
+ class_name: basic python type or a class
+ """
+
+ self._class_name = class_name
+
+ def equals(self, rhs):
+ """Check to see if the RHS is an instance of class_name.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: object
+
+ Returns:
+ bool
+ """
+
+ try:
+ return isinstance(rhs, self._class_name)
+ except TypeError:
+ # Check raw types if there was a type error. This is helpful for
+ # things like cStringIO.StringIO.
+ return type(rhs) == type(self._class_name)
+
+ def _IsSubClass(self, clazz):
+ """Check to see if the IsA comparators class is a subclass of clazz.
+
+ Args:
+ # clazz: a class object
+
+ Returns:
+ bool
+ """
+
+ try:
+ return issubclass(self._class_name, clazz)
+ except TypeError:
+ # Check raw types if there was a type error. This is helpful for
+ # things like cStringIO.StringIO.
+ return type(clazz) == type(self._class_name)
+
+ def __repr__(self):
+ return 'mox.IsA(%s) ' % str(self._class_name)
+
+
+class IsAlmost(Comparator):
+ """Comparison class used to check whether a parameter is nearly equal
+ to a given value. Generally useful for floating point numbers.
+
+ Example mock_dao.SetTimeout((IsAlmost(3.9)))
+ """
+
+ def __init__(self, float_value, places=7):
+ """Initialize IsAlmost.
+
+ Args:
+ float_value: The value for making the comparison.
+ places: The number of decimal places to round to.
+ """
+
+ self._float_value = float_value
+ self._places = places
+
+ def equals(self, rhs):
+ """Check to see if RHS is almost equal to float_value
+
+ Args:
+ rhs: the value to compare to float_value
+
+ Returns:
+ bool
+ """
+
+ try:
+ return round(rhs - self._float_value, self._places) == 0
+ except Exception:
+ # This is probably because either float_value or rhs is not a number.
+ return False
+
+ def __repr__(self):
+ return str(self._float_value)
+
+
+class StrContains(Comparator):
+ """Comparison class used to check whether a substring exists in a
+ string parameter. This can be useful in mocking a database with SQL
+ passed in as a string parameter, for example.
+
+ Example:
+ mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
+ """
+
+ def __init__(self, search_string):
+ """Initialize.
+
+ Args:
+ # search_string: the string you are searching for
+ search_string: str
+ """
+
+ self._search_string = search_string
+
+ def equals(self, rhs):
+ """Check to see if the search_string is contained in the rhs string.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: object
+
+ Returns:
+ bool
+ """
+
+ try:
+ return rhs.find(self._search_string) > -1
+ except Exception:
+ return False
+
+ def __repr__(self):
+ return '<str containing \'%s\'>' % self._search_string
+
+
+class Regex(Comparator):
+ """Checks if a string matches a regular expression.
+
+ This uses a given regular expression to determine equality.
+ """
+
+ def __init__(self, pattern, flags=0):
+ """Initialize.
+
+ Args:
+ # pattern is the regular expression to search for
+ pattern: str
+ # flags passed to re.compile function as the second argument
+ flags: int
+ """
+ self.flags = flags
+ self.regex = re.compile(pattern, flags=flags)
+
+ def equals(self, rhs):
+ """Check to see if rhs matches regular expression pattern.
+
+ Returns:
+ bool
+ """
+
+ try:
+ return self.regex.search(rhs) is not None
+ except Exception:
+ return False
+
+ def __repr__(self):
+ s = '<regular expression \'%s\'' % self.regex.pattern
+ if self.flags:
+ s += ', flags=%d' % self.flags
+ s += '>'
+ return s
+
+
+class In(Comparator):
+ """Checks whether an item (or key) is in a list (or dict) parameter.
+
+ Example:
+ mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
+ """
+
+ def __init__(self, key):
+ """Initialize.
+
+ Args:
+ # key is any thing that could be in a list or a key in a dict
+ """
+
+ self._key = key
+
+ def equals(self, rhs):
+ """Check to see whether key is in rhs.
+
+ Args:
+ rhs: dict
+
+ Returns:
+ bool
+ """
+
+ try:
+ return self._key in rhs
+ except Exception:
+ return False
+
+ def __repr__(self):
+ return '<sequence or map containing \'%s\'>' % str(self._key)
+
+
+class Not(Comparator):
+ """Checks whether a predicates is False.
+
+ Example:
+ mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info)))
+ """
+
+ def __init__(self, predicate):
+ """Initialize.
+
+ Args:
+ # predicate: a Comparator instance.
+ """
+
+ assert isinstance(predicate, Comparator), ("predicate %r must be a"
+ " Comparator." % predicate)
+ self._predicate = predicate
+
+ def equals(self, rhs):
+ """Check to see whether the predicate is False.
+
+ Args:
+ rhs: A value that will be given in argument of the predicate.
+
+ Returns:
+ bool
+ """
+
+ try:
+ return not self._predicate.equals(rhs)
+ except Exception:
+ return False
+
+ def __repr__(self):
+ return '<not \'%s\'>' % self._predicate
+
+
+class ContainsKeyValue(Comparator):
+ """Checks whether a key/value pair is in a dict parameter.
+
+ Example:
+ mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
+ """
+
+ def __init__(self, key, value):
+ """Initialize.
+
+ Args:
+ # key: a key in a dict
+ # value: the corresponding value
+ """
+
+ self._key = key
+ self._value = value
+
+ def equals(self, rhs):
+ """Check whether the given key/value pair is in the rhs dict.
+
+ Returns:
+ bool
+ """
+
+ try:
+ return rhs[self._key] == self._value
+ except Exception:
+ return False
+
+ def __repr__(self):
+ return '<map containing the entry \'%s: %s\'>' % (str(self._key),
+ str(self._value))
+
+
+class ContainsAttributeValue(Comparator):
+ """Checks whether a passed parameter contains attributes with a given value.
+
+ Example:
+ mock_dao.UpdateSomething(ContainsAttribute('stevepm', stevepm_user_info))
+ """
+
+ def __init__(self, key, value):
+ """Initialize.
+
+ Args:
+ # key: an attribute name of an object
+ # value: the corresponding value
+ """
+
+ self._key = key
+ self._value = value
+
+ def equals(self, rhs):
+ """Check if the given attribute has a matching value in the rhs object.
+
+ Returns:
+ bool
+ """
+
+ try:
+ return getattr(rhs, self._key) == self._value
+ except Exception:
+ return False
+
+
+class SameElementsAs(Comparator):
+ """Checks whether sequences contain the same elements (ignoring order).
+
+ Example:
+ mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
+ """
+
+ def __init__(self, expected_seq):
+ """Initialize.
+
+ Args:
+ expected_seq: a sequence
+ """
+ # Store in case expected_seq is an iterator.
+ self._expected_list = list(expected_seq)
+
+ def equals(self, actual_seq):
+ """Check to see whether actual_seq has same elements as expected_seq.
+
+ Args:
+ actual_seq: sequence
+
+ Returns:
+ bool
+ """
+ try:
+ # Store in case actual_seq is an iterator. We potentially iterate twice:
+ # once to make the dict, once in the list fallback.
+ actual_list = list(actual_seq)
+ except TypeError:
+ # actual_seq cannot be read as a sequence.
+ #
+ # This happens because Mox uses __eq__ both to check object equality (in
+ # MethodSignatureChecker) and to invoke Comparators.
+ return False
+
+ try:
+ return set(self._expected_list) == set(actual_list)
+ except TypeError:
+ # Fall back to slower list-compare if any of the objects are unhashable.
+ if len(self._expected_list) != len(actual_list):
+ return False
+ for el in actual_list:
+ if el not in self._expected_list:
+ return False
+ return True
+
+ def __repr__(self):
+ return '<sequence with same elements as \'%s\'>' % self._expected_list
+
+
+class And(Comparator):
+ """Evaluates one or more Comparators on RHS, returns an AND of the results.
+ """
+
+ def __init__(self, *args):
+ """Initialize.
+
+ Args:
+ *args: One or more Comparator
+ """
+
+ self._comparators = args
+
+ def equals(self, rhs):
+ """Checks whether all Comparators are equal to rhs.
+
+ Args:
+ # rhs: can be anything
+
+ Returns:
+ bool
+ """
+
+ for comparator in self._comparators:
+ if not comparator.equals(rhs):
+ return False
+
+ return True
+
+ def __repr__(self):
+ return '<AND %s>' % str(self._comparators)
+
+
+class Or(Comparator):
+ """Evaluates one or more Comparators on RHS and returns an OR of the results.
+ """
+
+ def __init__(self, *args):
+ """Initialize.
+
+ Args:
+ *args: One or more Mox comparators
+ """
+
+ self._comparators = args
+
+ def equals(self, rhs):
+ """Checks whether any Comparator is equal to rhs.
+
+ Args:
+ # rhs: can be anything
+
+ Returns:
+ bool
+ """
+
+ for comparator in self._comparators:
+ if comparator.equals(rhs):
+ return True
+
+ return False
+
+ def __repr__(self):
+ return '<OR %s>' % str(self._comparators)
+
+
+class Func(Comparator):
+ """Call a function that should verify the parameter passed in is correct.
+
+ You may need the ability to perform more advanced operations on the parameter
+ in order to validate it. You can use this to have a callable validate any
+ parameter. The callable should return either True or False.
+
+
+ Example:
+
+ def myParamValidator(param):
+ # Advanced logic here
+ return True
+
+ mock_dao.DoSomething(Func(myParamValidator), true)
+ """
+
+ def __init__(self, func):
+ """Initialize.
+
+ Args:
+ func: callable that takes one parameter and returns a bool
+ """
+
+ self._func = func
+
+ def equals(self, rhs):
+ """Test whether rhs passes the function test.
+
+ rhs is passed into func.
+
+ Args:
+ rhs: any python object
+
+ Returns:
+ the result of func(rhs)
+ """
+
+ return self._func(rhs)
+
+ def __repr__(self):
+ return str(self._func)
+
+
+class IgnoreArg(Comparator):
+ """Ignore an argument.
+
+ This can be used when we don't care about an argument of a method call.
+
+ Example:
+ # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
+ mymock.CastMagic(3, IgnoreArg(), 'disappear')
+ """
+
+ def equals(self, unused_rhs):
+ """Ignores arguments and returns True.
+
+ Args:
+ unused_rhs: any python object
+
+ Returns:
+ always returns True
+ """
+
+ return True
+
+ def __repr__(self):
+ return '<IgnoreArg>'
+
+
+class Value(Comparator):
+ """Compares argument against a remembered value.
+
+ To be used in conjunction with Remember comparator. See Remember()
+ for example.
+ """
+
+ def __init__(self):
+ self._value = None
+ self._has_value = False
+
+ def store_value(self, rhs):
+ self._value = rhs
+ self._has_value = True
+
+ def equals(self, rhs):
+ if not self._has_value:
+ return False
+ else:
+ return rhs == self._value
+
+ def __repr__(self):
+ if self._has_value:
+ return "<Value %r>" % self._value
+ else:
+ return "<Value>"
+
+
+class Remember(Comparator):
+ """Remembers the argument to a value store.
+
+ To be used in conjunction with Value comparator.
+
+ Example:
+ # Remember the argument for one method call.
+ users_list = Value()
+ mock_dao.ProcessUsers(Remember(users_list))
+
+ # Check argument against remembered value.
+ mock_dao.ReportUsers(users_list)
+ """
+
+ def __init__(self, value_store):
+ if not isinstance(value_store, Value):
+ raise TypeError("value_store is not an instance of the Value class")
+ self._value_store = value_store
+
+ def equals(self, rhs):
+ self._value_store.store_value(rhs)
+ return True
+
+ def __repr__(self):
+ return "<Remember %d>" % id(self._value_store)
+
+
+class MethodGroup(object):
+ """Base class containing common behaviour for MethodGroups."""
+
+ def __init__(self, group_name):
+ self._group_name = group_name
+
+ def group_name(self):
+ return self._group_name
+
+ def __str__(self):
+ return '<%s "%s">' % (self.__class__.__name__, self._group_name)
+
+ def AddMethod(self, mock_method):
+ raise NotImplementedError
+
+ def MethodCalled(self, mock_method):
+ raise NotImplementedError
+
+ def IsSatisfied(self):
+ raise NotImplementedError
+
+
+class UnorderedGroup(MethodGroup):
+ """UnorderedGroup holds a set of method calls that may occur in any order.
+
+ This construct is helpful for non-deterministic events, such as iterating
+ over the keys of a dict.
+ """
+
+ def __init__(self, group_name):
+ super(UnorderedGroup, self).__init__(group_name)
+ self._methods = []
+
+ def __str__(self):
+ return '%s "%s" pending calls:\n%s' % (
+ self.__class__.__name__,
+ self._group_name,
+ "\n".join(str(method) for method in self._methods))
+
+ def AddMethod(self, mock_method):
+ """Add a method to this group.
+
+ Args:
+ mock_method: A mock method to be added to this group.
+ """
+
+ self._methods.append(mock_method)
+
+ def MethodCalled(self, mock_method):
+ """Remove a method call from the group.
+
+ If the method is not in the set, an UnexpectedMethodCallError will be
+ raised.
+
+ Args:
+ mock_method: a mock method that should be equal to a method in the group.
+
+ Returns:
+ The mock method from the group
+
+ Raises:
+ UnexpectedMethodCallError if the mock_method was not in the group.
+ """
+
+ # Check to see if this method exists, and if so, remove it from the set
+ # and return it.
+ for method in self._methods:
+ if method == mock_method:
+ # Remove the called mock_method instead of the method in the group.
+ # The called method will match any comparators when equality is checked
+ # during removal. The method in the group could pass a comparator to
+ # another comparator during the equality check.
+ self._methods.remove(mock_method)
+
+ # If this group is not empty, put it back at the head of the queue.
+ if not self.IsSatisfied():
+ mock_method._call_queue.appendleft(self)
+
+ return self, method
+
+ raise UnexpectedMethodCallError(mock_method, self)
+
+ def IsSatisfied(self):
+ """Return True if there are not any methods in this group."""
+
+ return len(self._methods) == 0
+
+
+class MultipleTimesGroup(MethodGroup):
+ """MultipleTimesGroup holds methods that may be called any number of times.
+
+ Note: Each method must be called at least once.
+
+ This is helpful, if you don't know or care how many times a method is called.
+ """
+
+ def __init__(self, group_name):
+ super(MultipleTimesGroup, self).__init__(group_name)
+ self._methods = set()
+ self._methods_left = set()
+
+ def AddMethod(self, mock_method):
+ """Add a method to this group.
+
+ Args:
+ mock_method: A mock method to be added to this group.
+ """
+
+ self._methods.add(mock_method)
+ self._methods_left.add(mock_method)
+
+ def MethodCalled(self, mock_method):
+ """Remove a method call from the group.
+
+ If the method is not in the set, an UnexpectedMethodCallError will be
+ raised.
+
+ Args:
+ mock_method: a mock method that should be equal to a method in the group.
+
+ Returns:
+ The mock method from the group
+
+ Raises:
+ UnexpectedMethodCallError if the mock_method was not in the group.
+ """
+
+ # Check to see if this method exists, and if so add it to the set of
+ # called methods.
+ for method in self._methods:
+ if method == mock_method:
+ self._methods_left.discard(method)
+ # Always put this group back on top of the queue, because we don't know
+ # when we are done.
+ mock_method._call_queue.appendleft(self)
+ return self, method
+
+ if self.IsSatisfied():
+ next_method = mock_method._PopNextMethod()
+ return next_method, None
+ else:
+ raise UnexpectedMethodCallError(mock_method, self)
+
+ def IsSatisfied(self):
+ """Return True if all methods in this group are called at least once."""
+ return len(self._methods_left) == 0
+
+
+class MoxMetaTestBase(type):
+ """Metaclass to add mox cleanup and verification to every test.
+
+ As the mox unit testing class is being constructed (MoxTestBase or a
+ subclass), this metaclass will modify all test functions to call the
+ CleanUpMox method of the test class after they finish. This means that
+ unstubbing and verifying will happen for every test with no additional code,
+ and any failures will result in test failures as opposed to errors.
+ """
+
+ def __init__(cls, name, bases, d):
+ type.__init__(cls, name, bases, d)
+
+ # also get all the attributes from the base classes to account
+ # for a case when test class is not the immediate child of MoxTestBase
+ for base in bases:
+ for attr_name in dir(base):
+ if attr_name not in d:
+ d[attr_name] = getattr(base, attr_name)
+
+ for func_name, func in d.items():
+ if func_name.startswith('test') and callable(func):
+
+ setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
+
+ @staticmethod
+ def CleanUpTest(cls, func):
+ """Adds Mox cleanup code to any MoxTestBase method.
+
+ Always unsets stubs after a test. Will verify all mocks for tests that
+ otherwise pass.
+
+ Args:
+ cls: MoxTestBase or subclass; the class whose method we are altering.
+ func: method; the method of the MoxTestBase test class we wish to alter.
+
+ Returns:
+ The modified method.
+ """
+ def new_method(self, *args, **kwargs):
+ mox_obj = getattr(self, 'mox', None)
+ stubout_obj = getattr(self, 'stubs', None)
+ cleanup_mox = False
+ cleanup_stubout = False
+ if mox_obj and isinstance(mox_obj, Mox):
+ cleanup_mox = True
+ if stubout_obj and isinstance(stubout_obj, stubout.StubOutForTesting):
+ cleanup_stubout = True
+ try:
+ func(self, *args, **kwargs)
+ finally:
+ if cleanup_mox:
+ mox_obj.UnsetStubs()
+ if cleanup_stubout:
+ stubout_obj.UnsetAll()
+ stubout_obj.SmartUnsetAll()
+ if cleanup_mox:
+ mox_obj.VerifyAll()
+ new_method.__name__ = func.__name__
+ new_method.__doc__ = func.__doc__
+ new_method.__module__ = func.__module__
+ return new_method
+
+
+_MoxTestBase = MoxMetaTestBase('_MoxTestBase', (unittest.TestCase, ), {})
+
+
+class MoxTestBase(_MoxTestBase):
+ """Convenience test class to make stubbing easier.
+
+ Sets up a "mox" attribute which is an instance of Mox (any mox tests will
+ want this), and a "stubs" attribute that is an instance of StubOutForTesting
+ (needed at times). Also automatically unsets any stubs and verifies that all
+ mock methods have been called at the end of each test, eliminating
+ boilerplate code.
+ """
+
+ def setUp(self):
+ super(MoxTestBase, self).setUp()
+ self.mox = Mox()
+ self.stubs = stubout.StubOutForTesting()