summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2013-11-19 11:32:45 -0800
committerGuido van Rossum <guido@python.org>2013-11-19 11:32:45 -0800
commitf2c6d5a40ce3e7ebc90b6f103527b9d6b9d47376 (patch)
treeb6ccc8d0b6249339a0bd03065fc25eb4971a9df1
parentc391fb84370c5e865eaab1687026d28cda1aac63 (diff)
downloadtrollius-f2c6d5a40ce3e7ebc90b6f103527b9d6b9d47376.tar.gz
Add streams.start_server(), by Gustavo Carneiro.
-rw-r--r--asyncio/streams.py53
-rw-r--r--examples/simple_tcp_server.py151
-rw-r--r--tests/test_streams.py66
3 files changed, 268 insertions, 2 deletions
diff --git a/asyncio/streams.py b/asyncio/streams.py
index e995368..331d28d 100644
--- a/asyncio/streams.py
+++ b/asyncio/streams.py
@@ -1,6 +1,8 @@
"""Stream-related things."""
-__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection']
+__all__ = ['StreamReader', 'StreamReaderProtocol',
+ 'open_connection', 'start_server',
+ ]
import collections
@@ -43,6 +45,42 @@ def open_connection(host=None, port=None, *,
return reader, writer
+@tasks.coroutine
+def start_server(client_connected_cb, host=None, port=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Start a socket server, call back for each client connected.
+
+ The first parameter, `client_connected_cb`, takes two parameters:
+ client_reader, client_writer. client_reader is a StreamReader
+ object, while client_writer is a StreamWriter object. This
+ parameter can either be a plain callback function or a coroutine;
+ if it is a coroutine, it will be automatically converted into a
+ Task.
+
+ The rest of the arguments are all the usual arguments to
+ loop.create_server() except protocol_factory; most common are
+ positional host and port, with various optional keyword arguments
+ following. The return value is the same as loop.create_server().
+
+ Additional optional keyword arguments are loop (to set the event loop
+ instance to use) and limit (to set the buffer limit passed to the
+ StreamReader).
+
+ The return value is the same as loop.create_server(), i.e. a
+ Server object which can be used to stop the service.
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+
+ def factory():
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, client_connected_cb,
+ loop=loop)
+ return protocol
+
+ return (yield from loop.create_server(factory, host, port, **kwds))
+
+
class StreamReaderProtocol(protocols.Protocol):
"""Trivial helper class to adapt between Protocol and StreamReader.
@@ -52,13 +90,24 @@ class StreamReaderProtocol(protocols.Protocol):
call inappropriate methods of the protocol.)
"""
- def __init__(self, stream_reader):
+ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
self._stream_reader = stream_reader
+ self._stream_writer = None
self._drain_waiter = None
self._paused = False
+ self._client_connected_cb = client_connected_cb
+ self._loop = loop # May be None; we may never need it.
def connection_made(self, transport):
self._stream_reader.set_transport(transport)
+ if self._client_connected_cb is not None:
+ self._stream_writer = StreamWriter(transport, self,
+ self._stream_reader,
+ self._loop)
+ res = self._client_connected_cb(self._stream_reader,
+ self._stream_writer)
+ if tasks.iscoroutine(res):
+ tasks.Task(res, loop=self._loop)
def connection_lost(self, exc):
if exc is None:
diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py
new file mode 100644
index 0000000..c36710d
--- /dev/null
+++ b/examples/simple_tcp_server.py
@@ -0,0 +1,151 @@
+"""
+Example of a simple TCP server that is written in (mostly) coroutine
+style and uses asyncio.streams.start_server() and
+asyncio.streams.open_connection().
+
+Note that running this example starts both the TCP server and client
+in the same process. It listens on port 1234 on 127.0.0.1, so it will
+fail if this port is currently in use.
+"""
+
+import sys
+import asyncio
+import asyncio.streams
+
+
+class MyServer:
+ """
+ This is just an example of how a TCP server might be potentially
+ structured. This class has basically 3 methods: start the server,
+ handle a client, and stop the server.
+
+ Note that you don't have to follow this structure, it is really
+ just an example or possible starting point.
+ """
+
+ def __init__(self):
+ self.server = None # encapsulates the server sockets
+
+ # this keeps track of all the clients that connected to our
+ # server. It can be useful in some cases, for instance to
+ # kill client connections or to broadcast some data to all
+ # clients...
+ self.clients = {} # task -> (reader, writer)
+
+ def _accept_client(self, client_reader, client_writer):
+ """
+ This method accepts a new client connection and creates a Task
+ to handle this client. self.clients is updated to keep track
+ of the new client.
+ """
+
+ # start a new Task to handle this specific client connection
+ task = asyncio.Task(self._handle_client(client_reader, client_writer))
+ self.clients[task] = (client_reader, client_writer)
+
+ def client_done(task):
+ print("client task done:", task, file=sys.stderr)
+ del self.clients[task]
+
+ task.add_done_callback(client_done)
+
+ @asyncio.coroutine
+ def _handle_client(self, client_reader, client_writer):
+ """
+ This method actually does the work to handle the requests for
+ a specific client. The protocol is line oriented, so there is
+ a main loop that reads a line with a request and then sends
+ out one or more lines back to the client with the result.
+ """
+ while True:
+ data = (yield from client_reader.readline()).decode("utf-8")
+ if not data: # an empty string means the client disconnected
+ break
+ cmd, *args = data.rstrip().split(' ')
+ if cmd == 'add':
+ arg1 = float(args[0])
+ arg2 = float(args[1])
+ retval = arg1 + arg2
+ client_writer.write("{!r}\n".format(retval).encode("utf-8"))
+ elif cmd == 'repeat':
+ times = int(args[0])
+ msg = args[1]
+ client_writer.write("begin\n".encode("utf-8"))
+ for idx in range(times):
+ client_writer.write("{}. {}\n".format(idx+1, msg)
+ .encode("utf-8"))
+ client_writer.write("end\n".encode("utf-8"))
+ else:
+ print("Bad command {!r}".format(data), file=sys.stderr)
+
+ # This enables us to have flow control in our connection.
+ yield from client_writer.drain()
+
+ def start(self, loop):
+ """
+ Starts the TCP server, so that it listens on port 1234.
+
+ For each client that connects, the accept_client method gets
+ called. This method runs the loop until the server sockets
+ are ready to accept connections.
+ """
+ self.server = loop.run_until_complete(
+ asyncio.streams.start_server(self._accept_client,
+ '127.0.0.1', 12345,
+ loop=loop))
+
+ def stop(self, loop):
+ """
+ Stops the TCP server, i.e. closes the listening socket(s).
+
+ This method runs the loop until the server sockets are closed.
+ """
+ if self.server is not None:
+ self.server.close()
+ loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+
+def main():
+ loop = asyncio.get_event_loop()
+
+ # creates a server and starts listening to TCP connections
+ server = MyServer()
+ server.start(loop)
+
+ @asyncio.coroutine
+ def client():
+ reader, writer = yield from asyncio.streams.open_connection(
+ '127.0.0.1', 12345, loop=loop)
+
+ def send(msg):
+ print("> " + msg)
+ writer.write((msg + '\n').encode("utf-8"))
+
+ def recv():
+ msgback = (yield from reader.readline()).decode("utf-8").rstrip()
+ print("< " + msgback)
+ return msgback
+
+ # send a line
+ send("add 1 2")
+ msg = yield from recv()
+
+ send("repeat 5 hello")
+ msg = yield from recv()
+ assert msg == 'begin'
+ while True:
+ msg = yield from recv()
+ if msg == 'end':
+ break
+
+ writer.close()
+ yield from asyncio.sleep(0.5)
+
+ # creates a client and connects to our server
+ msg = loop.run_until_complete(client())
+ server.stop(loop)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/test_streams.py b/tests/test_streams.py
index 69e2246..5516c15 100644
--- a/tests/test_streams.py
+++ b/tests/test_streams.py
@@ -359,6 +359,72 @@ class StreamReaderTests(unittest.TestCase):
test_utils.run_briefly(self.loop)
self.assertIs(stream._waiter, None)
+ def test_start_server(self):
+
+ class MyServer:
+
+ def __init__(self, loop):
+ self.server = None
+ self.loop = loop
+
+ @tasks.coroutine
+ def handle_client(self, client_reader, client_writer):
+ data = yield from client_reader.readline()
+ client_writer.write(data)
+
+ def start(self):
+ self.server = self.loop.run_until_complete(
+ streams.start_server(self.handle_client,
+ '127.0.0.1', 12345,
+ loop=self.loop))
+
+ def handle_client_callback(self, client_reader, client_writer):
+ task = tasks.Task(client_reader.readline(), loop=self.loop)
+
+ def done(task):
+ client_writer.write(task.result())
+
+ task.add_done_callback(done)
+
+ def start_callback(self):
+ self.server = self.loop.run_until_complete(
+ streams.start_server(self.handle_client_callback,
+ '127.0.0.1', 12345,
+ loop=self.loop))
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ @tasks.coroutine
+ def client():
+ reader, writer = yield from streams.open_connection(
+ '127.0.0.1', 12345, loop=self.loop)
+ # send a line
+ writer.write(b"hello world!\n")
+ # read it back
+ msgback = yield from reader.readline()
+ writer.close()
+ return msgback
+
+ # test the server variant with a coroutine as client handler
+ server = MyServer(self.loop)
+ server.start()
+ msg = self.loop.run_until_complete(tasks.Task(client(),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
+ # test the server variant with a callback as client handler
+ server = MyServer(self.loop)
+ server.start_callback()
+ msg = self.loop.run_until_complete(tasks.Task(client(),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
if __name__ == '__main__':
unittest.main()