diff options
author | Guido van Rossum <guido@python.org> | 2013-11-19 11:32:45 -0800 |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2013-11-19 11:32:45 -0800 |
commit | f2c6d5a40ce3e7ebc90b6f103527b9d6b9d47376 (patch) | |
tree | b6ccc8d0b6249339a0bd03065fc25eb4971a9df1 | |
parent | c391fb84370c5e865eaab1687026d28cda1aac63 (diff) | |
download | trollius-f2c6d5a40ce3e7ebc90b6f103527b9d6b9d47376.tar.gz |
Add streams.start_server(), by Gustavo Carneiro.
-rw-r--r-- | asyncio/streams.py | 53 | ||||
-rw-r--r-- | examples/simple_tcp_server.py | 151 | ||||
-rw-r--r-- | tests/test_streams.py | 66 |
3 files changed, 268 insertions, 2 deletions
diff --git a/asyncio/streams.py b/asyncio/streams.py index e995368..331d28d 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -1,6 +1,8 @@ """Stream-related things.""" -__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] +__all__ = ['StreamReader', 'StreamReaderProtocol', + 'open_connection', 'start_server', + ] import collections @@ -43,6 +45,42 @@ def open_connection(host=None, port=None, *, return reader, writer +@tasks.coroutine +def start_server(client_connected_cb, host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Start a socket server, call back for each client connected. + + The first parameter, `client_connected_cb`, takes two parameters: + client_reader, client_writer. client_reader is a StreamReader + object, while client_writer is a StreamWriter object. This + parameter can either be a plain callback function or a coroutine; + if it is a coroutine, it will be automatically converted into a + Task. + + The rest of the arguments are all the usual arguments to + loop.create_server() except protocol_factory; most common are + positional host and port, with various optional keyword arguments + following. The return value is the same as loop.create_server(). + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + The return value is the same as loop.create_server(), i.e. a + Server object which can be used to stop the service. + """ + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_server(factory, host, port, **kwds)) + + class StreamReaderProtocol(protocols.Protocol): """Trivial helper class to adapt between Protocol and StreamReader. @@ -52,13 +90,24 @@ class StreamReaderProtocol(protocols.Protocol): call inappropriate methods of the protocol.) """ - def __init__(self, stream_reader): + def __init__(self, stream_reader, client_connected_cb=None, loop=None): self._stream_reader = stream_reader + self._stream_writer = None self._drain_waiter = None self._paused = False + self._client_connected_cb = client_connected_cb + self._loop = loop # May be None; we may never need it. def connection_made(self, transport): self._stream_reader.set_transport(transport) + if self._client_connected_cb is not None: + self._stream_writer = StreamWriter(transport, self, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, + self._stream_writer) + if tasks.iscoroutine(res): + tasks.Task(res, loop=self._loop) def connection_lost(self, exc): if exc is None: diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py new file mode 100644 index 0000000..c36710d --- /dev/null +++ b/examples/simple_tcp_server.py @@ -0,0 +1,151 @@ +""" +Example of a simple TCP server that is written in (mostly) coroutine +style and uses asyncio.streams.start_server() and +asyncio.streams.open_connection(). + +Note that running this example starts both the TCP server and client +in the same process. It listens on port 1234 on 127.0.0.1, so it will +fail if this port is currently in use. +""" + +import sys +import asyncio +import asyncio.streams + + +class MyServer: + """ + This is just an example of how a TCP server might be potentially + structured. This class has basically 3 methods: start the server, + handle a client, and stop the server. + + Note that you don't have to follow this structure, it is really + just an example or possible starting point. + """ + + def __init__(self): + self.server = None # encapsulates the server sockets + + # this keeps track of all the clients that connected to our + # server. It can be useful in some cases, for instance to + # kill client connections or to broadcast some data to all + # clients... + self.clients = {} # task -> (reader, writer) + + def _accept_client(self, client_reader, client_writer): + """ + This method accepts a new client connection and creates a Task + to handle this client. self.clients is updated to keep track + of the new client. + """ + + # start a new Task to handle this specific client connection + task = asyncio.Task(self._handle_client(client_reader, client_writer)) + self.clients[task] = (client_reader, client_writer) + + def client_done(task): + print("client task done:", task, file=sys.stderr) + del self.clients[task] + + task.add_done_callback(client_done) + + @asyncio.coroutine + def _handle_client(self, client_reader, client_writer): + """ + This method actually does the work to handle the requests for + a specific client. The protocol is line oriented, so there is + a main loop that reads a line with a request and then sends + out one or more lines back to the client with the result. + """ + while True: + data = (yield from client_reader.readline()).decode("utf-8") + if not data: # an empty string means the client disconnected + break + cmd, *args = data.rstrip().split(' ') + if cmd == 'add': + arg1 = float(args[0]) + arg2 = float(args[1]) + retval = arg1 + arg2 + client_writer.write("{!r}\n".format(retval).encode("utf-8")) + elif cmd == 'repeat': + times = int(args[0]) + msg = args[1] + client_writer.write("begin\n".encode("utf-8")) + for idx in range(times): + client_writer.write("{}. {}\n".format(idx+1, msg) + .encode("utf-8")) + client_writer.write("end\n".encode("utf-8")) + else: + print("Bad command {!r}".format(data), file=sys.stderr) + + # This enables us to have flow control in our connection. + yield from client_writer.drain() + + def start(self, loop): + """ + Starts the TCP server, so that it listens on port 1234. + + For each client that connects, the accept_client method gets + called. This method runs the loop until the server sockets + are ready to accept connections. + """ + self.server = loop.run_until_complete( + asyncio.streams.start_server(self._accept_client, + '127.0.0.1', 12345, + loop=loop)) + + def stop(self, loop): + """ + Stops the TCP server, i.e. closes the listening socket(s). + + This method runs the loop until the server sockets are closed. + """ + if self.server is not None: + self.server.close() + loop.run_until_complete(self.server.wait_closed()) + self.server = None + + +def main(): + loop = asyncio.get_event_loop() + + # creates a server and starts listening to TCP connections + server = MyServer() + server.start(loop) + + @asyncio.coroutine + def client(): + reader, writer = yield from asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop) + + def send(msg): + print("> " + msg) + writer.write((msg + '\n').encode("utf-8")) + + def recv(): + msgback = (yield from reader.readline()).decode("utf-8").rstrip() + print("< " + msgback) + return msgback + + # send a line + send("add 1 2") + msg = yield from recv() + + send("repeat 5 hello") + msg = yield from recv() + assert msg == 'begin' + while True: + msg = yield from recv() + if msg == 'end': + break + + writer.close() + yield from asyncio.sleep(0.5) + + # creates a client and connects to our server + msg = loop.run_until_complete(client()) + server.stop(loop) + + +if __name__ == '__main__': + main() diff --git a/tests/test_streams.py b/tests/test_streams.py index 69e2246..5516c15 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -359,6 +359,72 @@ class StreamReaderTests(unittest.TestCase): test_utils.run_briefly(self.loop) self.assertIs(stream._waiter, None) + def test_start_server(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + @tasks.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + self.server = self.loop.run_until_complete( + streams.start_server(self.handle_client, + '127.0.0.1', 12345, + loop=self.loop)) + + def handle_client_callback(self, client_reader, client_writer): + task = tasks.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + self.server = self.loop.run_until_complete( + streams.start_server(self.handle_client_callback, + '127.0.0.1', 12345, + loop=self.loop)) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @tasks.coroutine + def client(): + reader, writer = yield from streams.open_connection( + '127.0.0.1', 12345, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + server = MyServer(self.loop) + server.start() + msg = self.loop.run_until_complete(tasks.Task(client(), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + server = MyServer(self.loop) + server.start_callback() + msg = self.loop.run_until_complete(tasks.Task(client(), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + if __name__ == '__main__': unittest.main() |