diff options
author | A. Jesse Jiryu Davis <jesse@10gen.com> | 2013-03-28 15:39:55 -0400 |
---|---|---|
committer | A. Jesse Jiryu Davis <jesse@10gen.com> | 2013-03-28 15:39:55 -0400 |
commit | 80d1312a3e9c869f26fa4790a8978fd7f8486fb1 (patch) | |
tree | e5fdcef6fa6327b903a4bf897c771547d7a1434c | |
download | trollius-git-80d1312a3e9c869f26fa4790a8978fd7f8486fb1.tar.gz |
Use logger named 'tulip' for library events, Issue 26
71 files changed, 17494 insertions, 0 deletions
@@ -0,0 +1,2 @@ +[patterns]
+** = native
diff --git a/.hgignore b/.hgignore new file mode 100644 index 0000000..2590249 --- /dev/null +++ b/.hgignore @@ -0,0 +1,11 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..274da4c --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=1 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum <guido@python.org> @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 0000000..64bc2cd --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 0000000..4e5bebe --- /dev/null +++ b/crawl.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +import tulip.http + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = tulip.http.HttpClientProtocol( + netloc, path=path, ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +import tulip.http + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(f) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(stream.read()) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 0000000..9e995d1 --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,73 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('127.0.0.1', 10000) + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "%s"' % self.message) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "%s"' % data.decode()) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=ADDRESS)) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=ADDRESS)) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 0000000..d352cd7 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 0000000..c24c573 --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 0000000..4085f4c --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 0000000..8937ba2 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 0000000..2b1e3dd --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 0000000..c1f9d0a --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum <guido@python.org>' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 0000000..35e14c9 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 0000000..6586efc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 0000000..3864571 --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from <generator>'). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum <guido@python.org>' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 0000000..a5005dc --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum <guido@python.org>' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 0000000..19095bf --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 0000000..474009d --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 0000000..f55234b --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from <generator> vs. yield <future> calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 0000000..c9f6ec9 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include <winsock2.h> +#include <ws2tcpip.h> +#include <mswsock.h> + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 0000000..096a256 --- /dev/null +++ b/runtests.py @@ -0,0 +1,198 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('%s%s' % (prefix, mod), modpath)) + + prefix = '%s%s.' % (prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('%s%s.' % (prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('%s%s' % (prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except Exception as err: + print("Skipping '%s': %s" % (modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: %s\n" % testsdir) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: %s\n" % sdir) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..0260f9d --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..dcaee96 --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import email.message +import os +import sys + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + def handle_request(self, request_info, message): + print('method = {!r}; path = {!r}; version = {!r}'.format( + request_info.method, request_info.uri, request_info.version)) + + path = request_info.uri + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'<ul>\r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'<li><a href="' + bname + + b'/">' + bname + b'/</a></li>\r\n') + else: + response.write(b'<li><a href="' + bname + + b'">' + bname + b'</a></li>\r\n') + response.write(b'</ul>') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +def main(): + host = '127.0.0.1' + port = 8080 + if sys.argv[1:]: + host = sys.argv[1] + if sys.argv[2:]: + port = int(sys.argv[2]) + elif ':' in host: + host, port = host.split(':', 1) + port = int(port) + loop = tulip.get_event_loop() + f = loop.start_serving(lambda: HttpServer(debug=True), host, port) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 0000000..a1bc04f --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 0000000..88f3faf --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,283 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.event_loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_write_pipe_transport, m, m) + + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + # If event loop has ready callbacks, select timeout is always 0. + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(event_loop): + nonlocal processed, handle + processed = True + handle = event_loop.call_soon(lambda: True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.event_loop._ready)) + + def test_run_until_complete_assertion(self): + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, 'blah') + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 0000000..4085921 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1379 @@ +"""Tests for events.py.""" + +import concurrent.futures +import contextlib +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + +class MyReadPipeProto(protocols.Protocol): + + def __init__(self): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + + +class MyWritePipeProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + @contextlib.contextmanager + def run_test_server(self, *, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.dirname(__file__) + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '302 Found' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server('127.0.0.1', 0, app, + server_class, SilentWSGIRequestHandler) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_run_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + self.assertTrue(self.event_loop.is_running()) + try: + self.event_loop.run_until_complete( + tasks.sleep(0.1)) + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + tasks.sleep(0.1) + try: + self.event_loop.run_once() + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_block(self): + called = False + + def callback(): + nonlocal called + called = True + + def run(): + time.sleep(0.1) + self.event_loop.call_soon_threadsafe(callback) + + t = threading.Thread(target=run) + t0 = time.monotonic() + t.start() + self.event_loop.run_once(None) + t1 = time.monotonic() + t.join() + self.assertTrue(called) + self.assertTrue(0.09 < t1-t0 <= 0.12) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handle(self): + results = [] + + def callback(): + results.append('yeah') + + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handle(self): + results = [] + + def callback(arg): + results.append(arg) + + handle = events.Handle(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handle), handle) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handle(self): + def run(arg): + time.sleep(0.1) + return arg + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handle(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handle.cancel() + if not data: + r.close() + + handle = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handle(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handle.cancel() + + handle = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + with self.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + address = httpd.socket.getsockname() + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with self.run_test_server(use_ssl=True) as httpsd: + host, port = httpsd.socket.getsockname() + f = self.event_loop.create_connection( + MyProto, host, port, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.event_loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.event_loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, remote_addr=(host, port)) + transport, client = self.event_loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + self.event_loop.run_once(None) + self.assertEqual(3, server.nbytes) + self.event_loop.run_once(None) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_create_datagram_endpoint_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.event_loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.task + def connect(): + t, p = yield from self.event_loop.connect_read_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.event_loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + self.event_loop.run_once() + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.event_loop.run_once() + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_noaddr_nofamily(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_socket_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_connect_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())<cancelled>')) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handle, h1, (1, 2)) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())<cancelled>')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, ev_loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 0000000..5569cca --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,222 @@ +"""Tests for futures.py.""" + +import unittest + +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future<PENDING>') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future<result=4>') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future<PENDING, [<function _fakefunc', + repr(f_few_callbacks)) + + f_many_callbacks = futures.Future() + for i in range(20): + f_many_callbacks.add_done_callback(_fakefunc) + r = repr(f_many_callbacks) + self.assertIn('Future<PENDING, [<function _fakefunc', r) + self.assertIn('<18 more>', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 0000000..74aef7c --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,972 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_request_line()) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + self.stream.read_response_status()) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete(self.stream.read_headers()) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + self.stream.read_headers()) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 1))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 0))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=False)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: %s\r\n' + 'Content-Encoding: deflate\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete(self.stream.read_message()) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: %s\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'datadata', data) + + def test_read_message_payload_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, msg.payload.read()) + + def test_read_message_length_payload_zero(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete(self.stream.read_message()) + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', data) + + def test_read_message_length_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_eof_payload(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_length_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertIsInstance(msg.payload, tulip.StreamReader) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_message_length_payload_extra(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_parse_length_payload_eof_exc(self): + parser = self.stream._parse_length_payload(4) + next(parser) + + stream = tulip.StreamReader() + parser.send(stream) + self.stream._parser = parser + self.stream.feed_data(b'da') + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.wait([t1, t2])) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._parser) + + def test_read_message_deflate_payload(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Encoding: deflate\r\n' + + ('Content-Length: %s\r\n\r\n' % len(data)).encode()) + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(data) + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_chunked_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_chunks(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_chunked_payload_extension(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_size_error(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'blah\r\n') + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_deflate_stream_set_exception(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.set_exception(exc) + self.assertIs(exc, stream.exception()) + + def test_deflate_stream_feed_data(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.return_value = b'line' + + dstream.feed_data(b'data') + self.assertEqual([b'line'], list(stream.buffer)) + + def test_deflate_stream_feed_data_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.side_effect = exc + + dstream.feed_data(b'data') + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + def test_deflate_stream_feed_eof(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + + dstream.feed_eof() + self.assertEqual([b'line'], list(stream.buffer)) + self.assertTrue(stream.eof) + + def test_deflate_stream_feed_eof_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + dstream.zlib.eof = False + + dstream.feed_eof() + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], msg.headers) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], msg.headers) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], msg.headers) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg.force_chunked() + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 0000000..dc55eff --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,242 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip.test_utils import LogTrackingTestCase + + +class HttpServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_http_status_exception(self): + exc = errors.HttpStatusException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol() + self.assertIsNone(srv._request_handle) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handle) + + def test_data_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', b''.join(srv.stream.buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', b''.join(srv.stream.buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream.eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + handle = srv._request_handle + srv.connection_lost(None) + + self.assertIsNone(srv._request_handle) + self.assertTrue(handle.cancelled()) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handle) + + def test_close(self): + srv = server.ServerHttpProtocol() + self.assertFalse(srv.closing) + + srv.close() + self.assertTrue(srv.closing) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + srv.handle_error(404) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(handle.called) + self.assertIsNone(srv._request_handle) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + called = False + + @tulip.coroutine + def coro(rline, message): + nonlocal called + called = True + yield from [] + srv.eof_received() + + srv.handle_request = coro + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(called) + + def test_handle_close(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.task + def cancel(): + yield from [] + srv._request_handle.cancel() + + srv.close() + self.loop.run_until_complete( + tulip.wait([srv._request_handle, cancel()])) + self.assertTrue(log.debug.called) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + + def side_effect(*args): + srv.close() + srv.handle_error.side_effect = side_effect + + srv.stream.feed_data(b'GET / HT/asd\r\n') + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + srv.close() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 0000000..7d2111d --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,747 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + + self.event_loop.call_later(0.1, lock.release) + acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.1)) + tasks.Task(lock.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.1, ev.set) + acquired = self.event_loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(cond.acquire())) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.2)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.18 < total_time < 0.22) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete( + cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + self.event_loop.call_later(0.1, sem.release) + acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 0000000..d86abd7 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,370 @@ +"""Tests for queues.py""" + +import unittest +import queue + +from tulip import events +from tulip import locks +from tulip import queues +from tulip import tasks + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = queues.Queue() + self.assertTrue(fn(q).startswith('<Queue')) + id_is_present = hex(id(q)) in fn(q) + self.assertEqual(expect_id, id_is_present) + + @tasks.coroutine + def add_getter(): + q = queues.Queue() + # Start a task that waits to get. + tasks.Task(q.get()) + # Let it start waiting. + yield from tasks.sleep(0.1) + self.assertTrue('_getters[1]' in fn(q)) + + self.event_loop.run_until_complete(add_getter()) + + @tasks.coroutine + def add_putter(): + q = queues.Queue(maxsize=1) + q.put_nowait(1) + # Start a task that waits to put. + tasks.Task(q.put(2)) + # Let it start waiting. + yield from tasks.sleep(0.1) + self.assertTrue('_putters[1]' in fn(q)) + + self.event_loop.run_until_complete(add_putter()) + + q = queues.Queue() + q.put_nowait(1) + self.assertTrue('_queue=[1]' in fn(q)) + + def test_repr(self): + self._test_repr_or_str(repr, True) + + def test_str(self): + self._test_repr_or_str(str, False) + + def test_empty(self): + q = queues.Queue() + self.assertTrue(q.empty()) + q.put_nowait(1) + self.assertFalse(q.empty()) + self.assertEqual(1, q.get_nowait()) + self.assertTrue(q.empty()) + + def test_full(self): + q = queues.Queue() + self.assertFalse(q.full()) + + q = queues.Queue(maxsize=1) + q.put_nowait(1) + self.assertTrue(q.full()) + + def test_order(self): + q = queues.Queue() + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 3, 2], items) + + def test_maxsize(self): + q = queues.Queue(maxsize=2) + self.assertEqual(2, q.maxsize) + have_been_put = [] + + @tasks.coroutine + def putter(): + for i in range(3): + yield from q.put(i) + have_been_put.append(i) + + @tasks.coroutine + def test(): + tasks.Task(putter()) + yield from tasks.sleep(0.1) + + # The putter is blocked after putting two items. + self.assertEqual([0, 1], have_been_put) + self.assertEqual(0, q.get_nowait()) + + # Let the putter resume and put last item. + yield from tasks.sleep(0.1) + self.assertEqual([0, 1, 2], have_been_put) + self.assertEqual(1, q.get_nowait()) + self.assertEqual(2, q.get_nowait()) + + self.event_loop.run_until_complete(test()) + + +class QueueGetTests(_QueueTestBase): + + def test_blocking_get(self): + q = queues.Queue() + q.put_nowait(1) + + @tasks.coroutine + def queue_get(): + return (yield from q.get()) + + res = self.event_loop.run_until_complete(queue_get()) + self.assertEqual(1, res) + + def test_blocking_get_wait(self): + q = queues.Queue() + started = locks.EventWaiter() + finished = False + + @tasks.coroutine + def queue_get(): + nonlocal finished + started.set() + res = yield from q.get() + finished = True + return res + + @tasks.coroutine + def queue_put(): + self.event_loop.call_later(0.1, q.put_nowait, 1) + queue_get_task = tasks.Task(queue_get()) + yield from started.wait() + self.assertFalse(finished) + res = yield from queue_get_task + self.assertTrue(finished) + return res + + res = self.event_loop.run_until_complete(queue_put()) + self.assertEqual(1, res) + + def test_nonblocking_get(self): + q = queues.Queue() + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_get_exception(self): + q = queues.Queue() + self.assertRaises(queue.Empty, q.get_nowait) + + def test_get_timeout(self): + q = queues.Queue() + + @tasks.coroutine + def queue_get(): + with self.assertRaises(queue.Empty): + return (yield from q.get(timeout=0.1)) + + # Get works after timeout, with blocking and non-blocking put. + q.put_nowait(1) + self.assertEqual(1, (yield from q.get())) + + tasks.Task(q.put(2)) + self.assertEqual(2, (yield from q.get())) + + self.event_loop.run_until_complete(queue_get()) + + def test_get_timeout_cancelled(self): + q = queues.Queue() + + @tasks.coroutine + def queue_get(): + return (yield from q.get(timeout=0.2)) + + @tasks.coroutine + def test(): + get_task = tasks.Task(queue_get()) + yield from tasks.sleep(0.1) # let the task start + q.put_nowait(1) + return (yield from get_task) + + self.assertEqual(1, self.event_loop.run_until_complete(test())) + + +class QueuePutTests(_QueueTestBase): + + def test_blocking_put(self): + q = queues.Queue() + + @tasks.coroutine + def queue_put(): + # No maxsize, won't block. + yield from q.put(1) + + self.event_loop.run_until_complete(queue_put()) + + def test_blocking_put_wait(self): + q = queues.Queue(maxsize=1) + started = locks.EventWaiter() + finished = False + + @tasks.coroutine + def queue_put(): + nonlocal finished + started.set() + yield from q.put(1) + yield from q.put(2) + finished = True + + @tasks.coroutine + def queue_get(): + self.event_loop.call_later(0.1, q.get_nowait) + queue_put_task = tasks.Task(queue_put()) + yield from started.wait() + self.assertFalse(finished) + yield from queue_put_task + self.assertTrue(finished) + + self.event_loop.run_until_complete(queue_get()) + + def test_nonblocking_put(self): + q = queues.Queue() + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_put_exception(self): + q = queues.Queue(maxsize=1) + q.put_nowait(1) + self.assertRaises(queue.Full, q.put_nowait, 2) + + def test_put_timeout(self): + q = queues.Queue(1) + q.put_nowait(0) + + @tasks.coroutine + def queue_put(): + with self.assertRaises(queue.Full): + return (yield from q.put(1, timeout=0.1)) + + self.assertEqual(0, q.get_nowait()) + + # Put works after timeout, with blocking and non-blocking get. + get_task = tasks.Task(q.get()) + # Let the get start waiting. + yield from tasks.sleep(0.1) + q.put_nowait(2) + self.assertEqual(2, (yield from get_task)) + + q.put_nowait(3) + self.assertEqual(3, q.get_nowait()) + + self.event_loop.run_until_complete(queue_put()) + + def test_put_timeout_cancelled(self): + q = queues.Queue() + + @tasks.coroutine + def queue_put(): + yield from q.put(1, timeout=0.1) + + @tasks.coroutine + def test(): + tasks.Task(queue_put()) + return (yield from q.get()) + + self.assertEqual(1, self.event_loop.run_until_complete(test())) + + +class LifoQueueTests(_QueueTestBase): + + def test_order(self): + q = queues.LifoQueue() + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([2, 3, 1], items) + + +class PriorityQueueTests(_QueueTestBase): + + def test_order(self): + q = queues.PriorityQueue() + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 2, 3], items) + + +class JoinableQueueTests(_QueueTestBase): + + def test_task_done_underflow(self): + q = queues.JoinableQueue() + self.assertRaises(q.task_done) + + def test_task_done(self): + q = queues.JoinableQueue() + for i in range(100): + q.put_nowait(i) + + accumulator = 0 + + # Two workers get items from the queue and call task_done after each. + # Join the queue and assert all items have been processed. + + @tasks.coroutine + def worker(): + nonlocal accumulator + + while True: + item = yield from q.get() + accumulator += item + q.task_done() + + @tasks.coroutine + def test(): + for _ in range(2): + tasks.Task(worker()) + + yield from q.join() + + self.event_loop.run_until_complete(test()) + self.assertEqual(sum(range(100)), accumulator) + + def test_join_empty_queue(self): + q = queues.JoinableQueue() + + # Test that a queue join()s successfully, and before anything else + # (done twice for insurance). + + @tasks.coroutine + def join(): + yield from q.join() + yield from q.join() + + self.event_loop.run_until_complete(join()) + + def test_join_timeout(self): + q = queues.JoinableQueue() + q.put_nowait(1) + + @tasks.coroutine + def join(): + yield from q.join(0.1) + + # Join completes in ~ 0.1 seconds, although no one calls task_done(). + self.event_loop.run_until_complete(join()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/sample.crt b/tests/sample.crt new file mode 100644 index 0000000..6a1e3f3 --- /dev/null +++ b/tests/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- diff --git a/tests/sample.key b/tests/sample.key new file mode 100644 index 0000000..edfea8d --- /dev/null +++ b/tests/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py new file mode 100644 index 0000000..c63db15 --- /dev/null +++ b/tests/selector_events_test.py @@ -0,0 +1,1286 @@ +"""Tests for selector_events.py""" + +import errno +import socket +import unittest +import unittest.mock +try: + import ssl +except ImportError: + ssl = None + +from tulip import futures +from tulip import selectors +from tulip.events import AbstractEventLoop +from tulip.protocols import DatagramProtocol, Protocol +from tulip.selector_events import BaseSelectorEventLoop +from tulip.selector_events import _SelectorSslTransport +from tulip.selector_events import _SelectorSocketTransport +from tulip.selector_events import _SelectorDatagramTransport + + +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + self._ssock = unittest.mock.Mock() + self._csock = unittest.mock.Mock() + self._internal_fds += 1 + + +class BaseSelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) + + def test_make_socket_transport(self): + m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_socket_transport(m, m), + _SelectorSocketTransport) + + def test_make_ssl_transport(self): + m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_ssl_transport(m, m, m, m), + _SelectorSslTransport) + + def test_close(self): + ssock = self.event_loop._ssock + ssock.fileno.return_value = 7 + csock = self.event_loop._csock + csock.fileno.return_value = 1 + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + + self.event_loop._selector.close() + self.event_loop._selector = selector = unittest.mock.Mock() + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertIsNone(self.event_loop._csock) + self.assertIsNone(self.event_loop._ssock) + selector.close.assert_called_with() + ssock.close.assert_called_with() + csock.close.assert_called_with() + remove_reader.assert_called_with(7) + + self.event_loop.close() + self.event_loop.close() + + def test_close_no_selector(self): + ssock = self.event_loop._ssock + csock = self.event_loop._csock + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + + self.event_loop._selector.close() + self.event_loop._selector = None + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertFalse(ssock.close.called) + self.assertFalse(csock.close.called) + self.assertFalse(remove_reader.called) + + def test_socketpair(self): + self.assertRaises(NotImplementedError, self.event_loop._socketpair) + + def test_read_from_self_tryagain(self): + self.event_loop._ssock.recv.side_effect = BlockingIOError + self.assertIsNone(self.event_loop._read_from_self()) + + def test_read_from_self_exception(self): + self.event_loop._ssock.recv.side_effect = OSError + self.assertRaises(OSError, self.event_loop._read_from_self) + + def test_write_to_self_tryagain(self): + self.event_loop._csock.send.side_effect = BlockingIOError + self.assertIsNone(self.event_loop._write_to_self()) + + def test_write_to_self_exception(self): + self.event_loop._csock.send.side_effect = OSError() + self.assertRaises(OSError, self.event_loop._write_to_self) + + def test_sock_recv(self): + sock = unittest.mock.Mock() + self.event_loop._sock_recv = unittest.mock.Mock() + + f = self.event_loop.sock_recv(sock, 1024) + self.assertIsInstance(f, futures.Future) + self.event_loop._sock_recv.assert_called_with( + f, False, sock, 1024) + + def test__sock_recv_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertFalse(sock.recv.called) + + def test__sock_recv_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_recv_tryagain(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = BlockingIOError + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.event_loop._sock_recv, f, True, sock, 1024), + self.event_loop.add_reader.call_args[0]) + + def test__sock_recv_exception(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.recv.side_effect = OSError() + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertIs(err, f.exception()) + + def test_sock_sendall(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'data') + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, b'data'), + self.event_loop._sock_sendall.call_args[0]) + + def test_sock_sendall_nodata(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'') + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertFalse(self.event_loop._sock_sendall.called) + + def test__sock_sendall_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(sock.send.called) + + def test__sock_sendall_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_sendall_tryagain(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = BlockingIOError + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_exception(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.send.side_effect = OSError() + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertIs(f.exception(), err) + + def test__sock_sendall(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 4 + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertTrue(f.done()) + + def test__sock_sendall_partial(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 2 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'ta'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_none(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 0 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + + def test_sock_connect(self): + sock = unittest.mock.Mock() + self.event_loop._sock_connect = unittest.mock.Mock() + + f = self.event_loop.sock_connect(sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, ('127.0.0.1', 8080)), + self.event_loop._sock_connect.call_args[0]) + + def test__sock_connect(self): + f = futures.Future() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertTrue(f.done()) + self.assertTrue(sock.connect.called) + + def test__sock_connect_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.connect.called) + + def test__sock_connect_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_connect_tryagain(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.EAGAIN + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() + + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual( + (10, self.event_loop._sock_connect, f, + True, sock, ('127.0.0.1', 8080)), + self.event_loop.add_writer.call_args[0]) + + def test__sock_connect_exception(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.ENOTCONN + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f.exception(), socket.error) + + def test_sock_accept(self): + sock = unittest.mock.Mock() + self.event_loop._sock_accept = unittest.mock.Mock() + + f = self.event_loop.sock_accept(sock) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock), self.event_loop._sock_accept.call_args[0]) + + def test__sock_accept(self): + f = futures.Future() + + conn = unittest.mock.Mock() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.return_value = conn, ('127.0.0.1', 1000) + + self.event_loop._sock_accept(f, False, sock) + self.assertTrue(f.done()) + self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) + self.assertEqual((False,), conn.setblocking.call_args[0]) + + def test__sock_accept_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_accept(f, False, sock) + self.assertFalse(sock.accept.called) + + def test__sock_accept_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_accept_tryagain(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = BlockingIOError + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, False, sock) + self.assertEqual( + (10, self.event_loop._sock_accept, f, True, sock), + self.event_loop.add_reader.call_args[0]) + + def test__sock_accept_exception(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.accept.side_effect = OSError() + + self.event_loop._sock_accept(f, False, sock) + self.assertIs(err, f.exception()) + + def test_add_reader(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_READ, (h, None)), + self.event_loop._selector.register.call_args[0]) + + def test_add_reader_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (reader, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_add_reader_existing_writer(self): + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (None, None)) + self.assertFalse(self.event_loop.remove_reader(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_reader_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_reader(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_reader(1)) + + def test_add_writer(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, h)), + self.event_loop._selector.register.call_args[0]) + + def test_add_writer_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (reader, writer)) + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (reader, h)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, None)) + self.assertFalse(self.event_loop.remove_writer(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_writer_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_writer(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_writer(1)) + + def test_process_events_read(self): + reader = unittest.mock.Mock() + reader.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.assertTrue(self.event_loop._add_callback.called) + self.event_loop._add_callback.assert_called_with(reader) + + def test_process_events_read_cancelled(self): + reader = unittest.mock.Mock() + reader.cancelled = True + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.event_loop.remove_reader.assert_called_with(1) + + def test_process_events_write(self): + writer = unittest.mock.Mock() + writer.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.event_loop._add_callback.assert_called_with(writer) + + def test_process_events_write_cancelled(self): + writer = unittest.mock.Mock() + writer.cancelled = True + self.event_loop.remove_writer = unittest.mock.Mock() + + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.event_loop.remove_writer.assert_called_with(1) + + +class SelectorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.protocol = unittest.mock.Mock(Protocol) + + def test_ctor(self): + tr = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.add_reader.assert_called_with(7, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + def test_ctor_with_waiter(self): + fut = futures.Future() + + _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol, fut) + self.assertEqual(2, self.event_loop.call_soon.call_count) + self.assertEqual(fut.set_result, + self.event_loop.call_soon.call_args[0][0]) + + def test_read_ready(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.protocol.data_received.assert_called_with(b'data') + + def test_read_ready_eof(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.assertTrue(self.event_loop.remove_reader.called) + self.protocol.eof_received.assert_called_with() + + @unittest.mock.patch('logging.exception') + def test_read_ready_tryagain(self, m_exc): + self.sock.recv.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_err(self, m_exc): + err = self.sock.recv.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with(err) + + def test_abort(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._close = unittest.mock.Mock() + + transport.abort() + transport._close.assert_called_with(None) + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + + def test_write_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data1', b'data2'], transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'ta'], transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + self.sock.fileno.return_value = 7 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.event_loop.add_writer.assert_called_with( + 7, transport._write_ready) + self.assertEqual([b'data'], transport._buffer) + + def test_write_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + data = b'data' + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'data'], transport._buffer) + + def test_write_exception(self): + err = self.sock.send.side_effect = OSError() + + data = b'data' + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.write(data) + transport._fatal_error.assert_called_with(err) + + def test_write_str(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.write, b'data') + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_write_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append(data) + transport._write_ready() + self.sock.send.assert_called_with(data) + self.event_loop.remove_writer.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_ready_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport._write_ready) + + def test_write_ready_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'ta'], transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_ready_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer = [b'data1', b'data2'] + transport._write_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data1data2'], transport._buffer) + + def test_write_ready_exception(self): + err = self.sock.send.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._write_ready() + transport._fatal_error.assert_called_with(err) + + def test_close(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.event_loop.remove_reader.assert_called_with(7) + self.protocol.connection_lost(None) + + def test_close_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append(b'data') + transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + def test_fatal_error(self, m_exc): + exc = OSError() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport._fatal_error(exc) + + self.assertEqual([], transport._buffer) + self.event_loop.remove_reader.assert_called_with(7) + self.event_loop.remove_writer.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(exc) + m_exc.assert_called_with('Fatal error for %s', transport) + + def test_connection_lost(self): + exc = object() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._call_connection_lost(exc) + + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorSslTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.protocol = unittest.mock.Mock(spec_set=Protocol) + self.sslsock = unittest.mock.Mock() + self.sslsock.fileno.return_value = 1 + self.sslcontext = unittest.mock.Mock() + self.sslcontext.wrap_socket.return_value = self.sslsock + self.waiter = futures.Future() + + self.transport = _SelectorSslTransport( + self.event_loop, self.sock, + self.protocol, self.sslcontext, self.waiter) + self.event_loop.reset_mock() + self.sock.reset_mock() + self.protocol.reset_mock() + self.sslcontext.reset_mock() + + def test_on_handshake(self): + self.transport._on_handshake() + self.assertTrue(self.sslsock.do_handshake.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_reader.call_args[0]) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_reader_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_reader.call_args[0]) + + def test_on_handshake_writer_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_exc(self): + self.sslsock.do_handshake.side_effect = ValueError + self.transport._on_handshake() + self.assertTrue(self.sslsock.close.called) + + def test_on_handshake_base_exc(self): + self.sslsock.do_handshake.side_effect = BaseException + self.assertRaises(BaseException, self.transport._on_handshake) + self.assertTrue(self.sslsock.close.called) + + def test_write_no_data(self): + self.transport._buffer.append(b'data') + self.transport.write(b'') + self.assertEqual([b'data'], self.transport._buffer) + + def test_write_str(self): + self.assertRaises(AssertionError, self.transport.write, 'str') + + def test_write_closing(self): + self.transport.close() + self.assertRaises(AssertionError, self.transport.write, b'data') + + def test_abort(self): + self.transport._close = unittest.mock.Mock() + self.transport.abort() + self.transport._close.assert_called_with(None) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + def test_fatal_error(self, m_exc): + exc = OSError() + self.transport._buffer.append(b'data') + self.transport._fatal_error(exc) + + self.assertEqual([], self.transport._buffer) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.protocol.connection_lost.assert_called_with(exc) + m_exc.assert_called_with('Fatal error for %s', self.transport) + + def test_close(self): + self.transport.close() + self.assertTrue(self.transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.protocol.connection_lost.assert_called_with(None) + + def test_close_write_buffer(self): + self.transport._buffer.append(b'data') + self.transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_on_ready_closed(self): + self.sslsock.fileno.return_value = -1 + self.transport._on_ready() + self.assertFalse(self.sslsock.recv.called) + + def test_on_ready_recv(self): + self.sslsock.recv.return_value = b'data' + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) + + def test_on_ready_recv_eof(self): + self.sslsock.recv.return_value = b'' + self.transport._on_ready() + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_recv_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = BlockingIOError + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_on_ready_recv_exc(self): + err = self.sslsock.recv.side_effect = OSError() + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.transport._fatal_error.assert_called_with(err) + + def test_on_ready_send(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([], self.transport._buffer) + + def test_on_ready_send_none(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 0 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data1data2'], self.transport._buffer) + + def test_on_ready_send_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'ta1data2'], self.transport._buffer) + + def test_on_ready_send_closing_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertFalse(self.sslsock.close.called) + + def test_on_ready_send_closing(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport.close() + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_send_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + self.transport._buffer = [b'data'] + + self.sslsock.send.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data'], self.transport._buffer) + + self.sslsock.send.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + self.sslsock.send.side_effect = BlockingIOError() + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + def test_on_ready_send_exc(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + err = self.sslsock.send.side_effect = OSError() + + self.transport._buffer = [b'data'] + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.transport._fatal_error.assert_called_with(err) + self.assertEqual([], self.transport._buffer) + + +class SelectorDatagramTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.sock = unittest.mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + self.protocol = unittest.mock.Mock(spec_set=DatagramProtocol) + + def test_read_ready(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.protocol.datagram_received.assert_called_with( + b'data', ('0.0.0.0', 1234)) + + def test_read_ready_tryagain(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recvfrom.side_effect = BlockingIOError + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + err = self.sock.recvfrom.side_effect = OSError() + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with(err) + + def test_abort(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._close = unittest.mock.Mock() + + transport.abort() + transport._close.assert_called_with(None) + + def test_sendto(self): + data = b'data' + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_tryagain(self): + data = b'data' + + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 12345)) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._sendto_ready, + self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_exception(self): + data = b'data' + err = self.sock.sendto.side_effect = OSError() + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + transport._fatal_error.assert_called_with(err) + + def test_sendto_connection_refused(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_connection_refused_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data) + + self.assertTrue(transport._fatal_error.called) + + def test_sendto_str(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.assertRaises( + AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.sendto, b'data', ()) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((data, ('0.0.0.0', 12345))) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append((data, ())) + transport._sendto_ready() + self.sock.sendto.assert_called_with(data, ()) + self.event_loop.remove_writer.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(None) + + def test_sendto_ready_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_tryagain(self): + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + transport._sendto_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + err = self.sock.sendto.side_effect = OSError() + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + transport._fatal_error.assert_called_with(err) + + def test_sendto_ready_connection_refused(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_connection_refused_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + + def test_close(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.event_loop.remove_reader.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(None) + + def test_close_write_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data', ())) + transport.close() + + self.event_loop.remove_reader.assert_called_with(7) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + def test_fatal_error(self, m_exc): + exc = OSError() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append((b'data', ())) + transport._fatal_error(exc) + + self.assertEqual([], list(transport._buffer)) + self.event_loop.remove_writer.assert_called_with(7) + self.event_loop.remove_reader.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(exc) + m_exc.assert_called_with('Fatal error for %s', transport) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + def test_fatal_error_connected(self, m_exc): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.protocol.connection_refused.assert_called_with(err) + m_exc.assert_called_with('Fatal error for %s', transport) + + def test_transport_closing(self): + exc = object() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._call_connection_lost(exc) + + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() diff --git a/tests/selectors_test.py b/tests/selectors_test.py new file mode 100644 index 0000000..996c013 --- /dev/null +++ b/tests/selectors_test.py @@ -0,0 +1,137 @@ +"""Tests for selectors.py.""" + +import unittest +import unittest.mock + +from tulip import selectors + + +class BaseSelectorTests(unittest.TestCase): + + def test_fileobj_to_fd(self): + self.assertEqual(10, selectors._fileobj_to_fd(10)) + + f = unittest.mock.Mock() + f.fileno.return_value = 10 + self.assertEqual(10, selectors._fileobj_to_fd(f)) + + f.fileno.side_effect = TypeError + self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + + def test_selector_key_repr(self): + key = selectors.SelectorKey(10, selectors.EVENT_READ) + self.assertEqual( + "SelectorKey<fileobj=10, fd=10, events=0x1, data=None>", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 0000000..dc6eeaf --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,299 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(test_utils.LogTrackingTestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + super().tearDown() + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.event_loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.event_loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + self.event_loop.run_until_complete(stream.readline()) + + data = self.event_loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.event_loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + def set_err(): + yield from [] + stream.set_exception(ValueError()) + + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 0000000..09aaed5 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 0000000..9ac15bb --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,647 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>") + + def test_task_repr_custom(self): + def coro(): + yield from [] + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[](<coro>)') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.task + def inner1(): + yield from [] + return 42 + + @tasks.task + def inner2(): + yield from [] + return 1000 + + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + t = tasks.Task(coro(), timeout=0.1) + + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout_catch(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + err = None + + @tasks.coroutine + def coro2(): + nonlocal err + try: + yield from tasks.Task(coro(), timeout=0.1) + except futures.CancelledError as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro2())) + self.assertIsInstance(err, futures.CancelledError) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + self.suppress_log_errors() + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield from [None] + + @tasks.coroutine + def coro2(): + yield from [None, None] + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + + @tasks.coroutine + def exc(): + yield from [] + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper()) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + time.monotonic() + yield from sleepfut + finally: + time.monotonic() + + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + time.monotonic() + yield from sleeper + except futures.CancelledError: + time.monotonic() + return 'cancelled' + else: + return 'slept in' + + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warn.call_args[0][0].startswith( + '_step(): already done: ')) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [None, 1] + return 'ko' + + task = tasks.Task(notmuch()) + task._step() + self.assertFalse(m_logging.warn.called) + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertEqual( + '_step(): bad yield: %r', + m_logging.warn.call_args[0][0]) + self.assertEqual(1, m_logging.warn.call_args[0][1]) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + self.suppress_log_warnings() + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + self.suppress_log_warnings() + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + yield from [] + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_yield_vs_yield_from_generator(self): + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.event_loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future() + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func()) + tasks.Task(coro()) + res = self.event_loop.run_until_complete(t1) + self.assertEqual(res, 'test') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 0000000..4b24b50 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,45 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 0000000..d7af7ec --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,573 @@ +"""Tests for unix_events.py.""" + +import errno +import io +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handle) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.eof_received.assert_called_with() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.reset_mock() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.pause() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.resume() + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.event_loop.remove_reader.assert_called_with(5) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.called) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 0000000..381fb22 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,26 @@ +"""Tests for winsocketpair.py""" + +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 0000000..acec5c2 --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,28 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() + +- Remove test dependency on xkcd.com, write our own test server diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 0000000..faf307f --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 0000000..5ed257a --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,548 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handle = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handle.cancel() + + def run_once(self, timeout=0): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + self._run_once(timeout) + except _StopError: + pass + finally: + self._running = False + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if not isinstance(future, futures.Future): + if tasks.iscoroutine(future): + future = tasks.Task(future) + else: + assert False, 'A Future or coroutine is required' + + handle_called = False + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + + if handle_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if delay <= 0: + return self.call_soon(callback, *args) + + handle = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + assert interval > 0, 'Interval must be > 0: %r' % (interval,) + # TODO: What if callback is already a Handle? + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.Timer) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for (family, proto), (local_address, remote_address) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: + return + if isinstance(handle, events.Timer): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handle = self._scheduled[0] + if handle.when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 0000000..3a6ad40 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,356 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import sys +import threading + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '<cancelled>' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handle(callback, args): + if isinstance(callback, Handle): + assert not args + return callback + return Handle(callback, args) + + +class Timer(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '<cancelled>' + + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handles for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + #def spawn_subprocess(self, protocol_factory, pipe): + # raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handle. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 0000000..39137aa --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,255 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _timeout_handle = None + + _blocking = False # proper use of future (yield vs yield from) + + def __init__(self, *, event_loop=None, timeout=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + if timeout is not None: + self._timeout_handle = self._event_loop.call_later( + timeout, self.cancel) + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += '<exception={!r}>'.format(self._exception) + else: + res += '<result={!r}>'.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 0000000..582f080 --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,12 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 0000000..b65b90a --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,145 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +__all__ = ['HttpClientProtocol'] + + +import email.message +import email.parser + +import tulip + +from . import protocol + + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + sts, headers, stream = yield from p.connect() + + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version=(1, 1), + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + self.headers['Accept-Encoding'] = 'gzip, deflate' + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = version + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = tulip.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection( + lambda: self, self.host, self.port, ssl=self.ssl) + + # read response status + version, status, reason = yield from self.stream.read_response_status() + + message = yield from self.stream.read_message(version) + + # headers + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + sts = '{} {}'.format(status, reason) + return (sts, headers, message.payload) + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + + self.request = protocol.Request( + transport, self.method, self.path, self.version) + + self.request.add_headers(*self.headers.items()) + self.request.send_headers() + + if self.make_body is not None: + if self.chunked: + self.make_body( + self.request.write, self.request.eof) + else: + self.make_body( + self.request.write, self.request.eof) + else: + self.request.write_eof() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 0000000..41344de --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,44 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpStatusException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + + +class HttpStatusException(HttpException): + + def __init__(self, code, headers=None, message=''): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: %s' % hdr) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 0000000..6a0e127 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,877 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', + 'HttpMessage', 'Request', 'Response', + 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] + +import collections +import email.utils +import functools +import http.server +import itertools +import re +import sys +import zlib + +import tulip +from . import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') +RESPONSES = http.server.BaseHTTPRequestHandler.responses + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +RawHttpMessage = collections.namedtuple( + 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + +class HttpStreamReader(tulip.StreamReader): + + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + + # if _parser is set, feed_data and feed_eof sends data into + # _parser instead of self. is it being used as stream redirection for + # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload + _parser = None + + def feed_data(self, data): + """_parser is a generator, if _parser is set, feed_data sends + incoming data into the generator untile generator stops.""" + if self._parser: + try: + self._parser.send(data) + except StopIteration as exc: + self._parser = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_parser is a generator, if _parser is set feed_eof throws + StreamEofException into this generator.""" + if self._parser: + try: + self._parser.throw(StreamEofException()) + except StopIteration: + self._parser = None + + super().feed_eof() + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header %s' % line.strip()) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name %s' % name) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise errors.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers + + def _parse_chunked_payload(self): + """Chunked transfer encoding parser.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + except errors.IncompleteRead as exc: + stream.set_exception(exc) + + def _parse_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + + def _parse_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload parser + if chunked: + parser = self._parse_chunked_payload() + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + + parser = self._parse_length_payload(length) + else: + if readall: + parser = self._parse_eof_payload() + else: + parser = self._parse_length_payload(0) + + next(parser) + + payload = stream = tulip.StreamReader() + + # payload decompression wrapper + if encoding is not None: + stream = DeflateStream(stream, encoding) + + try: + # initialize payload parser with stream, stream is being + # used by parser as destination stream + parser.send(stream) + except StopIteration: + pass + else: + # feed existing buffer to payload parser + self.byte_count = 0 + while self.buffer: + try: + parser.send(self.buffer.popleft()) + except StopIteration as exc: + parser = None + + # parser is done + buf = b''.join(self.buffer) + self.buffer.clear() + + # re-add remaining data back to buffer + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + # parser still require more data + if parser is not None: + if self.eof: + try: + parser.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._parser = parser + + return RawHttpMessage(headers, payload, close_conn, encoding) + + +class StreamEofException(Exception): + """Internal exception: eof received.""" + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def set_exception(self, exc): + self.stream.set_exception(exc) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(errors.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(errors.IncompleteRead(b'')) + + self.stream.feed_eof() + + +EOF_MARKER = object() +EOL_MARKER = object() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + self.keepalive = False + + self.chunked = False + self.length = None + self.upgrade = False + self.headers = [] + self.headers_sent = False + + def force_close(self): + self.closing = True + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + return self.keepalive and not self.closing + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '%r is not a string' % name + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower().strip() + # handle websocket + if val == 'upgrade': + self.upgrade = True + # connection keep-alive + elif val == 'close': + self.keepalive = False + elif val == 'keep-alive': + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + # status line + self.transport.write(self.status_line.encode('ascii')) + + # send headers + self.transport.write( + ('%s\r\n\r\n' % '\r\n'.join( + ('%s: %s' % (k, v) for k, v in + itertools.chain(self._default_headers(), self.headers))) + ).encode('ascii')) + + def _default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif self.keep_alive(): + connection = 'keep-alive' + else: + connection = 'close' + + headers = [('CONNECTION', connection)] + + if self.chunked: + headers.append(('TRANSFER-ENCODING', 'chunked')) + + return headers + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(StreamEofException()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except StreamEofException: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(chunk) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except StreamEofException: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except StreamEofException: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk, buf = buf[:chunk_size], buf[chunk_size:] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format( + http_version, status, RESPONSES[status][0]) + + def _default_headers(self): + headers = super()._default_headers() + headers.extend((('DATE', email.utils.formatdate()), + ('SERVER', self.SERVER_SOFTWARE))) + + return headers + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, uri, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.uri = uri + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, uri, http_version) + + def _default_headers(self): + headers = super()._default_headers() + headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) + + return headers diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 0000000..7590e47 --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,176 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +import tulip.http + +from . import errors + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ +<html> + <head> + <title>%(status)s %(reason)s</title> + </head> + <body> + <h1>%(status)s %(reason)s</h1> + %(message)s + </body> +</html>""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + """ + closing = False + request_count = 0 + _request_handle = None + + def __init__(self, log=logging, debug=False): + self.log = log + self.debug = debug + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.http.HttpStreamReader() + self._request_handle = self.start() + + def data_received(self, data): + self.stream.feed_data(data) + + def connection_lost(self, exc): + if self._request_handle is not None: + self._request_handle.cancel() + self._request_handle = None + + def eof_received(self): + self.stream.feed_eof() + + def close(self): + self.closing = True + + def log_access(self, status, info, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.task + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. In case of any error connection is being closed. + """ + + while True: + info = None + message = None + self.request_count += 1 + + try: + info = yield from self.stream.read_request_line() + message = yield from self.stream.read_message(info.version) + + handler = self.handle_request(info, message) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self.closing: + self.transport.close() + break + + self._request_handle = None + + def handle_error(self, status=500, info=None, + message=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '<br><h2>Traceback:</h2>\n<pre>%s</pre>' % tb + except: + pass + + self.log_access(status, info, message) + + html = DEFAULT_ERROR_MESSAGE % { + 'status': status, 'reason': reason, 'message': msg} + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + + self.close() + + def handle_request(self, info, message): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=info.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.close() + self.log_access(404, info, message) diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 0000000..4024796 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,433 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:%s' % self._value) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 0000000..b918fe5 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 0000000..46ffa58 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,189 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +from . import base_events +from . import transports +from .log import tulip_log + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + + def loop(f=None): + try: + if f: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 0000000..593ee74 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,78 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 0000000..ee349e1 --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,291 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue'] + +import collections +import concurrent.futures +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0): + self._event_loop = events.get_event_loop() + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<%s at %s %s>' % ( + type(self).__name__, hex(id(self)), self._format()) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize=%r' % (self._maxsize, ) + if getattr(self, '_queue', None): + result += ' _queue=%r' % list(self._queue) + if self._getters: + result += ' _getters[%s]' % len(self._getters) + if self._putters: + result += ' _putters[%s]' % len(self._putters) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise queue.Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise queue.Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise queue.Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise queue.Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise queue.Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._event_loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise queue.Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise queue.Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise queue.Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0): + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter() + self._finished.set() + super().__init__(maxsize=maxsize) + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks=%s' % self._unfinished_tasks + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 0000000..20e5db0 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,655 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import errno +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.Selector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except (BlockingIOError, InterruptedError): + pass # False alarm. + except: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + return handle + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + return handle + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, (bytes, bytearray)), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except socket.error as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer.clear() + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, + protocol, sslcontext, waiter, extra=None): + super().__init__(extra) + + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._protocol.connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._protocol.connection_lost(exc) + + +class _SelectorDatagramTransport(transports.DatagramTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._event_loop.add_writer(self._fileno, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 0000000..57be7ab --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,418 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import sys +from select import * + +from .log import tulip_log + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}<fileobj={}, fd={}, events={:#x}, data={}>'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + tulip_log.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 0000000..8d7f623 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,145 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 0000000..734a5fa --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,139 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events +from .log import tulip_log + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + assert isinstance(data, bytes), repr(data) + if not data: + return + + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + pass + except Exception as exc: + self._fatal_error(exc) + return + else: + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + tulip_log.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + assert data, "Data shold not be empty" + + self._buffer = [] + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + pass + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 0000000..81359a2 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,320 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import functools +import inspect +import time + +from . import futures +from .log import tulip_log + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + tulip_log.warning( + 'Coroutine function %s is not a generator.', func.__name__) + + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None, timeout=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop, timeout=timeout) + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + '<PENDING' in res): + res = res.replace('<PENDING', '<CANCELLING', 1) + i = res.find('<') + if i < 0: + i = len(res) + res = res[:i] + '(<{}>)'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): + tulip_log.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + if not result._blocking: + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from in task %r ' + 'with %r' % (self, result))) + else: + result._blocking = False + result.add_done_callback(self._wakeup) + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task %r with %s' % (self, result))) + else: + if result is not None: + tulip_log.warn('_step(): bad yield: %r', result) + + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + # Will always be cancelled eventually. + bail = futures.Future(timeout=timeout) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 0000000..9b87db2 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,30 @@ +"""Utilities shared by tests.""" + +import logging +import socket +import sys +import unittest + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 0000000..a9ec07a --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,134 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 0000000..3073ab6 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,301 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import fcntl +import os +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import events +from . import selector_events +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + return handle + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if signal is None: + raise RuntimeError('Signals are not supported') + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except BlockingIOError: + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._fileno) + self._protocol.eof_received() + + def pause(self): + self._event_loop.remove_reader(self._fileno) + + def resume(self): + self._event_loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._event_loop.remove_reader(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() or write_eof() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._event_loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._call_connection_lost(None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 0000000..2ec8561 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows.""" + +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import winsocketpair +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 0000000..bd1e092 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) |