"""Client for cache server. See cachesvr.py for protocol description. """ import argparse import trollius as asyncio from trollius import From, Return from trollius import test_utils import json import logging ARGS = argparse.ArgumentParser(description='Cache client example.') ARGS.add_argument( '--tls', action='store_true', dest='tls', default=False, help='Use TLS') ARGS.add_argument( '--iocp', action='store_true', dest='iocp', default=False, help='Use IOCP event loop (Windows only)') ARGS.add_argument( '--host', action='store', dest='host', default='localhost', help='Host name') ARGS.add_argument( '--port', action='store', dest='port', default=54321, type=int, help='Port number') ARGS.add_argument( '--timeout', action='store', dest='timeout', default=5, type=float, help='Timeout') ARGS.add_argument( '--max_backoff', action='store', dest='max_backoff', default=5, type=float, help='Max backoff on reconnect') ARGS.add_argument( '--ntasks', action='store', dest='ntasks', default=10, type=int, help='Number of tester tasks') ARGS.add_argument( '--ntries', action='store', dest='ntries', default=5, type=int, help='Number of request tries before giving up') args = ARGS.parse_args() class CacheClient: """Multiplexing cache client. This wraps a single connection to the cache client. The connection is automatically re-opened when an error occurs. Multiple tasks may share this object; the requests will be serialized. The public API is get(), set(), delete() (all are coroutines). """ def __init__(self, host, port, sslctx=None, loop=None): self.host = host self.port = port self.sslctx = sslctx self.loop = loop self.todo = set() self.initialized = False self.task = asyncio.Task(self.activity(), loop=self.loop) @asyncio.coroutine def get(self, key): resp = yield From(self.request('get', key)) if resp is None: raise Return() raise Return(resp.get('value')) @asyncio.coroutine def set(self, key, value): resp = yield From(self.request('set', key, value)) if resp is None: raise Return(False) raise Return(resp.get('status') == 'ok') @asyncio.coroutine def delete(self, key): resp = yield From(self.request('delete', key)) if resp is None: raise Return(False) raise Return(resp.get('status') == 'ok') @asyncio.coroutine def request(self, type, key, value=None): assert not self.task.done() data = {'type': type, 'key': key} if value is not None: data['value'] = value payload = json.dumps(data).encode('utf8') waiter = asyncio.Future(loop=self.loop) if self.initialized: try: yield From(self.send(payload, waiter)) except IOError: self.todo.add((payload, waiter)) else: self.todo.add((payload, waiter)) result = (yield From(waiter)) raise Return(result) @asyncio.coroutine def activity(self): backoff = 0 while True: try: self.reader, self.writer = yield From(asyncio.open_connection( self.host, self.port, ssl=self.sslctx, loop=self.loop)) except Exception as exc: backoff = min(args.max_backoff, backoff + (backoff//2) + 1) logging.info('Error connecting: %r; sleep %s', exc, backoff) yield From(asyncio.sleep(backoff, loop=self.loop)) continue backoff = 0 self.next_id = 0 self.pending = {} self. initialized = True try: while self.todo: payload, waiter = self.todo.pop() if not waiter.done(): yield From(self.send(payload, waiter)) while True: resp_id, resp = yield From(self.process()) if resp_id in self.pending: payload, waiter = self.pending.pop(resp_id) if not waiter.done(): waiter.set_result(resp) except Exception as exc: self.initialized = False self.writer.close() while self.pending: req_id, pair = self.pending.popitem() payload, waiter = pair if not waiter.done(): self.todo.add(pair) logging.info('Error processing: %r', exc) @asyncio.coroutine def send(self, payload, waiter): self.next_id += 1 req_id = self.next_id frame = 'request %d %d\n' % (req_id, len(payload)) self.writer.write(frame.encode('ascii')) self.writer.write(payload) self.pending[req_id] = payload, waiter yield From(self.writer.drain()) @asyncio.coroutine def process(self): frame = yield From(self.reader.readline()) if not frame: raise EOFError() head, tail = frame.split(None, 1) if head == b'error': raise IOError('OOB error: %r' % tail) if head != b'response': raise IOError('Bad frame: %r' % frame) resp_id, resp_size = map(int, tail.split()) data = yield From(self.reader.readexactly(resp_size)) if len(data) != resp_size: raise EOFError() resp = json.loads(data.decode('utf8')) raise Return(resp_id, resp) def main(): asyncio.set_event_loop(None) if args.iocp: from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() else: loop = asyncio.new_event_loop() sslctx = None if args.tls: sslctx = test_utils.dummy_ssl_context() cache = CacheClient(args.host, args.port, sslctx=sslctx, loop=loop) try: loop.run_until_complete( asyncio.gather( *[testing(i, cache, loop) for i in range(args.ntasks)], loop=loop)) finally: loop.close() @asyncio.coroutine def testing(label, cache, loop): def w(g): return asyncio.wait_for(g, args.timeout, loop=loop) key = 'foo-%s' % label while True: logging.info('%s %s', label, '-'*20) try: ret = yield From(w(cache.set(key, 'hello-%s-world' % label))) logging.info('%s set %s', label, ret) ret = yield From(w(cache.get(key))) logging.info('%s get %s', label, ret) ret = yield From(w(cache.delete(key))) logging.info('%s del %s', label, ret) ret = yield From(w(cache.get(key))) logging.info('%s get2 %s', label, ret) except asyncio.TimeoutError: logging.warn('%s Timeout', label) except Exception as exc: logging.exception('%s Client exception: %r', label, exc) break if __name__ == '__main__': logging.basicConfig(level=logging.INFO) main()