From 80d1312a3e9c869f26fa4790a8978fd7f8486fb1 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Thu, 28 Mar 2013 15:39:55 -0400 Subject: Use logger named 'tulip' for library events, Issue 26 --- .hgeol | 2 + .hgignore | 11 + Makefile | 31 + NOTES | 176 ++++++ README | 21 + TODO | 163 +++++ check.py | 41 ++ crawl.py | 143 +++++ curl.py | 35 ++ examples/udp_echo.py | 73 +++ old/Makefile | 16 + old/echoclt.py | 79 +++ old/echosvr.py | 60 ++ old/http_client.py | 78 +++ old/http_server.py | 68 ++ old/main.py | 134 ++++ old/p3time.py | 47 ++ old/polling.py | 535 ++++++++++++++++ old/scheduling.py | 354 +++++++++++ old/sockets.py | 348 +++++++++++ old/transports.py | 496 +++++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 +++ overlapped.c | 997 +++++++++++++++++++++++++++++ runtests.py | 198 ++++++ setup.cfg | 2 + setup.py | 14 + srv.py | 115 ++++ sslsrv.py | 56 ++ tests/base_events_test.py | 283 +++++++++ tests/events_test.py | 1379 +++++++++++++++++++++++++++++++++++++++++ tests/futures_test.py | 222 +++++++ tests/http_protocol_test.py | 972 +++++++++++++++++++++++++++++ tests/http_server_test.py | 242 ++++++++ tests/locks_test.py | 747 ++++++++++++++++++++++ tests/queues_test.py | 370 +++++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1286 ++++++++++++++++++++++++++++++++++++++ tests/selectors_test.py | 137 ++++ tests/streams_test.py | 299 +++++++++ tests/subprocess_test.py | 54 ++ tests/tasks_test.py | 647 +++++++++++++++++++ tests/transports_test.py | 45 ++ tests/unix_events_test.py | 573 +++++++++++++++++ tests/winsocketpair_test.py | 26 + tulip/TODO | 28 + tulip/__init__.py | 26 + tulip/base_events.py | 548 ++++++++++++++++ tulip/events.py | 356 +++++++++++ tulip/futures.py | 255 ++++++++ tulip/http/__init__.py | 12 + tulip/http/client.py | 145 +++++ tulip/http/errors.py | 44 ++ tulip/http/protocol.py | 877 ++++++++++++++++++++++++++ tulip/http/server.py | 176 ++++++ tulip/locks.py | 433 +++++++++++++ tulip/log.py | 6 + tulip/proactor_events.py | 189 ++++++ tulip/protocols.py | 78 +++ tulip/queues.py | 291 +++++++++ tulip/selector_events.py | 655 +++++++++++++++++++ tulip/selectors.py | 418 +++++++++++++ tulip/streams.py | 145 +++++ tulip/subprocess_transport.py | 139 +++++ tulip/tasks.py | 320 ++++++++++ tulip/test_utils.py | 30 + tulip/transports.py | 134 ++++ tulip/unix_events.py | 301 +++++++++ tulip/windows_events.py | 157 +++++ tulip/winsocketpair.py | 34 + 71 files changed, 17494 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 examples/udp_echo.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100755 srv.py create mode 100644 sslsrv.py create mode 100644 tests/base_events_test.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/subprocess_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/winsocketpair_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 0000000..b6910a2 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 0000000..2590249 --- /dev/null +++ b/.hgignore @@ -0,0 +1,11 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..274da4c --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=1 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 0000000..3b94ba9 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 0000000..85bfe5a --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 0000000..c6d4eea --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 0000000..64bc2cd --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 0000000..4e5bebe --- /dev/null +++ b/crawl.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +import tulip.http + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = tulip.http.HttpClientProtocol( + netloc, path=path, ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/curl.py b/curl.py new file mode 100755 index 0000000..37fce75 --- /dev/null +++ b/curl.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +import tulip.http + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(f) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(stream.read()) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 0000000..9e995d1 --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,73 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('127.0.0.1', 10000) + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "%s"' % self.message) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "%s"' % data.decode()) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=ADDRESS)) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=ADDRESS)) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 0000000..d352cd7 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 0000000..c24c573 --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 0000000..4085f4c --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 0000000..8937ba2 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 0000000..2b1e3dd --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 0000000..c1f9d0a --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 0000000..35e14c9 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 0000000..6586efc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 0000000..3864571 --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 0000000..a5005dc --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 0000000..19095bf --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 0000000..474009d --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 0000000..f55234b --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 0000000..c9f6ec9 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 0000000..096a256 --- /dev/null +++ b/runtests.py @@ -0,0 +1,198 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('%s%s' % (prefix, mod), modpath)) + + prefix = '%s%s.' % (prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('%s%s.' % (prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('%s%s' % (prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except Exception as err: + print("Skipping '%s': %s" % (modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: %s\n" % testsdir) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: %s\n" % sdir) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..0260f9d --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..dcaee96 --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/srv.py b/srv.py new file mode 100755 index 0000000..b28abbd --- /dev/null +++ b/srv.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import email.message +import os +import sys + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + def handle_request(self, request_info, message): + print('method = {!r}; path = {!r}; version = {!r}'.format( + request_info.method, request_info.uri, request_info.version)) + + path = request_info.uri + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +def main(): + host = '127.0.0.1' + port = 8080 + if sys.argv[1:]: + host = sys.argv[1] + if sys.argv[2:]: + port = int(sys.argv[2]) + elif ':' in host: + host, port = host.split(':', 1) + port = int(port) + loop = tulip.get_event_loop() + f = loop.start_serving(lambda: HttpServer(debug=True), host, port) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 0000000..a1bc04f --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 0000000..88f3faf --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,283 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.event_loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_write_pipe_transport, m, m) + + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + # If event loop has ready callbacks, select timeout is always 0. + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(event_loop): + nonlocal processed, handle + processed = True + handle = event_loop.call_soon(lambda: True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.event_loop._ready)) + + def test_run_until_complete_assertion(self): + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, 'blah') + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 0000000..4085921 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1379 @@ +"""Tests for events.py.""" + +import concurrent.futures +import contextlib +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + +class MyReadPipeProto(protocols.Protocol): + + def __init__(self): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + + +class MyWritePipeProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + @contextlib.contextmanager + def run_test_server(self, *, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.dirname(__file__) + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '302 Found' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server('127.0.0.1', 0, app, + server_class, SilentWSGIRequestHandler) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_run_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + self.assertTrue(self.event_loop.is_running()) + try: + self.event_loop.run_until_complete( + tasks.sleep(0.1)) + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + tasks.sleep(0.1) + try: + self.event_loop.run_once() + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_block(self): + called = False + + def callback(): + nonlocal called + called = True + + def run(): + time.sleep(0.1) + self.event_loop.call_soon_threadsafe(callback) + + t = threading.Thread(target=run) + t0 = time.monotonic() + t.start() + self.event_loop.run_once(None) + t1 = time.monotonic() + t.join() + self.assertTrue(called) + self.assertTrue(0.09 < t1-t0 <= 0.12) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handle(self): + results = [] + + def callback(): + results.append('yeah') + + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handle(self): + results = [] + + def callback(arg): + results.append(arg) + + handle = events.Handle(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handle), handle) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handle(self): + def run(arg): + time.sleep(0.1) + return arg + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handle(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handle.cancel() + if not data: + r.close() + + handle = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handle(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handle.cancel() + + handle = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + with self.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + address = httpd.socket.getsockname() + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with self.run_test_server(use_ssl=True) as httpsd: + host, port = httpsd.socket.getsockname() + f = self.event_loop.create_connection( + MyProto, host, port, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.event_loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.event_loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, remote_addr=(host, port)) + transport, client = self.event_loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + self.event_loop.run_once(None) + self.assertEqual(3, server.nbytes) + self.event_loop.run_once(None) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_create_datagram_endpoint_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.event_loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.task + def connect(): + t, p = yield from self.event_loop.connect_read_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.event_loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + self.event_loop.run_once() + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.event_loop.run_once() + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_noaddr_nofamily(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_socket_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_connect_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handle, h1, (1, 2)) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, ev_loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 0000000..5569cca --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,222 @@ +"""Tests for futures.py.""" + +import unittest + +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 0000000..74aef7c --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,972 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_request_line()) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + self.stream.read_response_status()) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete(self.stream.read_headers()) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + self.stream.read_headers()) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 1))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 0))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=False)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: %s\r\n' + 'Content-Encoding: deflate\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete(self.stream.read_message()) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: %s\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'datadata', data) + + def test_read_message_payload_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, msg.payload.read()) + + def test_read_message_length_payload_zero(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete(self.stream.read_message()) + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', data) + + def test_read_message_length_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_eof_payload(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_length_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertIsInstance(msg.payload, tulip.StreamReader) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_message_length_payload_extra(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_parse_length_payload_eof_exc(self): + parser = self.stream._parse_length_payload(4) + next(parser) + + stream = tulip.StreamReader() + parser.send(stream) + self.stream._parser = parser + self.stream.feed_data(b'da') + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.wait([t1, t2])) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._parser) + + def test_read_message_deflate_payload(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Encoding: deflate\r\n' + + ('Content-Length: %s\r\n\r\n' % len(data)).encode()) + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(data) + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_chunked_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_chunks(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_chunked_payload_extension(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_size_error(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'blah\r\n') + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_deflate_stream_set_exception(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.set_exception(exc) + self.assertIs(exc, stream.exception()) + + def test_deflate_stream_feed_data(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.return_value = b'line' + + dstream.feed_data(b'data') + self.assertEqual([b'line'], list(stream.buffer)) + + def test_deflate_stream_feed_data_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.side_effect = exc + + dstream.feed_data(b'data') + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + def test_deflate_stream_feed_eof(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + + dstream.feed_eof() + self.assertEqual([b'line'], list(stream.buffer)) + self.assertTrue(stream.eof) + + def test_deflate_stream_feed_eof_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + dstream.zlib.eof = False + + dstream.feed_eof() + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], msg.headers) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], msg.headers) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], msg.headers) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg.force_chunked() + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 0000000..dc55eff --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,242 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip.test_utils import LogTrackingTestCase + + +class HttpServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_http_status_exception(self): + exc = errors.HttpStatusException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol() + self.assertIsNone(srv._request_handle) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handle) + + def test_data_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', b''.join(srv.stream.buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', b''.join(srv.stream.buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream.eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + handle = srv._request_handle + srv.connection_lost(None) + + self.assertIsNone(srv._request_handle) + self.assertTrue(handle.cancelled()) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handle) + + def test_close(self): + srv = server.ServerHttpProtocol() + self.assertFalse(srv.closing) + + srv.close() + self.assertTrue(srv.closing) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + srv.handle_error(404) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(handle.called) + self.assertIsNone(srv._request_handle) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + called = False + + @tulip.coroutine + def coro(rline, message): + nonlocal called + called = True + yield from [] + srv.eof_received() + + srv.handle_request = coro + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(called) + + def test_handle_close(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.task + def cancel(): + yield from [] + srv._request_handle.cancel() + + srv.close() + self.loop.run_until_complete( + tulip.wait([srv._request_handle, cancel()])) + self.assertTrue(log.debug.called) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + + def side_effect(*args): + srv.close() + srv.handle_error.side_effect = side_effect + + srv.stream.feed_data(b'GET / HT/asd\r\n') + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + srv.close() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 0000000..7d2111d --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,747 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + + self.event_loop.call_later(0.1, lock.release) + acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.1)) + tasks.Task(lock.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.1, ev.set) + acquired = self.event_loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(cond.acquire())) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.2)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.18 < total_time < 0.22) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete( + cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + self.event_loop.call_later(0.1, sem.release) + acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 0000000..d86abd7 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,370 @@ +"""Tests for queues.py""" + +import unittest +import queue + +from tulip import events +from tulip import locks +from tulip import queues +from tulip import tasks + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = queues.Queue() + self.assertTrue(fn(q).startswith('", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 0000000..dc6eeaf --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,299 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(test_utils.LogTrackingTestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + super().tearDown() + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.event_loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.event_loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + self.event_loop.run_until_complete(stream.readline()) + + data = self.event_loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.event_loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + def set_err(): + yield from [] + stream.set_exception(ValueError()) + + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 0000000..09aaed5 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 0000000..9ac15bb --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,647 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + def coro(): + yield from [] + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.task + def inner1(): + yield from [] + return 42 + + @tasks.task + def inner2(): + yield from [] + return 1000 + + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + t = tasks.Task(coro(), timeout=0.1) + + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout_catch(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + err = None + + @tasks.coroutine + def coro2(): + nonlocal err + try: + yield from tasks.Task(coro(), timeout=0.1) + except futures.CancelledError as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro2())) + self.assertIsInstance(err, futures.CancelledError) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + self.suppress_log_errors() + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield from [None] + + @tasks.coroutine + def coro2(): + yield from [None, None] + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + + @tasks.coroutine + def exc(): + yield from [] + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper()) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + time.monotonic() + yield from sleepfut + finally: + time.monotonic() + + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + time.monotonic() + yield from sleeper + except futures.CancelledError: + time.monotonic() + return 'cancelled' + else: + return 'slept in' + + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warn.call_args[0][0].startswith( + '_step(): already done: ')) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [None, 1] + return 'ko' + + task = tasks.Task(notmuch()) + task._step() + self.assertFalse(m_logging.warn.called) + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertEqual( + '_step(): bad yield: %r', + m_logging.warn.call_args[0][0]) + self.assertEqual(1, m_logging.warn.call_args[0][1]) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + self.suppress_log_warnings() + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + self.suppress_log_warnings() + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + yield from [] + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_yield_vs_yield_from_generator(self): + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.event_loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future() + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func()) + tasks.Task(coro()) + res = self.event_loop.run_until_complete(t1) + self.assertEqual(res, 'test') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 0000000..4b24b50 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,45 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 0000000..d7af7ec --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,573 @@ +"""Tests for unix_events.py.""" + +import errno +import io +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handle) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.eof_received.assert_called_with() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.reset_mock() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.pause() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.resume() + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.event_loop.remove_reader.assert_called_with(5) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.called) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 0000000..381fb22 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,26 @@ +"""Tests for winsocketpair.py""" + +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 0000000..acec5c2 --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,28 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() + +- Remove test dependency on xkcd.com, write our own test server diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 0000000..faf307f --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 0000000..5ed257a --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,548 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handle = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handle.cancel() + + def run_once(self, timeout=0): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + self._run_once(timeout) + except _StopError: + pass + finally: + self._running = False + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if not isinstance(future, futures.Future): + if tasks.iscoroutine(future): + future = tasks.Task(future) + else: + assert False, 'A Future or coroutine is required' + + handle_called = False + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + + if handle_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if delay <= 0: + return self.call_soon(callback, *args) + + handle = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + assert interval > 0, 'Interval must be > 0: %r' % (interval,) + # TODO: What if callback is already a Handle? + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.Timer) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for (family, proto), (local_address, remote_address) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: + return + if isinstance(handle, events.Timer): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handle = self._scheduled[0] + if handle.when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 0000000..3a6ad40 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,356 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import sys +import threading + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handle(callback, args): + if isinstance(callback, Handle): + assert not args + return callback + return Handle(callback, args) + + +class Timer(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handles for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + #def spawn_subprocess(self, protocol_factory, pipe): + # raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handle. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 0000000..39137aa --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,255 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _timeout_handle = None + + _blocking = False # proper use of future (yield vs yield from) + + def __init__(self, *, event_loop=None, timeout=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + if timeout is not None: + self._timeout_handle = self._event_loop.call_later( + timeout, self.cancel) + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 0000000..582f080 --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,12 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 0000000..b65b90a --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,145 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +__all__ = ['HttpClientProtocol'] + + +import email.message +import email.parser + +import tulip + +from . import protocol + + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + sts, headers, stream = yield from p.connect() + + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version=(1, 1), + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + self.headers['Accept-Encoding'] = 'gzip, deflate' + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = version + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = tulip.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection( + lambda: self, self.host, self.port, ssl=self.ssl) + + # read response status + version, status, reason = yield from self.stream.read_response_status() + + message = yield from self.stream.read_message(version) + + # headers + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + sts = '{} {}'.format(status, reason) + return (sts, headers, message.payload) + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + + self.request = protocol.Request( + transport, self.method, self.path, self.version) + + self.request.add_headers(*self.headers.items()) + self.request.send_headers() + + if self.make_body is not None: + if self.chunked: + self.make_body( + self.request.write, self.request.eof) + else: + self.make_body( + self.request.write, self.request.eof) + else: + self.request.write_eof() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 0000000..41344de --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,44 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpStatusException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + + +class HttpStatusException(HttpException): + + def __init__(self, code, headers=None, message=''): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: %s' % hdr) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 0000000..6a0e127 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,877 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', + 'HttpMessage', 'Request', 'Response', + 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] + +import collections +import email.utils +import functools +import http.server +import itertools +import re +import sys +import zlib + +import tulip +from . import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') +RESPONSES = http.server.BaseHTTPRequestHandler.responses + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +RawHttpMessage = collections.namedtuple( + 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + +class HttpStreamReader(tulip.StreamReader): + + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + + # if _parser is set, feed_data and feed_eof sends data into + # _parser instead of self. is it being used as stream redirection for + # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload + _parser = None + + def feed_data(self, data): + """_parser is a generator, if _parser is set, feed_data sends + incoming data into the generator untile generator stops.""" + if self._parser: + try: + self._parser.send(data) + except StopIteration as exc: + self._parser = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_parser is a generator, if _parser is set feed_eof throws + StreamEofException into this generator.""" + if self._parser: + try: + self._parser.throw(StreamEofException()) + except StopIteration: + self._parser = None + + super().feed_eof() + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header %s' % line.strip()) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name %s' % name) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise errors.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers + + def _parse_chunked_payload(self): + """Chunked transfer encoding parser.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + except errors.IncompleteRead as exc: + stream.set_exception(exc) + + def _parse_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + + def _parse_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload parser + if chunked: + parser = self._parse_chunked_payload() + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + + parser = self._parse_length_payload(length) + else: + if readall: + parser = self._parse_eof_payload() + else: + parser = self._parse_length_payload(0) + + next(parser) + + payload = stream = tulip.StreamReader() + + # payload decompression wrapper + if encoding is not None: + stream = DeflateStream(stream, encoding) + + try: + # initialize payload parser with stream, stream is being + # used by parser as destination stream + parser.send(stream) + except StopIteration: + pass + else: + # feed existing buffer to payload parser + self.byte_count = 0 + while self.buffer: + try: + parser.send(self.buffer.popleft()) + except StopIteration as exc: + parser = None + + # parser is done + buf = b''.join(self.buffer) + self.buffer.clear() + + # re-add remaining data back to buffer + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + # parser still require more data + if parser is not None: + if self.eof: + try: + parser.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._parser = parser + + return RawHttpMessage(headers, payload, close_conn, encoding) + + +class StreamEofException(Exception): + """Internal exception: eof received.""" + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def set_exception(self, exc): + self.stream.set_exception(exc) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(errors.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(errors.IncompleteRead(b'')) + + self.stream.feed_eof() + + +EOF_MARKER = object() +EOL_MARKER = object() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + self.keepalive = False + + self.chunked = False + self.length = None + self.upgrade = False + self.headers = [] + self.headers_sent = False + + def force_close(self): + self.closing = True + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + return self.keepalive and not self.closing + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '%r is not a string' % name + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower().strip() + # handle websocket + if val == 'upgrade': + self.upgrade = True + # connection keep-alive + elif val == 'close': + self.keepalive = False + elif val == 'keep-alive': + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + # status line + self.transport.write(self.status_line.encode('ascii')) + + # send headers + self.transport.write( + ('%s\r\n\r\n' % '\r\n'.join( + ('%s: %s' % (k, v) for k, v in + itertools.chain(self._default_headers(), self.headers))) + ).encode('ascii')) + + def _default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif self.keep_alive(): + connection = 'keep-alive' + else: + connection = 'close' + + headers = [('CONNECTION', connection)] + + if self.chunked: + headers.append(('TRANSFER-ENCODING', 'chunked')) + + return headers + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(StreamEofException()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except StreamEofException: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(chunk) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except StreamEofException: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except StreamEofException: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk, buf = buf[:chunk_size], buf[chunk_size:] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format( + http_version, status, RESPONSES[status][0]) + + def _default_headers(self): + headers = super()._default_headers() + headers.extend((('DATE', email.utils.formatdate()), + ('SERVER', self.SERVER_SOFTWARE))) + + return headers + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, uri, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.uri = uri + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, uri, http_version) + + def _default_headers(self): + headers = super()._default_headers() + headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) + + return headers diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 0000000..7590e47 --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,176 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +import tulip.http + +from . import errors + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + %(status)s %(reason)s + + +

%(status)s %(reason)s

+ %(message)s + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + """ + closing = False + request_count = 0 + _request_handle = None + + def __init__(self, log=logging, debug=False): + self.log = log + self.debug = debug + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.http.HttpStreamReader() + self._request_handle = self.start() + + def data_received(self, data): + self.stream.feed_data(data) + + def connection_lost(self, exc): + if self._request_handle is not None: + self._request_handle.cancel() + self._request_handle = None + + def eof_received(self): + self.stream.feed_eof() + + def close(self): + self.closing = True + + def log_access(self, status, info, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.task + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. In case of any error connection is being closed. + """ + + while True: + info = None + message = None + self.request_count += 1 + + try: + info = yield from self.stream.read_request_line() + message = yield from self.stream.read_message(info.version) + + handler = self.handle_request(info, message) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self.closing: + self.transport.close() + break + + self._request_handle = None + + def handle_error(self, status=500, info=None, + message=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
%s
' % tb + except: + pass + + self.log_access(status, info, message) + + html = DEFAULT_ERROR_MESSAGE % { + 'status': status, 'reason': reason, 'message': msg} + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + + self.close() + + def handle_request(self, info, message): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=info.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.close() + self.log_access(404, info, message) diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 0000000..4024796 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,433 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:%s' % self._value) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 0000000..b918fe5 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 0000000..46ffa58 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,189 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +from . import base_events +from . import transports +from .log import tulip_log + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + + def loop(f=None): + try: + if f: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 0000000..593ee74 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,78 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 0000000..ee349e1 --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,291 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue'] + +import collections +import concurrent.futures +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0): + self._event_loop = events.get_event_loop() + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<%s at %s %s>' % ( + type(self).__name__, hex(id(self)), self._format()) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize=%r' % (self._maxsize, ) + if getattr(self, '_queue', None): + result += ' _queue=%r' % list(self._queue) + if self._getters: + result += ' _getters[%s]' % len(self._getters) + if self._putters: + result += ' _putters[%s]' % len(self._putters) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise queue.Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise queue.Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise queue.Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise queue.Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise queue.Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._event_loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise queue.Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise queue.Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise queue.Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0): + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter() + self._finished.set() + super().__init__(maxsize=maxsize) + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks=%s' % self._unfinished_tasks + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 0000000..20e5db0 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,655 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import errno +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.Selector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except (BlockingIOError, InterruptedError): + pass # False alarm. + except: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + return handle + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + return handle + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, (bytes, bytearray)), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except socket.error as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer.clear() + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, + protocol, sslcontext, waiter, extra=None): + super().__init__(extra) + + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._protocol.connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._protocol.connection_lost(exc) + + +class _SelectorDatagramTransport(transports.DatagramTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._event_loop.add_writer(self._fileno, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 0000000..57be7ab --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,418 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import sys +from select import * + +from .log import tulip_log + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + tulip_log.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 0000000..8d7f623 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,145 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 0000000..734a5fa --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,139 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events +from .log import tulip_log + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + assert isinstance(data, bytes), repr(data) + if not data: + return + + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + pass + except Exception as exc: + self._fatal_error(exc) + return + else: + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + tulip_log.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + assert data, "Data shold not be empty" + + self._buffer = [] + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + pass + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 0000000..81359a2 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,320 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import functools +import inspect +import time + +from . import futures +from .log import tulip_log + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + tulip_log.warning( + 'Coroutine function %s is not a generator.', func.__name__) + + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None, timeout=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop, timeout=timeout) + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): + tulip_log.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + if not result._blocking: + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from in task %r ' + 'with %r' % (self, result))) + else: + result._blocking = False + result.add_done_callback(self._wakeup) + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task %r with %s' % (self, result))) + else: + if result is not None: + tulip_log.warn('_step(): bad yield: %r', result) + + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + # Will always be cancelled eventually. + bail = futures.Future(timeout=timeout) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 0000000..9b87db2 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,30 @@ +"""Utilities shared by tests.""" + +import logging +import socket +import sys +import unittest + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 0000000..a9ec07a --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,134 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 0000000..3073ab6 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,301 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import fcntl +import os +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import events +from . import selector_events +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + return handle + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if signal is None: + raise RuntimeError('Signals are not supported') + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except BlockingIOError: + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._fileno) + self._protocol.eof_received() + + def pause(self): + self._event_loop.remove_reader(self._fileno) + + def resume(self): + self._event_loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._event_loop.remove_reader(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() or write_eof() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._event_loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._call_connection_lost(None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 0000000..2ec8561 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows.""" + +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import winsocketpair +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 0000000..bd1e092 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) -- cgit v1.2.1