summaryrefslogtreecommitdiff
path: root/examples/cacheclt.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/cacheclt.py')
-rw-r--r--examples/cacheclt.py213
1 files changed, 213 insertions, 0 deletions
diff --git a/examples/cacheclt.py b/examples/cacheclt.py
new file mode 100644
index 0000000..b11a4d1
--- /dev/null
+++ b/examples/cacheclt.py
@@ -0,0 +1,213 @@
+"""Client for cache server.
+
+See cachesvr.py for protocol description.
+"""
+
+import argparse
+import asyncio
+from asyncio import test_utils
+import json
+import logging
+
+ARGS = argparse.ArgumentParser(description='Cache client example.')
+ARGS.add_argument(
+ '--tls', action='store_true', dest='tls',
+ default=False, help='Use TLS')
+ARGS.add_argument(
+ '--iocp', action='store_true', dest='iocp',
+ default=False, help='Use IOCP event loop (Windows only)')
+ARGS.add_argument(
+ '--host', action='store', dest='host',
+ default='localhost', help='Host name')
+ARGS.add_argument(
+ '--port', action='store', dest='port',
+ default=54321, type=int, help='Port number')
+ARGS.add_argument(
+ '--timeout', action='store', dest='timeout',
+ default=5, type=float, help='Timeout')
+ARGS.add_argument(
+ '--max_backoff', action='store', dest='max_backoff',
+ default=5, type=float, help='Max backoff on reconnect')
+ARGS.add_argument(
+ '--ntasks', action='store', dest='ntasks',
+ default=10, type=int, help='Number of tester tasks')
+ARGS.add_argument(
+ '--ntries', action='store', dest='ntries',
+ default=5, type=int, help='Number of request tries before giving up')
+
+
+args = ARGS.parse_args()
+
+
+class CacheClient:
+ """Multiplexing cache client.
+
+ This wraps a single connection to the cache client. The
+ connection is automatically re-opened when an error occurs.
+
+ Multiple tasks may share this object; the requests will be
+ serialized.
+
+ The public API is get(), set(), delete() (all are coroutines).
+ """
+
+ def __init__(self, host, port, sslctx=None, loop=None):
+ self.host = host
+ self.port = port
+ self.sslctx = sslctx
+ self.loop = loop
+ self.todo = set()
+ self.initialized = False
+ self.task = asyncio.Task(self.activity(), loop=self.loop)
+
+ @asyncio.coroutine
+ def get(self, key):
+ resp = yield from self.request('get', key)
+ if resp is None:
+ return None
+ return resp.get('value')
+
+ @asyncio.coroutine
+ def set(self, key, value):
+ resp = yield from self.request('set', key, value)
+ if resp is None:
+ return False
+ return resp.get('status') == 'ok'
+
+ @asyncio.coroutine
+ def delete(self, key):
+ resp = yield from self.request('delete', key)
+ if resp is None:
+ return False
+ return resp.get('status') == 'ok'
+
+ @asyncio.coroutine
+ def request(self, type, key, value=None):
+ assert not self.task.done()
+ data = {'type': type, 'key': key}
+ if value is not None:
+ data['value'] = value
+ payload = json.dumps(data).encode('utf8')
+ waiter = asyncio.Future(loop=self.loop)
+ if self.initialized:
+ try:
+ yield from self.send(payload, waiter)
+ except IOError:
+ self.todo.add((payload, waiter))
+ else:
+ self.todo.add((payload, waiter))
+ return (yield from waiter)
+
+ @asyncio.coroutine
+ def activity(self):
+ backoff = 0
+ while True:
+ try:
+ self.reader, self.writer = yield from asyncio.open_connection(
+ self.host, self.port, ssl=self.sslctx, loop=self.loop)
+ except Exception as exc:
+ backoff = min(args.max_backoff, backoff + (backoff//2) + 1)
+ logging.info('Error connecting: %r; sleep %s', exc, backoff)
+ yield from asyncio.sleep(backoff, loop=self.loop)
+ continue
+ backoff = 0
+ self.next_id = 0
+ self.pending = {}
+ self. initialized = True
+ try:
+ while self.todo:
+ payload, waiter = self.todo.pop()
+ if not waiter.done():
+ yield from self.send(payload, waiter)
+ while True:
+ resp_id, resp = yield from self.process()
+ if resp_id in self.pending:
+ payload, waiter = self.pending.pop(resp_id)
+ if not waiter.done():
+ waiter.set_result(resp)
+ except Exception as exc:
+ self.initialized = False
+ self.writer.close()
+ while self.pending:
+ req_id, pair = self.pending.popitem()
+ payload, waiter = pair
+ if not waiter.done():
+ self.todo.add(pair)
+ logging.info('Error processing: %r', exc)
+
+ @asyncio.coroutine
+ def send(self, payload, waiter):
+ self.next_id += 1
+ req_id = self.next_id
+ frame = 'request %d %d\n' % (req_id, len(payload))
+ self.writer.write(frame.encode('ascii'))
+ self.writer.write(payload)
+ self.pending[req_id] = payload, waiter
+ yield from self.writer.drain()
+
+ @asyncio.coroutine
+ def process(self):
+ frame = yield from self.reader.readline()
+ if not frame:
+ raise EOFError()
+ head, tail = frame.split(None, 1)
+ if head == b'error':
+ raise IOError('OOB error: %r' % tail)
+ if head != b'response':
+ raise IOError('Bad frame: %r' % frame)
+ resp_id, resp_size = map(int, tail.split())
+ data = yield from self.reader.readexactly(resp_size)
+ if len(data) != resp_size:
+ raise EOFError()
+ resp = json.loads(data.decode('utf8'))
+ return resp_id, resp
+
+
+def main():
+ asyncio.set_event_loop(None)
+ if args.iocp:
+ from asyncio.windows_events import ProactorEventLoop
+ loop = ProactorEventLoop()
+ else:
+ loop = asyncio.new_event_loop()
+ sslctx = None
+ if args.tls:
+ sslctx = test_utils.dummy_ssl_context()
+ cache = CacheClient(args.host, args.port, sslctx=sslctx, loop=loop)
+ try:
+ loop.run_until_complete(
+ asyncio.gather(
+ *[testing(i, cache, loop) for i in range(args.ntasks)],
+ loop=loop))
+ finally:
+ loop.close()
+
+
+@asyncio.coroutine
+def testing(label, cache, loop):
+
+ def w(g):
+ return asyncio.wait_for(g, args.timeout, loop=loop)
+
+ key = 'foo-%s' % label
+ while True:
+ logging.info('%s %s', label, '-'*20)
+ try:
+ ret = yield from w(cache.set(key, 'hello-%s-world' % label))
+ logging.info('%s set %s', label, ret)
+ ret = yield from w(cache.get(key))
+ logging.info('%s get %s', label, ret)
+ ret = yield from w(cache.delete(key))
+ logging.info('%s del %s', label, ret)
+ ret = yield from w(cache.get(key))
+ logging.info('%s get2 %s', label, ret)
+ except asyncio.TimeoutError:
+ logging.warn('%s Timeout', label)
+ except Exception as exc:
+ logging.exception('%s Client exception: %r', label, exc)
+ break
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ main()