summaryrefslogtreecommitdiff
path: root/tests/twisted/servicetest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/twisted/servicetest.py')
-rw-r--r--tests/twisted/servicetest.py510
1 files changed, 510 insertions, 0 deletions
diff --git a/tests/twisted/servicetest.py b/tests/twisted/servicetest.py
new file mode 100644
index 0000000..0634678
--- /dev/null
+++ b/tests/twisted/servicetest.py
@@ -0,0 +1,510 @@
+
+"""
+Infrastructure code for testing connection managers.
+"""
+
+from twisted.internet import glib2reactor
+from twisted.internet.protocol import Protocol, Factory, ClientFactory
+glib2reactor.install()
+import sys
+
+import pprint
+import unittest
+
+import dbus.glib
+
+from twisted.internet import reactor
+
+import constants as cs
+
+tp_name_prefix = 'org.freedesktop.Telepathy'
+tp_path_prefix = '/org/freedesktop/Telepathy'
+
+class Event:
+ def __init__(self, type, **kw):
+ self.__dict__.update(kw)
+ self.type = type
+
+def format_event(event):
+ ret = ['- type %s' % event.type]
+
+ for key in dir(event):
+ if key != 'type' and not key.startswith('_'):
+ ret.append('- %s: %s' % (
+ key, pprint.pformat(getattr(event, key))))
+
+ if key == 'error':
+ ret.append('%s' % getattr(event, key))
+
+ return ret
+
+class EventPattern:
+ def __init__(self, type, **properties):
+ self.type = type
+ self.predicate = lambda x: True
+ if 'predicate' in properties:
+ self.predicate = properties['predicate']
+ del properties['predicate']
+ self.properties = properties
+
+ def __repr__(self):
+ properties = dict(self.properties)
+
+ if self.predicate:
+ properties['predicate'] = self.predicate
+
+ return '%s(%r, **%r)' % (
+ self.__class__.__name__, self.type, properties)
+
+ def match(self, event):
+ if event.type != self.type:
+ return False
+
+ for key, value in self.properties.iteritems():
+ try:
+ if getattr(event, key) != value:
+ return False
+ except AttributeError:
+ return False
+
+ if self.predicate(event):
+ return True
+
+ return False
+
+
+class TimeoutError(Exception):
+ pass
+
+class BaseEventQueue:
+ """Abstract event queue base class.
+
+ Implement the wait() method to have something that works.
+ """
+
+ def __init__(self, timeout=None):
+ self.verbose = False
+ self.forbidden_events = set()
+
+ if timeout is None:
+ self.timeout = 5
+ else:
+ self.timeout = timeout
+
+ def log(self, s):
+ if self.verbose:
+ print s
+
+ def log_event(self, event):
+ if self.verbose:
+ self.log('got event:')
+
+ if self.verbose:
+ map(self.log, format_event(event))
+
+ def forbid_events(self, patterns):
+ """
+ Add patterns (an iterable of EventPattern) to the set of forbidden
+ events. If a forbidden event occurs during an expect or expect_many,
+ the test will fail.
+ """
+ self.forbidden_events.update(set(patterns))
+
+ def unforbid_events(self, patterns):
+ """
+ Remove 'patterns' (an iterable of EventPattern) from the set of
+ forbidden events. These must be the same EventPattern pointers that
+ were passed to forbid_events.
+ """
+ self.forbidden_events.difference_update(set(patterns))
+
+ def _check_forbidden(self, event):
+ for e in self.forbidden_events:
+ if e.match(event):
+ print "forbidden event occurred:"
+ for x in format_event(event):
+ print x
+ assert False
+
+ def expect(self, type, **kw):
+ pattern = EventPattern(type, **kw)
+
+ while True:
+ event = self.wait()
+ self.log_event(event)
+ self._check_forbidden(event)
+
+ if pattern.match(event):
+ self.log('handled')
+ self.log('')
+ return event
+
+ self.log('not handled')
+ self.log('')
+
+ def expect_many(self, *patterns):
+ ret = [None] * len(patterns)
+
+ while None in ret:
+ try:
+ event = self.wait()
+ except TimeoutError:
+ self.log('timeout')
+ self.log('still expecting:')
+ for i, pattern in enumerate(patterns):
+ if ret[i] is None:
+ self.log(' - %r' % pattern)
+ raise
+ self.log_event(event)
+ self._check_forbidden(event)
+
+ for i, pattern in enumerate(patterns):
+ if ret[i] is None and pattern.match(event):
+ self.log('handled')
+ self.log('')
+ ret[i] = event
+ break
+ else:
+ self.log('not handled')
+ self.log('')
+
+ return ret
+
+ def demand(self, type, **kw):
+ pattern = EventPattern(type, **kw)
+
+ event = self.wait()
+ self.log_event(event)
+
+ if pattern.match(event):
+ self.log('handled')
+ self.log('')
+ return event
+
+ self.log('not handled')
+ raise RuntimeError('expected %r, got %r' % (pattern, event))
+
+class IteratingEventQueue(BaseEventQueue):
+ """Event queue that works by iterating the Twisted reactor."""
+
+ def __init__(self, timeout=None):
+ BaseEventQueue.__init__(self, timeout)
+ self.events = []
+
+ def wait(self):
+ stop = [False]
+
+ def later():
+ stop[0] = True
+
+ delayed_call = reactor.callLater(self.timeout, later)
+
+ while (not self.events) and (not stop[0]):
+ reactor.iterate(0.1)
+
+ if self.events:
+ delayed_call.cancel()
+ return self.events.pop(0)
+ else:
+ raise TimeoutError
+
+ def append(self, event):
+ self.events.append(event)
+
+ # compatibility
+ handle_event = append
+
+class TestEventQueue(BaseEventQueue):
+ def __init__(self, events):
+ BaseEventQueue.__init__(self)
+ self.events = events
+
+ def wait(self):
+ if self.events:
+ return self.events.pop(0)
+ else:
+ raise TimeoutError
+
+class EventQueueTest(unittest.TestCase):
+ def test_expect(self):
+ queue = TestEventQueue([Event('foo'), Event('bar')])
+ assert queue.expect('foo').type == 'foo'
+ assert queue.expect('bar').type == 'bar'
+
+ def test_expect_many(self):
+ queue = TestEventQueue([Event('foo'), Event('bar')])
+ bar, foo = queue.expect_many(
+ EventPattern('bar'),
+ EventPattern('foo'))
+ assert bar.type == 'bar'
+ assert foo.type == 'foo'
+
+ def test_expect_many2(self):
+ # Test that events are only matched against patterns that haven't yet
+ # been matched. This tests a regression.
+ queue = TestEventQueue([Event('foo', x=1), Event('foo', x=2)])
+ foo1, foo2 = queue.expect_many(
+ EventPattern('foo'),
+ EventPattern('foo'))
+ assert foo1.type == 'foo' and foo1.x == 1
+ assert foo2.type == 'foo' and foo2.x == 2
+
+ def test_timeout(self):
+ queue = TestEventQueue([])
+ self.assertRaises(TimeoutError, queue.expect, 'foo')
+
+ def test_demand(self):
+ queue = TestEventQueue([Event('foo'), Event('bar')])
+ foo = queue.demand('foo')
+ assert foo.type == 'foo'
+
+ def test_demand_fail(self):
+ queue = TestEventQueue([Event('foo'), Event('bar')])
+ self.assertRaises(RuntimeError, queue.demand, 'bar')
+
+def unwrap(x):
+ """Hack to unwrap D-Bus values, so that they're easier to read when
+ printed."""
+
+ if isinstance(x, list):
+ return map(unwrap, x)
+
+ if isinstance(x, tuple):
+ return tuple(map(unwrap, x))
+
+ if isinstance(x, dict):
+ return dict([(unwrap(k), unwrap(v)) for k, v in x.iteritems()])
+
+ if isinstance(x, dbus.Boolean):
+ return bool(x)
+
+ for t in [unicode, str, long, int, float]:
+ if isinstance(x, t):
+ return t(x)
+
+ return x
+
+def call_async(test, proxy, method, *args, **kw):
+ """Call a D-Bus method asynchronously and generate an event for the
+ resulting method return/error."""
+
+ def reply_func(*ret):
+ test.handle_event(Event('dbus-return', method=method,
+ value=unwrap(ret)))
+
+ def error_func(err):
+ test.handle_event(Event('dbus-error', method=method, error=err,
+ name=err.get_dbus_name(), message=str(err)))
+
+ method_proxy = getattr(proxy, method)
+ kw.update({'reply_handler': reply_func, 'error_handler': error_func})
+ method_proxy(*args, **kw)
+
+def sync_dbus(bus, q, conn):
+ # Dummy D-Bus method call
+ # This won't do the right thing unless the proxy has a unique name.
+ assert conn.object.bus_name.startswith(':')
+ root_object = bus.get_object(conn.object.bus_name, '/')
+ call_async(
+ q, dbus.Interface(root_object, 'org.freedesktop.DBus.Peer'), 'Ping')
+ q.expect('dbus-return', method='Ping')
+
+class ProxyWrapper:
+ def __init__(self, object, default, others):
+ self.object = object
+ self.default_interface = dbus.Interface(object, default)
+ self.Properties = dbus.Interface(object, dbus.PROPERTIES_IFACE)
+ self.TpProperties = \
+ dbus.Interface(object, tp_name_prefix + '.Properties')
+ self.interfaces = dict([
+ (name, dbus.Interface(object, iface))
+ for name, iface in others.iteritems()])
+
+ def __getattr__(self, name):
+ if name in self.interfaces:
+ return self.interfaces[name]
+
+ if name in self.object.__dict__:
+ return getattr(self.object, name)
+
+ return getattr(self.default_interface, name)
+
+def wrap_connection(conn):
+ return ProxyWrapper(conn, tp_name_prefix + '.Connection',
+ dict([
+ (name, tp_name_prefix + '.Connection.Interface.' + name)
+ for name in ['Aliasing', 'Avatars', 'Capabilities', 'Contacts',
+ 'Presence', 'SimplePresence', 'Requests']] +
+ [('Peer', 'org.freedesktop.DBus.Peer'),
+ ('ContactCapabilities', cs.CONN_IFACE_CONTACT_CAPS),
+ ('Location', cs.CONN_IFACE_LOCATION),
+ ('Future', tp_name_prefix + '.Connection.FUTURE'),
+ ]))
+
+def wrap_channel(chan, type_, extra=None):
+ interfaces = {
+ type_: tp_name_prefix + '.Channel.Type.' + type_,
+ 'Group': tp_name_prefix + '.Channel.Interface.Group',
+ }
+
+ if extra:
+ interfaces.update(dict([
+ (name, tp_name_prefix + '.Channel.Interface.' + name)
+ for name in extra]))
+
+ return ProxyWrapper(chan, tp_name_prefix + '.Channel', interfaces)
+
+def make_connection(bus, event_func, name, proto, params):
+ cm = bus.get_object(
+ tp_name_prefix + '.ConnectionManager.%s' % name,
+ tp_path_prefix + '/ConnectionManager/%s' % name)
+ cm_iface = dbus.Interface(cm, tp_name_prefix + '.ConnectionManager')
+
+ connection_name, connection_path = cm_iface.RequestConnection(
+ proto, params)
+ conn = wrap_connection(bus.get_object(connection_name, connection_path))
+
+ bus.add_signal_receiver(
+ lambda *args, **kw:
+ event_func(
+ Event('dbus-signal',
+ path=unwrap(kw['path']),
+ signal=kw['member'], args=map(unwrap, args),
+ interface=kw['interface'])),
+ None, # signal name
+ None, # interface
+ None,
+ path_keyword='path',
+ member_keyword='member',
+ interface_keyword='interface',
+ byte_arrays=True
+ )
+
+ return conn
+
+def make_channel_proxy(conn, path, iface):
+ bus = dbus.SessionBus()
+ chan = bus.get_object(conn.object.bus_name, path)
+ chan = dbus.Interface(chan, tp_name_prefix + '.' + iface)
+ return chan
+
+# block_reading can be used if the test want to choose when we start to read
+# data from the socket.
+class EventProtocol(Protocol):
+ def __init__(self, queue=None, block_reading=False):
+ self.queue = queue
+ self.block_reading = block_reading
+
+ def dataReceived(self, data):
+ if self.queue is not None:
+ self.queue.handle_event(Event('socket-data', protocol=self,
+ data=data))
+
+ def sendData(self, data):
+ self.transport.write(data)
+
+ def connectionMade(self):
+ if self.block_reading:
+ self.transport.stopReading()
+
+ def connectionLost(self, reason=None):
+ if self.queue is not None:
+ self.queue.handle_event(Event('socket-disconnected', protocol=self))
+
+class EventProtocolFactory(Factory):
+ def __init__(self, queue, block_reading=False):
+ self.queue = queue
+ self.block_reading = block_reading
+
+ def _create_protocol(self):
+ return EventProtocol(self.queue, self.block_reading)
+
+ def buildProtocol(self, addr):
+ proto = self._create_protocol()
+ self.queue.handle_event(Event('socket-connected', protocol=proto))
+ return proto
+
+class EventProtocolClientFactory(EventProtocolFactory, ClientFactory):
+ pass
+
+def watch_tube_signals(q, tube):
+ def got_signal_cb(*args, **kwargs):
+ q.handle_event(Event('tube-signal',
+ path=kwargs['path'],
+ signal=kwargs['member'],
+ args=map(unwrap, args),
+ tube=tube))
+
+ tube.add_signal_receiver(got_signal_cb,
+ path_keyword='path', member_keyword='member',
+ byte_arrays=True)
+
+def pretty(x):
+ return pprint.pformat(unwrap(x))
+
+def assertEquals(expected, value):
+ if expected != value:
+ raise AssertionError(
+ "expected:\n%s\ngot:\n%s" % (pretty(expected), pretty(value)))
+
+def assertNotEquals(expected, value):
+ if expected == value:
+ raise AssertionError(
+ "expected something other than:\n%s" % pretty(value))
+
+def assertContains(element, value):
+ if element not in value:
+ raise AssertionError(
+ "expected:\n%s\nin:\n%s" % (pretty(element), pretty(value)))
+
+def assertDoesNotContain(element, value):
+ if element in value:
+ raise AssertionError(
+ "expected:\n%s\nnot in:\n%s" % (pretty(element), pretty(value)))
+
+def assertLength(length, value):
+ if len(value) != length:
+ raise AssertionError("expected: length %d, got length %d:\n%s" % (
+ length, len(value), pretty(value)))
+
+def assertFlagsSet(flags, value):
+ masked = value & flags
+ if masked != flags:
+ raise AssertionError(
+ "expected flags %u, of which only %u are set in %u" % (
+ flags, masked, value))
+
+def assertFlagsUnset(flags, value):
+ masked = value & flags
+ if masked != 0:
+ raise AssertionError(
+ "expected none of flags %u, but %u are set in %u" % (
+ flags, masked, value))
+
+def install_colourer():
+ def red(s):
+ return '\x1b[31m%s\x1b[0m' % s
+
+ def green(s):
+ return '\x1b[32m%s\x1b[0m' % s
+
+ patterns = {
+ 'handled': green,
+ 'not handled': red,
+ }
+
+ class Colourer:
+ def __init__(self, fh, patterns):
+ self.fh = fh
+ self.patterns = patterns
+
+ def write(self, s):
+ f = self.patterns.get(s, lambda x: x)
+ self.fh.write(f(s))
+
+ sys.stdout = Colourer(sys.stdout, patterns)
+ return sys.stdout
+
+if __name__ == '__main__':
+ unittest.main()
+