summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorA. Jesse Jiryu Davis <jesse@10gen.com>2013-03-28 15:39:55 -0400
committerA. Jesse Jiryu Davis <jesse@10gen.com>2013-03-28 15:39:55 -0400
commit80d1312a3e9c869f26fa4790a8978fd7f8486fb1 (patch)
treee5fdcef6fa6327b903a4bf897c771547d7a1434c
downloadtrollius-git-80d1312a3e9c869f26fa4790a8978fd7f8486fb1.tar.gz
Use logger named 'tulip' for library events, Issue 26
-rw-r--r--.hgeol2
-rw-r--r--.hgignore11
-rw-r--r--Makefile31
-rw-r--r--NOTES176
-rw-r--r--README21
-rw-r--r--TODO163
-rw-r--r--check.py41
-rwxr-xr-xcrawl.py143
-rwxr-xr-xcurl.py35
-rw-r--r--examples/udp_echo.py73
-rw-r--r--old/Makefile16
-rw-r--r--old/echoclt.py79
-rw-r--r--old/echosvr.py60
-rw-r--r--old/http_client.py78
-rw-r--r--old/http_server.py68
-rw-r--r--old/main.py134
-rw-r--r--old/p3time.py47
-rw-r--r--old/polling.py535
-rw-r--r--old/scheduling.py354
-rw-r--r--old/sockets.py348
-rw-r--r--old/transports.py496
-rwxr-xr-xold/xkcd.py18
-rw-r--r--old/yyftime.py75
-rw-r--r--overlapped.c997
-rw-r--r--runtests.py198
-rw-r--r--setup.cfg2
-rw-r--r--setup.py14
-rwxr-xr-xsrv.py115
-rw-r--r--sslsrv.py56
-rw-r--r--tests/base_events_test.py283
-rw-r--r--tests/events_test.py1379
-rw-r--r--tests/futures_test.py222
-rw-r--r--tests/http_protocol_test.py972
-rw-r--r--tests/http_server_test.py242
-rw-r--r--tests/locks_test.py747
-rw-r--r--tests/queues_test.py370
-rw-r--r--tests/sample.crt14
-rw-r--r--tests/sample.key15
-rw-r--r--tests/selector_events_test.py1286
-rw-r--r--tests/selectors_test.py137
-rw-r--r--tests/streams_test.py299
-rw-r--r--tests/subprocess_test.py54
-rw-r--r--tests/tasks_test.py647
-rw-r--r--tests/transports_test.py45
-rw-r--r--tests/unix_events_test.py573
-rw-r--r--tests/winsocketpair_test.py26
-rw-r--r--tulip/TODO28
-rw-r--r--tulip/__init__.py26
-rw-r--r--tulip/base_events.py548
-rw-r--r--tulip/events.py356
-rw-r--r--tulip/futures.py255
-rw-r--r--tulip/http/__init__.py12
-rw-r--r--tulip/http/client.py145
-rw-r--r--tulip/http/errors.py44
-rw-r--r--tulip/http/protocol.py877
-rw-r--r--tulip/http/server.py176
-rw-r--r--tulip/locks.py433
-rw-r--r--tulip/log.py6
-rw-r--r--tulip/proactor_events.py189
-rw-r--r--tulip/protocols.py78
-rw-r--r--tulip/queues.py291
-rw-r--r--tulip/selector_events.py655
-rw-r--r--tulip/selectors.py418
-rw-r--r--tulip/streams.py145
-rw-r--r--tulip/subprocess_transport.py139
-rw-r--r--tulip/tasks.py320
-rw-r--r--tulip/test_utils.py30
-rw-r--r--tulip/transports.py134
-rw-r--r--tulip/unix_events.py301
-rw-r--r--tulip/windows_events.py157
-rw-r--r--tulip/winsocketpair.py34
71 files changed, 17494 insertions, 0 deletions
diff --git a/.hgeol b/.hgeol
new file mode 100644
index 0000000..b6910a2
--- /dev/null
+++ b/.hgeol
@@ -0,0 +1,2 @@
+[patterns]
+** = native
diff --git a/.hgignore b/.hgignore
new file mode 100644
index 0000000..2590249
--- /dev/null
+++ b/.hgignore
@@ -0,0 +1,11 @@
+.*\.py[co]$
+.*~$
+.*\.orig$
+.*\#.*$
+.*@.*$
+\.coverage$
+htmlcov$
+\.DS_Store$
+venv$
+distribute_setup.py$
+distribute-\d+.\d+.\d+.tar.gz$
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..274da4c
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,31 @@
+# Some simple testing tasks (sorry, UNIX only).
+
+PYTHON=python3
+VERBOSE=1
+FLAGS=
+
+test:
+ $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS)
+
+testloop:
+ while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done
+
+# See README for coverage installation instructions.
+cov coverage:
+ $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS)
+ echo "open file://`pwd`/htmlcov/index.html"
+
+check:
+ $(PYTHON) check.py
+
+clean:
+ rm -rf `find . -name __pycache__`
+ rm -f `find . -type f -name '*.py[co]' `
+ rm -f `find . -type f -name '*~' `
+ rm -f `find . -type f -name '.*~' `
+ rm -f `find . -type f -name '@*' `
+ rm -f `find . -type f -name '#*#' `
+ rm -f `find . -type f -name '*.orig' `
+ rm -f `find . -type f -name '*.rej' `
+ rm -f .coverage
+ rm -rf htmlcov
diff --git a/NOTES b/NOTES
new file mode 100644
index 0000000..3b94ba9
--- /dev/null
+++ b/NOTES
@@ -0,0 +1,176 @@
+Notes from PyCon 2013 sprints
+=============================
+
+- Cancellation. If a task creates several subtasks, and then the
+ parent task fails, should the subtasks be cancelled? (How do we
+ even establish the parent/subtask relationship?)
+
+- Adam Sah suggests that there might be a need for scheduling
+ (especially when multiple frameworks share an event loop). He
+ points to lottery scheduling but also mentions that's just one of
+ the options. However, after posting on python-tulip, it appears
+ none of the other frameworks have scheduling, and nobody seems to
+ miss it.
+
+- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't
+ think connected UDP is worth supporting, it doesn't do anything
+ except tell the kernel about the default target address for
+ sendto(). Basically he says all UDP end points are servers. He
+ sent me his own UDP event loop so I might glean some tricks from it.
+ He says we should treat EINTR the same as EAGAIN and friends. (We
+ should use the exceptions dedicated to errno checking, BTW.) HE
+ said to make sure we use SO_REUSEADDR (I think we already do). He
+ said to set the max datagram sizes pretty large (anything larger
+ than the declared limit is dropped on the floor). He reminds us of
+ the importance of being able to pick a valid, unused port by binding
+ to port 0 and then using getsockname(). He has an idea where he's
+ like to be able to kill all registered callbacks (i.e. Handles)
+ belonging to a certain "context". I think this can be done at the
+ application level (you'd have to wrap everything that returns a
+ Handle and collect these handles in some set or other datastructure)
+ but if someone thinks it's interesting we could imagine having some
+ kind of notion of context part of the event loop state,
+ e.g. associated with a Task (see Cancellation point above). He
+ brought up uTP (Micro Transport Protocol), a reimplementation of TCP
+ over UDP with more refined congestion control.
+
+- Mumblings about UNIX domain sockets and IPv6 addresses being
+ 4-tuples. The former can be handled by passing in a socket. There
+ seem to be no real use cases for the latter that can't be dealt with
+ by passing in suitably esoteric strings for the hostname.
+ getaddrinfo() will produce the appropriate 4-tuple and connect()
+ will accept it.
+
+- Mumblings on the list about add vs. set.
+
+
+Notes from the second Tulip/Twisted meet-up
+===========================================
+
+Rackspace, 12/11/2012
+Glyph, Brian Warner, David Reid, Duncan McGreggor, others
+
+Flow control
+------------
+
+- Pause/resume on transport manages data_received.
+
+- There's also an API to tell the transport whom to pause when the
+ write calls are overwhelming it: IConsumer.registerProducer().
+
+- There's also something called pipes but it's built on top of the
+ old interface.
+
+- Twisted has variations on the basic flow control that I should
+ ignore.
+
+Half_close
+----------
+
+- This sends an EOF after writing some stuff.
+
+- Can't write any more.
+
+- Problem with TLS is known (the RFC sadly specifies this behavior).
+
+- It must be dynamimcally discoverable whether the transport supports
+ half_close, since the protocol may have to do something different to
+ make up for its missing (e.g. use chunked encoding). Twisted uses
+ an interface check for this and also hasattr(trans, 'halfClose')
+ but a flag (or flag method) is fine too.
+
+Constructing transport and protocol
+-----------------------------------
+
+- There are good reasons for passing a function to the transport
+ construction helper that creates the protocol. (You need these
+ anyway for server-side protocols.) The sequence of events is
+ something like
+
+ . open socket
+ . create transport (pass it a socket?)
+ . create protocol (pass it nothing)
+ . proto.make_connection(transport); this does:
+ . self.transport = transport
+ . self.connection_made(transport)
+
+ But it seems okay to skip make_connection and setting .transport.
+ Note that make_connection() is a concrete method on the Protocol
+ implementation base class, while connection_made() is an abstract
+ method on IProtocol.
+
+Event Loop
+----------
+
+- We discussed the sequence of actions in the event loop. I think in the
+ end we're fine with what Tulip currently does. There are two choices:
+
+ Tulip:
+ . run ready callbacks until there aren't any left
+ . poll, adding more callbacks to the ready list
+ . add now-ready delayed callbacks to the ready list
+ . go to top
+
+ Tornado:
+ . run all currently ready callbacks (but not new ones added during this)
+ . (the rest is the same)
+
+ The difference is that in the Tulip version, CPU bound callbacks
+ that keep adding more to the queue will starve I/O (and yielding to
+ other tasks won't actually cause I/O to happen unless you do
+ e.g. sleep(0.001)). OTOH this may be good because it means there's
+ less overhead if you frequently split operations in two.
+
+- I think Twisted does it Tornado style (in a convoluted way :-), but
+ it may not matter, and it's important to leave this vague so
+ implementations can do what's best for their platform. (E.g. if the
+ event loop is built into the OS there are different trade-offs.)
+
+System call cost
+----------------
+
+- System calls on MacOS are expensive, on Linux they are cheap.
+
+- Optimal buffer size ~16K.
+
+- Try joining small buffer pieces together, but expect to be tuning
+ this later.
+
+Futures
+-------
+
+- Futures are the most robust API for async stuff, you can check
+ errors etc. So let's do this.
+
+- Just don't implement wait().
+
+- For the basics, however, (recv/send, mostly), don't use Futures but use
+ basic callbacks, transport/protocol style.
+
+- make_connection() (by any name) can return a Future, it makes it
+ easier to check for errors.
+
+- This means revisiting the Tulip proactor branch (IOCP).
+
+- The semantics of add_done_callback() are fuzzy about in which thread
+ the callback will be called. (It may be the current thread or
+ another one.) We don't like that. But always inserting a
+ call_soon() indirection may be expensive? Glyph suggested changing
+ the add_done_callback() method name to something else to indicate
+ the changed promise.
+
+- Separately, I've been thinking about having two versions of
+ call_soon() -- a more heavy-weight one to be called from other
+ threads that also writes a byte to the self-pipe.
+
+Signals
+-------
+
+- There was a side conversation about signals. A signal handler is
+ similar to another thread, so probably should use (the heavy-weight
+ version of) call_soon() to schedule the real callback and not do
+ anything else.
+
+- Glyph vaguely recalled some trickiness with the self-pipe. We
+ should be able to fix this afterwards if necessary, it shouldn't
+ affect the API design.
diff --git a/README b/README
new file mode 100644
index 0000000..85bfe5a
--- /dev/null
+++ b/README
@@ -0,0 +1,21 @@
+Tulip is the codename for my reference implementation of PEP 3156.
+
+PEP 3156: http://www.python.org/dev/peps/pep-3156/
+
+*** This requires Python 3.3 or later! ***
+
+Copyright/license: Open source, Apache 2.0. Enjoy.
+
+Master Mercurial repo: http://code.google.com/p/tulip/
+
+The old code lives in the subdirectory 'old'; the new code (conforming
+to PEP 3156, under construction) lives in the 'tulip' subdirectory.
+
+To run tests:
+ - make test
+
+To run coverage (coverage package is required):
+ - make coverage
+
+
+--Guido van Rossum <guido@python.org>
diff --git a/TODO b/TODO
new file mode 100644
index 0000000..c6d4eea
--- /dev/null
+++ b/TODO
@@ -0,0 +1,163 @@
+# -*- Mode: text -*-
+
+TO DO LARGER TASKS
+
+- Need more examples.
+
+- Benchmarkable but more realistic HTTP server?
+
+- Example of using UDP.
+
+- Write up a tutorial for the scheduling API.
+
+- More systematic approach to logging. Logger objects? What about
+ heavy-duty logging, tracking essentially all task state changes?
+
+- Restructure directory, move demos and benchmarks to subdirectories.
+
+
+TO DO LATER
+
+- When multiple tasks are accessing the same socket, they should
+ either get interleaved I/O or an immediate exception; it should not
+ compromise the integrity of the scheduler or the app or leave a task
+ hanging.
+
+- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors.
+
+- Add the simplest API possible to run a generator with a timeout.
+
+- Ensure multiple tasks can do atomic writes to the same pipe (since
+ UNIX guarantees that short writes to pipes are atomic).
+
+- Ensure some easy way of distributing accepted connections across tasks.
+
+- Be wary of thread-local storage. There should be a standard API to
+ get the current Context (which holds current task, event loop, and
+ maybe more) and a standard meta-API to change how that standard API
+ works (i.e. without monkey-patching).
+
+- See how much of asyncore I've already replaced.
+
+- Could BufferedReader reuse the standard io module's readers???
+
+- Support ZeroMQ "sockets" which are user objects. Though possibly
+ this can be supported by getting the underlying fd? See
+ http://mail.python.org/pipermail/python-ideas/2012-October/017532.html
+ OTOH see
+ https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py
+
+- Study goroutines (again).
+
+- Benchmarks: http://nichol.as/benchmark-of-python-web-servers
+
+
+FROM OLDER LIST
+
+- Multiple readers/writers per socket? (At which level? pollster,
+ eventloop, or scheduler?)
+
+- Could poll() usefully be an iterator?
+
+- Do we need to support more epoll and/or kqueue modes/flags/options/etc.?
+
+- Optimize register/unregister calls away if they cancel each other out?
+
+- Add explicit wait queue to wait for Task's completion, instead of
+ callbacks?
+
+- Look at pyfdpdlib's ioloop.py:
+ http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py
+
+
+MISTAKES I MADE
+
+- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().)
+
+- Forgot to add bare yield at end of internal function, after block().
+
+- Forgot to call add_done_callback().
+
+- Forgot to pass an undoer to block(), bug only found when cancelled.
+
+- Subtle accounting mistake in a callback.
+
+- Used context.eventloop from a different thread, forgetting about TLS.
+
+- Nasty race: eventloop.ready may contain both an I/O callback and a
+ cancel callback. How to avoid? Keep the DelayedCall in ready. Is
+ that enough?
+
+- If a toplevel task raises an error it just stops and nothing is logged
+ unless you have debug logging on. This confused me. (Then again,
+ previously I logged whenever a task raised an error, and that was too
+ chatty...)
+
+- Forgot to set the connection socket returned by accept() in
+ nonblocking mode.
+
+- Nastiest so far (cost me about a day): A race condition in
+ call_in_thread() where the Future's done_callback (which was
+ task.unblock()) would run immediately at the time when
+ add_done_callback() was called, and this screwed over the task
+ state. Solution: wrap the callback in eventloop.call_later().
+ Ironically, I had a comment stating there might be a race condition.
+
+- Another bug where I was calling unblock() for the current thread
+ immediately after calling block(), before yielding.
+
+- readexactly() wasn't checking for EOF, so could be looping.
+ (Worse, the first fix I attempted was wrong.)
+
+- Spent a day trying to understand why a tentative patch trying to
+ move the recv() implementation into the eventloop (or the pollster)
+ resulted in problems cancelling a recv() call. Ultimately the
+ problem is that the cancellation mechanism is part of the coroutine
+ scheduler, which simply throws an exception into a task when it next
+ runs, and there isn't anything to be interrupted in the eventloop;
+ but the eventloop still has a reader registered (which will never
+ fire because I suspended the server -- that's my test case :-).
+ Then, the eventloop keeps running until the last file descriptor is
+ unregistered. What contributed to this disaster?
+ * I didn't build the whole infrastructure, just played with recv()
+ * I don't have unittests
+ * I don't have good logging to see what is going
+
+- In sockets.py, in some SSL error handling code, used the wrong
+ variable (sock instead of sslsock). A linter would have found this.
+
+- In polling.py, in KqueuePollster.register_writer(), a copy/paste
+ error where I was testing for "if fd not in self.readers" instead of
+ writers. This only came out when I had both a reader and a writer
+ for the same fd.
+
+- Submitted some changes prematurely (forgot to pass the filename on
+ hg ci).
+
+- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work
+ as I expected. I ran into this with the origininal sockets.py and
+ again in transport.py.
+
+- Having the same callback for both reading and writing has a problem:
+ it may be scheduled twice, and if the first call closes the socket,
+ the second runs into trouble.
+
+
+MISTAKES I MADE IN TULIP V2
+
+- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks.
+ Spot the bug in these four lines:
+
+ def _schedule_callbacks(self):
+ callbacks = self._callbacks[:]
+ self._callbacks[:] = []
+ for callback in self._callbacks:
+ self._event_loop.call_soon(callback, self)
+
+ The good news is that I found it with a unittest (albeit not the
+ unittest intended to exercise this particular method :-( ).
+
+- In _make_self_pipe_or_sock(), called _pollster.register_reader()
+ instead of add_reader(), trying to optimize something but breaking
+ things instead (since the -- internal -- API of register_reader()
+ had changed).
diff --git a/check.py b/check.py
new file mode 100644
index 0000000..64bc2cd
--- /dev/null
+++ b/check.py
@@ -0,0 +1,41 @@
+"""Search for lines > 80 chars or with trailing whitespace."""
+
+import sys, os
+
+def main():
+ args = sys.argv[1:] or os.curdir
+ for arg in args:
+ if os.path.isdir(arg):
+ for dn, dirs, files in os.walk(arg):
+ for fn in sorted(files):
+ if fn.endswith('.py'):
+ process(os.path.join(dn, fn))
+ dirs[:] = [d for d in dirs if d[0] != '.']
+ dirs.sort()
+ else:
+ process(arg)
+
+def isascii(x):
+ try:
+ x.encode('ascii')
+ return True
+ except UnicodeError:
+ return False
+
+def process(fn):
+ try:
+ f = open(fn)
+ except IOError as err:
+ print(err)
+ return
+ try:
+ for i, line in enumerate(f):
+ line = line.rstrip('\n')
+ sline = line.rstrip()
+ if len(line) > 80 or line != sline or not isascii(line):
+ print('%s:%d:%s%s' % (
+ fn, i+1, sline, '_' * (len(line) - len(sline))))
+ finally:
+ f.close()
+
+main()
diff --git a/crawl.py b/crawl.py
new file mode 100755
index 0000000..4e5bebe
--- /dev/null
+++ b/crawl.py
@@ -0,0 +1,143 @@
+#!/usr/bin/env python3
+
+import logging
+import re
+import signal
+import socket
+import sys
+import urllib.parse
+
+import tulip
+import tulip.http
+
+END = '\n'
+MAXTASKS = 100
+
+
+class Crawler:
+
+ def __init__(self, rooturl):
+ self.rooturl = rooturl
+ self.todo = set()
+ self.busy = set()
+ self.done = {}
+ self.tasks = set()
+ self.waiter = None
+ self.addurl(self.rooturl, '') # Set initial work.
+ self.run() # Kick off work.
+
+ def addurl(self, url, parenturl):
+ url = urllib.parse.urljoin(parenturl, url)
+ url, frag = urllib.parse.urldefrag(url)
+ if not url.startswith(self.rooturl):
+ return False
+ if url in self.busy or url in self.done or url in self.todo:
+ return False
+ self.todo.add(url)
+ waiter = self.waiter
+ if waiter is not None:
+ self.waiter = None
+ waiter.set_result(None)
+ return True
+
+ @tulip.task
+ def run(self):
+ while self.todo or self.busy or self.tasks:
+ complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0)
+ print(len(complete), 'completed tasks,', len(self.tasks),
+ 'still pending ', end=END)
+ for task in complete:
+ try:
+ yield from task
+ except Exception as exc:
+ print('Exception in task:', exc, end=END)
+ while self.todo and len(self.tasks) < MAXTASKS:
+ url = self.todo.pop()
+ self.busy.add(url)
+ self.tasks.add(self.process(url)) # Async task.
+ if self.busy:
+ self.waiter = tulip.Future()
+ yield from self.waiter
+ tulip.get_event_loop().stop()
+
+ @tulip.task
+ def process(self, url):
+ ok = False
+ p = None
+ try:
+ print('processing', url, end=END)
+ scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url)
+ if not path:
+ path = '/'
+ if query:
+ path = '?'.join([path, query])
+ p = tulip.http.HttpClientProtocol(
+ netloc, path=path, ssl=(scheme=='https'))
+ delay = 1
+ while True:
+ try:
+ status, headers, stream = yield from p.connect()
+ break
+ except socket.error as exc:
+ if delay >= 60:
+ raise
+ print('...', url, 'has error', repr(str(exc)),
+ 'retrying after sleep', delay, '...', end=END)
+ yield from tulip.sleep(delay)
+ delay *= 2
+
+ if status[:3] in ('301', '302'):
+ # Redirect.
+ u = headers.get('location') or headers.get('uri')
+ if self.addurl(u, url):
+ print(' ', url, status[:3], 'redirect to', u, end=END)
+ elif status.startswith('200'):
+ ctype = headers.get_content_type()
+ if ctype == 'text/html':
+ while True:
+ line = yield from stream.readline()
+ if not line:
+ break
+ line = line.decode('utf-8', 'replace')
+ urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)',
+ line)
+ for u in urls:
+ if self.addurl(u, url):
+ print(' ', url, 'href to', u, end=END)
+ ok = True
+ finally:
+ if p is not None:
+ p.transport.close()
+ self.done[url] = ok
+ self.busy.remove(url)
+ if not ok:
+ print('failure for', url, sys.exc_info(), end=END)
+ waiter = self.waiter
+ if waiter is not None:
+ self.waiter = None
+ waiter.set_result(None)
+
+
+def main():
+ rooturl = sys.argv[1]
+ c = Crawler(rooturl)
+ loop = tulip.get_event_loop()
+ try:
+ loop.add_signal_handler(signal.SIGINT, loop.stop)
+ except RuntimeError:
+ pass
+ loop.run_forever()
+ print('todo:', len(c.todo))
+ print('busy:', len(c.busy))
+ print('done:', len(c.done), '; ok:', sum(c.done.values()))
+ print('tasks:', len(c.tasks))
+
+
+if __name__ == '__main__':
+ if '--iocp' in sys.argv:
+ from tulip import events, windows_events
+ sys.argv.remove('--iocp')
+ logging.info('using iocp')
+ el = windows_events.ProactorEventLoop()
+ events.set_event_loop(el)
+ main()
diff --git a/curl.py b/curl.py
new file mode 100755
index 0000000..37fce75
--- /dev/null
+++ b/curl.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python3
+
+import sys
+import urllib.parse
+
+import tulip
+import tulip.http
+
+
+def main():
+ url = sys.argv[1]
+ scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url)
+ if not path:
+ path = '/'
+ if query:
+ path = '?'.join([path, query])
+ print(netloc, path, scheme)
+ p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https'))
+ f = p.connect()
+ sts, headers, stream = p.event_loop.run_until_complete(f)
+ print(sts)
+ for k, v in headers.items():
+ print('{}: {}'.format(k, v))
+ print()
+ data = p.event_loop.run_until_complete(stream.read())
+ print(data.decode('utf-8', 'replace'))
+
+
+if __name__ == '__main__':
+ if '--iocp' in sys.argv:
+ from tulip import events, windows_events
+ sys.argv.remove('--iocp')
+ el = windows_events.ProactorEventLoop()
+ events.set_event_loop(el)
+ main()
diff --git a/examples/udp_echo.py b/examples/udp_echo.py
new file mode 100644
index 0000000..9e995d1
--- /dev/null
+++ b/examples/udp_echo.py
@@ -0,0 +1,73 @@
+"""UDP echo example.
+
+Start server:
+
+ >> python ./udp_echo.py --server
+
+"""
+
+import sys
+import tulip
+
+ADDRESS = ('127.0.0.1', 10000)
+
+
+class MyServerUdpEchoProtocol:
+
+ def connection_made(self, transport):
+ print('start', transport)
+ self.transport = transport
+
+ def datagram_received(self, data, addr):
+ print('Data received:', data, addr)
+ self.transport.sendto(data, addr)
+
+ def connection_refused(self, exc):
+ print('Connection refused:', exc)
+
+ def connection_lost(self, exc):
+ print('stop', exc)
+
+
+class MyClientUdpEchoProtocol:
+
+ message = 'This is the message. It will be repeated.'
+
+ def connection_made(self, transport):
+ self.transport = transport
+ print('sending "%s"' % self.message)
+ self.transport.sendto(self.message.encode())
+ print('waiting to receive')
+
+ def datagram_received(self, data, addr):
+ print('received "%s"' % data.decode())
+ self.transport.close()
+
+ def connection_refused(self, exc):
+ print('Connection refused:', exc)
+
+ def connection_lost(self, exc):
+ print('closing transport', exc)
+ loop = tulip.get_event_loop()
+ loop.stop()
+
+
+def start_server():
+ loop = tulip.get_event_loop()
+ tulip.Task(loop.create_datagram_endpoint(
+ MyServerUdpEchoProtocol, local_addr=ADDRESS))
+ loop.run_forever()
+
+
+def start_client():
+ loop = tulip.get_event_loop()
+ tulip.Task(loop.create_datagram_endpoint(
+ MyClientUdpEchoProtocol, remote_addr=ADDRESS))
+ loop.run_forever()
+
+
+if __name__ == '__main__':
+ if '--server' in sys.argv:
+ start_server()
+ else:
+ start_client()
diff --git a/old/Makefile b/old/Makefile
new file mode 100644
index 0000000..d352cd7
--- /dev/null
+++ b/old/Makefile
@@ -0,0 +1,16 @@
+PYTHON=python3
+
+main:
+ $(PYTHON) main.py -v
+
+echo:
+ $(PYTHON) echosvr.py -v
+
+profile:
+ $(PYTHON) -m profile -s time main.py
+
+time:
+ $(PYTHON) p3time.py
+
+ytime:
+ $(PYTHON) yyftime.py
diff --git a/old/echoclt.py b/old/echoclt.py
new file mode 100644
index 0000000..c24c573
--- /dev/null
+++ b/old/echoclt.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3.3
+"""Example echo client."""
+
+# Stdlib imports.
+import logging
+import socket
+import sys
+import time
+
+# Local imports.
+import scheduling
+import sockets
+
+
+def echoclient(host, port):
+ """COROUTINE"""
+ testdata = b'hi hi hi ha ha ha\n'
+ try:
+ trans = yield from sockets.create_transport(host, port,
+ af=socket.AF_INET)
+ except OSError:
+ return False
+ try:
+ ok = yield from trans.send(testdata)
+ if ok:
+ response = yield from trans.recv(100)
+ ok = response == testdata.upper()
+ return ok
+ finally:
+ trans.close()
+
+
+def doit(n):
+ """COROUTINE"""
+ t0 = time.time()
+ tasks = set()
+ for i in range(n):
+ t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i)
+ tasks.add(t)
+ ok = 0
+ bad = 0
+ for t in tasks:
+ try:
+ yield from t
+ except Exception:
+ bad += 1
+ else:
+ ok += 1
+ t1 = time.time()
+ print('ok: ', ok)
+ print('bad:', bad)
+ print('dt: ', round(t1-t0, 6))
+
+
+def main():
+ # Initialize logging.
+ if '-d' in sys.argv:
+ level = logging.DEBUG
+ elif '-v' in sys.argv:
+ level = logging.INFO
+ elif '-q' in sys.argv:
+ level = logging.ERROR
+ else:
+ level = logging.WARN
+ logging.basicConfig(level=level)
+
+ # Get integer from command line.
+ n = 1
+ for arg in sys.argv[1:]:
+ if not arg.startswith('-'):
+ n = int(arg)
+ break
+
+ # Run scheduler, starting it off with doit().
+ scheduling.run(doit(n))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/old/echosvr.py b/old/echosvr.py
new file mode 100644
index 0000000..4085f4c
--- /dev/null
+++ b/old/echosvr.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python3.3
+"""Example echo server."""
+
+# Stdlib imports.
+import logging
+import socket
+import sys
+
+# Local imports.
+import scheduling
+import sockets
+
+
+def handler(conn, addr):
+ """COROUTINE: Handle one connection."""
+ logging.info('Accepting connection from %r', addr)
+ trans = sockets.SocketTransport(conn)
+ rdr = sockets.BufferedReader(trans)
+ while True:
+ line = yield from rdr.readline()
+ logging.debug('Received: %r from %r', line, addr)
+ if not line:
+ break
+ yield from trans.send(line.upper())
+ logging.debug('Closing %r', addr)
+ trans.close()
+
+
+def doit():
+ """COROUTINE: Set the wheels in motion."""
+ # Set up listener.
+ listener = yield from sockets.create_listener('localhost', 1111,
+ af=socket.AF_INET,
+ backlog=100)
+ logging.info('Listening on %r', listener.sock.getsockname())
+
+ # Loop accepting connections.
+ while True:
+ conn, addr = yield from listener.accept()
+ t = scheduling.Task(handler(conn, addr))
+
+
+def main():
+ # Initialize logging.
+ if '-d' in sys.argv:
+ level = logging.DEBUG
+ elif '-v' in sys.argv:
+ level = logging.INFO
+ elif '-q' in sys.argv:
+ level = logging.ERROR
+ else:
+ level = logging.WARN
+ logging.basicConfig(level=level)
+
+ # Run scheduler, starting it off with doit().
+ scheduling.run(doit())
+
+
+if __name__ == '__main__':
+ main()
diff --git a/old/http_client.py b/old/http_client.py
new file mode 100644
index 0000000..8937ba2
--- /dev/null
+++ b/old/http_client.py
@@ -0,0 +1,78 @@
+"""Crummy HTTP client.
+
+This is not meant as an example of how to write a good client.
+"""
+
+# Stdlib.
+import re
+import time
+
+# Local.
+import sockets
+
+
+def urlfetch(host, port=None, path='/', method='GET',
+ body=None, hdrs=None, encoding='utf-8', ssl=None, af=0):
+ """COROUTINE: Make an HTTP 1.0 request."""
+ t0 = time.time()
+ if port is None:
+ if ssl:
+ port = 443
+ else:
+ port = 80
+ trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af)
+ yield from trans.send(method.encode(encoding) + b' ' +
+ path.encode(encoding) + b' HTTP/1.0\r\n')
+ if hdrs:
+ kwds = dict(hdrs)
+ else:
+ kwds = {}
+ if 'host' not in kwds:
+ kwds['host'] = host
+ if body is not None:
+ kwds['content_length'] = len(body)
+ for header, value in kwds.items():
+ yield from trans.send(header.replace('_', '-').encode(encoding) +
+ b': ' + value.encode(encoding) + b'\r\n')
+
+ yield from trans.send(b'\r\n')
+ if body is not None:
+ yield from trans.send(body)
+
+ # Read HTTP response line.
+ rdr = sockets.BufferedReader(trans)
+ resp = yield from rdr.readline()
+ m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z',
+ resp)
+ if not m:
+ trans.close()
+ raise IOError('No valid HTTP response: %r' % resp)
+ http_version, status, message = m.groups()
+
+ # Read HTTP headers.
+ headers = []
+ hdict = {}
+ while True:
+ line = yield from rdr.readline()
+ if not line.strip():
+ break
+ m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line)
+ if not m:
+ raise IOError('Invalid header: %r' % line)
+ header, value = m.groups()
+ headers.append((header, value))
+ hdict[header.decode(encoding).lower()] = value.decode(encoding)
+
+ # Read response body.
+ content_length = hdict.get('content-length')
+ if content_length is not None:
+ size = int(content_length) # TODO: Catch errors.
+ assert size >= 0, size
+ else:
+ size = 2**20 # Protective limit (1 MB).
+ data = yield from rdr.readexactly(size)
+ trans.close() # Can this block?
+ t1 = time.time()
+ result = (host, port, path, int(status), len(data), round(t1-t0, 3))
+## print(result)
+ return result
diff --git a/old/http_server.py b/old/http_server.py
new file mode 100644
index 0000000..2b1e3dd
--- /dev/null
+++ b/old/http_server.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python3.3
+"""Simple HTTP server.
+
+This currenty exists just so we can benchmark this thing!
+"""
+
+# Stdlib imports.
+import logging
+import re
+import socket
+import sys
+
+# Local imports.
+import scheduling
+import sockets
+
+
+def handler(conn, addr):
+ """COROUTINE: Handle one connection."""
+ ##logging.info('Accepting connection from %r', addr)
+ trans = sockets.SocketTransport(conn)
+ rdr = sockets.BufferedReader(trans)
+
+ # Read but ignore request line.
+ request_line = yield from rdr.readline()
+
+ # Consume headers but don't interpret them.
+ while True:
+ header_line = yield from rdr.readline()
+ if not header_line.strip():
+ break
+
+ # Always send an empty 200 response and close.
+ yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n')
+ trans.close()
+
+
+def doit():
+ """COROUTINE: Set the wheels in motion."""
+ # Set up listener.
+ listener = yield from sockets.create_listener('localhost', 8080,
+ af=socket.AF_INET)
+ logging.info('Listening on %r', listener.sock.getsockname())
+
+ # Loop accepting connections.
+ while True:
+ conn, addr = yield from listener.accept()
+ t = scheduling.Task(handler(conn, addr))
+
+
+def main():
+ # Initialize logging.
+ if '-d' in sys.argv:
+ level = logging.DEBUG
+ elif '-v' in sys.argv:
+ level = logging.INFO
+ elif '-q' in sys.argv:
+ level = logging.ERROR
+ else:
+ level = logging.WARN
+ logging.basicConfig(level=level)
+
+ # Run scheduler, starting it off with doit().
+ scheduling.run(doit())
+
+
+if __name__ == '__main__':
+ main()
diff --git a/old/main.py b/old/main.py
new file mode 100644
index 0000000..c1f9d0a
--- /dev/null
+++ b/old/main.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python3.3
+"""Example HTTP client using yield-from coroutines (PEP 380).
+
+Requires Python 3.3.
+
+There are many micro-optimizations possible here, but that's not the point.
+
+Some incomplete laundry lists:
+
+TODO:
+- Take test urls from command line.
+- Move urlfetch to a separate module.
+- Profiling.
+- Docstrings.
+- Unittests.
+
+FUNCTIONALITY:
+- Connection pool (keep connection open).
+- Chunked encoding (request and response).
+- Pipelining, e.g. zlib (request and response).
+- Automatic encoding/decoding.
+"""
+
+__author__ = 'Guido van Rossum <guido@python.org>'
+
+# Standard library imports (keep in alphabetic order).
+import logging
+import os
+import time
+import socket
+import sys
+
+# Local imports (keep in alphabetic order).
+import scheduling
+import http_client
+
+
+
+def doit2():
+ argses = [
+ ('localhost', 8080, '/'),
+ ('127.0.0.1', 8080, '/home'),
+ ('python.org', 80, '/'),
+ ('xkcd.com', 443, '/'),
+ ]
+ results = yield from scheduling.map_over(
+ lambda args: http_client.urlfetch(*args), argses, timeout=2)
+ for res in results:
+ print('-->', res)
+ return []
+
+
+def doit():
+ TIMEOUT = 2
+ tasks = set()
+
+ # This references NDB's default test service.
+ # (Sadly the service is single-threaded.)
+ task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'),
+ 'root', timeout=TIMEOUT)
+ tasks.add(task1)
+ task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080,
+ path='/home'),
+ 'home', timeout=TIMEOUT)
+ tasks.add(task2)
+
+ # Fetch python.org home page.
+ task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'),
+ 'python', timeout=TIMEOUT)
+ tasks.add(task3)
+
+ # Fetch XKCD home page using SSL. (Doesn't like IPv6.)
+ task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/',
+ af=socket.AF_INET),
+ 'xkcd', timeout=TIMEOUT)
+ tasks.add(task4)
+
+## # Fetch many links from python.org (/x.y.z).
+## for x in '123':
+## for y in '0123456789':
+## path = '/{}.{}'.format(x, y)
+## g = http_client.urlfetch('82.94.164.162', 80,
+## path=path, hdrs={'host': 'python.org'})
+## t = scheduling.Task(g, path, timeout=2)
+## tasks.add(t)
+
+## print(tasks)
+ yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait()
+ winners = yield from scheduling.wait_any(tasks)
+ print('And the winners are:', [w.name for w in winners])
+ tasks = yield from scheduling.wait_all(tasks)
+ print('And the players were:', [t.name for t in tasks])
+ return tasks
+
+
+def logtimes(real):
+ utime, stime, cutime, cstime, unused = os.times()
+ logging.info('real %10.3f', real)
+ logging.info('user %10.3f', utime + cutime)
+ logging.info('sys %10.3f', stime + cstime)
+
+
+def main():
+ t0 = time.time()
+
+ # Initialize logging.
+ if '-d' in sys.argv:
+ level = logging.DEBUG
+ elif '-v' in sys.argv:
+ level = logging.INFO
+ elif '-q' in sys.argv:
+ level = logging.ERROR
+ else:
+ level = logging.WARN
+ logging.basicConfig(level=level)
+
+ # Run scheduler, starting it off with doit().
+ task = scheduling.run(doit())
+ if task.exception:
+ print('Exception:', repr(task.exception))
+ if isinstance(task.exception, AssertionError):
+ raise task.exception
+ else:
+ for t in task.result:
+ print(t.name + ':',
+ repr(t.exception) if t.exception else t.result)
+
+ # Report real, user, sys times.
+ t1 = time.time()
+ logtimes(t1-t0)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/old/p3time.py b/old/p3time.py
new file mode 100644
index 0000000..35e14c9
--- /dev/null
+++ b/old/p3time.py
@@ -0,0 +1,47 @@
+"""Compare timing of plain vs. yield-from calls."""
+
+import gc
+import time
+
+def plain(n):
+ if n <= 0:
+ return 1
+ l = plain(n-1)
+ r = plain(n-1)
+ return l + 1 + r
+
+def coroutine(n):
+ if n <= 0:
+ return 1
+ l = yield from coroutine(n-1)
+ r = yield from coroutine(n-1)
+ return l + 1 + r
+
+def submain(depth):
+ t0 = time.time()
+ k = plain(depth)
+ t1 = time.time()
+ fmt = ' {} {} {:-9,.5f}'
+ delta0 = t1-t0
+ print(('plain' + fmt).format(depth, k, delta0))
+
+ t0 = time.time()
+ try:
+ g = coroutine(depth)
+ while True:
+ next(g)
+ except StopIteration as err:
+ k = err.value
+ t1 = time.time()
+ delta1 = t1-t0
+ print(('coro.' + fmt).format(depth, k, delta1))
+ if delta0:
+ print(('relat' + fmt).format(depth, k, delta1/delta0))
+
+def main(reasonable=16):
+ gc.disable()
+ for depth in range(reasonable):
+ submain(depth)
+
+if __name__ == '__main__':
+ main()
diff --git a/old/polling.py b/old/polling.py
new file mode 100644
index 0000000..6586efc
--- /dev/null
+++ b/old/polling.py
@@ -0,0 +1,535 @@
+"""Event loop and related classes.
+
+The event loop can be broken up into a pollster (the part responsible
+for telling us when file descriptors are ready) and the event loop
+proper, which wraps a pollster with functionality for scheduling
+callbacks, immediately or at a given time in the future.
+
+Whenever a public API takes a callback, subsequent positional
+arguments will be passed to the callback if/when it is called. This
+avoids the proliferation of trivial lambdas implementing closures.
+Keyword arguments for the callback are not supported; this is a
+conscious design decision, leaving the door open for keyword arguments
+to modify the meaning of the API call itself.
+
+There are several implementations of the pollster part, several using
+esoteric system calls that exist only on some platforms. These are:
+
+- kqueue (most BSD systems)
+- epoll (newer Linux systems)
+- poll (most UNIX systems)
+- select (all UNIX systems, and Windows)
+- TODO: Support IOCP on Windows and some UNIX platforms.
+
+NOTE: We don't use select on systems where any of the others is
+available, because select performs poorly as the number of file
+descriptors goes up. The ranking is roughly:
+
+ 1. kqueue, epoll, IOCP
+ 2. poll
+ 3. select
+
+TODO:
+- Optimize the various pollsters.
+- Unittests.
+"""
+
+import collections
+import concurrent.futures
+import heapq
+import logging
+import os
+import select
+import threading
+import time
+
+
+class PollsterBase:
+ """Base class for all polling implementations.
+
+ This defines an interface to register and unregister readers and
+ writers for specific file descriptors, and an interface to get a
+ list of events. There's also an interface to check whether any
+ readers or writers are currently registered.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.readers = {} # {fd: token, ...}.
+ self.writers = {} # {fd: token, ...}.
+
+ def pollable(self):
+ """Return True if any readers or writers are currently registered."""
+ return bool(self.readers or self.writers)
+
+ # Subclasses are expected to extend the add/remove methods.
+
+ def register_reader(self, fd, token):
+ """Add or update a reader for a file descriptor."""
+ self.readers[fd] = token
+
+ def register_writer(self, fd, token):
+ """Add or update a writer for a file descriptor."""
+ self.writers[fd] = token
+
+ def unregister_reader(self, fd):
+ """Remove the reader for a file descriptor."""
+ del self.readers[fd]
+
+ def unregister_writer(self, fd):
+ """Remove the writer for a file descriptor."""
+ del self.writers[fd]
+
+ def poll(self, timeout=None):
+ """Poll for events. A subclass must implement this.
+
+ If timeout is omitted or None, this blocks until at least one
+ event is ready. Otherwise, timeout gives a maximum time to
+ wait (in seconds as an int or float) -- the method returns as
+ soon as at least one event is ready or when the timeout is
+ expired. For a non-blocking poll, pass 0.
+
+ The return value is a list of events; it is empty when the
+ timeout expired before any events were ready. Each event
+ is a token previously passed to register_reader/writer().
+ """
+ raise NotImplementedError
+
+
+class SelectPollster(PollsterBase):
+ """Pollster implementation using select."""
+
+ def poll(self, timeout=None):
+ readable, writable, _ = select.select(self.readers, self.writers,
+ [], timeout)
+ events = []
+ events += (self.readers[fd] for fd in readable)
+ events += (self.writers[fd] for fd in writable)
+ return events
+
+
+class PollPollster(PollsterBase):
+ """Pollster implementation using poll."""
+
+ def __init__(self):
+ super().__init__()
+ self._poll = select.poll()
+
+ def _update(self, fd):
+ assert isinstance(fd, int), fd
+ flags = 0
+ if fd in self.readers:
+ flags |= select.POLLIN
+ if fd in self.writers:
+ flags |= select.POLLOUT
+ if flags:
+ self._poll.register(fd, flags)
+ else:
+ self._poll.unregister(fd)
+
+ def register_reader(self, fd, callback, *args):
+ super().register_reader(fd, callback, *args)
+ self._update(fd)
+
+ def register_writer(self, fd, callback, *args):
+ super().register_writer(fd, callback, *args)
+ self._update(fd)
+
+ def unregister_reader(self, fd):
+ super().unregister_reader(fd)
+ self._update(fd)
+
+ def unregister_writer(self, fd):
+ super().unregister_writer(fd)
+ self._update(fd)
+
+ def poll(self, timeout=None):
+ # Timeout is in seconds, but poll() takes milliseconds.
+ msecs = None if timeout is None else int(round(1000 * timeout))
+ events = []
+ for fd, flags in self._poll.poll(msecs):
+ if flags & (select.POLLIN | select.POLLHUP):
+ if fd in self.readers:
+ events.append(self.readers[fd])
+ if flags & (select.POLLOUT | select.POLLHUP):
+ if fd in self.writers:
+ events.append(self.writers[fd])
+ return events
+
+
+class EPollPollster(PollsterBase):
+ """Pollster implementation using epoll."""
+
+ def __init__(self):
+ super().__init__()
+ self._epoll = select.epoll()
+
+ def _update(self, fd):
+ assert isinstance(fd, int), fd
+ eventmask = 0
+ if fd in self.readers:
+ eventmask |= select.EPOLLIN
+ if fd in self.writers:
+ eventmask |= select.EPOLLOUT
+ if eventmask:
+ try:
+ self._epoll.register(fd, eventmask)
+ except IOError:
+ self._epoll.modify(fd, eventmask)
+ else:
+ self._epoll.unregister(fd)
+
+ def register_reader(self, fd, callback, *args):
+ super().register_reader(fd, callback, *args)
+ self._update(fd)
+
+ def register_writer(self, fd, callback, *args):
+ super().register_writer(fd, callback, *args)
+ self._update(fd)
+
+ def unregister_reader(self, fd):
+ super().unregister_reader(fd)
+ self._update(fd)
+
+ def unregister_writer(self, fd):
+ super().unregister_writer(fd)
+ self._update(fd)
+
+ def poll(self, timeout=None):
+ if timeout is None:
+ timeout = -1 # epoll.poll() uses -1 to mean "wait forever".
+ events = []
+ for fd, eventmask in self._epoll.poll(timeout):
+ if eventmask & select.EPOLLIN:
+ if fd in self.readers:
+ events.append(self.readers[fd])
+ if eventmask & select.EPOLLOUT:
+ if fd in self.writers:
+ events.append(self.writers[fd])
+ return events
+
+
+class KqueuePollster(PollsterBase):
+ """Pollster implementation using kqueue."""
+
+ def __init__(self):
+ super().__init__()
+ self._kqueue = select.kqueue()
+
+ def register_reader(self, fd, callback, *args):
+ if fd not in self.readers:
+ kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD)
+ self._kqueue.control([kev], 0, 0)
+ return super().register_reader(fd, callback, *args)
+
+ def register_writer(self, fd, callback, *args):
+ if fd not in self.writers:
+ kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD)
+ self._kqueue.control([kev], 0, 0)
+ return super().register_writer(fd, callback, *args)
+
+ def unregister_reader(self, fd):
+ super().unregister_reader(fd)
+ kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE)
+ self._kqueue.control([kev], 0, 0)
+
+ def unregister_writer(self, fd):
+ super().unregister_writer(fd)
+ kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE)
+ self._kqueue.control([kev], 0, 0)
+
+ def poll(self, timeout=None):
+ events = []
+ max_ev = len(self.readers) + len(self.writers)
+ for kev in self._kqueue.control(None, max_ev, timeout):
+ fd = kev.ident
+ flag = kev.filter
+ if flag == select.KQ_FILTER_READ and fd in self.readers:
+ events.append(self.readers[fd])
+ elif flag == select.KQ_FILTER_WRITE and fd in self.writers:
+ events.append(self.writers[fd])
+ return events
+
+
+# Pick the best pollster class for the platform.
+if hasattr(select, 'kqueue'):
+ best_pollster = KqueuePollster
+elif hasattr(select, 'epoll'):
+ best_pollster = EPollPollster
+elif hasattr(select, 'poll'):
+ best_pollster = PollPollster
+else:
+ best_pollster = SelectPollster
+
+
+class DelayedCall:
+ """Object returned by callback registration methods."""
+
+ def __init__(self, when, callback, args, kwds=None):
+ self.when = when
+ self.callback = callback
+ self.args = args
+ self.kwds = kwds
+ self.cancelled = False
+
+ def cancel(self):
+ self.cancelled = True
+
+ def __lt__(self, other):
+ return self.when < other.when
+
+ def __le__(self, other):
+ return self.when <= other.when
+
+ def __eq__(self, other):
+ return self.when == other.when
+
+
+class EventLoop:
+ """Event loop functionality.
+
+ This defines public APIs call_soon(), call_later(), run_once() and
+ run(). It also wraps Pollster APIs register_reader(),
+ register_writer(), remove_reader(), remove_writer() with
+ add_reader() etc.
+
+ This class's instance variables are not part of its API.
+ """
+
+ def __init__(self, pollster=None):
+ super().__init__()
+ if pollster is None:
+ logging.info('Using pollster: %s', best_pollster.__name__)
+ pollster = best_pollster()
+ self.pollster = pollster
+ self.ready = collections.deque() # [(callback, args), ...]
+ self.scheduled = [] # [(when, callback, args), ...]
+
+ def add_reader(self, fd, callback, *args):
+ """Add a reader callback. Return a DelayedCall instance."""
+ dcall = DelayedCall(None, callback, args)
+ self.pollster.register_reader(fd, dcall)
+ return dcall
+
+ def remove_reader(self, fd):
+ """Remove a reader callback."""
+ self.pollster.unregister_reader(fd)
+
+ def add_writer(self, fd, callback, *args):
+ """Add a writer callback. Return a DelayedCall instance."""
+ dcall = DelayedCall(None, callback, args)
+ self.pollster.register_writer(fd, dcall)
+ return dcall
+
+ def remove_writer(self, fd):
+ """Remove a writer callback."""
+ self.pollster.unregister_writer(fd)
+
+ def add_callback(self, dcall):
+ """Add a DelayedCall to ready or scheduled."""
+ if dcall.cancelled:
+ return
+ if dcall.when is None:
+ self.ready.append(dcall)
+ else:
+ heapq.heappush(self.scheduled, dcall)
+
+ def call_soon(self, callback, *args):
+ """Arrange for a callback to be called as soon as possible.
+
+ This operates as a FIFO queue, callbacks are called in the
+ order in which they are registered. Each callback will be
+ called exactly once.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ dcall = DelayedCall(None, callback, args)
+ self.ready.append(dcall)
+ return dcall
+
+ def call_later(self, when, callback, *args):
+ """Arrange for a callback to be called at a given time.
+
+ Return an object with a cancel() method that can be used to
+ cancel the call.
+
+ The time can be an int or float, expressed in seconds.
+
+ If when is small enough (~11 days), it's assumed to be a
+ relative time, meaning the call will be scheduled that many
+ seconds in the future; otherwise it's assumed to be a posix
+ timestamp as returned by time.time().
+
+ Each callback will be called exactly once. If two callbacks
+ are scheduled for exactly the same time, it undefined which
+ will be called first.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ if when < 10000000:
+ when += time.time()
+ dcall = DelayedCall(when, callback, args)
+ heapq.heappush(self.scheduled, dcall)
+ return dcall
+
+ def run_once(self):
+ """Run one full iteration of the event loop.
+
+ This calls all currently ready callbacks, polls for I/O,
+ schedules the resulting callbacks, and finally schedules
+ 'call_later' callbacks.
+ """
+ # TODO: Break each of these into smaller pieces.
+ # TODO: Pass in a timeout or deadline or something.
+ # TODO: Refactor to separate the callbacks from the readers/writers.
+ # TODO: As step 4, run everything scheduled by steps 1-3.
+ # TODO: An alternative API would be to do the *minimal* amount
+ # of work, e.g. one callback or one I/O poll.
+
+ # This is the only place where callbacks are actually *called*.
+ # All other places just add them to ready.
+ # TODO: Ensure this loop always finishes, even if some
+ # callbacks keeps registering more callbacks.
+ while self.ready:
+ dcall = self.ready.popleft()
+ if not dcall.cancelled:
+ try:
+ if dcall.kwds:
+ dcall.callback(*dcall.args, **dcall.kwds)
+ else:
+ dcall.callback(*dcall.args)
+ except Exception:
+ logging.exception('Exception in callback %s %r',
+ dcall.callback, dcall.args)
+
+ # Remove delayed calls that were cancelled from head of queue.
+ while self.scheduled and self.scheduled[0].cancelled:
+ heapq.heappop(self.scheduled)
+
+ # Inspect the poll queue.
+ if self.pollster.pollable():
+ if self.scheduled:
+ when = self.scheduled[0].when
+ timeout = max(0, when - time.time())
+ else:
+ timeout = None
+ t0 = time.time()
+ events = self.pollster.poll(timeout)
+ t1 = time.time()
+ argstr = '' if timeout is None else ' %.3f' % timeout
+ if t1-t0 >= 1:
+ level = logging.INFO
+ else:
+ level = logging.DEBUG
+ logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0)
+ for dcall in events:
+ self.add_callback(dcall)
+
+ # Handle 'later' callbacks that are ready.
+ now = time.time()
+ while self.scheduled:
+ dcall = self.scheduled[0]
+ if dcall.when > now:
+ break
+ dcall = heapq.heappop(self.scheduled)
+ self.call_soon(dcall.callback, *dcall.args)
+
+ def run(self):
+ """Run the event loop until there is no work left to do.
+
+ This keeps going as long as there are either readable and
+ writable file descriptors, or scheduled callbacks (of either
+ variety).
+ """
+ while self.ready or self.scheduled or self.pollster.pollable():
+ self.run_once()
+
+
+MAX_WORKERS = 5 # Default max workers when creating an executor.
+
+
+class ThreadRunner:
+ """Helper to submit work to a thread pool and wait for it.
+
+ This is the glue between the single-threaded callback-based async
+ world and the threaded world. Use it to call functions that must
+ block and don't have an async alternative (e.g. getaddrinfo()).
+
+ The only public API is submit().
+ """
+
+ def __init__(self, eventloop, executor=None):
+ self.eventloop = eventloop
+ self.executor = executor # Will be constructed lazily.
+ self.pipe_read_fd, self.pipe_write_fd = os.pipe()
+ self.active_count = 0
+
+ def read_callback(self):
+ """Semi-permanent callback while at least one future is active."""
+ assert self.active_count > 0, self.active_count
+ data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size.
+ self.active_count -= len(data)
+ if self.active_count == 0:
+ self.eventloop.remove_reader(self.pipe_read_fd)
+ assert self.active_count >= 0, self.active_count
+
+ def submit(self, func, *args, executor=None, callback=None):
+ """Submit a function to the thread pool.
+
+ This returns a concurrent.futures.Future instance. The caller
+ should not wait for that, but rather use the callback argument..
+ """
+ if executor is None:
+ executor = self.executor
+ if executor is None:
+ # Lazily construct a default executor.
+ # TODO: Should this be shared between threads?
+ executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS)
+ self.executor = executor
+ assert self.active_count >= 0, self.active_count
+ future = executor.submit(func, *args)
+ if self.active_count == 0:
+ self.eventloop.add_reader(self.pipe_read_fd, self.read_callback)
+ self.active_count += 1
+ def done_callback(fut):
+ if callback is not None:
+ self.eventloop.call_soon(callback, fut)
+ # TODO: Wake up the pipe in call_soon()?
+ os.write(self.pipe_write_fd, b'x')
+ future.add_done_callback(done_callback)
+ return future
+
+
+class Context(threading.local):
+ """Thread-local context.
+
+ We use this to avoid having to explicitly pass around an event loop
+ or something to hold the current task.
+
+ TODO: Add an API so frameworks can substitute a different notion
+ of context more easily.
+ """
+
+ def __init__(self, eventloop=None, threadrunner=None):
+ # Default event loop and thread runner are lazily constructed
+ # when first accessed.
+ self._eventloop = eventloop
+ self._threadrunner = threadrunner
+ self.current_task = None # For the benefit of scheduling.py.
+
+ @property
+ def eventloop(self):
+ if self._eventloop is None:
+ self._eventloop = EventLoop()
+ return self._eventloop
+
+ @property
+ def threadrunner(self):
+ if self._threadrunner is None:
+ self._threadrunner = ThreadRunner(self.eventloop)
+ return self._threadrunner
+
+
+context = Context() # Thread-local!
diff --git a/old/scheduling.py b/old/scheduling.py
new file mode 100644
index 0000000..3864571
--- /dev/null
+++ b/old/scheduling.py
@@ -0,0 +1,354 @@
+#!/usr/bin/env python3.3
+"""Example coroutine scheduler, PEP-380-style ('yield from <generator>').
+
+Requires Python 3.3.
+
+There are likely micro-optimizations possible here, but that's not the point.
+
+TODO:
+- Docstrings.
+- Unittests.
+
+PATTERNS TO TRY:
+- Various synchronization primitives (Lock, RLock, Event, Condition,
+ Semaphore, BoundedSemaphore, Barrier).
+"""
+
+__author__ = 'Guido van Rossum <guido@python.org>'
+
+# Standard library imports (keep in alphabetic order).
+from concurrent.futures import CancelledError, TimeoutError
+import logging
+import time
+import types
+
+# Local imports (keep in alphabetic order).
+import polling
+
+
+context = polling.context
+
+
+class Task:
+ """Wrapper around a stack of generators.
+
+ This is a bit like a Future, but with a different interface.
+
+ TODO:
+ - wait for result.
+ """
+
+ def __init__(self, gen, name=None, *, timeout=None):
+ assert isinstance(gen, types.GeneratorType), repr(gen)
+ self.gen = gen
+ self.name = name or gen.__name__
+ self.timeout = timeout
+ self.eventloop = context.eventloop
+ self.canceleer = None
+ if timeout is not None:
+ self.canceleer = self.eventloop.call_later(timeout, self.cancel)
+ self.blocked = False
+ self.unblocker = None
+ self.cancelled = False
+ self.must_cancel = False
+ self.alive = True
+ self.result = None
+ self.exception = None
+ self.done_callbacks = []
+ # Start the task immediately.
+ self.eventloop.call_soon(self.step)
+
+ def add_done_callback(self, done_callback):
+ # For better or for worse, the callback will always be called
+ # with the task as an argument, like concurrent.futures.Future.
+ # TODO: Call it right away if task is no longer alive.
+ dcall = polling.DelayedCall(None, done_callback, (self,))
+ self.done_callbacks.append(dcall)
+ self.done_callbacks = [dc for dc in self.done_callbacks
+ if not dc.cancelled]
+ return dcall
+
+ def __repr__(self):
+ parts = [self.name]
+ is_current = (self is context.current_task)
+ if self.blocked:
+ parts.append('blocking' if is_current else 'blocked')
+ elif self.alive:
+ parts.append('running' if is_current else 'runnable')
+ if self.must_cancel:
+ parts.append('must_cancel')
+ if self.cancelled:
+ parts.append('cancelled')
+ if self.exception is not None:
+ parts.append('exception=%r' % self.exception)
+ elif not self.alive:
+ parts.append('result=%r' % (self.result,))
+ if self.timeout is not None:
+ parts.append('timeout=%.3f' % self.timeout)
+ return 'Task<' + ', '.join(parts) + '>'
+
+ def cancel(self):
+ if self.alive:
+ if not self.must_cancel and not self.cancelled:
+ self.must_cancel = True
+ if self.blocked:
+ self.unblock()
+
+ def step(self):
+ assert self.alive, self
+ try:
+ context.current_task = self
+ if self.must_cancel:
+ self.must_cancel = False
+ self.cancelled = True
+ self.gen.throw(CancelledError())
+ else:
+ next(self.gen)
+ except StopIteration as exc:
+ self.alive = False
+ self.result = exc.value
+ except Exception as exc:
+ self.alive = False
+ self.exception = exc
+ logging.debug('Uncaught exception in %s', self,
+ exc_info=True, stack_info=True)
+ except BaseException as exc:
+ self.alive = False
+ self.exception = exc
+ raise
+ else:
+ if not self.blocked:
+ self.eventloop.call_soon(self.step)
+ finally:
+ context.current_task = None
+ if not self.alive:
+ # Cancel timeout callback if set.
+ if self.canceleer is not None:
+ self.canceleer.cancel()
+ # Schedule done_callbacks.
+ for dcall in self.done_callbacks:
+ self.eventloop.add_callback(dcall)
+
+ def block(self, unblock_callback=None, *unblock_args):
+ assert self is context.current_task, self
+ assert self.alive, self
+ assert not self.blocked, self
+ self.blocked = True
+ self.unblocker = (unblock_callback, unblock_args)
+
+ def unblock_if_alive(self, unused=None):
+ # Ignore optional argument so we can be a Future's done_callback.
+ if self.alive:
+ self.unblock()
+
+ def unblock(self, unused=None):
+ # Ignore optional argument so we can be a Future's done_callback.
+ assert self.alive, self
+ assert self.blocked, self
+ self.blocked = False
+ unblock_callback, unblock_args = self.unblocker
+ if unblock_callback is not None:
+ try:
+ unblock_callback(*unblock_args)
+ except Exception:
+ logging.error('Exception in unblocker in task %r', self.name)
+ raise
+ finally:
+ self.unblocker = None
+ self.eventloop.call_soon(self.step)
+
+ def block_io(self, fd, flag):
+ assert isinstance(fd, int), repr(fd)
+ assert flag in ('r', 'w'), repr(flag)
+ if flag == 'r':
+ self.block(self.eventloop.remove_reader, fd)
+ self.eventloop.add_reader(fd, self.unblock)
+ else:
+ self.block(self.eventloop.remove_writer, fd)
+ self.eventloop.add_writer(fd, self.unblock)
+
+ def wait(self):
+ """COROUTINE: Wait until this task is finished."""
+ current_task = context.current_task
+ assert self is not current_task, (self, current_task) # How confusing!
+ if not self.alive:
+ return
+ current_task.block()
+ self.add_done_callback(current_task.unblock)
+ yield
+
+ def __iter__(self):
+ """COROUTINE: Wait, then return result or raise exception.
+
+ This adds a little magic so you can say
+
+ x = yield from Task(gen())
+
+ and it is equivalent to
+
+ x = yield from gen()
+
+ but with the option to add a timeout (and only a tad slower).
+ """
+ if self.alive:
+ yield from self.wait()
+ assert not self.alive
+ if self.exception is not None:
+ raise self.exception
+ return self.result
+
+
+def run(arg=None):
+ """Run the event loop until it's out of work.
+
+ If you pass a generator, it will be spawned for you.
+ You can also pass a task (already started).
+ Returns the task.
+ """
+ t = None
+ if arg is not None:
+ if isinstance(arg, Task):
+ t = arg
+ else:
+ t = Task(arg)
+ context.eventloop.run()
+ if t is not None and t.exception is not None:
+ logging.error('Uncaught exception in startup task: %r',
+ t.exception)
+ return t
+
+
+def sleep(secs):
+ """COROUTINE: Sleep for some time (a float in seconds)."""
+ current_task = context.current_task
+ unblocker = context.eventloop.call_later(secs, current_task.unblock)
+ current_task.block(unblocker.cancel)
+ yield
+
+
+def block_r(fd):
+ """COROUTINE: Block until a file descriptor is ready for reading."""
+ context.current_task.block_io(fd, 'r')
+ yield
+
+
+def block_w(fd):
+ """COROUTINE: Block until a file descriptor is ready for writing."""
+ context.current_task.block_io(fd, 'w')
+ yield
+
+
+def call_in_thread(func, *args, executor=None):
+ """COROUTINE: Run a function in a thread."""
+ task = context.current_task
+ eventloop = context.eventloop
+ future = context.threadrunner.submit(func, *args,
+ executor=executor,
+ callback=task.unblock_if_alive)
+ task.block(future.cancel)
+ yield
+ assert future.done()
+ return future.result()
+
+
+def wait_for(count, tasks):
+ """COROUTINE: Wait for the first N of a set of tasks to complete.
+
+ May return more than N if more than N are immediately ready.
+
+ NOTE: Tasks that were cancelled or raised are also considered ready.
+ """
+ assert tasks
+ assert all(isinstance(task, Task) for task in tasks)
+ tasks = set(tasks)
+ assert 1 <= count <= len(tasks)
+ current_task = context.current_task
+ assert all(task is not current_task for task in tasks)
+ todo = set()
+ done = set()
+ dcalls = []
+ def wait_for_callback(task):
+ nonlocal todo, done, current_task, count, dcalls
+ todo.remove(task)
+ if len(done) < count:
+ done.add(task)
+ if len(done) == count:
+ for dcall in dcalls:
+ dcall.cancel()
+ current_task.unblock()
+ for task in tasks:
+ if task.alive:
+ todo.add(task)
+ else:
+ done.add(task)
+ if len(done) < count:
+ for task in todo:
+ dcall = task.add_done_callback(wait_for_callback)
+ dcalls.append(dcall)
+ current_task.block()
+ yield
+ return done
+
+
+def wait_any(tasks):
+ """COROUTINE: Wait for the first of a set of tasks to complete."""
+ return wait_for(1, tasks)
+
+
+def wait_all(tasks):
+ """COROUTINE: Wait for all of a set of tasks to complete."""
+ return wait_for(len(tasks), tasks)
+
+
+def map_over(gen, *args, timeout=None):
+ """COROUTINE: map a generator over one or more iterables.
+
+ E.g. map_over(foo, xs, ys) runs
+
+ Task(foo(x, y)) for x, y in zip(xs, ys)
+
+ and returns a list of all results (in that order). However if any
+ task raises an exception, the remaining tasks are cancelled and
+ the exception is propagated.
+ """
+ # gen is a generator function.
+ tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)]
+ return (yield from par_tasks(tasks))
+
+
+def par(*args):
+ """COROUTINE: Wait for generators, return a list of results.
+
+ Raises as soon as one of the tasks raises an exception (and then
+ remaining tasks are cancelled).
+
+ This differs from par_tasks() in two ways:
+ - takes *args instead of list of args
+ - each arg may be a generator or a task
+ """
+ tasks = []
+ for arg in args:
+ if not isinstance(arg, Task):
+ # TODO: assert arg is a generator or an iterator?
+ arg = Task(arg)
+ tasks.append(arg)
+ return (yield from par_tasks(tasks))
+
+
+def par_tasks(tasks):
+ """COROUTINE: Wait for a list of tasks, return a list of results.
+
+ Raises as soon as one of the tasks raises an exception (and then
+ remaining tasks are cancelled).
+ """
+ todo = set(tasks)
+ while todo:
+ ts = yield from wait_any(todo)
+ for t in ts:
+ assert not t.alive, t
+ todo.remove(t)
+ if t.exception is not None:
+ for other in todo:
+ other.cancel()
+ raise t.exception
+ return [t.result for t in tasks]
diff --git a/old/sockets.py b/old/sockets.py
new file mode 100644
index 0000000..a5005dc
--- /dev/null
+++ b/old/sockets.py
@@ -0,0 +1,348 @@
+"""Socket wrappers to go with scheduling.py.
+
+Classes:
+
+- SocketTransport: a transport implementation wrapping a socket.
+- SslTransport: a transport implementation wrapping SSL around a socket.
+- BufferedReader: a buffer wrapping the read end of a transport.
+
+Functions (all coroutines):
+
+- connect(): connect a socket.
+- getaddrinfo(): look up an address.
+- create_connection(): look up address and return a connected socket for it.
+- create_transport(): look up address and return a connected transport.
+
+TODO:
+- Improve transport abstraction.
+- Make a nice protocol abstraction.
+- Unittests.
+- A write() call that isn't a generator (needed so you can substitute it
+ for sys.stderr, pass it to logging.StreamHandler, etc.).
+"""
+
+__author__ = 'Guido van Rossum <guido@python.org>'
+
+# Stdlib imports.
+import errno
+import socket
+import ssl
+
+# Local imports.
+import scheduling
+
+# Errno values indicating the connection was disconnected.
+_DISCONNECTED = frozenset((errno.ECONNRESET,
+ errno.ENOTCONN,
+ errno.ESHUTDOWN,
+ errno.ECONNABORTED,
+ errno.EPIPE,
+ errno.EBADF,
+ ))
+
+# Errno values indicating the socket isn't ready for I/O just yet.
+_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK))
+
+
+class SocketTransport:
+ """Transport wrapping a socket.
+
+ The socket must already be connected in non-blocking mode.
+ """
+
+ def __init__(self, sock):
+ self.sock = sock
+
+ def recv(self, n):
+ """COROUTINE: Read up to n bytes, blocking as needed.
+
+ Always returns at least one byte, except if the socket was
+ closed or disconnected and there's no more data; then it
+ returns b''.
+ """
+ assert n >= 0, n
+ while True:
+ try:
+ return self.sock.recv(n)
+ except socket.error as err:
+ if err.errno in _TRYAGAIN:
+ pass
+ elif err.errno in _DISCONNECTED:
+ return b''
+ else:
+ raise # Unexpected, propagate.
+ yield from scheduling.block_r(self.sock.fileno())
+
+ def send(self, data):
+ """COROUTINE; Send data to the socket, blocking until all written.
+
+ Return True if all went well, False if socket was disconnected.
+ """
+ while data:
+ try:
+ n = self.sock.send(data)
+ except socket.error as err:
+ if err.errno in _TRYAGAIN:
+ pass
+ elif err.errno in _DISCONNECTED:
+ return False
+ else:
+ raise # Unexpected, propagate.
+ else:
+ assert 0 <= n <= len(data), (n, len(data))
+ if n == len(data):
+ break
+ data = data[n:]
+ continue
+ yield from scheduling.block_w(self.sock.fileno())
+
+ return True
+
+ def close(self):
+ """Close the socket. (Not a coroutine.)"""
+ self.sock.close()
+
+
+class SslTransport:
+ """Transport wrapping a socket in SSL.
+
+ The socket must already be connected at the TCP level in
+ non-blocking mode.
+ """
+
+ def __init__(self, rawsock, sslcontext=None):
+ self.rawsock = rawsock
+ self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ self.sslsock = self.sslcontext.wrap_socket(
+ self.rawsock, do_handshake_on_connect=False)
+
+ def do_handshake(self):
+ """COROUTINE: Finish the SSL handshake."""
+ while True:
+ try:
+ self.sslsock.do_handshake()
+ except ssl.SSLWantReadError:
+ yield from scheduling.block_r(self.sslsock.fileno())
+ except ssl.SSLWantWriteError:
+ yield from scheduling.block_w(self.sslsock.fileno())
+ else:
+ break
+
+ def recv(self, n):
+ """COROUTINE: Read up to n bytes.
+
+ This blocks until at least one byte is read, or until EOF.
+ """
+ while True:
+ try:
+ return self.sslsock.recv(n)
+ except ssl.SSLWantReadError:
+ yield from scheduling.block_r(self.sslsock.fileno())
+ except ssl.SSLWantWriteError:
+ yield from scheduling.block_w(self.sslsock.fileno())
+ except socket.error as err:
+ if err.errno in _TRYAGAIN:
+ yield from scheduling.block_r(self.sslsock.fileno())
+ elif err.errno in _DISCONNECTED:
+ # Can this happen?
+ return b''
+ else:
+ raise # Unexpected, propagate.
+
+ def send(self, data):
+ """COROUTINE: Send data to the socket, blocking as needed."""
+ while data:
+ try:
+ n = self.sslsock.send(data)
+ except ssl.SSLWantReadError:
+ yield from scheduling.block_r(self.sslsock.fileno())
+ except ssl.SSLWantWriteError:
+ yield from scheduling.block_w(self.sslsock.fileno())
+ except socket.error as err:
+ if err.errno in _TRYAGAIN:
+ yield from scheduling.block_w(self.sslsock.fileno())
+ elif err.errno in _DISCONNECTED:
+ return False
+ else:
+ raise # Unexpected, propagate.
+ if n == len(data):
+ break
+ data = data[n:]
+
+ return True
+
+ def close(self):
+ """Close the socket. (Not a coroutine.)
+
+ This also closes the raw socket.
+ """
+ self.sslsock.close()
+
+ # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ...
+
+
+class BufferedReader:
+ """A buffered reader wrapping a transport."""
+
+ def __init__(self, trans, limit=8192):
+ self.trans = trans
+ self.limit = limit
+ self.buffer = b''
+ self.eof = False
+
+ def read(self, n):
+ """COROUTINE: Read up to n bytes, blocking at most once."""
+ assert n >= 0, n
+ if not self.buffer and not self.eof:
+ yield from self._fillbuffer(max(n, self.limit))
+ return self._getfrombuffer(n)
+
+ def readexactly(self, n):
+ """COUROUTINE: Read exactly n bytes, or until EOF."""
+ blocks = []
+ count = 0
+ while count < n:
+ block = yield from self.read(n - count)
+ if not block:
+ break
+ blocks.append(block)
+ count += len(block)
+ return b''.join(blocks)
+
+ def readline(self):
+ """COROUTINE: Read up to newline or limit, whichever comes first."""
+ end = self.buffer.find(b'\n') + 1 # Point past newline, or 0.
+ while not end and not self.eof and len(self.buffer) < self.limit:
+ anchor = len(self.buffer)
+ yield from self._fillbuffer(self.limit)
+ end = self.buffer.find(b'\n', anchor) + 1
+ if not end:
+ end = len(self.buffer)
+ if end > self.limit:
+ end = self.limit
+ return self._getfrombuffer(end)
+
+ def _getfrombuffer(self, n):
+ """Read up to n bytes without blocking (not a coroutine)."""
+ if n >= len(self.buffer):
+ result, self.buffer = self.buffer, b''
+ else:
+ result, self.buffer = self.buffer[:n], self.buffer[n:]
+ return result
+
+ def _fillbuffer(self, n):
+ """COROUTINE: Fill buffer with one (up to) n bytes from transport."""
+ assert not self.eof, '_fillbuffer called at eof'
+ data = yield from self.trans.recv(n)
+ if data:
+ self.buffer += data
+ else:
+ self.eof = True
+
+
+def connect(sock, address):
+ """COROUTINE: Connect a socket to an address."""
+ try:
+ sock.connect(address)
+ except socket.error as err:
+ if err.errno != errno.EINPROGRESS:
+ raise
+ yield from scheduling.block_w(sock.fileno())
+ err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ raise IOError(err, 'Connection refused')
+
+
+def getaddrinfo(host, port, af=0, socktype=0, proto=0):
+ """COROUTINE: Look up an address and return a list of infos for it.
+
+ Each info is a tuple (af, socktype, protocol, canonname, address).
+ """
+ infos = yield from scheduling.call_in_thread(socket.getaddrinfo,
+ host, port, af,
+ socktype, proto)
+ return infos
+
+
+def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0):
+ """COROUTINE: Look up address and create a socket connected to it."""
+ infos = yield from getaddrinfo(host, port, af, socktype, proto)
+ if not infos:
+ raise IOError('getaddrinfo() returned an empty list')
+ exc = None
+ for af, socktype, proto, cname, address in infos:
+ sock = None
+ try:
+ sock = socket.socket(af, socktype, proto)
+ sock.setblocking(False)
+ yield from connect(sock, address)
+ break
+ except socket.error as err:
+ if sock is not None:
+ sock.close()
+ if exc is None:
+ exc = err
+ else:
+ raise exc
+ return sock
+
+
+def create_transport(host, port, af=0, ssl=None):
+ """COROUTINE: Look up address and create a transport connected to it."""
+ if ssl is None:
+ ssl = (port == 443)
+ sock = yield from create_connection(host, port, af)
+ if ssl:
+ trans = SslTransport(sock)
+ yield from trans.do_handshake()
+ else:
+ trans = SocketTransport(sock)
+ return trans
+
+
+class Listener:
+ """Wrapper for a listening socket."""
+
+ def __init__(self, sock):
+ self.sock = sock
+
+ def accept(self):
+ """COROUTINE: Accept a connection."""
+ while True:
+ try:
+ conn, addr = self.sock.accept()
+ except socket.error as err:
+ if err.errno in _TRYAGAIN:
+ yield from scheduling.block_r(self.sock.fileno())
+ else:
+ raise # Unexpected, propagate.
+ else:
+ conn.setblocking(False)
+ return conn, addr
+
+
+def create_listener(host, port, af=0, socktype=0, proto=0,
+ backlog=5, reuse_addr=True):
+ """COROUTINE: Look up address and create a listener for it."""
+ infos = yield from getaddrinfo(host, port, af, socktype, proto)
+ if not infos:
+ raise IOError('getaddrinfo() returned an empty list')
+ exc = None
+ for af, socktype, proto, cname, address in infos:
+ sock = None
+ try:
+ sock = socket.socket(af, socktype, proto)
+ sock.setblocking(False)
+ if reuse_addr:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(address)
+ sock.listen(backlog)
+ break
+ except socket.error as err:
+ if sock is not None:
+ sock.close()
+ if exc is None:
+ exc = err
+ else:
+ raise exc
+ return Listener(sock)
diff --git a/old/transports.py b/old/transports.py
new file mode 100644
index 0000000..19095bf
--- /dev/null
+++ b/old/transports.py
@@ -0,0 +1,496 @@
+"""Transports and Protocols, actually.
+
+Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep.
+
+THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT.
+"""
+
+# Stdlib imports.
+import collections
+import errno
+import logging
+import socket
+import ssl
+import sys
+import time
+
+# Local imports.
+import polling
+import scheduling
+import sockets
+
+# Errno values indicating the connection was disconnected.
+_DISCONNECTED = frozenset((errno.ECONNRESET,
+ errno.ENOTCONN,
+ errno.ESHUTDOWN,
+ errno.ECONNABORTED,
+ errno.EPIPE,
+ errno.EBADF,
+ ))
+
+# Errno values indicating the socket isn't ready for I/O just yet.
+_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK))
+
+
+class Transport:
+ """ABC representing a transport.
+
+ There may be many implementations. The user never instantiates
+ this directly; they call some utility function, passing it a
+ protocol, and the utility function will call the protocol's
+ connection_made() method with a transport (or it will call
+ connection_lost() with an exception if it fails to create the
+ desired transport).
+
+ The implementation here raises NotImplemented for every method
+ except writelines(), which calls write() in a loop.
+ """
+
+ def write(self, data):
+ """Write some data (bytes) to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ """
+ raise NotImplementedError
+
+ def writelines(self, list_of_data):
+ """Write a list (or any iterable) of data (bytes) to the transport.
+
+ The default implementation just calls write() for each item in
+ the list/iterable.
+ """
+ for data in list_of_data:
+ self.write(data)
+
+ def close(self):
+ """Closes the transport.
+
+ Buffered data will be flushed asynchronously. No more data will
+ be received. When all buffered data is flushed, the protocol's
+ connection_lost() method is called with None as its argument.
+ """
+ raise NotImplementedError
+
+ def abort(self):
+ """Closes the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method is called with None as
+ its argument.
+ """
+ raise NotImplementedError
+
+ def half_close(self):
+ """Closes the write end after flushing buffered data.
+
+ Data may still be received.
+
+ TODO: What's the use case for this? How to implement it?
+ Should it call shutdown(SHUT_WR) after all the data is flushed?
+ Is there no use case for closing the other half first?
+ """
+ raise NotImplementedError
+
+ def pause(self):
+ """Pause the receiving end.
+
+ No data will be received until resume() is called.
+ """
+ raise NotImplementedError
+
+ def resume(self):
+ """Resume the receiving end.
+
+ Cancels a pause() call, resumes receiving data.
+ """
+ raise NotImplementedError
+
+
+class Protocol:
+ """ABC representing a protocol.
+
+ The user should implement this interface. They can inherit from
+ this class but don't need to. The implementations here do
+ nothing.
+
+ When the user wants to requests a transport, they pass a protocol
+ instance to a utility function.
+
+ When the connection is made successfully, connection_made() is
+ called with a suitable transport object. Then data_received()
+ will be called 0 or more times with data (bytes) received from the
+ transport; finally, connection_list() will be called exactly once
+ with either an exception object or None as an argument.
+
+ If the utility function does not succeed in creating a transport,
+ it will call connection_lost() with an exception object.
+
+ State machine of calls:
+
+ start -> [CM -> DR*] -> CL -> end
+ """
+
+ def connection_made(self, transport):
+ """Called when a connection is made.
+
+ The argument is the transport representing the connection.
+ To send data, call its write() or writelines() method.
+ To receive data, wait for data_received() calls.
+ When the connection is closed, connection_lost() is called.
+ """
+
+ def data_received(self, data):
+ """Called when some data is received.
+
+ The argument is a bytes object.
+
+ TODO: Should we allow it to be a bytesarray or some other
+ memory buffer?
+ """
+
+ def connection_lost(self, exc):
+ """Called when the connection is lost or closed.
+
+ Also called when we fail to make a connection at all (in that
+ case connection_made() will not be called).
+
+ The argument is an exception object or None (the latter
+ meaning a regular EOF is received or the connection was
+ aborted or closed).
+ """
+
+
+# TODO: The rest is platform specific and should move elsewhere.
+
+class UnixSocketTransport(Transport):
+
+ def __init__(self, eventloop, protocol, sock):
+ self._eventloop = eventloop
+ self._protocol = protocol
+ self._sock = sock
+ self._buffer = collections.deque() # For write().
+ self._write_closed = False
+
+ def _on_readable(self):
+ try:
+ data = self._sock.recv(8192)
+ except socket.error as exc:
+ if exc.errno not in _TRYAGAIN:
+ self._bad_error(exc)
+ else:
+ if not data:
+ self._eventloop.remove_reader(self._sock.fileno())
+ self._sock.close()
+ self._protocol.connection_lost(None)
+ else:
+ self._protocol.data_received(data) # XXX call_soon()?
+
+ def write(self, data):
+ assert isinstance(data, bytes)
+ assert not self._write_closed
+ if not data:
+ # Silly, but it happens.
+ return
+ if self._buffer:
+ # We've already registered a callback, just buffer the data.
+ self._buffer.append(data)
+ # Consider pausing if the total length of the buffer is
+ # truly huge.
+ return
+
+ # TODO: Refactor so there's more sharing between this and
+ # _on_writable().
+
+ # There's no callback registered yet. It's quite possible
+ # that the kernel has buffer space for our data, so try to
+ # write now. Since the socket is non-blocking it will
+ # give us an error in _TRYAGAIN if it doesn't have enough
+ # space for even one more byte; it will return the number
+ # of bytes written if it can write at least one byte.
+ try:
+ n = self._sock.send(data)
+ except socket.error as exc:
+ # An error.
+ if exc.errno not in _TRYAGAIN:
+ self._bad_error(exc)
+ return
+ # The kernel doesn't have room for more data right now.
+ n = 0
+ else:
+ # Wrote at least one byte.
+ if n == len(data):
+ # Wrote it all. Done!
+ if self._write_closed:
+ self._sock.shutdown(socket.SHUT_WR)
+ return
+ # Throw away the data that was already written.
+ # TODO: Do this without copying the data?
+ data = data[n:]
+ self._buffer.append(data)
+ self._eventloop.add_writer(self._sock.fileno(), self._on_writable)
+
+ def _on_writable(self):
+ while self._buffer:
+ data = self._buffer[0]
+ # TODO: Join small amounts of data?
+ try:
+ n = self._sock.send(data)
+ except socket.error as exc:
+ # Error handling is the same as in write().
+ if exc.errno not in _TRYAGAIN:
+ self._bad_error(exc)
+ return
+ if n < len(data):
+ self._buffer[0] = data[n:]
+ return
+ self._buffer.popleft()
+ self._eventloop.remove_writer(self._sock.fileno())
+ if self._write_closed:
+ self._sock.shutdown(socket.SHUT_WR)
+
+ def abort(self):
+ self._bad_error(None)
+
+ def _bad_error(self, exc):
+ # A serious error. Close the socket etc.
+ fd = self._sock.fileno()
+ # TODO: Record whether we have a writer and/or reader registered.
+ try:
+ self._eventloop.remove_writer(fd)
+ except Exception:
+ pass
+ try:
+ self._eventloop.remove_reader(fd)
+ except Exception:
+ pass
+ self._sock.close()
+ self._protocol.connection_lost(exc) # XXX call_soon()?
+
+ def half_close(self):
+ self._write_closed = True
+
+
+class UnixSslTransport(Transport):
+
+ # TODO: Refactor Socket and Ssl transport to share some code.
+ # (E.g. buffering.)
+
+ # TODO: Consider using coroutines instead of callbacks, it seems
+ # much easier that way.
+
+ def __init__(self, eventloop, protocol, rawsock, sslcontext=None):
+ self._eventloop = eventloop
+ self._protocol = protocol
+ self._rawsock = rawsock
+ self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ self._sslsock = self._sslcontext.wrap_socket(
+ self._rawsock, do_handshake_on_connect=False)
+
+ self._buffer = collections.deque() # For write().
+ self._write_closed = False
+
+ # Try the handshake now. Likely it will raise EAGAIN, then it
+ # will take care of registering the appropriate callback.
+ self._on_handshake()
+
+ def _bad_error(self, exc):
+ # A serious error. Close the socket etc.
+ fd = self._sslsock.fileno()
+ # TODO: Record whether we have a writer and/or reader registered.
+ try:
+ self._eventloop.remove_writer(fd)
+ except Exception:
+ pass
+ try:
+ self._eventloop.remove_reader(fd)
+ except Exception:
+ pass
+ self._sslsock.close()
+ self._protocol.connection_lost(exc) # XXX call_soon()?
+
+ def _on_handshake(self):
+ fd = self._sslsock.fileno()
+ try:
+ self._sslsock.do_handshake()
+ except ssl.SSLWantReadError:
+ self._eventloop.add_reader(fd, self._on_handshake)
+ return
+ except ssl.SSLWantWriteError:
+ self._eventloop.add_writable(fd, self._on_handshake)
+ return
+ # TODO: What if it raises another error?
+ try:
+ self._eventloop.remove_reader(fd)
+ except Exception:
+ pass
+ try:
+ self._eventloop.remove_writer(fd)
+ except Exception:
+ pass
+ self._protocol.connection_made(self)
+ self._eventloop.add_reader(fd, self._on_ready)
+ self._eventloop.add_writer(fd, self._on_ready)
+
+ def _on_ready(self):
+ # Because of renegotiations (?), there's no difference between
+ # readable and writable. We just try both. XXX This may be
+ # incorrect; we probably need to keep state about what we
+ # should do next.
+
+ # Maybe we're already closed...
+ fd = self._sslsock.fileno()
+ if fd < 0:
+ return
+
+ # First try reading.
+ try:
+ data = self._sslsock.recv(8192)
+ except ssl.SSLWantReadError:
+ pass
+ except ssl.SSLWantWriteError:
+ pass
+ except socket.error as exc:
+ if exc.errno not in _TRYAGAIN:
+ self._bad_error(exc)
+ return
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ # TODO: Don't close when self._buffer is non-empty.
+ assert not self._buffer
+ self._eventloop.remove_reader(fd)
+ self._eventloop.remove_writer(fd)
+ self._sslsock.close()
+ self._protocol.connection_lost(None)
+ return
+
+ # Now try writing, if there's anything to write.
+ if not self._buffer:
+ return
+
+ data = self._buffer[0]
+ try:
+ n = self._sslsock.send(data)
+ except ssl.SSLWantReadError:
+ pass
+ except ssl.SSLWantWriteError:
+ pass
+ except socket.error as exc:
+ if exc.errno not in _TRYAGAIN:
+ self._bad_error(exc)
+ return
+ else:
+ if n == len(data):
+ self._buffer.popleft()
+ # Could try again, but let's just have the next callback do it.
+ else:
+ self._buffer[0] = data[n:]
+
+ def write(self, data):
+ assert isinstance(data, bytes)
+ assert not self._write_closed
+ if not data:
+ return
+ self._buffer.append(data)
+ # We could optimize, but the callback can do this for now.
+
+ def half_close(self):
+ self._write_closed = True
+ # Just set the flag. Calling shutdown() on the ssl socket
+ # breaks something, causing recv() to return binary data.
+
+
+def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0,
+ use_ssl=None):
+ # TODO: Pass in a protocol factory, not a protocol.
+ # What should be the exact sequence of events?
+ # - socket
+ # - transport
+ # - protocol
+ # - tell transport about protocol
+ # - tell protocol about transport
+ # Or should the latter two be reversed? Does it matter?
+ if port is None:
+ port = 443 if use_ssl else 80
+ if use_ssl is None:
+ use_ssl = (port == 443)
+ if not socktype:
+ socktype = socket.SOCK_STREAM
+ eventloop = polling.context.eventloop
+
+ def on_socket_connected(task):
+ assert not task.alive
+ if task.exception is not None:
+ # TODO: Call some callback.
+ raise task.exception
+ sock = task.result
+ assert sock is not None
+ logging.debug('on_socket_connected')
+ if use_ssl:
+ # You can pass an ssl.SSLContext object as use_ssl,
+ # or a bool.
+ if isinstance(use_ssl, bool):
+ sslcontext = None
+ else:
+ sslcontext = use_ssl
+ transport = UnixSslTransport(eventloop, protocol, sock, sslcontext)
+ else:
+ transport = UnixSocketTransport(eventloop, protocol, sock)
+ # TODO: Should the ransport make the following calls?
+ protocol.connection_made(transport) # XXX call_soon()?
+ # Don't do this before connection_made() is called.
+ eventloop.add_reader(sock.fileno(), transport._on_readable)
+
+ coro = sockets.create_connection(host, port, af, socktype, proto)
+ task = scheduling.Task(coro)
+ task.add_done_callback(on_socket_connected)
+
+
+def main(): # Testing...
+
+ # Initialize logging.
+ if '-d' in sys.argv:
+ level = logging.DEBUG
+ elif '-v' in sys.argv:
+ level = logging.INFO
+ elif '-q' in sys.argv:
+ level = logging.ERROR
+ else:
+ level = logging.WARN
+ logging.basicConfig(level=level)
+
+ host = 'xkcd.com'
+ if sys.argv[1:] and '.' in sys.argv[-1]:
+ host = sys.argv[-1]
+
+ t0 = time.time()
+
+ class TestProtocol(Protocol):
+ def connection_made(self, transport):
+ logging.info('Connection made at %.3f secs', time.time() - t0)
+ self.transport = transport
+ self.transport.write(b'GET / HTTP/1.0\r\nHost: ' +
+ host.encode('ascii') +
+ b'\r\n\r\n')
+ self.transport.half_close()
+ def data_received(self, data):
+ logging.info('Received %d bytes at t=%.3f',
+ len(data), time.time() - t0)
+ logging.debug('Received %r', data)
+ def connection_lost(self, exc):
+ logging.debug('Connection lost: %r', exc)
+ self.t1 = time.time()
+ logging.info('Total time %.3f secs', self.t1 - t0)
+
+ tp = TestProtocol()
+ logging.debug('tp = %r', tp)
+ make_connection(tp, host, use_ssl=('-S' in sys.argv))
+ logging.info('Running...')
+ polling.context.eventloop.run()
+ logging.info('Done.')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/old/xkcd.py b/old/xkcd.py
new file mode 100755
index 0000000..474009d
--- /dev/null
+++ b/old/xkcd.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python3.3
+"""Minimal synchronous SSL demo, connecting to xkcd.com."""
+
+import socket, ssl
+
+s = socket.socket()
+s.connect(('xkcd.com', 443))
+ss = ssl.wrap_socket(s)
+
+ss.send(b'GET / HTTP/1.0\r\n\r\n')
+
+while True:
+ data = ss.recv(1000000)
+ print(data)
+ if not data:
+ break
+
+ss.close()
diff --git a/old/yyftime.py b/old/yyftime.py
new file mode 100644
index 0000000..f55234b
--- /dev/null
+++ b/old/yyftime.py
@@ -0,0 +1,75 @@
+"""Compare timing of yield-from <generator> vs. yield <future> calls."""
+
+import gc
+import time
+
+def coroutine(n):
+ if n <= 0:
+ return 1
+ l = yield from coroutine(n-1)
+ r = yield from coroutine(n-1)
+ return l + 1 + r
+
+def run_coro(depth):
+ t0 = time.time()
+ try:
+ g = coroutine(depth)
+ while True:
+ next(g)
+ except StopIteration as err:
+ k = err.value
+ t1 = time.time()
+ print('coro', depth, k, round(t1-t0, 6))
+ return t1-t0
+
+class Future:
+
+ def __init__(self, g):
+ self.g = g
+
+ def wait(self):
+ value = None
+ try:
+ while True:
+ f = self.g.send(value)
+ f.wait()
+ value = f.value
+ except StopIteration as err:
+ self.value = err.value
+
+
+
+def task(func): # Decorator
+ def wrapper(*args):
+ g = func(*args)
+ f = Future(g)
+ return f
+ return wrapper
+
+@task
+def oldstyle(n):
+ if n <= 0:
+ return 1
+ l = yield oldstyle(n-1)
+ r = yield oldstyle(n-1)
+ return l + 1 + r
+
+def run_olds(depth):
+ t0 = time.time()
+ f = oldstyle(depth)
+ f.wait()
+ k = f.value
+ t1 = time.time()
+ print('olds', depth, k, round(t1-t0, 6))
+ return t1-t0
+
+def main():
+ gc.disable()
+ for depth in range(16):
+ tc = run_coro(depth)
+ to = run_olds(depth)
+ if tc:
+ print('ratio', round(to/tc, 2))
+
+if __name__ == '__main__':
+ main()
diff --git a/overlapped.c b/overlapped.c
new file mode 100644
index 0000000..c9f6ec9
--- /dev/null
+++ b/overlapped.c
@@ -0,0 +1,997 @@
+/*
+ * Support for overlapped IO
+ *
+ * Some code borrowed from Modules/_winapi.c of CPython
+ */
+
+/* XXX check overflow and DWORD <-> Py_ssize_t conversions
+ Check itemsize */
+
+#include "Python.h"
+#include "structmember.h"
+
+#define WINDOWS_LEAN_AND_MEAN
+#include <winsock2.h>
+#include <ws2tcpip.h>
+#include <mswsock.h>
+
+#if defined(MS_WIN32) && !defined(MS_WIN64)
+# define F_POINTER "k"
+# define T_POINTER T_ULONG
+#else
+# define F_POINTER "K"
+# define T_POINTER T_ULONGLONG
+#endif
+
+#define F_HANDLE F_POINTER
+#define F_ULONG_PTR F_POINTER
+#define F_DWORD "k"
+#define F_BOOL "i"
+#define F_UINT "I"
+
+#define T_HANDLE T_POINTER
+
+enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT,
+ TYPE_CONNECT, TYPE_DISCONNECT};
+
+/*
+ * Map Windows error codes to subclasses of OSError
+ */
+
+static PyObject *
+SetFromWindowsErr(DWORD err)
+{
+ PyObject *exception_type;
+
+ if (err == 0)
+ err = GetLastError();
+ switch (err) {
+ case ERROR_CONNECTION_REFUSED:
+ exception_type = PyExc_ConnectionRefusedError;
+ break;
+ case ERROR_CONNECTION_ABORTED:
+ exception_type = PyExc_ConnectionAbortedError;
+ break;
+ default:
+ exception_type = PyExc_OSError;
+ }
+ return PyErr_SetExcFromWindowsErr(exception_type, err);
+}
+
+/*
+ * Some functions should be loaded at runtime
+ */
+
+static LPFN_ACCEPTEX Py_AcceptEx = NULL;
+static LPFN_CONNECTEX Py_ConnectEx = NULL;
+static LPFN_DISCONNECTEX Py_DisconnectEx = NULL;
+static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL;
+
+#define GET_WSA_POINTER(s, x) \
+ (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \
+ &Guid##x, sizeof(Guid##x), &Py_##x, \
+ sizeof(Py_##x), &dwBytes, NULL, NULL))
+
+static int
+initialize_function_pointers(void)
+{
+ GUID GuidAcceptEx = WSAID_ACCEPTEX;
+ GUID GuidConnectEx = WSAID_CONNECTEX;
+ GUID GuidDisconnectEx = WSAID_DISCONNECTEX;
+ HINSTANCE hKernel32;
+ SOCKET s;
+ DWORD dwBytes;
+
+ s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (s == INVALID_SOCKET) {
+ SetFromWindowsErr(WSAGetLastError());
+ return -1;
+ }
+
+ if (!GET_WSA_POINTER(s, AcceptEx) ||
+ !GET_WSA_POINTER(s, ConnectEx) ||
+ !GET_WSA_POINTER(s, DisconnectEx))
+ {
+ closesocket(s);
+ SetFromWindowsErr(WSAGetLastError());
+ return -1;
+ }
+
+ closesocket(s);
+
+ /* On WinXP we will have Py_CancelIoEx == NULL */
+ hKernel32 = GetModuleHandle("KERNEL32");
+ *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx");
+ return 0;
+}
+
+/*
+ * Completion port stuff
+ */
+
+PyDoc_STRVAR(
+ CreateIoCompletionPort_doc,
+ "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n"
+ "Create a completion port or register a handle with a port.");
+
+static PyObject *
+overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args)
+{
+ HANDLE FileHandle;
+ HANDLE ExistingCompletionPort;
+ ULONG_PTR CompletionKey;
+ DWORD NumberOfConcurrentThreads;
+ HANDLE ret;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD,
+ &FileHandle, &ExistingCompletionPort, &CompletionKey,
+ &NumberOfConcurrentThreads))
+ return NULL;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort,
+ CompletionKey, NumberOfConcurrentThreads);
+ Py_END_ALLOW_THREADS
+
+ if (ret == NULL)
+ return SetFromWindowsErr(0);
+ return Py_BuildValue(F_HANDLE, ret);
+}
+
+PyDoc_STRVAR(
+ GetQueuedCompletionStatus_doc,
+ "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n"
+ "Get a message from completion port. Wait for up to msecs milliseconds.");
+
+static PyObject *
+overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args)
+{
+ HANDLE CompletionPort = NULL;
+ DWORD NumberOfBytes = 0;
+ ULONG_PTR CompletionKey = 0;
+ OVERLAPPED *Overlapped = NULL;
+ DWORD Milliseconds;
+ DWORD err;
+ BOOL ret;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD,
+ &CompletionPort, &Milliseconds))
+ return NULL;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes,
+ &CompletionKey, &Overlapped, Milliseconds);
+ Py_END_ALLOW_THREADS
+
+ err = ret ? ERROR_SUCCESS : GetLastError();
+ if (Overlapped == NULL) {
+ if (err == WAIT_TIMEOUT)
+ Py_RETURN_NONE;
+ else
+ return SetFromWindowsErr(err);
+ }
+ return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER,
+ err, NumberOfBytes, CompletionKey, Overlapped);
+}
+
+PyDoc_STRVAR(
+ PostQueuedCompletionStatus_doc,
+ "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n"
+ "Post a message to completion port.");
+
+static PyObject *
+overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args)
+{
+ HANDLE CompletionPort;
+ DWORD NumberOfBytes;
+ ULONG_PTR CompletionKey;
+ OVERLAPPED *Overlapped;
+ BOOL ret;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER,
+ &CompletionPort, &NumberOfBytes, &CompletionKey,
+ &Overlapped))
+ return NULL;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes,
+ CompletionKey, Overlapped);
+ Py_END_ALLOW_THREADS
+
+ if (!ret)
+ return SetFromWindowsErr(0);
+ Py_RETURN_NONE;
+}
+
+/*
+ * Bind socket handle to local port without doing slow getaddrinfo()
+ */
+
+PyDoc_STRVAR(
+ BindLocal_doc,
+ "BindLocal(handle, length_of_address_tuple) -> None\n\n"
+ "Bind a socket handle to an arbitrary local port.\n"
+ "If length_of_address_tuple is 2 then an AF_INET address is used.\n"
+ "If length_of_address_tuple is 4 then an AF_INET6 address is used.");
+
+static PyObject *
+overlapped_BindLocal(PyObject *self, PyObject *args)
+{
+ SOCKET Socket;
+ int TupleLength;
+ BOOL ret;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength))
+ return NULL;
+
+ if (TupleLength == 2) {
+ struct sockaddr_in addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_port = 0;
+ addr.sin_addr.S_un.S_addr = INADDR_ANY;
+ ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR;
+ } else if (TupleLength == 4) {
+ struct sockaddr_in6 addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sin6_family = AF_INET6;
+ addr.sin6_port = 0;
+ addr.sin6_addr = in6addr_any;
+ ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR;
+ } else {
+ PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4");
+ return NULL;
+ }
+
+ if (!ret)
+ return SetFromWindowsErr(WSAGetLastError());
+ Py_RETURN_NONE;
+}
+
+/*
+ * A Python object wrapping an OVERLAPPED structure and other useful data
+ * for overlapped I/O
+ */
+
+PyDoc_STRVAR(
+ Overlapped_doc,
+ "Overlapped object");
+
+typedef struct {
+ PyObject_HEAD
+ OVERLAPPED overlapped;
+ /* For convenience, we store the file handle too */
+ HANDLE handle;
+ /* Error returned by last method call */
+ DWORD error;
+ /* Type of operation */
+ DWORD type;
+ /* Buffer used for reading (optional) */
+ PyObject *read_buffer;
+ /* Buffer used for writing (optional) */
+ Py_buffer write_buffer;
+} OverlappedObject;
+
+
+static PyObject *
+Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ OverlappedObject *self;
+ HANDLE event = INVALID_HANDLE_VALUE;
+ static char *kwlist[] = {"event", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event))
+ return NULL;
+
+ if (event == INVALID_HANDLE_VALUE) {
+ event = CreateEvent(NULL, TRUE, FALSE, NULL);
+ if (event == NULL)
+ return SetFromWindowsErr(0);
+ }
+
+ self = PyObject_New(OverlappedObject, type);
+ if (self == NULL) {
+ if (event != NULL)
+ CloseHandle(event);
+ return NULL;
+ }
+
+ self->handle = NULL;
+ self->error = 0;
+ self->type = TYPE_NONE;
+ self->read_buffer = NULL;
+ memset(&self->overlapped, 0, sizeof(OVERLAPPED));
+ memset(&self->write_buffer, 0, sizeof(Py_buffer));
+ if (event)
+ self->overlapped.hEvent = event;
+ return (PyObject *)self;
+}
+
+static void
+Overlapped_dealloc(OverlappedObject *self)
+{
+ DWORD bytes;
+ DWORD olderr = GetLastError();
+ BOOL wait = FALSE;
+ BOOL ret;
+
+ if (!HasOverlappedIoCompleted(&self->overlapped) &&
+ self->type != TYPE_NOT_STARTED)
+ {
+ if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped))
+ wait = TRUE;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = GetOverlappedResult(self->handle, &self->overlapped,
+ &bytes, wait);
+ Py_END_ALLOW_THREADS
+
+ switch (ret ? ERROR_SUCCESS : GetLastError()) {
+ case ERROR_SUCCESS:
+ case ERROR_NOT_FOUND:
+ case ERROR_OPERATION_ABORTED:
+ break;
+ default:
+ PyErr_Format(
+ PyExc_RuntimeError,
+ "%R still has pending operation at "
+ "deallocation, the process may crash", self);
+ PyErr_WriteUnraisable(NULL);
+ }
+ }
+
+ if (self->overlapped.hEvent != NULL)
+ CloseHandle(self->overlapped.hEvent);
+
+ if (self->write_buffer.obj)
+ PyBuffer_Release(&self->write_buffer);
+
+ Py_CLEAR(self->read_buffer);
+ PyObject_Del(self);
+ SetLastError(olderr);
+}
+
+PyDoc_STRVAR(
+ Overlapped_cancel_doc,
+ "cancel() -> None\n\n"
+ "Cancel overlapped operation");
+
+static PyObject *
+Overlapped_cancel(OverlappedObject *self)
+{
+ BOOL ret = TRUE;
+
+ if (self->type == TYPE_NOT_STARTED)
+ Py_RETURN_NONE;
+
+ if (!HasOverlappedIoCompleted(&self->overlapped)) {
+ Py_BEGIN_ALLOW_THREADS
+ if (Py_CancelIoEx)
+ ret = Py_CancelIoEx(self->handle, &self->overlapped);
+ else
+ ret = CancelIo(self->handle);
+ Py_END_ALLOW_THREADS
+ }
+
+ /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */
+ if (!ret && GetLastError() != ERROR_NOT_FOUND)
+ return SetFromWindowsErr(0);
+ Py_RETURN_NONE;
+}
+
+PyDoc_STRVAR(
+ Overlapped_getresult_doc,
+ "getresult(wait=False) -> result\n\n"
+ "Retrieve result of operation. If wait is true then it blocks\n"
+ "until the operation is finished. If wait is false and the\n"
+ "operation is still pending then an error is raised.");
+
+static PyObject *
+Overlapped_getresult(OverlappedObject *self, PyObject *args)
+{
+ BOOL wait = FALSE;
+ DWORD transferred = 0;
+ BOOL ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait))
+ return NULL;
+
+ if (self->type == TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation not yet attempted");
+ return NULL;
+ }
+
+ if (self->type == TYPE_NOT_STARTED) {
+ PyErr_SetString(PyExc_ValueError, "operation failed to start");
+ return NULL;
+ }
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred,
+ wait);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = ret ? ERROR_SUCCESS : GetLastError();
+ switch (err) {
+ case ERROR_SUCCESS:
+ case ERROR_MORE_DATA:
+ break;
+ case ERROR_BROKEN_PIPE:
+ if (self->read_buffer != NULL)
+ break;
+ /* fall through */
+ default:
+ return SetFromWindowsErr(err);
+ }
+
+ switch (self->type) {
+ case TYPE_READ:
+ assert(PyBytes_CheckExact(self->read_buffer));
+ if (transferred != PyBytes_GET_SIZE(self->read_buffer) &&
+ _PyBytes_Resize(&self->read_buffer, transferred))
+ return NULL;
+ Py_INCREF(self->read_buffer);
+ return self->read_buffer;
+ case TYPE_ACCEPT:
+ case TYPE_CONNECT:
+ case TYPE_DISCONNECT:
+ Py_RETURN_NONE;
+ default:
+ return PyLong_FromUnsignedLong((unsigned long) transferred);
+ }
+}
+
+PyDoc_STRVAR(
+ Overlapped_ReadFile_doc,
+ "ReadFile(handle, size) -> Overlapped[message]\n\n"
+ "Start overlapped read");
+
+static PyObject *
+Overlapped_ReadFile(OverlappedObject *self, PyObject *args)
+{
+ HANDLE handle;
+ DWORD size;
+ DWORD nread;
+ PyObject *buf;
+ BOOL ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+#if SIZEOF_SIZE_T <= SIZEOF_LONG
+ size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX);
+#endif
+ buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1));
+ if (buf == NULL)
+ return NULL;
+
+ self->type = TYPE_READ;
+ self->handle = handle;
+ self->read_buffer = buf;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread,
+ &self->overlapped);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = ret ? ERROR_SUCCESS : GetLastError();
+ switch (err) {
+ case ERROR_BROKEN_PIPE:
+ self->type = TYPE_NOT_STARTED;
+ Py_RETURN_NONE;
+ case ERROR_SUCCESS:
+ case ERROR_MORE_DATA:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
+PyDoc_STRVAR(
+ Overlapped_WSARecv_doc,
+ "RecvFile(handle, size, flags) -> Overlapped[message]\n\n"
+ "Start overlapped receive");
+
+static PyObject *
+Overlapped_WSARecv(OverlappedObject *self, PyObject *args)
+{
+ HANDLE handle;
+ DWORD size;
+ DWORD flags = 0;
+ DWORD nread;
+ PyObject *buf;
+ WSABUF wsabuf;
+ int ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD,
+ &handle, &size, &flags))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+#if SIZEOF_SIZE_T <= SIZEOF_LONG
+ size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX);
+#endif
+ buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1));
+ if (buf == NULL)
+ return NULL;
+
+ self->type = TYPE_READ;
+ self->handle = handle;
+ self->read_buffer = buf;
+ wsabuf.len = size;
+ wsabuf.buf = PyBytes_AS_STRING(buf);
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags,
+ &self->overlapped, NULL);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
+ switch (err) {
+ case ERROR_BROKEN_PIPE:
+ self->type = TYPE_NOT_STARTED;
+ Py_RETURN_NONE;
+ case ERROR_SUCCESS:
+ case ERROR_MORE_DATA:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
+PyDoc_STRVAR(
+ Overlapped_WriteFile_doc,
+ "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n"
+ "Start overlapped write");
+
+static PyObject *
+Overlapped_WriteFile(OverlappedObject *self, PyObject *args)
+{
+ HANDLE handle;
+ PyObject *bufobj;
+ DWORD written;
+ BOOL ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+ if (!PyArg_Parse(bufobj, "y*", &self->write_buffer))
+ return NULL;
+
+#if SIZEOF_SIZE_T > SIZEOF_LONG
+ if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) {
+ PyBuffer_Release(&self->write_buffer);
+ PyErr_SetString(PyExc_ValueError, "buffer to large");
+ return NULL;
+ }
+#endif
+
+ self->type = TYPE_WRITE;
+ self->handle = handle;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = WriteFile(handle, self->write_buffer.buf,
+ (DWORD)self->write_buffer.len,
+ &written, &self->overlapped);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = ret ? ERROR_SUCCESS : GetLastError();
+ switch (err) {
+ case ERROR_SUCCESS:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
+PyDoc_STRVAR(
+ Overlapped_WSASend_doc,
+ "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n"
+ "Start overlapped send");
+
+static PyObject *
+Overlapped_WSASend(OverlappedObject *self, PyObject *args)
+{
+ HANDLE handle;
+ PyObject *bufobj;
+ DWORD flags;
+ DWORD written;
+ WSABUF wsabuf;
+ int ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD,
+ &handle, &bufobj, &flags))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+ if (!PyArg_Parse(bufobj, "y*", &self->write_buffer))
+ return NULL;
+
+#if SIZEOF_SIZE_T > SIZEOF_LONG
+ if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) {
+ PyBuffer_Release(&self->write_buffer);
+ PyErr_SetString(PyExc_ValueError, "buffer to large");
+ return NULL;
+ }
+#endif
+
+ self->type = TYPE_WRITE;
+ self->handle = handle;
+ wsabuf.len = (DWORD)self->write_buffer.len;
+ wsabuf.buf = self->write_buffer.buf;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags,
+ &self->overlapped, NULL);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
+ switch (err) {
+ case ERROR_SUCCESS:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
+PyDoc_STRVAR(
+ Overlapped_AcceptEx_doc,
+ "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n"
+ "Start overlapped wait for client to connect");
+
+static PyObject *
+Overlapped_AcceptEx(OverlappedObject *self, PyObject *args)
+{
+ SOCKET ListenSocket;
+ SOCKET AcceptSocket;
+ DWORD BytesReceived;
+ DWORD size;
+ PyObject *buf;
+ BOOL ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE,
+ &ListenSocket, &AcceptSocket))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+ size = sizeof(struct sockaddr_in6) + 16;
+ buf = PyBytes_FromStringAndSize(NULL, size*2);
+ if (!buf)
+ return NULL;
+
+ self->type = TYPE_ACCEPT;
+ self->handle = (HANDLE)ListenSocket;
+ self->read_buffer = buf;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf),
+ 0, size, size, &BytesReceived, &self->overlapped);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError();
+ switch (err) {
+ case ERROR_SUCCESS:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
+
+static int
+parse_address(PyObject *obj, SOCKADDR *Address, int Length)
+{
+ char *Host;
+ unsigned short Port;
+ unsigned long FlowInfo;
+ unsigned long ScopeId;
+
+ memset(Address, 0, Length);
+
+ if (PyArg_ParseTuple(obj, "sH", &Host, &Port))
+ {
+ Address->sa_family = AF_INET;
+ if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) {
+ SetFromWindowsErr(WSAGetLastError());
+ return -1;
+ }
+ ((SOCKADDR_IN*)Address)->sin_port = htons(Port);
+ return Length;
+ }
+ else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId))
+ {
+ PyErr_Clear();
+ Address->sa_family = AF_INET6;
+ if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) {
+ SetFromWindowsErr(WSAGetLastError());
+ return -1;
+ }
+ ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port);
+ ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo;
+ ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId;
+ return Length;
+ }
+
+ return -1;
+}
+
+
+PyDoc_STRVAR(
+ Overlapped_ConnectEx_doc,
+ "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n"
+ "Start overlapped connect. client_handle should be unbound.");
+
+static PyObject *
+Overlapped_ConnectEx(OverlappedObject *self, PyObject *args)
+{
+ SOCKET ConnectSocket;
+ PyObject *AddressObj;
+ char AddressBuf[sizeof(struct sockaddr_in6)];
+ SOCKADDR *Address = (SOCKADDR*)AddressBuf;
+ int Length;
+ BOOL ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+ Length = sizeof(AddressBuf);
+ Length = parse_address(AddressObj, Address, Length);
+ if (Length < 0)
+ return NULL;
+
+ self->type = TYPE_CONNECT;
+ self->handle = (HANDLE)ConnectSocket;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = Py_ConnectEx(ConnectSocket, Address, Length,
+ NULL, 0, NULL, &self->overlapped);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError();
+ switch (err) {
+ case ERROR_SUCCESS:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
+PyDoc_STRVAR(
+ Overlapped_DisconnectEx_doc,
+ "DisconnectEx(handle, flags) -> Overlapped[None]\n\n"
+ "Start overlapped connect. client_handle should be unbound.");
+
+static PyObject *
+Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args)
+{
+ SOCKET Socket;
+ DWORD flags;
+ BOOL ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+ self->type = TYPE_DISCONNECT;
+ self->handle = (HANDLE)Socket;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError();
+ switch (err) {
+ case ERROR_SUCCESS:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
+static PyObject*
+Overlapped_getaddress(OverlappedObject *self)
+{
+ return PyLong_FromVoidPtr(&self->overlapped);
+}
+
+static PyObject*
+Overlapped_getpending(OverlappedObject *self)
+{
+ return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) &&
+ self->type != TYPE_NOT_STARTED);
+}
+
+static PyMethodDef Overlapped_methods[] = {
+ {"getresult", (PyCFunction) Overlapped_getresult,
+ METH_VARARGS, Overlapped_getresult_doc},
+ {"cancel", (PyCFunction) Overlapped_cancel,
+ METH_NOARGS, Overlapped_cancel_doc},
+ {"ReadFile", (PyCFunction) Overlapped_ReadFile,
+ METH_VARARGS, Overlapped_ReadFile_doc},
+ {"WSARecv", (PyCFunction) Overlapped_WSARecv,
+ METH_VARARGS, Overlapped_WSARecv_doc},
+ {"WriteFile", (PyCFunction) Overlapped_WriteFile,
+ METH_VARARGS, Overlapped_WriteFile_doc},
+ {"WSASend", (PyCFunction) Overlapped_WSASend,
+ METH_VARARGS, Overlapped_WSASend_doc},
+ {"AcceptEx", (PyCFunction) Overlapped_AcceptEx,
+ METH_VARARGS, Overlapped_AcceptEx_doc},
+ {"ConnectEx", (PyCFunction) Overlapped_ConnectEx,
+ METH_VARARGS, Overlapped_ConnectEx_doc},
+ {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx,
+ METH_VARARGS, Overlapped_DisconnectEx_doc},
+ {NULL}
+};
+
+static PyMemberDef Overlapped_members[] = {
+ {"error", T_ULONG,
+ offsetof(OverlappedObject, error),
+ READONLY, "Error from last operation"},
+ {"event", T_HANDLE,
+ offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent),
+ READONLY, "Overlapped event handle"},
+ {NULL}
+};
+
+static PyGetSetDef Overlapped_getsets[] = {
+ {"address", (getter)Overlapped_getaddress, NULL,
+ "Address of overlapped structure"},
+ {"pending", (getter)Overlapped_getpending, NULL,
+ "Whether the operation is pending"},
+ {NULL},
+};
+
+PyTypeObject OverlappedType = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ /* tp_name */ "_overlapped.Overlapped",
+ /* tp_basicsize */ sizeof(OverlappedObject),
+ /* tp_itemsize */ 0,
+ /* tp_dealloc */ (destructor) Overlapped_dealloc,
+ /* tp_print */ 0,
+ /* tp_getattr */ 0,
+ /* tp_setattr */ 0,
+ /* tp_reserved */ 0,
+ /* tp_repr */ 0,
+ /* tp_as_number */ 0,
+ /* tp_as_sequence */ 0,
+ /* tp_as_mapping */ 0,
+ /* tp_hash */ 0,
+ /* tp_call */ 0,
+ /* tp_str */ 0,
+ /* tp_getattro */ 0,
+ /* tp_setattro */ 0,
+ /* tp_as_buffer */ 0,
+ /* tp_flags */ Py_TPFLAGS_DEFAULT,
+ /* tp_doc */ "OVERLAPPED structure wrapper",
+ /* tp_traverse */ 0,
+ /* tp_clear */ 0,
+ /* tp_richcompare */ 0,
+ /* tp_weaklistoffset */ 0,
+ /* tp_iter */ 0,
+ /* tp_iternext */ 0,
+ /* tp_methods */ Overlapped_methods,
+ /* tp_members */ Overlapped_members,
+ /* tp_getset */ Overlapped_getsets,
+ /* tp_base */ 0,
+ /* tp_dict */ 0,
+ /* tp_descr_get */ 0,
+ /* tp_descr_set */ 0,
+ /* tp_dictoffset */ 0,
+ /* tp_init */ 0,
+ /* tp_alloc */ 0,
+ /* tp_new */ Overlapped_new,
+};
+
+static PyMethodDef overlapped_functions[] = {
+ {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort,
+ METH_VARARGS, CreateIoCompletionPort_doc},
+ {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus,
+ METH_VARARGS, GetQueuedCompletionStatus_doc},
+ {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus,
+ METH_VARARGS, PostQueuedCompletionStatus_doc},
+ {"BindLocal", overlapped_BindLocal,
+ METH_VARARGS, BindLocal_doc},
+ {NULL}
+};
+
+static struct PyModuleDef overlapped_module = {
+ PyModuleDef_HEAD_INIT,
+ "_overlapped",
+ NULL,
+ -1,
+ overlapped_functions,
+ NULL,
+ NULL,
+ NULL,
+ NULL
+};
+
+#define WINAPI_CONSTANT(fmt, con) \
+ PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con))
+
+PyMODINIT_FUNC
+PyInit__overlapped(void)
+{
+ PyObject *m, *d;
+
+ /* Ensure WSAStartup() called before initializing function pointers */
+ m = PyImport_ImportModule("_socket");
+ if (!m)
+ return NULL;
+ Py_DECREF(m);
+
+ if (initialize_function_pointers() < 0)
+ return NULL;
+
+ if (PyType_Ready(&OverlappedType) < 0)
+ return NULL;
+
+ m = PyModule_Create(&overlapped_module);
+ if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0)
+ return NULL;
+
+ d = PyModule_GetDict(m);
+
+ WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING);
+ WINAPI_CONSTANT(F_DWORD, INFINITE);
+ WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE);
+ WINAPI_CONSTANT(F_HANDLE, NULL);
+ WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT);
+ WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT);
+ WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET);
+
+ return m;
+}
diff --git a/runtests.py b/runtests.py
new file mode 100644
index 0000000..096a256
--- /dev/null
+++ b/runtests.py
@@ -0,0 +1,198 @@
+"""Run all unittests.
+
+Usage:
+ python3 runtests.py [-v] [-q] [pattern] ...
+
+Where:
+ -v: verbose
+ -q: quiet
+ pattern: optional regex patterns to match test ids (default all tests)
+
+Note that the test id is the fully qualified name of the test,
+including package, module, class and method,
+e.g. 'tests.events_test.PolicyTests.testPolicy'.
+
+runtests.py with --coverage argument is equivalent of:
+
+ $(COVERAGE) run --branch runtests.py -v
+ $(COVERAGE) html $(list of files)
+ $(COVERAGE) report -m $(list of files)
+
+"""
+
+# Originally written by Beech Horn (for NDB).
+
+import argparse
+import logging
+import os
+import re
+import sys
+import subprocess
+import unittest
+import importlib.machinery
+
+assert sys.version >= '3.3', 'Please use Python 3.3 or higher.'
+
+ARGS = argparse.ArgumentParser(description="Run all unittests.")
+ARGS.add_argument(
+ '-v', action="store", dest='verbose',
+ nargs='?', const=1, type=int, default=0, help='verbose')
+ARGS.add_argument(
+ '-x', action="store_true", dest='exclude', help='exclude tests')
+ARGS.add_argument(
+ '-q', action="store_true", dest='quiet', help='quiet')
+ARGS.add_argument(
+ '--tests', action="store", dest='testsdir', default='tests',
+ help='tests directory')
+ARGS.add_argument(
+ '--coverage', action="store", dest='coverage', nargs='?', const='',
+ help='enable coverage report and provide python files directory')
+ARGS.add_argument(
+ 'pattern', action="store", nargs="*",
+ help='optional regex patterns to match test ids (default all tests)')
+
+COV_ARGS = argparse.ArgumentParser(description="Run all unittests.")
+COV_ARGS.add_argument(
+ '--coverage', action="store", dest='coverage', nargs='?', const='',
+ help='enable coverage report and provide python files directory')
+
+
+def load_modules(basedir, suffix='.py'):
+ def list_dir(prefix, dir):
+ files = []
+
+ modpath = os.path.join(dir, '__init__.py')
+ if os.path.isfile(modpath):
+ mod = os.path.split(dir)[-1]
+ files.append(('%s%s' % (prefix, mod), modpath))
+
+ prefix = '%s%s.' % (prefix, mod)
+
+ for name in os.listdir(dir):
+ path = os.path.join(dir, name)
+
+ if os.path.isdir(path):
+ files.extend(list_dir('%s%s.' % (prefix, name), path))
+ else:
+ if (name != '__init__.py' and
+ name.endswith(suffix) and
+ not name.startswith(('.', '_'))):
+ files.append(('%s%s' % (prefix, name[:-3]), path))
+
+ return files
+
+ mods = []
+ for modname, sourcefile in list_dir('', basedir):
+ if modname == 'runtests':
+ continue
+ try:
+ loader = importlib.machinery.SourceFileLoader(modname, sourcefile)
+ mods.append((loader.load_module(), sourcefile))
+ except Exception as err:
+ print("Skipping '%s': %s" % (modname, err), file=sys.stderr)
+
+ return mods
+
+
+def load_tests(testsdir, includes=(), excludes=()):
+ mods = [mod for mod, _ in load_modules(testsdir)]
+
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+
+ for mod in mods:
+ for name in set(dir(mod)):
+ if name.endswith('Tests'):
+ test_module = getattr(mod, name)
+ tests = loader.loadTestsFromTestCase(test_module)
+ if includes:
+ tests = [test
+ for test in tests
+ if any(re.search(pat, test.id())
+ for pat in includes)]
+ if excludes:
+ tests = [test
+ for test in tests
+ if not any(re.search(pat, test.id())
+ for pat in excludes)]
+ suite.addTests(tests)
+
+ return suite
+
+
+def runtests():
+ args = ARGS.parse_args()
+
+ testsdir = os.path.abspath(args.testsdir)
+ if not os.path.isdir(testsdir):
+ print("Tests directory is not found: %s\n" % testsdir)
+ ARGS.print_help()
+ return
+
+ excludes = includes = []
+ if args.exclude:
+ excludes = args.pattern
+ else:
+ includes = args.pattern
+
+ v = 0 if args.quiet else args.verbose + 1
+
+ tests = load_tests(args.testsdir, includes, excludes)
+ logger = logging.getLogger()
+ if v == 0:
+ logger.setLevel(logging.CRITICAL)
+ elif v == 1:
+ logger.setLevel(logging.ERROR)
+ elif v == 2:
+ logger.setLevel(logging.WARNING)
+ elif v == 3:
+ logger.setLevel(logging.INFO)
+ elif v >= 4:
+ logger.setLevel(logging.DEBUG)
+ result = unittest.TextTestRunner(verbosity=v).run(tests)
+ sys.exit(not result.wasSuccessful())
+
+
+def runcoverage(sdir, args):
+ """
+ To install coverage3 for Python 3, you need:
+ - Distribute (http://packages.python.org/distribute/)
+
+ What worked for me:
+ - download http://python-distribute.org/distribute_setup.py
+ * curl -O http://python-distribute.org/distribute_setup.py
+ - python3 distribute_setup.py
+ - python3 -m easy_install coverage
+ """
+ try:
+ import coverage
+ except ImportError:
+ print("Coverage package is not found.")
+ print(runcoverage.__doc__)
+ return
+
+ sdir = os.path.abspath(sdir)
+ if not os.path.isdir(sdir):
+ print("Python files directory is not found: %s\n" % sdir)
+ ARGS.print_help()
+ return
+
+ mods = [source for _, source in load_modules(sdir)]
+ coverage = [sys.executable, '-m', 'coverage']
+
+ try:
+ subprocess.check_call(
+ coverage + ['run', '--branch', 'runtests.py'] + args)
+ except:
+ pass
+ else:
+ subprocess.check_call(coverage + ['html'] + mods)
+ subprocess.check_call(coverage + ['report'] + mods)
+
+
+if __name__ == '__main__':
+ if '--coverage' in sys.argv:
+ cov_args, args = COV_ARGS.parse_known_args()
+ runcoverage(cov_args.coverage, args)
+ else:
+ runtests()
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..0260f9d
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[build_ext]
+build_lib=tulip
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..dcaee96
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,14 @@
+import os
+from distutils.core import setup, Extension
+
+extensions = []
+if os.name == 'nt':
+ ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32'])
+ extensions.append(ext)
+
+setup(name='tulip',
+ description="reference implementation of PEP 3156",
+ url='http://www.python.org/dev/peps/pep-3156/',
+ packages=['tulip'],
+ ext_modules=extensions
+ )
diff --git a/srv.py b/srv.py
new file mode 100755
index 0000000..b28abbd
--- /dev/null
+++ b/srv.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python3
+"""Simple server written using an event loop."""
+
+import email.message
+import os
+import sys
+
+assert sys.version >= '3.3', 'Please use Python 3.3 or higher.'
+
+import tulip
+import tulip.http
+
+
+class HttpServer(tulip.http.ServerHttpProtocol):
+
+ def handle_request(self, request_info, message):
+ print('method = {!r}; path = {!r}; version = {!r}'.format(
+ request_info.method, request_info.uri, request_info.version))
+
+ path = request_info.uri
+
+ if (not (path.isprintable() and path.startswith('/')) or '/.' in path):
+ print('bad path', repr(path))
+ path = None
+ else:
+ path = '.' + path
+ if not os.path.exists(path):
+ print('no file', repr(path))
+ path = None
+ else:
+ isdir = os.path.isdir(path)
+
+ if not path:
+ raise tulip.http.HttpStatusException(404)
+
+ headers = email.message.Message()
+ for hdr, val in message.headers:
+ print(hdr, val)
+ headers.add_header(hdr, val)
+
+ if isdir and not path.endswith('/'):
+ path = path + '/'
+ raise tulip.http.HttpStatusException(
+ 302, headers=(('URI', path), ('Location', path)))
+
+ response = tulip.http.Response(self.transport, 200)
+ response.add_header('Transfer-Encoding', 'chunked')
+
+ # content encoding
+ accept_encoding = headers.get('accept-encoding', '').lower()
+ if 'deflate' in accept_encoding:
+ response.add_header('Content-Encoding', 'deflate')
+ response.add_compression_filter('deflate')
+ elif 'gzip' in accept_encoding:
+ response.add_header('Content-Encoding', 'gzip')
+ response.add_compression_filter('gzip')
+
+ response.add_chunking_filter(1025)
+
+ if isdir:
+ response.add_header('Content-type', 'text/html')
+ response.send_headers()
+
+ response.write(b'<ul>\r\n')
+ for name in sorted(os.listdir(path)):
+ if name.isprintable() and not name.startswith('.'):
+ try:
+ bname = name.encode('ascii')
+ except UnicodeError:
+ pass
+ else:
+ if os.path.isdir(os.path.join(path, name)):
+ response.write(b'<li><a href="' + bname +
+ b'/">' + bname + b'/</a></li>\r\n')
+ else:
+ response.write(b'<li><a href="' + bname +
+ b'">' + bname + b'</a></li>\r\n')
+ response.write(b'</ul>')
+ else:
+ response.add_header('Content-type', 'text/plain')
+ response.send_headers()
+
+ try:
+ with open(path, 'rb') as fp:
+ chunk = fp.read(8196)
+ while chunk:
+ if not response.write(chunk):
+ break
+ chunk = fp.read(8196)
+ except OSError:
+ response.write(b'Cannot open')
+
+ response.write_eof()
+ self.close()
+
+
+def main():
+ host = '127.0.0.1'
+ port = 8080
+ if sys.argv[1:]:
+ host = sys.argv[1]
+ if sys.argv[2:]:
+ port = int(sys.argv[2])
+ elif ':' in host:
+ host, port = host.split(':', 1)
+ port = int(port)
+ loop = tulip.get_event_loop()
+ f = loop.start_serving(lambda: HttpServer(debug=True), host, port)
+ x = loop.run_until_complete(f)
+ print('serving on', x.getsockname())
+ loop.run_forever()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sslsrv.py b/sslsrv.py
new file mode 100644
index 0000000..a1bc04f
--- /dev/null
+++ b/sslsrv.py
@@ -0,0 +1,56 @@
+"""Serve up an SSL connection, after Python ssl module docs."""
+
+import socket
+import ssl
+import os
+
+
+def main():
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ certfile = getcertfile()
+ context.load_cert_chain(certfile=certfile, keyfile=certfile)
+ bindsocket = socket.socket()
+ bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ bindsocket.bind(('', 4443))
+ bindsocket.listen(5)
+
+ while True:
+ newsocket, fromaddr = bindsocket.accept()
+ try:
+ connstream = context.wrap_socket(newsocket, server_side=True)
+ try:
+ deal_with_client(connstream)
+ finally:
+ connstream.shutdown(socket.SHUT_RDWR)
+ connstream.close()
+ except Exception as exc:
+ print(exc.__class__.__name__ + ':', exc)
+
+
+def getcertfile():
+ import test # Test package
+ testdir = os.path.dirname(test.__file__)
+ certfile = os.path.join(testdir, 'keycert.pem')
+ print('certfile =', certfile)
+ return certfile
+
+
+def deal_with_client(connstream):
+ data = connstream.recv(1024)
+ # empty data means the client is finished with us
+ while data:
+ if not do_something(connstream, data):
+ # we'll assume do_something returns False
+ # when we're finished with client
+ break
+ data = connstream.recv(1024)
+ # finished with client
+
+
+def do_something(connstream, data):
+ # just echo back
+ connstream.sendall(data)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/base_events_test.py b/tests/base_events_test.py
new file mode 100644
index 0000000..88f3faf
--- /dev/null
+++ b/tests/base_events_test.py
@@ -0,0 +1,283 @@
+"""Tests for base_events.py"""
+
+import concurrent.futures
+import logging
+import socket
+import time
+import unittest
+import unittest.mock
+
+from tulip import base_events
+from tulip import events
+from tulip import futures
+from tulip import protocols
+from tulip import tasks
+from tulip import test_utils
+
+
+class BaseEventLoopTests(test_utils.LogTrackingTestCase):
+
+ def setUp(self):
+ super().setUp()
+
+ self.event_loop = base_events.BaseEventLoop()
+ self.event_loop._selector = unittest.mock.Mock()
+ self.event_loop._selector.registered_count.return_value = 1
+
+ def test_not_implemented(self):
+ m = unittest.mock.Mock()
+ self.assertRaises(
+ NotImplementedError,
+ self.event_loop._make_socket_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.event_loop._make_ssl_transport, m, m, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.event_loop._make_datagram_transport, m, m)
+ self.assertRaises(
+ NotImplementedError, self.event_loop._process_events, [])
+ self.assertRaises(
+ NotImplementedError, self.event_loop._write_to_self)
+ self.assertRaises(
+ NotImplementedError, self.event_loop._read_from_self)
+ self.assertRaises(
+ NotImplementedError,
+ self.event_loop._make_read_pipe_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.event_loop._make_write_pipe_transport, m, m)
+
+ def test_add_callback_handle(self):
+ h = events.Handle(lambda: False, ())
+
+ self.event_loop._add_callback(h)
+ self.assertFalse(self.event_loop._scheduled)
+ self.assertIn(h, self.event_loop._ready)
+
+ def test_add_callback_timer(self):
+ when = time.monotonic()
+
+ h1 = events.Timer(when, lambda: False, ())
+ h2 = events.Timer(when+10.0, lambda: False, ())
+
+ self.event_loop._add_callback(h2)
+ self.event_loop._add_callback(h1)
+ self.assertEqual([h1, h2], self.event_loop._scheduled)
+ self.assertFalse(self.event_loop._ready)
+
+ def test_add_callback_cancelled_handle(self):
+ h = events.Handle(lambda: False, ())
+ h.cancel()
+
+ self.event_loop._add_callback(h)
+ self.assertFalse(self.event_loop._scheduled)
+ self.assertFalse(self.event_loop._ready)
+
+ def test_wrap_future(self):
+ f = futures.Future()
+ self.assertIs(self.event_loop.wrap_future(f), f)
+
+ def test_wrap_future_concurrent(self):
+ f = concurrent.futures.Future()
+ self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future)
+
+ def test_set_default_executor(self):
+ executor = unittest.mock.Mock()
+ self.event_loop.set_default_executor(executor)
+ self.assertIs(executor, self.event_loop._default_executor)
+
+ def test_getnameinfo(self):
+ sockaddr = unittest.mock.Mock()
+ self.event_loop.run_in_executor = unittest.mock.Mock()
+ self.event_loop.getnameinfo(sockaddr)
+ self.assertEqual(
+ (None, socket.getnameinfo, sockaddr, 0),
+ self.event_loop.run_in_executor.call_args[0])
+
+ def test_call_soon(self):
+ def cb():
+ pass
+
+ h = self.event_loop.call_soon(cb)
+ self.assertEqual(h._callback, cb)
+ self.assertIsInstance(h, events.Handle)
+ self.assertIn(h, self.event_loop._ready)
+
+ def test_call_later(self):
+ def cb():
+ pass
+
+ h = self.event_loop.call_later(10.0, cb)
+ self.assertIsInstance(h, events.Timer)
+ self.assertIn(h, self.event_loop._scheduled)
+ self.assertNotIn(h, self.event_loop._ready)
+
+ def test_call_later_no_delay(self):
+ def cb():
+ pass
+
+ h = self.event_loop.call_later(0, cb)
+ self.assertIn(h, self.event_loop._ready)
+ self.assertNotIn(h, self.event_loop._scheduled)
+
+ def test_run_once_in_executor_handle(self):
+ def cb():
+ pass
+
+ self.assertRaises(
+ AssertionError, self.event_loop.run_in_executor,
+ None, events.Handle(cb, ()), ('',))
+ self.assertRaises(
+ AssertionError, self.event_loop.run_in_executor,
+ None, events.Timer(10, cb, ()))
+
+ def test_run_once_in_executor_canceled(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ h.cancel()
+
+ f = self.event_loop.run_in_executor(None, h)
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+
+ def test_run_once_in_executor(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ f = futures.Future()
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+
+ self.event_loop.set_default_executor(executor)
+
+ res = self.event_loop.run_in_executor(None, h)
+ self.assertIs(f, res)
+
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+ res = self.event_loop.run_in_executor(executor, h)
+ self.assertIs(f, res)
+ self.assertTrue(executor.submit.called)
+
+ def test_run_once(self):
+ self.event_loop._run_once = unittest.mock.Mock()
+ self.event_loop._run_once.side_effect = base_events._StopError
+ self.event_loop.run_once()
+ self.assertTrue(self.event_loop._run_once.called)
+
+ def test__run_once(self):
+ h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ())
+ h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ())
+
+ h1.cancel()
+
+ self.event_loop._process_events = unittest.mock.Mock()
+ self.event_loop._scheduled.append(h1)
+ self.event_loop._scheduled.append(h2)
+ self.event_loop._run_once()
+
+ t = self.event_loop._selector.select.call_args[0][0]
+ self.assertTrue(9.99 < t < 10.1)
+ self.assertEqual([h2], self.event_loop._scheduled)
+ self.assertTrue(self.event_loop._process_events.called)
+
+ def test__run_once_timeout(self):
+ h = events.Timer(time.monotonic() + 10.0, lambda: True, ())
+
+ self.event_loop._process_events = unittest.mock.Mock()
+ self.event_loop._scheduled.append(h)
+ self.event_loop._run_once(1.0)
+ self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0])
+
+ def test__run_once_timeout_with_ready(self):
+ # If event loop has ready callbacks, select timeout is always 0.
+ h = events.Timer(time.monotonic() + 10.0, lambda: True, ())
+
+ self.event_loop._process_events = unittest.mock.Mock()
+ self.event_loop._scheduled.append(h)
+ self.event_loop._ready.append(h)
+ self.event_loop._run_once(1.0)
+
+ self.assertEqual((0,), self.event_loop._selector.select.call_args[0])
+
+ @unittest.mock.patch('tulip.base_events.time')
+ @unittest.mock.patch('tulip.base_events.tulip_log')
+ def test__run_once_logging(self, m_logging, m_time):
+ # Log to INFO level if timeout > 1.0 sec.
+ idx = -1
+ data = [10.0, 10.0, 12.0, 13.0]
+
+ def monotonic():
+ nonlocal data, idx
+ idx += 1
+ return data[idx]
+
+ m_time.monotonic = monotonic
+ m_logging.INFO = logging.INFO
+ m_logging.DEBUG = logging.DEBUG
+
+ self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ()))
+ self.event_loop._process_events = unittest.mock.Mock()
+ self.event_loop._run_once()
+ self.assertEqual(logging.INFO, m_logging.log.call_args[0][0])
+
+ idx = -1
+ data = [10.0, 10.0, 10.3, 13.0]
+ self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())]
+ self.event_loop._run_once()
+ self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0])
+
+ def test__run_once_schedule_handle(self):
+ handle = None
+ processed = False
+
+ def cb(event_loop):
+ nonlocal processed, handle
+ processed = True
+ handle = event_loop.call_soon(lambda: True)
+
+ h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,))
+
+ self.event_loop._process_events = unittest.mock.Mock()
+ self.event_loop._scheduled.append(h)
+ self.event_loop._run_once()
+
+ self.assertTrue(processed)
+ self.assertEqual([handle], list(self.event_loop._ready))
+
+ def test_run_until_complete_assertion(self):
+ self.assertRaises(
+ AssertionError, self.event_loop.run_until_complete, 'blah')
+
+ @unittest.mock.patch('tulip.base_events.socket')
+ def test_create_connection_mutiple_errors(self, m_socket):
+ self.suppress_log_errors()
+
+ class MyProto(protocols.Protocol):
+ pass
+
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+
+ idx = -1
+ errors = ['err1', 'err2']
+
+ def _socket(*args, **kw):
+ nonlocal idx, errors
+ idx += 1
+ raise socket.error(errors[idx])
+
+ m_socket.socket = _socket
+ m_socket.error = socket.error
+
+ self.event_loop.getaddrinfo = getaddrinfo
+
+ task = tasks.Task(
+ self.event_loop.create_connection(MyProto, 'xkcd.com', 80))
+ task._step()
+ exc = task.exception()
+ self.assertEqual("Multiple exceptions: err1, err2", str(exc))
diff --git a/tests/events_test.py b/tests/events_test.py
new file mode 100644
index 0000000..4085921
--- /dev/null
+++ b/tests/events_test.py
@@ -0,0 +1,1379 @@
+"""Tests for events.py."""
+
+import concurrent.futures
+import contextlib
+import gc
+import io
+import os
+import re
+import signal
+import socket
+try:
+ import ssl
+except ImportError:
+ ssl = None
+import sys
+import threading
+import time
+import unittest
+import unittest.mock
+
+from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
+
+from tulip import events
+from tulip import transports
+from tulip import protocols
+from tulip import selector_events
+from tulip import tasks
+from tulip import test_utils
+
+
+class MyProto(protocols.Protocol):
+
+ def __init__(self):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+ self.transport.close()
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+
+
+class MyDatagramProto(protocols.DatagramProtocol):
+
+ def __init__(self):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def connection_refused(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+
+
+class MyReadPipeProto(protocols.Protocol):
+
+ def __init__(self):
+ self.state = ['INITIAL']
+ self.nbytes = 0
+ self.transport = None
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == ['INITIAL'], self.state
+ self.state.append('CONNECTED')
+
+ def data_received(self, data):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.state.append('EOF')
+ self.transport.close()
+
+ def connection_lost(self, exc):
+ assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state
+ self.state.append('CLOSED')
+
+
+class MyWritePipeProto(protocols.Protocol):
+
+ def __init__(self):
+ self.state = 'INITIAL'
+ self.transport = None
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+
+
+class EventLoopTestsMixin:
+
+ def setUp(self):
+ super().setUp()
+ self.event_loop = self.create_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+ gc.collect()
+ super().tearDown()
+
+ @contextlib.contextmanager
+ def run_test_server(self, *, use_ssl=False):
+
+ class SilentWSGIRequestHandler(WSGIRequestHandler):
+ def get_stderr(self):
+ return io.StringIO()
+
+ def log_message(self, format, *args):
+ pass
+
+ class SilentWSGIServer(WSGIServer):
+ def handle_error(self, request, client_address):
+ pass
+
+ class SSLWSGIServer(SilentWSGIServer):
+ def finish_request(self, request, client_address):
+ here = os.path.dirname(__file__)
+ keyfile = os.path.join(here, 'sample.key')
+ certfile = os.path.join(here, 'sample.crt')
+ ssock = ssl.wrap_socket(request,
+ keyfile=keyfile,
+ certfile=certfile,
+ server_side=True)
+ try:
+ self.RequestHandlerClass(ssock, client_address, self)
+ ssock.close()
+ except OSError:
+ # maybe socket has been closed by peer
+ pass
+
+ def app(environ, start_response):
+ status = '302 Found'
+ headers = [('Content-type', 'text/plain')]
+ start_response(status, headers)
+ return [b'Test message']
+
+ # Run the test WSGI server in a separate thread in order not to
+ # interfere with event handling in the main thread
+ server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
+ httpd = make_server('127.0.0.1', 0, app,
+ server_class, SilentWSGIRequestHandler)
+ server_thread = threading.Thread(target=httpd.serve_forever)
+ server_thread.start()
+ try:
+ yield httpd
+ finally:
+ httpd.shutdown()
+ server_thread.join()
+
+ def test_run(self):
+ self.event_loop.run() # Returns immediately.
+
+ def test_run_nesting(self):
+ err = None
+
+ @tasks.coroutine
+ def coro():
+ nonlocal err
+ yield from []
+ self.assertTrue(self.event_loop.is_running())
+ try:
+ self.event_loop.run_until_complete(
+ tasks.sleep(0.1))
+ except Exception as exc:
+ err = exc
+
+ self.event_loop.run_until_complete(tasks.Task(coro()))
+ self.assertIsInstance(err, RuntimeError)
+
+ def test_run_once_nesting(self):
+ err = None
+
+ @tasks.coroutine
+ def coro():
+ nonlocal err
+ yield from []
+ tasks.sleep(0.1)
+ try:
+ self.event_loop.run_once()
+ except Exception as exc:
+ err = exc
+
+ self.event_loop.run_until_complete(tasks.Task(coro()))
+ self.assertIsInstance(err, RuntimeError)
+
+ def test_run_once_block(self):
+ called = False
+
+ def callback():
+ nonlocal called
+ called = True
+
+ def run():
+ time.sleep(0.1)
+ self.event_loop.call_soon_threadsafe(callback)
+
+ t = threading.Thread(target=run)
+ t0 = time.monotonic()
+ t.start()
+ self.event_loop.run_once(None)
+ t1 = time.monotonic()
+ t.join()
+ self.assertTrue(called)
+ self.assertTrue(0.09 < t1-t0 <= 0.12)
+
+ def test_call_later(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+
+ self.event_loop.call_later(0.1, callback, 'hello world')
+ t0 = time.monotonic()
+ self.event_loop.run()
+ t1 = time.monotonic()
+ self.assertEqual(results, ['hello world'])
+ self.assertTrue(t1-t0 >= 0.09)
+
+ def test_call_repeatedly(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+
+ self.event_loop.call_repeatedly(0.03, callback, 'ho')
+ self.event_loop.call_later(0.1, self.event_loop.stop)
+ self.event_loop.run()
+ self.assertEqual(results, ['ho', 'ho', 'ho'])
+
+ def test_call_soon(self):
+ results = []
+
+ def callback(arg1, arg2):
+ results.append((arg1, arg2))
+
+ self.event_loop.call_soon(callback, 'hello', 'world')
+ self.event_loop.run()
+ self.assertEqual(results, [('hello', 'world')])
+
+ def test_call_soon_with_handle(self):
+ results = []
+
+ def callback():
+ results.append('yeah')
+
+ handle = events.Handle(callback, ())
+ self.assertIs(self.event_loop.call_soon(handle), handle)
+ self.event_loop.run()
+ self.assertEqual(results, ['yeah'])
+
+ def test_call_soon_threadsafe(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+
+ def run():
+ self.event_loop.call_soon_threadsafe(callback, 'hello')
+
+ t = threading.Thread(target=run)
+ self.event_loop.call_later(0.1, callback, 'world')
+ t0 = time.monotonic()
+ t.start()
+ self.event_loop.run()
+ t1 = time.monotonic()
+ t.join()
+ self.assertEqual(results, ['hello', 'world'])
+ self.assertTrue(t1-t0 >= 0.09)
+
+ def test_call_soon_threadsafe_same_thread(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+
+ self.event_loop.call_later(0.1, callback, 'world')
+ self.event_loop.call_soon_threadsafe(callback, 'hello')
+ self.event_loop.run()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_call_soon_threadsafe_with_handle(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+
+ handle = events.Handle(callback, ('hello',))
+
+ def run():
+ self.assertIs(
+ self.event_loop.call_soon_threadsafe(handle), handle)
+
+ t = threading.Thread(target=run)
+ self.event_loop.call_later(0.1, callback, 'world')
+
+ t0 = time.monotonic()
+ t.start()
+ self.event_loop.run()
+ t1 = time.monotonic()
+ t.join()
+ self.assertEqual(results, ['hello', 'world'])
+ self.assertTrue(t1-t0 >= 0.09)
+
+ def test_wrap_future(self):
+ def run(arg):
+ time.sleep(0.1)
+ return arg
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = self.event_loop.wrap_future(f1)
+ res = self.event_loop.run_until_complete(f2)
+ self.assertEqual(res, 'oi')
+
+ def test_run_in_executor(self):
+ def run(arg):
+ time.sleep(0.1)
+ return arg
+ f2 = self.event_loop.run_in_executor(None, run, 'yo')
+ res = self.event_loop.run_until_complete(f2)
+ self.assertEqual(res, 'yo')
+
+ def test_run_in_executor_with_handle(self):
+ def run(arg):
+ time.sleep(0.1)
+ return arg
+ handle = events.Handle(run, ('yo',))
+ f2 = self.event_loop.run_in_executor(None, handle)
+ res = self.event_loop.run_until_complete(f2)
+ self.assertEqual(res, 'yo')
+
+ def test_reader_callback(self):
+ r, w = test_utils.socketpair()
+ bytes_read = []
+
+ def reader():
+ try:
+ data = r.recv(1024)
+ except BlockingIOError:
+ # Spurious readiness notifications are possible
+ # at least on Linux -- see man select.
+ return
+ if data:
+ bytes_read.append(data)
+ else:
+ self.assertTrue(self.event_loop.remove_reader(r.fileno()))
+ r.close()
+
+ self.event_loop.add_reader(r.fileno(), reader)
+ self.event_loop.call_later(0.05, w.send, b'abc')
+ self.event_loop.call_later(0.1, w.send, b'def')
+ self.event_loop.call_later(0.15, w.close)
+ self.event_loop.run()
+ self.assertEqual(b''.join(bytes_read), b'abcdef')
+
+ def test_reader_callback_with_handle(self):
+ r, w = test_utils.socketpair()
+ bytes_read = []
+
+ def reader():
+ try:
+ data = r.recv(1024)
+ except BlockingIOError:
+ # Spurious readiness notifications are possible
+ # at least on Linux -- see man select.
+ return
+ if data:
+ bytes_read.append(data)
+ else:
+ self.assertTrue(self.event_loop.remove_reader(r.fileno()))
+ r.close()
+
+ handle = events.Handle(reader, ())
+ self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle))
+
+ self.event_loop.call_later(0.05, w.send, b'abc')
+ self.event_loop.call_later(0.1, w.send, b'def')
+ self.event_loop.call_later(0.15, w.close)
+ self.event_loop.run()
+ self.assertEqual(b''.join(bytes_read), b'abcdef')
+
+ def test_reader_callback_cancel(self):
+ r, w = test_utils.socketpair()
+ bytes_read = []
+
+ def reader():
+ try:
+ data = r.recv(1024)
+ except BlockingIOError:
+ return
+ if data:
+ bytes_read.append(data)
+ if sum(len(b) for b in bytes_read) >= 6:
+ handle.cancel()
+ if not data:
+ r.close()
+
+ handle = self.event_loop.add_reader(r.fileno(), reader)
+ self.event_loop.call_later(0.05, w.send, b'abc')
+ self.event_loop.call_later(0.1, w.send, b'def')
+ self.event_loop.call_later(0.15, w.close)
+ self.event_loop.run()
+ self.assertEqual(b''.join(bytes_read), b'abcdef')
+
+ def test_writer_callback(self):
+ r, w = test_utils.socketpair()
+ w.setblocking(False)
+ self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024))
+
+ def remove_writer():
+ self.assertTrue(self.event_loop.remove_writer(w.fileno()))
+
+ self.event_loop.call_later(0.1, remove_writer)
+ self.event_loop.run()
+ w.close()
+ data = r.recv(256*1024)
+ r.close()
+ self.assertTrue(len(data) >= 200)
+
+ def test_writer_callback_with_handle(self):
+ r, w = test_utils.socketpair()
+ w.setblocking(False)
+ handle = events.Handle(w.send, (b'x'*(256*1024),))
+ self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle)
+
+ def remove_writer():
+ self.assertTrue(self.event_loop.remove_writer(w.fileno()))
+
+ self.event_loop.call_later(0.1, remove_writer)
+ self.event_loop.run()
+ w.close()
+ data = r.recv(256*1024)
+ r.close()
+ self.assertTrue(len(data) >= 200)
+
+ def test_writer_callback_cancel(self):
+ r, w = test_utils.socketpair()
+ w.setblocking(False)
+
+ def sender():
+ w.send(b'x'*256)
+ handle.cancel()
+
+ handle = self.event_loop.add_writer(w.fileno(), sender)
+ self.event_loop.run()
+ w.close()
+ data = r.recv(1024)
+ r.close()
+ self.assertTrue(data == b'x'*256)
+
+ def test_sock_client_ops(self):
+ with self.run_test_server() as httpd:
+ sock = socket.socket()
+ sock.setblocking(False)
+ address = httpd.socket.getsockname()
+ self.event_loop.run_until_complete(
+ self.event_loop.sock_connect(sock, address))
+ self.event_loop.run_until_complete(
+ self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ data = self.event_loop.run_until_complete(
+ self.event_loop.sock_recv(sock, 1024))
+ sock.close()
+
+ self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data)
+
+ def test_sock_client_fail(self):
+ # Make sure that we will get an unused port
+ address = None
+ try:
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ address = s.getsockname()
+ finally:
+ s.close()
+
+ sock = socket.socket()
+ sock.setblocking(False)
+ with self.assertRaises(ConnectionRefusedError):
+ self.event_loop.run_until_complete(
+ self.event_loop.sock_connect(sock, address))
+ sock.close()
+
+ def test_sock_accept(self):
+ listener = socket.socket()
+ listener.setblocking(False)
+ listener.bind(('127.0.0.1', 0))
+ listener.listen(1)
+ client = socket.socket()
+ client.connect(listener.getsockname())
+
+ f = self.event_loop.sock_accept(listener)
+ conn, addr = self.event_loop.run_until_complete(f)
+ self.assertEqual(conn.gettimeout(), 0)
+ self.assertEqual(addr, client.getsockname())
+ self.assertEqual(client.getpeername(), listener.getsockname())
+ client.close()
+ conn.close()
+ listener.close()
+
+ @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL')
+ def test_add_signal_handler(self):
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+
+ # Check error behavior first.
+ self.assertRaises(
+ TypeError, self.event_loop.add_signal_handler, 'boom', my_handler)
+ self.assertRaises(
+ TypeError, self.event_loop.remove_signal_handler, 'boom')
+ self.assertRaises(
+ ValueError, self.event_loop.add_signal_handler, signal.NSIG+1,
+ my_handler)
+ self.assertRaises(
+ ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1)
+ self.assertRaises(
+ ValueError, self.event_loop.add_signal_handler, 0, my_handler)
+ self.assertRaises(
+ ValueError, self.event_loop.remove_signal_handler, 0)
+ self.assertRaises(
+ ValueError, self.event_loop.add_signal_handler, -1, my_handler)
+ self.assertRaises(
+ ValueError, self.event_loop.remove_signal_handler, -1)
+ self.assertRaises(
+ RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL,
+ my_handler)
+ # Removing SIGKILL doesn't raise, since we don't call signal().
+ self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL))
+ # Now set a handler and handle it.
+ self.event_loop.add_signal_handler(signal.SIGINT, my_handler)
+ self.event_loop.run_once()
+ os.kill(os.getpid(), signal.SIGINT)
+ self.event_loop.run_once()
+ self.assertEqual(caught, 1)
+ # Removing it should restore the default handler.
+ self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT))
+ self.assertEqual(signal.getsignal(signal.SIGINT),
+ signal.default_int_handler)
+ # Removing again returns False.
+ self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT))
+
+ @unittest.skipIf(sys.platform == 'win32', 'Unix only')
+ def test_cancel_signal_handler(self):
+ # Cancelling the handler should remove it (eventually).
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+
+ handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler)
+ handle.cancel()
+ os.kill(os.getpid(), signal.SIGINT)
+ self.event_loop.run_once()
+ self.assertEqual(caught, 0)
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_while_selecting(self):
+ # Test with a signal actually arriving during a select() call.
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+
+ self.event_loop.add_signal_handler(
+ signal.SIGALRM, my_handler)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once.
+ self.event_loop.call_later(0.15, self.event_loop.stop)
+ self.event_loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_args(self):
+ some_args = (42,)
+ caught = 0
+
+ def my_handler(*args):
+ nonlocal caught
+ caught += 1
+ self.assertEqual(args, some_args)
+
+ self.event_loop.add_signal_handler(
+ signal.SIGALRM, my_handler, *some_args)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once.
+ self.event_loop.call_later(0.15, self.event_loop.stop)
+ self.event_loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ def test_create_connection(self):
+ with self.run_test_server() as httpd:
+ host, port = httpd.socket.getsockname()
+ f = tasks.Task(
+ self.event_loop.create_connection(MyProto, host, port))
+ tr, pr = self.event_loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.event_loop.run()
+ self.assertTrue(pr.nbytes > 0)
+
+ def test_create_connection_sock(self):
+ with self.run_test_server() as httpd:
+ host, port = httpd.socket.getsockname()
+ f = tasks.Task(
+ self.event_loop.create_connection(MyProto, host, port))
+ tr, pr = self.event_loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.event_loop.run()
+ self.assertTrue(pr.nbytes > 0)
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_ssl_connection(self):
+ with self.run_test_server(use_ssl=True) as httpsd:
+ host, port = httpsd.socket.getsockname()
+ f = self.event_loop.create_connection(
+ MyProto, host, port, ssl=True)
+ tr, pr = self.event_loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.assertTrue('ssl' in tr.__class__.__name__.lower())
+ self.assertTrue(
+ hasattr(tr.get_extra_info('socket'), 'getsockname'))
+ self.event_loop.run()
+ self.assertTrue(pr.nbytes > 0)
+
+ def test_create_connection_host_port_sock(self):
+ self.suppress_log_errors()
+ coro = self.event_loop.create_connection(
+ MyProto, 'xkcd.com', 80, sock=object())
+ self.assertRaises(ValueError, self.event_loop.run_until_complete, coro)
+
+ def test_create_connection_no_host_port_sock(self):
+ self.suppress_log_errors()
+ coro = self.event_loop.create_connection(MyProto)
+ self.assertRaises(ValueError, self.event_loop.run_until_complete, coro)
+
+ def test_create_connection_no_getaddrinfo(self):
+ self.suppress_log_errors()
+ getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock()
+ getaddrinfo.return_value = []
+
+ coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80)
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+
+ def test_create_connection_connect_err(self):
+ self.suppress_log_errors()
+ self.event_loop.sock_connect = unittest.mock.Mock()
+ self.event_loop.sock_connect.side_effect = socket.error
+
+ coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80)
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+
+ def test_create_connection_mutiple_errors(self):
+ self.suppress_log_errors()
+
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+ self.event_loop.getaddrinfo = getaddrinfo
+ self.event_loop.sock_connect = unittest.mock.Mock()
+ self.event_loop.sock_connect.side_effect = socket.error
+
+ coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80)
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+
+ def test_start_serving(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyProto()
+ return proto
+
+ f = self.event_loop.start_serving(factory, '0.0.0.0', 0)
+ sock = self.event_loop.run_until_complete(f)
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ self.event_loop.run_once()
+ self.assertIsInstance(proto, MyProto)
+ self.assertEqual('INITIAL', proto.state)
+ self.event_loop.run_once()
+ self.assertEqual('CONNECTED', proto.state)
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('socket'))
+ conn = proto.transport.get_extra_info('socket')
+ self.assertTrue(hasattr(conn, 'getsockname'))
+ self.assertEqual(
+ '127.0.0.1', proto.transport.get_extra_info('addr')[0])
+
+ # close connection
+ proto.transport.close()
+
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ def test_start_serving_sock(self):
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.event_loop.start_serving(MyProto, sock=sock_ob)
+ sock = self.event_loop.run_until_complete(f)
+ self.assertIs(sock, sock_ob)
+
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ self.event_loop.run_once() # This is quite mysterious, but necessary.
+ self.event_loop.run_once()
+ sock.close()
+ client.close()
+
+ def test_start_serving_host_port_sock(self):
+ self.suppress_log_errors()
+ fut = self.event_loop.start_serving(
+ MyProto, '0.0.0.0', 0, sock=object())
+ self.assertRaises(ValueError, self.event_loop.run_until_complete, fut)
+
+ def test_start_serving_no_host_port_sock(self):
+ self.suppress_log_errors()
+ fut = self.event_loop.start_serving(MyProto)
+ self.assertRaises(ValueError, self.event_loop.run_until_complete, fut)
+
+ def test_start_serving_no_getaddrinfo(self):
+ self.suppress_log_errors()
+ getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock()
+ getaddrinfo.return_value = []
+
+ f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0)
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, f)
+
+ @unittest.mock.patch('tulip.base_events.socket')
+ def test_start_serving_cant_bind(self, m_socket):
+ self.suppress_log_errors()
+
+ class Err(socket.error):
+ pass
+
+ m_socket.error = socket.error
+ m_socket.getaddrinfo.return_value = [
+ (2, 1, 6, '', ('127.0.0.1', 10100))]
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.setsockopt.side_effect = Err
+
+ fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0)
+ self.assertRaises(Err, self.event_loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ @unittest.mock.patch('tulip.base_events.socket')
+ def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
+ self.suppress_log_errors()
+
+ m_socket.error = socket.error
+ m_socket.getaddrinfo.return_value = []
+
+ coro = self.event_loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 0))
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_addr_error(self):
+ self.suppress_log_errors()
+
+ coro = self.event_loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr='localhost')
+ self.assertRaises(
+ AssertionError, self.event_loop.run_until_complete, coro)
+ coro = self.event_loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 1, 2, 3))
+ self.assertRaises(
+ AssertionError, self.event_loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint(self):
+ class TestMyDatagramProto(MyDatagramProto):
+ def datagram_received(self, data, addr):
+ super().datagram_received(data, addr)
+ self.transport.sendto(b'resp:'+data, addr)
+
+ coro = self.event_loop.create_datagram_endpoint(
+ TestMyDatagramProto, local_addr=('127.0.0.1', 0))
+ s_transport, server = self.event_loop.run_until_complete(coro)
+ host, port = s_transport.get_extra_info('addr')
+
+ coro = self.event_loop.create_datagram_endpoint(
+ MyDatagramProto, remote_addr=(host, port))
+ transport, client = self.event_loop.run_until_complete(coro)
+
+ self.assertEqual('INITIALIZED', client.state)
+ transport.sendto(b'xxx')
+ self.event_loop.run_once(None)
+ self.assertEqual(3, server.nbytes)
+ self.event_loop.run_once(None)
+
+ # received
+ self.assertEqual(8, client.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(transport.get_extra_info('socket'))
+ conn = transport.get_extra_info('socket')
+ self.assertTrue(hasattr(conn, 'getsockname'))
+
+ # close connection
+ transport.close()
+
+ self.assertEqual('CLOSED', client.state)
+ server.transport.close()
+
+ def test_create_datagram_endpoint_connect_err(self):
+ self.suppress_log_errors()
+ self.event_loop.sock_connect = unittest.mock.Mock()
+ self.event_loop.sock_connect.side_effect = socket.error
+
+ coro = self.event_loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+
+ @unittest.mock.patch('tulip.base_events.socket')
+ def test_create_datagram_endpoint_socket_err(self, m_socket):
+ self.suppress_log_errors()
+
+ m_socket.error = socket.error
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_socket.socket.side_effect = socket.error
+
+ coro = self.event_loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+
+ coro = self.event_loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, local_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_no_matching_family(self):
+ self.suppress_log_errors()
+
+ coro = self.event_loop.create_datagram_endpoint(
+ protocols.DatagramProtocol,
+ remote_addr=('127.0.0.1', 0), local_addr=('::1', 0))
+ self.assertRaises(
+ ValueError, self.event_loop.run_until_complete, coro)
+
+ @unittest.mock.patch('tulip.base_events.socket')
+ def test_create_datagram_endpoint_setblk_err(self, m_socket):
+ self.suppress_log_errors()
+
+ m_socket.error = socket.error
+ m_socket.socket.return_value.setblocking.side_effect = socket.error
+
+ coro = self.event_loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ socket.error, self.event_loop.run_until_complete, coro)
+ self.assertTrue(
+ m_socket.socket.return_value.close.called)
+
+ def test_create_datagram_endpoint_noaddr_nofamily(self):
+ self.suppress_log_errors()
+
+ coro = self.event_loop.create_datagram_endpoint(
+ protocols.DatagramProtocol)
+ self.assertRaises(ValueError, self.event_loop.run_until_complete, coro)
+
+ @unittest.mock.patch('tulip.base_events.socket')
+ def test_create_datagram_endpoint_cant_bind(self, m_socket):
+ self.suppress_log_errors()
+
+ class Err(socket.error):
+ pass
+
+ m_socket.error = socket.error
+ m_socket.AF_INET6 = socket.AF_INET6
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.event_loop.create_datagram_endpoint(
+ MyDatagramProto,
+ local_addr=('127.0.0.1', 0), family=socket.AF_INET)
+ self.assertRaises(Err, self.event_loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ def test_accept_connection_retry(self):
+ sock = unittest.mock.Mock()
+ sock.accept.side_effect = BlockingIOError()
+
+ self.event_loop._accept_connection(MyProto, sock)
+ self.assertFalse(sock.close.called)
+
+ def test_accept_connection_exception(self):
+ self.suppress_log_errors()
+
+ sock = unittest.mock.Mock()
+ sock.accept.side_effect = OSError()
+
+ self.event_loop._accept_connection(MyProto, sock)
+ self.assertTrue(sock.close.called)
+
+ def test_internal_fds(self):
+ event_loop = self.create_event_loop()
+ if not isinstance(event_loop, selector_events.BaseSelectorEventLoop):
+ return
+
+ self.assertEqual(1, event_loop._internal_fds)
+ event_loop.close()
+ self.assertEqual(0, event_loop._internal_fds)
+ self.assertIsNone(event_loop._csock)
+ self.assertIsNone(event_loop._ssock)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_read_pipe(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyReadPipeProto()
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(rpipe, 'rb', 1024)
+
+ @tasks.task
+ def connect():
+ t, p = yield from self.event_loop.connect_read_pipe(factory,
+ pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(0, proto.nbytes)
+
+ self.event_loop.run_until_complete(connect())
+
+ os.write(wpipe, b'1')
+ self.event_loop.run_once()
+ self.assertEqual(1, proto.nbytes)
+
+ os.write(wpipe, b'2345')
+ self.event_loop.run_once()
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(5, proto.nbytes)
+
+ os.close(wpipe)
+ self.event_loop.run_once()
+ self.assertEqual(
+ ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe(self):
+ proto = None
+ transport = None
+
+ def factory():
+ nonlocal proto
+ proto = MyWritePipeProto()
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ @tasks.task
+ def connect():
+ nonlocal transport
+ t, p = yield from self.event_loop.connect_write_pipe(factory,
+ pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+ transport = t
+
+ self.event_loop.run_until_complete(connect())
+
+ transport.write(b'1')
+ self.event_loop.run_once()
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'1', data)
+
+ transport.write(b'2345')
+ self.event_loop.run_once()
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'2345', data)
+ self.assertEqual('CONNECTED', proto.state)
+
+ os.close(rpipe)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ # close connection
+ proto.transport.close()
+ self.event_loop.run_once()
+ self.assertEqual('CLOSED', proto.state)
+
+
+if sys.platform == 'win32':
+ from tulip import windows_events
+
+ class SelectEventLoopTests(EventLoopTestsMixin,
+ test_utils.LogTrackingTestCase):
+ def create_event_loop(self):
+ return windows_events.SelectorEventLoop()
+
+ class ProactorEventLoopTests(EventLoopTestsMixin,
+ test_utils.LogTrackingTestCase):
+ def create_event_loop(self):
+ return windows_events.ProactorEventLoop()
+ def test_create_ssl_connection(self):
+ raise unittest.SkipTest("IocpEventLoop imcompatible with SSL")
+ def test_reader_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+ def test_reader_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+ def test_reader_callback_with_handle(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+ def test_writer_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+ def test_writer_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+ def test_writer_callback_with_handle(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+ def test_accept_connection_retry(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have _accept_connection()")
+ def test_accept_connection_exception(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have _accept_connection()")
+ def test_create_datagram_endpoint(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+ def test_create_datagram_endpoint_no_connection(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+ def test_create_datagram_endpoint_cant_bind(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+ def test_create_datagram_endpoint_noaddr_nofamily(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+ def test_create_datagram_endpoint_socket_err(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+ def test_create_datagram_endpoint_connect_err(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+else:
+ from tulip import selectors
+ from tulip import unix_events
+
+ if hasattr(selectors, 'KqueueSelector'):
+ class KqueueEventLoopTests(EventLoopTestsMixin,
+ test_utils.LogTrackingTestCase):
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(
+ selectors.KqueueSelector())
+
+ if hasattr(selectors, 'EpollSelector'):
+ class EPollEventLoopTests(EventLoopTestsMixin,
+ test_utils.LogTrackingTestCase):
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.EpollSelector())
+
+ if hasattr(selectors, 'PollSelector'):
+ class PollEventLoopTests(EventLoopTestsMixin,
+ test_utils.LogTrackingTestCase):
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.PollSelector())
+
+ # Should always exist.
+ class SelectEventLoopTests(EventLoopTestsMixin,
+ test_utils.LogTrackingTestCase):
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.SelectSelector())
+
+
+class HandleTests(unittest.TestCase):
+
+ def test_handle(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ h = events.Handle(callback, args)
+ self.assertIs(h.callback, callback)
+ self.assertIs(h.args, args)
+ self.assertFalse(h.cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h.cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())<cancelled>'))
+
+ def test_make_handle(self):
+ def callback(*args):
+ return args
+ h1 = events.Handle(callback, ())
+ h2 = events.make_handle(h1, ())
+ self.assertIs(h1, h2)
+
+ self.assertRaises(
+ AssertionError, events.make_handle, h1, (1, 2))
+
+
+class TimerTests(unittest.TestCase):
+
+ def test_timer(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ when = time.monotonic()
+ h = events.Timer(when, callback, args)
+ self.assertIs(h.callback, callback)
+ self.assertIs(h.args, args)
+ self.assertFalse(h.cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h.cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())<cancelled>'))
+
+ self.assertRaises(AssertionError, events.Timer, None, callback, args)
+
+ def test_timer_comparison(self):
+ def callback(*args):
+ return args
+
+ when = time.monotonic()
+
+ h1 = events.Timer(when, callback, ())
+ h2 = events.Timer(when, callback, ())
+ self.assertFalse(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertTrue(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertFalse(h2 > h1)
+ self.assertTrue(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertTrue(h1 == h2)
+ self.assertFalse(h1 != h2)
+
+ h2.cancel()
+ self.assertFalse(h1 == h2)
+
+ h1 = events.Timer(when, callback, ())
+ h2 = events.Timer(when + 10.0, callback, ())
+ self.assertTrue(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertFalse(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertTrue(h2 > h1)
+ self.assertFalse(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h3 = events.Handle(callback, ())
+ self.assertIs(NotImplemented, h1.__eq__(h3))
+ self.assertIs(NotImplemented, h1.__ne__(h3))
+
+
+class AbstractEventLoopTests(unittest.TestCase):
+
+ def test_not_imlemented(self):
+ f = unittest.mock.Mock()
+ ev_loop = events.AbstractEventLoop()
+ self.assertRaises(
+ NotImplementedError, ev_loop.run)
+ self.assertRaises(
+ NotImplementedError, ev_loop.run_forever)
+ self.assertRaises(
+ NotImplementedError, ev_loop.run_once)
+ self.assertRaises(
+ NotImplementedError, ev_loop.run_until_complete, None)
+ self.assertRaises(
+ NotImplementedError, ev_loop.stop)
+ self.assertRaises(
+ NotImplementedError, ev_loop.call_later, None, None)
+ self.assertRaises(
+ NotImplementedError, ev_loop.call_repeatedly, None, None)
+ self.assertRaises(
+ NotImplementedError, ev_loop.call_soon, None)
+ self.assertRaises(
+ NotImplementedError, ev_loop.call_soon_threadsafe, None)
+ self.assertRaises(
+ NotImplementedError, ev_loop.wrap_future, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.run_in_executor, f, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080)
+ self.assertRaises(
+ NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080))
+ self.assertRaises(
+ NotImplementedError, ev_loop.create_connection, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.start_serving, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.create_datagram_endpoint, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.add_reader, 1, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.remove_reader, 1)
+ self.assertRaises(
+ NotImplementedError, ev_loop.add_writer, 1, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.remove_writer, 1)
+ self.assertRaises(
+ NotImplementedError, ev_loop.sock_recv, f, 10)
+ self.assertRaises(
+ NotImplementedError, ev_loop.sock_sendall, f, 10)
+ self.assertRaises(
+ NotImplementedError, ev_loop.sock_connect, f, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.sock_accept, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.add_signal_handler, 1, f)
+ self.assertRaises(
+ NotImplementedError, ev_loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, ev_loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, ev_loop.connect_read_pipe, f,
+ unittest.mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, ev_loop.connect_write_pipe, f,
+ unittest.mock.sentinel.pipe)
+
+
+class ProtocolsAbsTests(unittest.TestCase):
+
+ def test_empty(self):
+ f = unittest.mock.Mock()
+ p = protocols.Protocol()
+ self.assertIsNone(p.connection_made(f))
+ self.assertIsNone(p.connection_lost(f))
+ self.assertIsNone(p.data_received(f))
+ self.assertIsNone(p.eof_received())
+
+ dp = protocols.DatagramProtocol()
+ self.assertIsNone(dp.connection_made(f))
+ self.assertIsNone(dp.connection_lost(f))
+ self.assertIsNone(dp.connection_refused(f))
+ self.assertIsNone(dp.datagram_received(f, f))
+
+
+class PolicyTests(unittest.TestCase):
+
+ def test_event_loop_policy(self):
+ policy = events.EventLoopPolicy()
+ self.assertRaises(NotImplementedError, policy.get_event_loop)
+ self.assertRaises(NotImplementedError, policy.set_event_loop, object())
+ self.assertRaises(NotImplementedError, policy.new_event_loop)
+
+ def test_get_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+ self.assertIsNone(policy._event_loop)
+
+ event_loop = policy.get_event_loop()
+ self.assertIsInstance(event_loop, events.AbstractEventLoop)
+
+ self.assertIs(policy._event_loop, event_loop)
+ self.assertIs(event_loop, policy.get_event_loop())
+
+ @unittest.mock.patch('tulip.events.threading')
+ def test_get_event_loop_thread(self, m_threading):
+ m_t = m_threading.current_thread.return_value = unittest.mock.Mock()
+ m_t.name = 'Thread 1'
+
+ policy = events.DefaultEventLoopPolicy()
+ self.assertIsNone(policy.get_event_loop())
+
+ def test_new_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+
+ event_loop = policy.new_event_loop()
+ self.assertIsInstance(event_loop, events.AbstractEventLoop)
+
+ def test_set_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+ old_event_loop = policy.get_event_loop()
+
+ self.assertRaises(AssertionError, policy.set_event_loop, object())
+
+ event_loop = policy.new_event_loop()
+ policy.set_event_loop(event_loop)
+ self.assertIs(event_loop, policy.get_event_loop())
+ self.assertIsNot(old_event_loop, policy.get_event_loop())
+
+ def test_get_event_loop_policy(self):
+ policy = events.get_event_loop_policy()
+ self.assertIsInstance(policy, events.EventLoopPolicy)
+ self.assertIs(policy, events.get_event_loop_policy())
+
+ def test_set_event_loop_policy(self):
+ self.assertRaises(
+ AssertionError, events.set_event_loop_policy, object())
+
+ old_policy = events.get_event_loop_policy()
+
+ policy = events.DefaultEventLoopPolicy()
+ events.set_event_loop_policy(policy)
+ self.assertIs(policy, events.get_event_loop_policy())
+ self.assertIsNot(policy, old_policy)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/futures_test.py b/tests/futures_test.py
new file mode 100644
index 0000000..5569cca
--- /dev/null
+++ b/tests/futures_test.py
@@ -0,0 +1,222 @@
+"""Tests for futures.py."""
+
+import unittest
+
+from tulip import futures
+
+
+def _fakefunc(f):
+ return f
+
+
+class FutureTests(unittest.TestCase):
+
+ def test_initial_state(self):
+ f = futures.Future()
+ self.assertFalse(f.cancelled())
+ self.assertFalse(f.running())
+ self.assertFalse(f.done())
+
+ def test_init_event_loop_positional(self):
+ # Make sure Future does't accept a positional argument
+ self.assertRaises(TypeError, futures.Future, 42)
+
+ def test_cancel(self):
+ f = futures.Future()
+ self.assertTrue(f.cancel())
+ self.assertTrue(f.cancelled())
+ self.assertFalse(f.running())
+ self.assertTrue(f.done())
+ self.assertRaises(futures.CancelledError, f.result)
+ self.assertRaises(futures.CancelledError, f.exception)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_result(self):
+ f = futures.Future()
+ self.assertRaises(futures.InvalidStateError, f.result)
+ self.assertRaises(futures.InvalidTimeoutError, f.result, 10)
+
+ f.set_result(42)
+ self.assertFalse(f.cancelled())
+ self.assertFalse(f.running())
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 42)
+ self.assertEqual(f.exception(), None)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_exception(self):
+ exc = RuntimeError()
+ f = futures.Future()
+ self.assertRaises(futures.InvalidStateError, f.exception)
+ self.assertRaises(futures.InvalidTimeoutError, f.exception, 10)
+
+ f.set_exception(exc)
+ self.assertFalse(f.cancelled())
+ self.assertFalse(f.running())
+ self.assertTrue(f.done())
+ self.assertRaises(RuntimeError, f.result)
+ self.assertEqual(f.exception(), exc)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_yield_from_twice(self):
+ f = futures.Future()
+
+ def fixture():
+ yield 'A'
+ x = yield from f
+ yield 'B', x
+ y = yield from f
+ yield 'C', y
+
+ g = fixture()
+ self.assertEqual(next(g), 'A') # yield 'A'.
+ self.assertEqual(next(g), f) # First yield from f.
+ f.set_result(42)
+ self.assertEqual(next(g), ('B', 42)) # yield 'B', x.
+ # The second "yield from f" does not yield f.
+ self.assertEqual(next(g), ('C', 42)) # yield 'C', y.
+
+ def test_repr(self):
+ f_pending = futures.Future()
+ self.assertEqual(repr(f_pending), 'Future<PENDING>')
+
+ f_cancelled = futures.Future()
+ f_cancelled.cancel()
+ self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>')
+
+ f_result = futures.Future()
+ f_result.set_result(4)
+ self.assertEqual(repr(f_result), 'Future<result=4>')
+
+ f_exception = futures.Future()
+ f_exception.set_exception(RuntimeError())
+ self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>')
+
+ f_few_callbacks = futures.Future()
+ f_few_callbacks.add_done_callback(_fakefunc)
+ self.assertIn('Future<PENDING, [<function _fakefunc',
+ repr(f_few_callbacks))
+
+ f_many_callbacks = futures.Future()
+ for i in range(20):
+ f_many_callbacks.add_done_callback(_fakefunc)
+ r = repr(f_many_callbacks)
+ self.assertIn('Future<PENDING, [<function _fakefunc', r)
+ self.assertIn('<18 more>', r)
+
+ def test_copy_state(self):
+ # Test the internal _copy_state method since it's being directly
+ # invoked in other modules.
+ f = futures.Future()
+ f.set_result(10)
+
+ newf = futures.Future()
+ newf._copy_state(f)
+ self.assertTrue(newf.done())
+ self.assertEqual(newf.result(), 10)
+
+ f_exception = futures.Future()
+ f_exception.set_exception(RuntimeError())
+
+ newf_exception = futures.Future()
+ newf_exception._copy_state(f_exception)
+ self.assertTrue(newf_exception.done())
+ self.assertRaises(RuntimeError, newf_exception.result)
+
+ f_cancelled = futures.Future()
+ f_cancelled.cancel()
+
+ newf_cancelled = futures.Future()
+ newf_cancelled._copy_state(f_cancelled)
+ self.assertTrue(newf_cancelled.cancelled())
+
+ def test_iter(self):
+ def coro():
+ fut = futures.Future()
+ yield from fut
+
+ def test():
+ arg1, arg2 = coro()
+
+ self.assertRaises(AssertionError, test)
+
+
+# A fake event loop for tests. All it does is implement a call_soon method
+# that immediately invokes the given function.
+class _FakeEventLoop:
+ def call_soon(self, fn, future):
+ fn(future)
+
+
+class FutureDoneCallbackTests(unittest.TestCase):
+
+ def _make_callback(self, bag, thing):
+ # Create a callback function that appends thing to bag.
+ def bag_appender(future):
+ bag.append(thing)
+ return bag_appender
+
+ def _new_future(self):
+ return futures.Future(event_loop=_FakeEventLoop())
+
+ def test_callbacks_invoked_on_set_result(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 42))
+ f.add_done_callback(self._make_callback(bag, 17))
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+ self.assertEqual(bag, [42, 17])
+ self.assertEqual(f.result(), 'foo')
+
+ def test_callbacks_invoked_on_set_exception(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 100))
+
+ self.assertEqual(bag, [])
+ exc = RuntimeError()
+ f.set_exception(exc)
+ self.assertEqual(bag, [100])
+ self.assertEqual(f.exception(), exc)
+
+ def test_remove_done_callback(self):
+ bag = []
+ f = self._new_future()
+ cb1 = self._make_callback(bag, 1)
+ cb2 = self._make_callback(bag, 2)
+ cb3 = self._make_callback(bag, 3)
+
+ # Add one cb1 and one cb2.
+ f.add_done_callback(cb1)
+ f.add_done_callback(cb2)
+
+ # One instance of cb2 removed. Now there's only one cb1.
+ self.assertEqual(f.remove_done_callback(cb2), 1)
+
+ # Never had any cb3 in there.
+ self.assertEqual(f.remove_done_callback(cb3), 0)
+
+ # After this there will be 6 instances of cb1 and one of cb2.
+ f.add_done_callback(cb2)
+ for i in range(5):
+ f.add_done_callback(cb1)
+
+ # Remove all instances of cb1. One cb2 remains.
+ self.assertEqual(f.remove_done_callback(cb1), 6)
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+ self.assertEqual(bag, [2])
+ self.assertEqual(f.result(), 'foo')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py
new file mode 100644
index 0000000..74aef7c
--- /dev/null
+++ b/tests/http_protocol_test.py
@@ -0,0 +1,972 @@
+"""Tests for http/protocol.py"""
+
+import http.client
+import unittest
+import unittest.mock
+import zlib
+
+import tulip
+from tulip.http import protocol
+from tulip.test_utils import LogTrackingTestCase
+
+
+class HttpStreamReaderTests(LogTrackingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.suppress_log_errors()
+
+ self.loop = tulip.new_event_loop()
+ tulip.set_event_loop(self.loop)
+
+ self.transport = unittest.mock.Mock()
+ self.stream = protocol.HttpStreamReader()
+
+ def tearDown(self):
+ self.loop.close()
+ super().tearDown()
+
+ def test_request_line(self):
+ self.stream.feed_data(b'get /path HTTP/1.1\r\n')
+ self.assertEqual(
+ ('GET', '/path', (1, 1)),
+ self.loop.run_until_complete(self.stream.read_request_line()))
+
+ def test_request_line_two_slashes(self):
+ self.stream.feed_data(b'get //path HTTP/1.1\r\n')
+ self.assertEqual(
+ ('GET', '//path', (1, 1)),
+ self.loop.run_until_complete(self.stream.read_request_line()))
+
+ def test_request_line_non_ascii(self):
+ self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n')
+
+ with self.assertRaises(http.client.BadStatusLine) as cm:
+ self.loop.run_until_complete(self.stream.read_request_line())
+
+ self.assertEqual(
+ b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0])
+
+ def test_request_line_bad_status_line(self):
+ self.stream.feed_data(b'\r\n')
+ self.assertRaises(
+ http.client.BadStatusLine,
+ self.loop.run_until_complete,
+ self.stream.read_request_line())
+
+ def test_request_line_bad_method(self):
+ self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n')
+ self.assertRaises(
+ http.client.BadStatusLine,
+ self.loop.run_until_complete,
+ self.stream.read_request_line())
+
+ def test_request_line_bad_version(self):
+ self.stream.feed_data(b'GET //get HT/11\r\n')
+ self.assertRaises(
+ http.client.BadStatusLine,
+ self.loop.run_until_complete,
+ self.stream.read_request_line())
+
+ def test_response_status_bad_status_line(self):
+ self.stream.feed_data(b'\r\n')
+ self.assertRaises(
+ http.client.BadStatusLine,
+ self.loop.run_until_complete,
+ self.stream.read_response_status())
+
+ def test_response_status_bad_status_line_eof(self):
+ self.stream.feed_eof()
+ self.assertRaises(
+ http.client.BadStatusLine,
+ self.loop.run_until_complete,
+ self.stream.read_response_status())
+
+ def test_response_status_bad_status_non_ascii(self):
+ self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n')
+ with self.assertRaises(http.client.BadStatusLine) as cm:
+ self.loop.run_until_complete(self.stream.read_response_status())
+
+ self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0])
+
+ def test_response_status_bad_version(self):
+ self.stream.feed_data(b'HT/11 200 Ok\r\n')
+ with self.assertRaises(http.client.BadStatusLine) as cm:
+ self.loop.run_until_complete(self.stream.read_response_status())
+
+ self.assertEqual('HT/11 200 Ok', cm.exception.args[0])
+
+ def test_response_status_no_reason(self):
+ self.stream.feed_data(b'HTTP/1.1 200\r\n')
+
+ v, s, r = self.loop.run_until_complete(
+ self.stream.read_response_status())
+ self.assertEqual(v, (1, 1))
+ self.assertEqual(s, 200)
+ self.assertEqual(r, '')
+
+ def test_response_status_bad(self):
+ self.stream.feed_data(b'HTT/1\r\n')
+ with self.assertRaises(http.client.BadStatusLine) as cm:
+ self.loop.run_until_complete(
+ self.stream.read_response_status())
+
+ self.assertIn('HTT/1', str(cm.exception))
+
+ def test_response_status_bad_code_under_100(self):
+ self.stream.feed_data(b'HTTP/1.1 99 test\r\n')
+ with self.assertRaises(http.client.BadStatusLine) as cm:
+ self.loop.run_until_complete(
+ self.stream.read_response_status())
+
+ self.assertIn('HTTP/1.1 99 test', str(cm.exception))
+
+ def test_response_status_bad_code_above_999(self):
+ self.stream.feed_data(b'HTTP/1.1 9999 test\r\n')
+ with self.assertRaises(http.client.BadStatusLine) as cm:
+ self.loop.run_until_complete(
+ self.stream.read_response_status())
+
+ self.assertIn('HTTP/1.1 9999 test', str(cm.exception))
+
+ def test_response_status_bad_code_not_int(self):
+ self.stream.feed_data(b'HTTP/1.1 ttt test\r\n')
+ with self.assertRaises(http.client.BadStatusLine) as cm:
+ self.loop.run_until_complete(
+ self.stream.read_response_status())
+
+ self.assertIn('HTTP/1.1 ttt test', str(cm.exception))
+
+ def test_read_headers(self):
+ self.stream.feed_data(b'test: line\r\n'
+ b' continue\r\n'
+ b'test2: data\r\n'
+ b'\r\n')
+
+ headers = self.loop.run_until_complete(self.stream.read_headers())
+ self.assertEqual(headers,
+ [('TEST', 'line\r\n continue'), ('TEST2', 'data')])
+
+ def test_read_headers_size(self):
+ self.stream.feed_data(b'test: line\r\n')
+ self.stream.feed_data(b' continue\r\n')
+ self.stream.feed_data(b'test2: data\r\n')
+ self.stream.feed_data(b'\r\n')
+
+ self.stream.MAX_HEADERS = 5
+ self.assertRaises(
+ http.client.LineTooLong,
+ self.loop.run_until_complete,
+ self.stream.read_headers())
+
+ def test_read_headers_invalid_header(self):
+ self.stream.feed_data(b'test line\r\n')
+
+ with self.assertRaises(ValueError) as cm:
+ self.loop.run_until_complete(self.stream.read_headers())
+
+ self.assertIn("Invalid header b'test line'", str(cm.exception))
+
+ def test_read_headers_invalid_name(self):
+ self.stream.feed_data(b'test[]: line\r\n')
+
+ with self.assertRaises(ValueError) as cm:
+ self.loop.run_until_complete(self.stream.read_headers())
+
+ self.assertIn("Invalid header name b'TEST[]'", str(cm.exception))
+
+ def test_read_headers_headers_size(self):
+ self.stream.MAX_HEADERFIELD_SIZE = 5
+ self.stream.feed_data(b'test: line data data\r\ndata\r\n')
+
+ with self.assertRaises(http.client.LineTooLong) as cm:
+ self.loop.run_until_complete(self.stream.read_headers())
+
+ self.assertIn("limit request headers fields size", str(cm.exception))
+
+ def test_read_headers_continuation_headers_size(self):
+ self.stream.MAX_HEADERFIELD_SIZE = 5
+ self.stream.feed_data(b'test: line\r\n test\r\n')
+
+ with self.assertRaises(http.client.LineTooLong) as cm:
+ self.loop.run_until_complete(self.stream.read_headers())
+
+ self.assertIn("limit request headers fields size", str(cm.exception))
+
+ def test_read_message_should_close(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\nConnection: close\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+ self.assertTrue(msg.should_close)
+
+ def test_read_message_should_close_http11(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(version=(1, 1)))
+ self.assertFalse(msg.should_close)
+
+ def test_read_message_should_close_http10(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(version=(1, 0)))
+ self.assertTrue(msg.should_close)
+
+ def test_read_message_should_close_keep_alive(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\nConnection: keep-alive\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+ self.assertFalse(msg.should_close)
+
+ def test_read_message_content_length_broken(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\nContent-Length: qwe\r\n\r\n')
+
+ self.assertRaises(
+ http.client.HTTPException,
+ self.loop.run_until_complete,
+ self.stream.read_message())
+
+ def test_read_message_content_length_wrong(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\nContent-Length: -1\r\n\r\n')
+
+ self.assertRaises(
+ http.client.HTTPException,
+ self.loop.run_until_complete,
+ self.stream.read_message())
+
+ def test_read_message_content_length(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\nContent-Length: 2\r\n\r\n12')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'12', payload)
+
+ def test_read_message_content_length_no_val(self):
+ self.stream.feed_data(b'Host: example.com\r\n\r\n12')
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(readall=False))
+
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'', payload)
+
+ _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
+ _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()])
+
+ def test_read_message_deflate(self):
+ self.stream.feed_data(
+ ('Host: example.com\r\nContent-Length: %s\r\n'
+ 'Content-Encoding: deflate\r\n\r\n' %
+ len(self._COMPRESSED)).encode())
+ self.stream.feed_data(self._COMPRESSED)
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'data', payload)
+
+ def test_read_message_deflate_disabled(self):
+ self.stream.feed_data(
+ ('Host: example.com\r\nContent-Encoding: deflate\r\n'
+ 'Content-Length: %s\r\n\r\n' %
+ len(self._COMPRESSED)).encode())
+ self.stream.feed_data(self._COMPRESSED)
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(compression=False))
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(self._COMPRESSED, payload)
+
+ def test_read_message_deflate_unknown(self):
+ self.stream.feed_data(
+ ('Host: example.com\r\nContent-Encoding: compress\r\n'
+ 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode())
+ self.stream.feed_data(self._COMPRESSED)
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(compression=False))
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(self._COMPRESSED, payload)
+
+ def test_read_message_websocket(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'12345678', payload)
+
+ def test_read_message_chunked(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n')
+ self.stream.feed_data(
+ b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'dataline', payload)
+
+ def test_read_message_readall_eof(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n\r\n')
+ self.stream.feed_data(b'data')
+ self.stream.feed_data(b'line')
+ self.stream.feed_eof()
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(readall=True))
+
+ payload = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'dataline', payload)
+
+ def test_read_message_payload(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 8\r\n\r\n')
+ self.stream.feed_data(b'data')
+ self.stream.feed_data(b'data')
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(readall=True))
+
+ data = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'datadata', data)
+
+ def test_read_message_payload_eof(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 4\r\n\r\n')
+ self.stream.feed_data(b'da')
+ self.stream.feed_eof()
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(readall=True))
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, msg.payload.read())
+
+ def test_read_message_length_payload_zero(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 0\r\n\r\n')
+ self.stream.feed_data(b'data')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+ data = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'', data)
+
+ def test_read_message_length_payload_incomplete(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 8\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ def coro():
+ self.stream.feed_data(b'data')
+ self.stream.feed_eof()
+ return (yield from msg.payload.read())
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, coro())
+
+ def test_read_message_eof_payload(self):
+ self.stream.feed_data(b'Host: example.com\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(readall=True))
+
+ def coro():
+ self.stream.feed_data(b'data')
+ self.stream.feed_eof()
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(coro())
+ self.assertEqual(b'data', data)
+
+ def test_read_message_length_payload(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 4\r\n\r\n')
+ self.stream.feed_data(b'da')
+ self.stream.feed_data(b't')
+ self.stream.feed_data(b'ali')
+ self.stream.feed_data(b'ne')
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(readall=True))
+
+ self.assertIsInstance(msg.payload, tulip.StreamReader)
+
+ data = self.loop.run_until_complete(msg.payload.read())
+ self.assertEqual(b'data', data)
+ self.assertEqual(b'line', b''.join(self.stream.buffer))
+
+ def test_read_message_length_payload_extra(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 4\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ def coro():
+ self.stream.feed_data(b'da')
+ self.stream.feed_data(b't')
+ self.stream.feed_data(b'ali')
+ self.stream.feed_data(b'ne')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(coro())
+ self.assertEqual(b'data', data)
+ self.assertEqual(b'line', b''.join(self.stream.buffer))
+
+ def test_parse_length_payload_eof_exc(self):
+ parser = self.stream._parse_length_payload(4)
+ next(parser)
+
+ stream = tulip.StreamReader()
+ parser.send(stream)
+ self.stream._parser = parser
+ self.stream.feed_data(b'da')
+
+ def eof():
+ yield from []
+ self.stream.feed_eof()
+
+ t1 = tulip.Task(stream.read())
+ t2 = tulip.Task(eof())
+
+ self.loop.run_until_complete(tulip.wait([t1, t2]))
+ self.assertRaises(http.client.IncompleteRead, t1.result)
+ self.assertIsNone(self.stream._parser)
+
+ def test_read_message_deflate_payload(self):
+ comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
+
+ data = b''.join([comp.compress(b'data'), comp.flush()])
+
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Encoding: deflate\r\n' +
+ ('Content-Length: %s\r\n\r\n' % len(data)).encode())
+
+ msg = self.loop.run_until_complete(
+ self.stream.read_message(readall=True))
+
+ def coro():
+ self.stream.feed_data(data)
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(coro())
+ self.assertEqual(b'data', data)
+
+ def test_read_message_chunked_payload(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ def coro():
+ self.stream.feed_data(
+ b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(coro())
+ self.assertEqual(b'dataline', data)
+
+ def test_read_message_chunked_payload_chunks(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ def coro():
+ self.stream.feed_data(b'4\r\ndata\r')
+ self.stream.feed_data(b'\n4')
+ self.stream.feed_data(b'\r')
+ self.stream.feed_data(b'\n')
+ self.stream.feed_data(b'line\r\n0\r\n')
+ self.stream.feed_data(b'test\r\n\r\n')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(coro())
+ self.assertEqual(b'dataline', data)
+
+ def test_read_message_chunked_payload_incomplete(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ def coro():
+ self.stream.feed_data(b'4\r\ndata\r\n')
+ self.stream.feed_eof()
+ return (yield from msg.payload.read())
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, coro())
+
+ def test_read_message_chunked_payload_extension(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ def coro():
+ self.stream.feed_data(
+ b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(coro())
+ self.assertEqual(b'dataline', data)
+
+ def test_read_message_chunked_payload_size_error(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(self.stream.read_message())
+
+ def coro():
+ self.stream.feed_data(b'blah\r\n')
+ return (yield from msg.payload.read())
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, coro())
+
+ def test_deflate_stream_set_exception(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ exc = ValueError()
+ dstream.set_exception(exc)
+ self.assertIs(exc, stream.exception())
+
+ def test_deflate_stream_feed_data(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.decompress.return_value = b'line'
+
+ dstream.feed_data(b'data')
+ self.assertEqual([b'line'], list(stream.buffer))
+
+ def test_deflate_stream_feed_data_err(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ exc = ValueError()
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.decompress.side_effect = exc
+
+ dstream.feed_data(b'data')
+ self.assertIsInstance(stream.exception(), http.client.IncompleteRead)
+
+ def test_deflate_stream_feed_eof(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.flush.return_value = b'line'
+
+ dstream.feed_eof()
+ self.assertEqual([b'line'], list(stream.buffer))
+ self.assertTrue(stream.eof)
+
+ def test_deflate_stream_feed_eof_err(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.flush.return_value = b'line'
+ dstream.zlib.eof = False
+
+ dstream.feed_eof()
+ self.assertIsInstance(stream.exception(), http.client.IncompleteRead)
+
+
+class HttpMessageTests(unittest.TestCase):
+
+ def setUp(self):
+ self.transport = unittest.mock.Mock()
+
+ def test_start_request(self):
+ msg = protocol.Request(
+ self.transport, 'GET', '/index.html', close=True)
+
+ self.assertIs(msg.transport, self.transport)
+ self.assertIsNone(msg.status)
+ self.assertTrue(msg.closing)
+ self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n')
+
+ def test_start_response(self):
+ msg = protocol.Response(self.transport, 200, close=True)
+
+ self.assertIs(msg.transport, self.transport)
+ self.assertEqual(msg.status, 200)
+ self.assertTrue(msg.closing)
+ self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n')
+
+ def test_force_close(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.closing)
+ msg.force_close()
+ self.assertTrue(msg.closing)
+
+ def test_force_chunked(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.chunked)
+ msg.force_chunked()
+ self.assertTrue(msg.chunked)
+
+ def test_keep_alive(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.keep_alive())
+ msg.keepalive = True
+ self.assertTrue(msg.keep_alive())
+
+ msg.force_close()
+ self.assertFalse(msg.keep_alive())
+
+ def test_add_header(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertEqual([], msg.headers)
+
+ msg.add_header('content-type', 'plain/html')
+ self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers)
+
+ def test_add_headers(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertEqual([], msg.headers)
+
+ msg.add_headers(('content-type', 'plain/html'))
+ self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers)
+
+ def test_add_headers_length(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertIsNone(msg.length)
+
+ msg.add_headers(('content-length', '200'))
+ self.assertEqual(200, msg.length)
+
+ def test_add_headers_upgrade(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.upgrade)
+
+ msg.add_headers(('connection', 'upgrade'))
+ self.assertTrue(msg.upgrade)
+
+ def test_add_headers_upgrade_websocket(self):
+ msg = protocol.Response(self.transport, 200)
+
+ msg.add_headers(('upgrade', 'test'))
+ self.assertEqual([], msg.headers)
+
+ msg.add_headers(('upgrade', 'websocket'))
+ self.assertEqual([('UPGRADE', 'websocket')], msg.headers)
+
+ def test_add_headers_connection_keepalive(self):
+ msg = protocol.Response(self.transport, 200)
+
+ msg.add_headers(('connection', 'keep-alive'))
+ self.assertEqual([], msg.headers)
+ self.assertTrue(msg.keepalive)
+
+ msg.add_headers(('connection', 'close'))
+ self.assertFalse(msg.keepalive)
+
+ def test_add_headers_hop_headers(self):
+ msg = protocol.Response(self.transport, 200)
+
+ msg.add_headers(('connection', 'test'), ('transfer-encoding', 't'))
+ self.assertEqual([], msg.headers)
+
+ def test_default_headers(self):
+ msg = protocol.Response(self.transport, 200)
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertIn('DATE', headers)
+ self.assertIn('CONNECTION', headers)
+
+ def test_default_headers_server(self):
+ msg = protocol.Response(self.transport, 200)
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertIn('SERVER', headers)
+
+ def test_default_headers_useragent(self):
+ msg = protocol.Request(self.transport, 'GET', '/')
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertNotIn('SERVER', headers)
+ self.assertIn('USER-AGENT', headers)
+
+ def test_default_headers_chunked(self):
+ msg = protocol.Response(self.transport, 200)
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertNotIn('TRANSFER-ENCODING', headers)
+
+ msg.force_chunked()
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertIn('TRANSFER-ENCODING', headers)
+
+ def test_default_headers_connection_upgrade(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.upgrade = True
+
+ headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION']
+ self.assertEqual([('CONNECTION', 'upgrade')], headers)
+
+ def test_default_headers_connection_close(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.force_close()
+
+ headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION']
+ self.assertEqual([('CONNECTION', 'close')], headers)
+
+ def test_default_headers_connection_keep_alive(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.keepalive = True
+
+ headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION']
+ self.assertEqual([('CONNECTION', 'keep-alive')], headers)
+
+ def test_send_headers(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-type', 'plain/html'))
+ self.assertFalse(msg.is_headers_sent())
+
+ msg.send_headers()
+
+ content = b''.join([arg[1][0] for arg in list(write.mock_calls)])
+
+ self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n'))
+ self.assertIn(b'CONTENT-TYPE: plain/html', content)
+ self.assertTrue(msg.headers_sent)
+ self.assertTrue(msg.is_headers_sent())
+
+ def test_send_headers_nomore_add(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-type', 'plain/html'))
+ msg.send_headers()
+
+ self.assertRaises(AssertionError,
+ msg.add_header, 'content-type', 'plain/html')
+
+ def test_prepare_length(self):
+ msg = protocol.Response(self.transport, 200)
+ length = msg._write_length_payload = unittest.mock.Mock()
+ length.return_value = iter([1, 2, 3])
+
+ msg.add_headers(('content-length', '200'))
+ msg.send_headers()
+
+ self.assertTrue(length.called)
+ self.assertTrue((200,), length.call_args[0])
+
+ def test_prepare_chunked_force(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.force_chunked()
+
+ chunked = msg._write_chunked_payload = unittest.mock.Mock()
+ chunked.return_value = iter([1, 2, 3])
+
+ msg.add_headers(('content-length', '200'))
+ msg.send_headers()
+ self.assertTrue(chunked.called)
+
+ def test_prepare_chunked_no_length(self):
+ msg = protocol.Response(self.transport, 200)
+
+ chunked = msg._write_chunked_payload = unittest.mock.Mock()
+ chunked.return_value = iter([1, 2, 3])
+
+ msg.send_headers()
+ self.assertTrue(chunked.called)
+
+ def test_prepare_eof(self):
+ msg = protocol.Response(self.transport, 200, http_version=(1, 0))
+
+ eof = msg._write_eof_payload = unittest.mock.Mock()
+ eof.return_value = iter([1, 2, 3])
+
+ msg.send_headers()
+ self.assertTrue(eof.called)
+
+ def test_write_auto_send_headers(self):
+ msg = protocol.Response(self.transport, 200, http_version=(1, 0))
+ msg._send_headers = True
+
+ msg.write(b'data1')
+ self.assertTrue(msg.headers_sent)
+
+ def test_write_payload_eof(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200, http_version=(1, 0))
+ msg.send_headers()
+
+ msg.write(b'data1')
+ self.assertTrue(msg.headers_sent)
+
+ msg.write(b'data2')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'data1data2', content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.force_chunked()
+ msg.send_headers()
+
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'4\r\ndata\r\n0\r\n\r\n',
+ content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked_multiple(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.force_chunked()
+ msg.send_headers()
+
+ msg.write(b'data1')
+ msg.write(b'data2')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n',
+ content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_length(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-length', '2'))
+ msg.send_headers()
+
+ msg.write(b'd')
+ msg.write(b'ata')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'da', content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked_filter(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_chunking_filter(2)
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n'))
+
+ def test_write_payload_chunked_filter_mutiple_chunks(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_chunking_filter(2)
+ msg.write(b'data1')
+ msg.write(b'data2')
+ msg.write_eof()
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertTrue(content.endswith(
+ b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n'
+ b'2\r\na2\r\n0\r\n\r\n'))
+
+ def test_write_payload_chunked_large_chunk(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_chunking_filter(1024)
+ msg.write(b'data')
+ msg.write_eof()
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n'))
+
+ _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
+ _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()])
+
+ def test_write_payload_deflate_filter(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-length', '%s' % len(self._COMPRESSED)))
+ msg.send_headers()
+
+ msg.add_compression_filter('deflate')
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_deflate_and_chunked(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_compression_filter('deflate')
+ msg.add_chunking_filter(2)
+
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n',
+ content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked_and_deflate(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-length', '%s' % len(self._COMPRESSED)))
+
+ msg.add_chunking_filter(2)
+ msg.add_compression_filter('deflate')
+ msg.send_headers()
+
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1])
diff --git a/tests/http_server_test.py b/tests/http_server_test.py
new file mode 100644
index 0000000..dc55eff
--- /dev/null
+++ b/tests/http_server_test.py
@@ -0,0 +1,242 @@
+"""Tests for http/server.py"""
+
+import unittest
+import unittest.mock
+
+import tulip
+from tulip.http import server
+from tulip.http import errors
+from tulip.test_utils import LogTrackingTestCase
+
+
+class HttpServerProtocolTests(LogTrackingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.suppress_log_errors()
+
+ self.loop = tulip.new_event_loop()
+ tulip.set_event_loop(self.loop)
+
+ def tearDown(self):
+ self.loop.close()
+ super().tearDown()
+
+ def test_http_status_exception(self):
+ exc = errors.HttpStatusException(500, message='Internal error')
+ self.assertEqual(exc.code, 500)
+ self.assertEqual(exc.message, 'Internal error')
+
+ def test_handle_request(self):
+ transport = unittest.mock.Mock()
+
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(transport)
+
+ rline = unittest.mock.Mock()
+ rline.version = (1, 1)
+ message = unittest.mock.Mock()
+ srv.handle_request(rline, message)
+
+ content = b''.join([c[1][0] for c in list(transport.write.mock_calls)])
+ self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n'))
+
+ def test_connection_made(self):
+ srv = server.ServerHttpProtocol()
+ self.assertIsNone(srv._request_handle)
+
+ srv.connection_made(unittest.mock.Mock())
+ self.assertIsNotNone(srv._request_handle)
+
+ def test_data_received(self):
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(unittest.mock.Mock())
+
+ srv.data_received(b'123')
+ self.assertEqual(b'123', b''.join(srv.stream.buffer))
+
+ srv.data_received(b'456')
+ self.assertEqual(b'123456', b''.join(srv.stream.buffer))
+
+ def test_eof_received(self):
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(unittest.mock.Mock())
+ srv.eof_received()
+ self.assertTrue(srv.stream.eof)
+
+ def test_connection_lost(self):
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(unittest.mock.Mock())
+ srv.data_received(b'123')
+
+ handle = srv._request_handle
+ srv.connection_lost(None)
+
+ self.assertIsNone(srv._request_handle)
+ self.assertTrue(handle.cancelled())
+
+ srv.connection_lost(None)
+ self.assertIsNone(srv._request_handle)
+
+ def test_close(self):
+ srv = server.ServerHttpProtocol()
+ self.assertFalse(srv.closing)
+
+ srv.close()
+ self.assertTrue(srv.closing)
+
+ def test_handle_error(self):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(transport)
+
+ srv.handle_error(404)
+ content = b''.join([c[1][0] for c in list(transport.write.mock_calls)])
+ self.assertIn(b'HTTP/1.1 404 Not Found', content)
+
+ @unittest.mock.patch('tulip.http.server.traceback')
+ def test_handle_error_traceback_exc(self, m_trace):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol(debug=True)
+ srv.connection_made(transport)
+
+ m_trace.format_exc.side_effect = ValueError
+
+ srv.handle_error(500, exc=object())
+ content = b''.join([c[1][0] for c in list(transport.write.mock_calls)])
+ self.assertTrue(
+ content.startswith(b'HTTP/1.1 500 Internal Server Error'))
+
+ def test_handle_error_debug(self):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol()
+ srv.debug = True
+ srv.connection_made(transport)
+
+ try:
+ raise ValueError()
+ except Exception as exc:
+ srv.handle_error(999, exc=exc)
+
+ content = b''.join([c[1][0] for c in list(transport.write.mock_calls)])
+
+ self.assertIn(b'HTTP/1.1 500 Internal', content)
+ self.assertIn(b'Traceback (most recent call last):', content)
+
+ def test_handle_error_500(self):
+ log = unittest.mock.Mock()
+ transport = unittest.mock.Mock()
+
+ srv = server.ServerHttpProtocol(log=log)
+ srv.connection_made(transport)
+
+ srv.handle_error(500)
+ self.assertTrue(log.exception.called)
+
+ def test_handle(self):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(transport)
+
+ handle = srv.handle_request = unittest.mock.Mock()
+
+ srv.stream.feed_data(
+ b'GET / HTTP/1.0\r\n'
+ b'Host: example.com\r\n\r\n')
+ srv.close()
+
+ self.loop.run_until_complete(srv._request_handle)
+ self.assertTrue(handle.called)
+ self.assertIsNone(srv._request_handle)
+
+ def test_handle_coro(self):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(transport)
+
+ called = False
+
+ @tulip.coroutine
+ def coro(rline, message):
+ nonlocal called
+ called = True
+ yield from []
+ srv.eof_received()
+
+ srv.handle_request = coro
+
+ srv.stream.feed_data(
+ b'GET / HTTP/1.0\r\n'
+ b'Host: example.com\r\n\r\n')
+ self.loop.run_until_complete(srv._request_handle)
+ self.assertTrue(called)
+
+ def test_handle_close(self):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(transport)
+
+ handle = srv.handle_request = unittest.mock.Mock()
+
+ srv.stream.feed_data(
+ b'GET / HTTP/1.0\r\n'
+ b'Host: example.com\r\n\r\n')
+ srv.close()
+ self.loop.run_until_complete(srv._request_handle)
+
+ self.assertTrue(handle.called)
+ self.assertTrue(transport.close.called)
+
+ def test_handle_cancel(self):
+ log = unittest.mock.Mock()
+ transport = unittest.mock.Mock()
+
+ srv = server.ServerHttpProtocol(log=log, debug=True)
+ srv.connection_made(transport)
+
+ srv.handle_request = unittest.mock.Mock()
+
+ @tulip.task
+ def cancel():
+ yield from []
+ srv._request_handle.cancel()
+
+ srv.close()
+ self.loop.run_until_complete(
+ tulip.wait([srv._request_handle, cancel()]))
+ self.assertTrue(log.debug.called)
+
+ def test_handle_400(self):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(transport)
+ srv.handle_error = unittest.mock.Mock()
+
+ def side_effect(*args):
+ srv.close()
+ srv.handle_error.side_effect = side_effect
+
+ srv.stream.feed_data(b'GET / HT/asd\r\n')
+
+ self.loop.run_until_complete(srv._request_handle)
+ self.assertTrue(srv.handle_error.called)
+ self.assertTrue(400, srv.handle_error.call_args[0][0])
+ self.assertTrue(transport.close.called)
+
+ def test_handle_500(self):
+ transport = unittest.mock.Mock()
+ srv = server.ServerHttpProtocol()
+ srv.connection_made(transport)
+
+ handle = srv.handle_request = unittest.mock.Mock()
+ handle.side_effect = ValueError
+ srv.handle_error = unittest.mock.Mock()
+ srv.close()
+
+ srv.stream.feed_data(
+ b'GET / HTTP/1.0\r\n'
+ b'Host: example.com\r\n\r\n')
+ self.loop.run_until_complete(srv._request_handle)
+
+ self.assertTrue(srv.handle_error.called)
+ self.assertTrue(500, srv.handle_error.call_args[0][0])
diff --git a/tests/locks_test.py b/tests/locks_test.py
new file mode 100644
index 0000000..7d2111d
--- /dev/null
+++ b/tests/locks_test.py
@@ -0,0 +1,747 @@
+"""Tests for lock.py"""
+
+import time
+import unittest
+import unittest.mock
+
+from tulip import events
+from tulip import futures
+from tulip import locks
+from tulip import tasks
+from tulip import test_utils
+
+
+class LockTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+
+ def test_repr(self):
+ lock = locks.Lock()
+ self.assertTrue(repr(lock).endswith('[unlocked]>'))
+
+ def acquire_lock():
+ yield from lock
+
+ self.event_loop.run_until_complete(acquire_lock())
+ self.assertTrue(repr(lock).endswith('[locked]>'))
+
+ def test_lock(self):
+ lock = locks.Lock()
+
+ def acquire_lock():
+ return (yield from lock)
+
+ res = self.event_loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_acquire(self):
+ lock = locks.Lock()
+ result = []
+
+ self.assertTrue(
+ self.event_loop.run_until_complete(lock.acquire()))
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from lock.acquire()):
+ result.append(1)
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from lock.acquire()):
+ result.append(2)
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from lock.acquire()):
+ result.append(3)
+
+ tasks.Task(c1(result))
+ tasks.Task(c2(result))
+
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ lock.release()
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+
+ tasks.Task(c3(result))
+
+ lock.release()
+ self.event_loop.run_once()
+ self.assertEqual([1, 2], result)
+
+ lock.release()
+ self.event_loop.run_once()
+ self.assertEqual([1, 2, 3], result)
+
+ def test_acquire_timeout(self):
+ lock = locks.Lock()
+ self.assertTrue(
+ self.event_loop.run_until_complete(lock.acquire()))
+
+ t0 = time.monotonic()
+ acquired = self.event_loop.run_until_complete(
+ lock.acquire(timeout=0.1))
+ self.assertFalse(acquired)
+
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.08 < total_time < 0.12)
+
+ lock = locks.Lock()
+ self.event_loop.run_until_complete(lock.acquire())
+
+ self.event_loop.call_later(0.1, lock.release)
+ acquired = self.event_loop.run_until_complete(lock.acquire(10.1))
+ self.assertTrue(acquired)
+
+ def test_acquire_timeout_mixed(self):
+ lock = locks.Lock()
+ self.event_loop.run_until_complete(lock.acquire())
+ tasks.Task(lock.acquire())
+ tasks.Task(lock.acquire())
+ acquire_task = tasks.Task(lock.acquire(0.1))
+ tasks.Task(lock.acquire())
+
+ t0 = time.monotonic()
+ acquired = self.event_loop.run_until_complete(acquire_task)
+ self.assertFalse(acquired)
+
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.08 < total_time < 0.12)
+
+ self.assertEqual(3, len(lock._waiters))
+
+ def test_acquire_cancel(self):
+ lock = locks.Lock()
+ self.assertTrue(
+ self.event_loop.run_until_complete(lock.acquire()))
+
+ task = tasks.Task(lock.acquire())
+ self.event_loop.call_soon(task.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.event_loop.run_until_complete, task)
+ self.assertFalse(lock._waiters)
+
+ def test_release_not_acquired(self):
+ lock = locks.Lock()
+
+ self.assertRaises(RuntimeError, lock.release)
+
+ def test_release_no_waiters(self):
+ lock = locks.Lock()
+ self.event_loop.run_until_complete(lock.acquire())
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_context_manager(self):
+ lock = locks.Lock()
+
+ @tasks.task
+ def acquire_lock():
+ return (yield from lock)
+
+ with self.event_loop.run_until_complete(acquire_lock()):
+ self.assertTrue(lock.locked())
+
+ self.assertFalse(lock.locked())
+
+ def test_context_manager_no_yield(self):
+ lock = locks.Lock()
+
+ try:
+ with lock:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+
+class EventWaiterTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+
+ def test_repr(self):
+ ev = locks.EventWaiter()
+ self.assertTrue(repr(ev).endswith('[unset]>'))
+
+ ev.set()
+ self.assertTrue(repr(ev).endswith('[set]>'))
+
+ def test_wait(self):
+ ev = locks.EventWaiter()
+ self.assertFalse(ev.is_set())
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from ev.wait()):
+ result.append(2)
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from ev.wait()):
+ result.append(3)
+
+ tasks.Task(c1(result))
+ tasks.Task(c2(result))
+
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ tasks.Task(c3(result))
+
+ ev.set()
+ self.event_loop.run_once()
+ self.assertEqual([3, 1, 2], result)
+
+ def test_wait_on_set(self):
+ ev = locks.EventWaiter()
+ ev.set()
+
+ res = self.event_loop.run_until_complete(ev.wait())
+ self.assertTrue(res)
+
+ def test_wait_timeout(self):
+ ev = locks.EventWaiter()
+
+ t0 = time.monotonic()
+ res = self.event_loop.run_until_complete(ev.wait(0.1))
+ self.assertFalse(res)
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.08 < total_time < 0.12)
+
+ ev = locks.EventWaiter()
+ self.event_loop.call_later(0.1, ev.set)
+ acquired = self.event_loop.run_until_complete(ev.wait(10.1))
+ self.assertTrue(acquired)
+
+ def test_wait_timeout_mixed(self):
+ ev = locks.EventWaiter()
+ tasks.Task(ev.wait())
+ tasks.Task(ev.wait())
+ acquire_task = tasks.Task(ev.wait(0.1))
+ tasks.Task(ev.wait())
+
+ t0 = time.monotonic()
+ acquired = self.event_loop.run_until_complete(acquire_task)
+ self.assertFalse(acquired)
+
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.08 < total_time < 0.12)
+
+ self.assertEqual(3, len(ev._waiters))
+
+ def test_wait_cancel(self):
+ ev = locks.EventWaiter()
+
+ wait = tasks.Task(ev.wait())
+ self.event_loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.event_loop.run_until_complete, wait)
+ self.assertFalse(ev._waiters)
+
+ def test_clear(self):
+ ev = locks.EventWaiter()
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ self.assertTrue(ev.is_set())
+
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ def test_clear_with_waiters(self):
+ ev = locks.EventWaiter()
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+
+ tasks.Task(c1(result))
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ ev.set()
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ ev.set()
+ self.assertEqual(1, len(ev._waiters))
+
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+ self.assertEqual(0, len(ev._waiters))
+
+
+class ConditionTests(test_utils.LogTrackingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+
+ def test_wait(self):
+ cond = locks.Condition()
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+
+ tasks.Task(c1(result))
+ tasks.Task(c2(result))
+ tasks.Task(c3(result))
+
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+ self.assertFalse(cond.locked())
+
+ self.assertTrue(
+ self.event_loop.run_until_complete(cond.acquire()))
+ cond.notify()
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.notify(2)
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([1, 2], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(cond.locked())
+
+ def test_wait_timeout(self):
+ cond = locks.Condition()
+ self.event_loop.run_until_complete(cond.acquire())
+
+ t0 = time.monotonic()
+ wait = self.event_loop.run_until_complete(cond.wait(0.1))
+ self.assertFalse(wait)
+ self.assertTrue(cond.locked())
+
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.08 < total_time < 0.12)
+
+ def test_wait_cancel(self):
+ cond = locks.Condition()
+ self.event_loop.run_until_complete(cond.acquire())
+
+ wait = tasks.Task(cond.wait())
+ self.event_loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.event_loop.run_until_complete, wait)
+ self.assertFalse(cond._condition_waiters)
+ self.assertTrue(cond.locked())
+
+ def test_wait_unacquired(self):
+ self.suppress_log_errors()
+
+ cond = locks.Condition()
+ self.assertRaises(
+ RuntimeError,
+ self.event_loop.run_until_complete, cond.wait())
+
+ def test_wait_for(self):
+ cond = locks.Condition()
+ presult = False
+
+ def predicate():
+ return presult
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait_for(predicate)):
+ result.append(1)
+ cond.release()
+
+ tasks.Task(c1(result))
+
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ self.event_loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ presult = True
+ self.event_loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+
+ def test_wait_for_timeout(self):
+ cond = locks.Condition()
+
+ result = []
+
+ predicate = unittest.mock.Mock()
+ predicate.return_value = False
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait_for(predicate, 0.2)):
+ result.append(1)
+ else:
+ result.append(2)
+ cond.release()
+
+ wait_for = tasks.Task(c1(result))
+
+ t0 = time.monotonic()
+
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ self.event_loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ self.event_loop.run_until_complete(wait_for)
+ self.assertEqual([2], result)
+ self.assertEqual(3, predicate.call_count)
+
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.18 < total_time < 0.22)
+
+ def test_wait_for_unacquired(self):
+ self.suppress_log_errors()
+
+ cond = locks.Condition()
+
+ # predicate can return true immediately
+ res = self.event_loop.run_until_complete(
+ cond.wait_for(lambda: [1, 2, 3]))
+ self.assertEqual([1, 2, 3], res)
+
+ self.assertRaises(
+ RuntimeError,
+ self.event_loop.run_until_complete,
+ cond.wait_for(lambda: False))
+
+ def test_notify(self):
+ cond = locks.Condition()
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ cond.release()
+
+ tasks.Task(c1(result))
+ tasks.Task(c2(result))
+ tasks.Task(c3(result))
+
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ self.event_loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+
+ self.event_loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.notify(2048)
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([1, 2, 3], result)
+
+ def test_notify_all(self):
+ cond = locks.Condition()
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+
+ tasks.Task(c1(result))
+ tasks.Task(c2(result))
+
+ self.event_loop.run_once()
+ self.assertEqual([], result)
+
+ self.event_loop.run_until_complete(cond.acquire())
+ cond.notify_all()
+ cond.release()
+ self.event_loop.run_once()
+ self.assertEqual([1, 2], result)
+
+ def test_notify_unacquired(self):
+ cond = locks.Condition()
+ self.assertRaises(RuntimeError, cond.notify)
+
+ def test_notify_all_unacquired(self):
+ cond = locks.Condition()
+ self.assertRaises(RuntimeError, cond.notify_all)
+
+
+class SemaphoreTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+
+ def test_repr(self):
+ sem = locks.Semaphore()
+ self.assertTrue(repr(sem).endswith('[unlocked,value:1]>'))
+
+ self.event_loop.run_until_complete(sem.acquire())
+ self.assertTrue(repr(sem).endswith('[locked]>'))
+
+ def test_semaphore(self):
+ sem = locks.Semaphore()
+ self.assertEqual(1, sem._value)
+
+ @tasks.task
+ def acquire_lock():
+ return (yield from sem)
+
+ res = self.event_loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(sem.locked())
+ self.assertEqual(0, sem._value)
+
+ sem.release()
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ def test_semaphore_value(self):
+ self.assertRaises(ValueError, locks.Semaphore, -1)
+
+ def test_acquire(self):
+ sem = locks.Semaphore(3)
+ result = []
+
+ self.assertTrue(
+ self.event_loop.run_until_complete(sem.acquire()))
+ self.assertTrue(
+ self.event_loop.run_until_complete(sem.acquire()))
+ self.assertFalse(sem.locked())
+
+ @tasks.coroutine
+ def c1(result):
+ yield from sem.acquire()
+ result.append(1)
+
+ @tasks.coroutine
+ def c2(result):
+ yield from sem.acquire()
+ result.append(2)
+
+ @tasks.coroutine
+ def c3(result):
+ yield from sem.acquire()
+ result.append(3)
+
+ @tasks.coroutine
+ def c4(result):
+ yield from sem.acquire()
+ result.append(4)
+
+ tasks.Task(c1(result))
+ tasks.Task(c2(result))
+ tasks.Task(c3(result))
+
+ self.event_loop.run_once()
+ self.assertEqual([1], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(2, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ tasks.Task(c4(result))
+
+ sem.release()
+ sem.release()
+ self.assertEqual(2, sem._value)
+
+ self.event_loop.run_once()
+ self.assertEqual(0, sem._value)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(1, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ def test_acquire_timeout(self):
+ sem = locks.Semaphore()
+ self.event_loop.run_until_complete(sem.acquire())
+
+ t0 = time.monotonic()
+ acquired = self.event_loop.run_until_complete(sem.acquire(0.1))
+ self.assertFalse(acquired)
+
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.08 < total_time < 0.12)
+
+ sem = locks.Semaphore()
+ self.event_loop.run_until_complete(sem.acquire())
+
+ self.event_loop.call_later(0.1, sem.release)
+ acquired = self.event_loop.run_until_complete(sem.acquire(10.1))
+ self.assertTrue(acquired)
+
+ def test_acquire_timeout_mixed(self):
+ sem = locks.Semaphore()
+ self.event_loop.run_until_complete(sem.acquire())
+ tasks.Task(sem.acquire())
+ tasks.Task(sem.acquire())
+ acquire_task = tasks.Task(sem.acquire(0.1))
+ tasks.Task(sem.acquire())
+
+ t0 = time.monotonic()
+ acquired = self.event_loop.run_until_complete(acquire_task)
+ self.assertFalse(acquired)
+
+ total_time = (time.monotonic() - t0)
+ self.assertTrue(0.08 < total_time < 0.12)
+
+ self.assertEqual(3, len(sem._waiters))
+
+ def test_acquire_cancel(self):
+ sem = locks.Semaphore()
+ self.event_loop.run_until_complete(sem.acquire())
+
+ acquire = tasks.Task(sem.acquire())
+ self.event_loop.call_soon(acquire.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.event_loop.run_until_complete, acquire)
+ self.assertFalse(sem._waiters)
+
+ def test_release_not_acquired(self):
+ sem = locks.Semaphore(bound=True)
+
+ self.assertRaises(ValueError, sem.release)
+
+ def test_release_no_waiters(self):
+ sem = locks.Semaphore()
+ self.event_loop.run_until_complete(sem.acquire())
+ self.assertTrue(sem.locked())
+
+ sem.release()
+ self.assertFalse(sem.locked())
+
+ def test_context_manager(self):
+ sem = locks.Semaphore(2)
+
+ @tasks.task
+ def acquire_lock():
+ return (yield from sem)
+
+ with self.event_loop.run_until_complete(acquire_lock()):
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ with self.event_loop.run_until_complete(acquire_lock()):
+ self.assertTrue(sem.locked())
+
+ self.assertEqual(2, sem._value)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/queues_test.py b/tests/queues_test.py
new file mode 100644
index 0000000..d86abd7
--- /dev/null
+++ b/tests/queues_test.py
@@ -0,0 +1,370 @@
+"""Tests for queues.py"""
+
+import unittest
+import queue
+
+from tulip import events
+from tulip import locks
+from tulip import queues
+from tulip import tasks
+
+
+class _QueueTestBase(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+
+
+class QueueBasicTests(_QueueTestBase):
+
+ def _test_repr_or_str(self, fn, expect_id):
+ """Test Queue's repr or str.
+
+ fn is repr or str. expect_id is True if we expect the Queue's id to
+ appear in fn(Queue()).
+ """
+ q = queues.Queue()
+ self.assertTrue(fn(q).startswith('<Queue'))
+ id_is_present = hex(id(q)) in fn(q)
+ self.assertEqual(expect_id, id_is_present)
+
+ @tasks.coroutine
+ def add_getter():
+ q = queues.Queue()
+ # Start a task that waits to get.
+ tasks.Task(q.get())
+ # Let it start waiting.
+ yield from tasks.sleep(0.1)
+ self.assertTrue('_getters[1]' in fn(q))
+
+ self.event_loop.run_until_complete(add_getter())
+
+ @tasks.coroutine
+ def add_putter():
+ q = queues.Queue(maxsize=1)
+ q.put_nowait(1)
+ # Start a task that waits to put.
+ tasks.Task(q.put(2))
+ # Let it start waiting.
+ yield from tasks.sleep(0.1)
+ self.assertTrue('_putters[1]' in fn(q))
+
+ self.event_loop.run_until_complete(add_putter())
+
+ q = queues.Queue()
+ q.put_nowait(1)
+ self.assertTrue('_queue=[1]' in fn(q))
+
+ def test_repr(self):
+ self._test_repr_or_str(repr, True)
+
+ def test_str(self):
+ self._test_repr_or_str(str, False)
+
+ def test_empty(self):
+ q = queues.Queue()
+ self.assertTrue(q.empty())
+ q.put_nowait(1)
+ self.assertFalse(q.empty())
+ self.assertEqual(1, q.get_nowait())
+ self.assertTrue(q.empty())
+
+ def test_full(self):
+ q = queues.Queue()
+ self.assertFalse(q.full())
+
+ q = queues.Queue(maxsize=1)
+ q.put_nowait(1)
+ self.assertTrue(q.full())
+
+ def test_order(self):
+ q = queues.Queue()
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 3, 2], items)
+
+ def test_maxsize(self):
+ q = queues.Queue(maxsize=2)
+ self.assertEqual(2, q.maxsize)
+ have_been_put = []
+
+ @tasks.coroutine
+ def putter():
+ for i in range(3):
+ yield from q.put(i)
+ have_been_put.append(i)
+
+ @tasks.coroutine
+ def test():
+ tasks.Task(putter())
+ yield from tasks.sleep(0.1)
+
+ # The putter is blocked after putting two items.
+ self.assertEqual([0, 1], have_been_put)
+ self.assertEqual(0, q.get_nowait())
+
+ # Let the putter resume and put last item.
+ yield from tasks.sleep(0.1)
+ self.assertEqual([0, 1, 2], have_been_put)
+ self.assertEqual(1, q.get_nowait())
+ self.assertEqual(2, q.get_nowait())
+
+ self.event_loop.run_until_complete(test())
+
+
+class QueueGetTests(_QueueTestBase):
+
+ def test_blocking_get(self):
+ q = queues.Queue()
+ q.put_nowait(1)
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from q.get())
+
+ res = self.event_loop.run_until_complete(queue_get())
+ self.assertEqual(1, res)
+
+ def test_blocking_get_wait(self):
+ q = queues.Queue()
+ started = locks.EventWaiter()
+ finished = False
+
+ @tasks.coroutine
+ def queue_get():
+ nonlocal finished
+ started.set()
+ res = yield from q.get()
+ finished = True
+ return res
+
+ @tasks.coroutine
+ def queue_put():
+ self.event_loop.call_later(0.1, q.put_nowait, 1)
+ queue_get_task = tasks.Task(queue_get())
+ yield from started.wait()
+ self.assertFalse(finished)
+ res = yield from queue_get_task
+ self.assertTrue(finished)
+ return res
+
+ res = self.event_loop.run_until_complete(queue_put())
+ self.assertEqual(1, res)
+
+ def test_nonblocking_get(self):
+ q = queues.Queue()
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_get_exception(self):
+ q = queues.Queue()
+ self.assertRaises(queue.Empty, q.get_nowait)
+
+ def test_get_timeout(self):
+ q = queues.Queue()
+
+ @tasks.coroutine
+ def queue_get():
+ with self.assertRaises(queue.Empty):
+ return (yield from q.get(timeout=0.1))
+
+ # Get works after timeout, with blocking and non-blocking put.
+ q.put_nowait(1)
+ self.assertEqual(1, (yield from q.get()))
+
+ tasks.Task(q.put(2))
+ self.assertEqual(2, (yield from q.get()))
+
+ self.event_loop.run_until_complete(queue_get())
+
+ def test_get_timeout_cancelled(self):
+ q = queues.Queue()
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from q.get(timeout=0.2))
+
+ @tasks.coroutine
+ def test():
+ get_task = tasks.Task(queue_get())
+ yield from tasks.sleep(0.1) # let the task start
+ q.put_nowait(1)
+ return (yield from get_task)
+
+ self.assertEqual(1, self.event_loop.run_until_complete(test()))
+
+
+class QueuePutTests(_QueueTestBase):
+
+ def test_blocking_put(self):
+ q = queues.Queue()
+
+ @tasks.coroutine
+ def queue_put():
+ # No maxsize, won't block.
+ yield from q.put(1)
+
+ self.event_loop.run_until_complete(queue_put())
+
+ def test_blocking_put_wait(self):
+ q = queues.Queue(maxsize=1)
+ started = locks.EventWaiter()
+ finished = False
+
+ @tasks.coroutine
+ def queue_put():
+ nonlocal finished
+ started.set()
+ yield from q.put(1)
+ yield from q.put(2)
+ finished = True
+
+ @tasks.coroutine
+ def queue_get():
+ self.event_loop.call_later(0.1, q.get_nowait)
+ queue_put_task = tasks.Task(queue_put())
+ yield from started.wait()
+ self.assertFalse(finished)
+ yield from queue_put_task
+ self.assertTrue(finished)
+
+ self.event_loop.run_until_complete(queue_get())
+
+ def test_nonblocking_put(self):
+ q = queues.Queue()
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_put_exception(self):
+ q = queues.Queue(maxsize=1)
+ q.put_nowait(1)
+ self.assertRaises(queue.Full, q.put_nowait, 2)
+
+ def test_put_timeout(self):
+ q = queues.Queue(1)
+ q.put_nowait(0)
+
+ @tasks.coroutine
+ def queue_put():
+ with self.assertRaises(queue.Full):
+ return (yield from q.put(1, timeout=0.1))
+
+ self.assertEqual(0, q.get_nowait())
+
+ # Put works after timeout, with blocking and non-blocking get.
+ get_task = tasks.Task(q.get())
+ # Let the get start waiting.
+ yield from tasks.sleep(0.1)
+ q.put_nowait(2)
+ self.assertEqual(2, (yield from get_task))
+
+ q.put_nowait(3)
+ self.assertEqual(3, q.get_nowait())
+
+ self.event_loop.run_until_complete(queue_put())
+
+ def test_put_timeout_cancelled(self):
+ q = queues.Queue()
+
+ @tasks.coroutine
+ def queue_put():
+ yield from q.put(1, timeout=0.1)
+
+ @tasks.coroutine
+ def test():
+ tasks.Task(queue_put())
+ return (yield from q.get())
+
+ self.assertEqual(1, self.event_loop.run_until_complete(test()))
+
+
+class LifoQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.LifoQueue()
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([2, 3, 1], items)
+
+
+class PriorityQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.PriorityQueue()
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 2, 3], items)
+
+
+class JoinableQueueTests(_QueueTestBase):
+
+ def test_task_done_underflow(self):
+ q = queues.JoinableQueue()
+ self.assertRaises(q.task_done)
+
+ def test_task_done(self):
+ q = queues.JoinableQueue()
+ for i in range(100):
+ q.put_nowait(i)
+
+ accumulator = 0
+
+ # Two workers get items from the queue and call task_done after each.
+ # Join the queue and assert all items have been processed.
+
+ @tasks.coroutine
+ def worker():
+ nonlocal accumulator
+
+ while True:
+ item = yield from q.get()
+ accumulator += item
+ q.task_done()
+
+ @tasks.coroutine
+ def test():
+ for _ in range(2):
+ tasks.Task(worker())
+
+ yield from q.join()
+
+ self.event_loop.run_until_complete(test())
+ self.assertEqual(sum(range(100)), accumulator)
+
+ def test_join_empty_queue(self):
+ q = queues.JoinableQueue()
+
+ # Test that a queue join()s successfully, and before anything else
+ # (done twice for insurance).
+
+ @tasks.coroutine
+ def join():
+ yield from q.join()
+ yield from q.join()
+
+ self.event_loop.run_until_complete(join())
+
+ def test_join_timeout(self):
+ q = queues.JoinableQueue()
+ q.put_nowait(1)
+
+ @tasks.coroutine
+ def join():
+ yield from q.join(0.1)
+
+ # Join completes in ~ 0.1 seconds, although no one calls task_done().
+ self.event_loop.run_until_complete(join())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/sample.crt b/tests/sample.crt
new file mode 100644
index 0000000..6a1e3f3
--- /dev/null
+++ b/tests/sample.crt
@@ -0,0 +1,14 @@
+-----BEGIN CERTIFICATE-----
+MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV
+UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi
+MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3
+MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp
+Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g
+U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn
+t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7
+gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg
+Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB
+BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB
+MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I
+I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB
+-----END CERTIFICATE-----
diff --git a/tests/sample.key b/tests/sample.key
new file mode 100644
index 0000000..edfea8d
--- /dev/null
+++ b/tests/sample.key
@@ -0,0 +1,15 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx
+/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi
+qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB
+AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y
+bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1
+iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb
+DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP
+lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL
+21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF
+ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0
+zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u
+GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq
+V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof
+-----END RSA PRIVATE KEY-----
diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py
new file mode 100644
index 0000000..c63db15
--- /dev/null
+++ b/tests/selector_events_test.py
@@ -0,0 +1,1286 @@
+"""Tests for selector_events.py"""
+
+import errno
+import socket
+import unittest
+import unittest.mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from tulip import futures
+from tulip import selectors
+from tulip.events import AbstractEventLoop
+from tulip.protocols import DatagramProtocol, Protocol
+from tulip.selector_events import BaseSelectorEventLoop
+from tulip.selector_events import _SelectorSslTransport
+from tulip.selector_events import _SelectorSocketTransport
+from tulip.selector_events import _SelectorDatagramTransport
+
+
+class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
+
+ def _make_self_pipe(self):
+ self._ssock = unittest.mock.Mock()
+ self._csock = unittest.mock.Mock()
+ self._internal_fds += 1
+
+
+class BaseSelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = TestBaseSelectorEventLoop(unittest.mock.Mock())
+
+ def test_make_socket_transport(self):
+ m = unittest.mock.Mock()
+ self.event_loop.add_reader = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.event_loop._make_socket_transport(m, m),
+ _SelectorSocketTransport)
+
+ def test_make_ssl_transport(self):
+ m = unittest.mock.Mock()
+ self.event_loop.add_reader = unittest.mock.Mock()
+ self.event_loop.add_writer = unittest.mock.Mock()
+ self.event_loop.remove_reader = unittest.mock.Mock()
+ self.event_loop.remove_writer = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.event_loop._make_ssl_transport(m, m, m, m),
+ _SelectorSslTransport)
+
+ def test_close(self):
+ ssock = self.event_loop._ssock
+ ssock.fileno.return_value = 7
+ csock = self.event_loop._csock
+ csock.fileno.return_value = 1
+ remove_reader = self.event_loop.remove_reader = unittest.mock.Mock()
+
+ self.event_loop._selector.close()
+ self.event_loop._selector = selector = unittest.mock.Mock()
+ self.event_loop.close()
+ self.assertIsNone(self.event_loop._selector)
+ self.assertIsNone(self.event_loop._csock)
+ self.assertIsNone(self.event_loop._ssock)
+ selector.close.assert_called_with()
+ ssock.close.assert_called_with()
+ csock.close.assert_called_with()
+ remove_reader.assert_called_with(7)
+
+ self.event_loop.close()
+ self.event_loop.close()
+
+ def test_close_no_selector(self):
+ ssock = self.event_loop._ssock
+ csock = self.event_loop._csock
+ remove_reader = self.event_loop.remove_reader = unittest.mock.Mock()
+
+ self.event_loop._selector.close()
+ self.event_loop._selector = None
+ self.event_loop.close()
+ self.assertIsNone(self.event_loop._selector)
+ self.assertFalse(ssock.close.called)
+ self.assertFalse(csock.close.called)
+ self.assertFalse(remove_reader.called)
+
+ def test_socketpair(self):
+ self.assertRaises(NotImplementedError, self.event_loop._socketpair)
+
+ def test_read_from_self_tryagain(self):
+ self.event_loop._ssock.recv.side_effect = BlockingIOError
+ self.assertIsNone(self.event_loop._read_from_self())
+
+ def test_read_from_self_exception(self):
+ self.event_loop._ssock.recv.side_effect = OSError
+ self.assertRaises(OSError, self.event_loop._read_from_self)
+
+ def test_write_to_self_tryagain(self):
+ self.event_loop._csock.send.side_effect = BlockingIOError
+ self.assertIsNone(self.event_loop._write_to_self())
+
+ def test_write_to_self_exception(self):
+ self.event_loop._csock.send.side_effect = OSError()
+ self.assertRaises(OSError, self.event_loop._write_to_self)
+
+ def test_sock_recv(self):
+ sock = unittest.mock.Mock()
+ self.event_loop._sock_recv = unittest.mock.Mock()
+
+ f = self.event_loop.sock_recv(sock, 1024)
+ self.assertIsInstance(f, futures.Future)
+ self.event_loop._sock_recv.assert_called_with(
+ f, False, sock, 1024)
+
+ def test__sock_recv_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop._sock_recv(f, False, sock, 1024)
+ self.assertFalse(sock.recv.called)
+
+ def test__sock_recv_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop.remove_reader = unittest.mock.Mock()
+ self.event_loop._sock_recv(f, True, sock, 1024)
+ self.assertEqual((10,), self.event_loop.remove_reader.call_args[0])
+
+ def test__sock_recv_tryagain(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.recv.side_effect = BlockingIOError
+
+ self.event_loop.add_reader = unittest.mock.Mock()
+ self.event_loop._sock_recv(f, False, sock, 1024)
+ self.assertEqual((10, self.event_loop._sock_recv, f, True, sock, 1024),
+ self.event_loop.add_reader.call_args[0])
+
+ def test__sock_recv_exception(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.recv.side_effect = OSError()
+
+ self.event_loop._sock_recv(f, False, sock, 1024)
+ self.assertIs(err, f.exception())
+
+ def test_sock_sendall(self):
+ sock = unittest.mock.Mock()
+ self.event_loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.event_loop.sock_sendall(sock, b'data')
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, b'data'),
+ self.event_loop._sock_sendall.call_args[0])
+
+ def test_sock_sendall_nodata(self):
+ sock = unittest.mock.Mock()
+ self.event_loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.event_loop.sock_sendall(sock, b'')
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+ self.assertFalse(self.event_loop._sock_sendall.called)
+
+ def test__sock_sendall_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(sock.send.called)
+
+ def test__sock_sendall_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop.remove_writer = unittest.mock.Mock()
+ self.event_loop._sock_sendall(f, True, sock, b'data')
+ self.assertEqual((10,), self.event_loop.remove_writer.call_args[0])
+
+ def test__sock_sendall_tryagain(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = BlockingIOError
+
+ self.event_loop.add_writer = unittest.mock.Mock()
+ self.event_loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.event_loop._sock_sendall, f, True, sock, b'data'),
+ self.event_loop.add_writer.call_args[0])
+
+ def test__sock_sendall_exception(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.send.side_effect = OSError()
+
+ self.event_loop._sock_sendall(f, False, sock, b'data')
+ self.assertIs(f.exception(), err)
+
+ def test__sock_sendall(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future()
+ sock.fileno.return_value = 10
+ sock.send.return_value = 4
+
+ self.event_loop._sock_sendall(f, False, sock, b'data')
+ self.assertTrue(f.done())
+
+ def test__sock_sendall_partial(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future()
+ sock.fileno.return_value = 10
+ sock.send.return_value = 2
+
+ self.event_loop.add_writer = unittest.mock.Mock()
+ self.event_loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.event_loop._sock_sendall, f, True, sock, b'ta'),
+ self.event_loop.add_writer.call_args[0])
+
+ def test__sock_sendall_none(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future()
+ sock.fileno.return_value = 10
+ sock.send.return_value = 0
+
+ self.event_loop.add_writer = unittest.mock.Mock()
+ self.event_loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.event_loop._sock_sendall, f, True, sock, b'data'),
+ self.event_loop.add_writer.call_args[0])
+
+ def test_sock_connect(self):
+ sock = unittest.mock.Mock()
+ self.event_loop._sock_connect = unittest.mock.Mock()
+
+ f = self.event_loop.sock_connect(sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, ('127.0.0.1', 8080)),
+ self.event_loop._sock_connect.call_args[0])
+
+ def test__sock_connect(self):
+ f = futures.Future()
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertTrue(f.done())
+ self.assertTrue(sock.connect.called)
+
+ def test__sock_connect_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertFalse(sock.connect.called)
+
+ def test__sock_connect_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop.remove_writer = unittest.mock.Mock()
+ self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual((10,), self.event_loop.remove_writer.call_args[0])
+
+ def test__sock_connect_tryagain(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.EAGAIN
+
+ self.event_loop.add_writer = unittest.mock.Mock()
+ self.event_loop.remove_writer = unittest.mock.Mock()
+
+ self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual(
+ (10, self.event_loop._sock_connect, f,
+ True, sock, ('127.0.0.1', 8080)),
+ self.event_loop.add_writer.call_args[0])
+
+ def test__sock_connect_exception(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.ENOTCONN
+
+ self.event_loop.remove_writer = unittest.mock.Mock()
+ self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f.exception(), socket.error)
+
+ def test_sock_accept(self):
+ sock = unittest.mock.Mock()
+ self.event_loop._sock_accept = unittest.mock.Mock()
+
+ f = self.event_loop.sock_accept(sock)
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock), self.event_loop._sock_accept.call_args[0])
+
+ def test__sock_accept(self):
+ f = futures.Future()
+
+ conn = unittest.mock.Mock()
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.return_value = conn, ('127.0.0.1', 1000)
+
+ self.event_loop._sock_accept(f, False, sock)
+ self.assertTrue(f.done())
+ self.assertEqual((conn, ('127.0.0.1', 1000)), f.result())
+ self.assertEqual((False,), conn.setblocking.call_args[0])
+
+ def test__sock_accept_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop._sock_accept(f, False, sock)
+ self.assertFalse(sock.accept.called)
+
+ def test__sock_accept_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future()
+ f.cancel()
+
+ self.event_loop.remove_reader = unittest.mock.Mock()
+ self.event_loop._sock_accept(f, True, sock)
+ self.assertEqual((10,), self.event_loop.remove_reader.call_args[0])
+
+ def test__sock_accept_tryagain(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = BlockingIOError
+
+ self.event_loop.add_reader = unittest.mock.Mock()
+ self.event_loop._sock_accept(f, False, sock)
+ self.assertEqual(
+ (10, self.event_loop._sock_accept, f, True, sock),
+ self.event_loop.add_reader.call_args[0])
+
+ def test__sock_accept_exception(self):
+ f = futures.Future()
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.accept.side_effect = OSError()
+
+ self.event_loop._sock_accept(f, False, sock)
+ self.assertIs(err, f.exception())
+
+ def test_add_reader(self):
+ self.event_loop._selector.get_info.side_effect = KeyError
+ h = self.event_loop.add_reader(1, lambda: True)
+
+ self.assertTrue(self.event_loop._selector.register.called)
+ self.assertEqual(
+ (1, selectors.EVENT_READ, (h, None)),
+ self.event_loop._selector.register.call_args[0])
+
+ def test_add_reader_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.event_loop._selector.get_info.return_value = (
+ selectors.EVENT_WRITE, (reader, writer))
+ h = self.event_loop.add_reader(1, lambda: True)
+
+ self.assertTrue(reader.cancel.called)
+ self.assertFalse(self.event_loop._selector.register.called)
+ self.assertTrue(self.event_loop._selector.modify.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)),
+ self.event_loop._selector.modify.call_args[0])
+
+ def test_add_reader_existing_writer(self):
+ writer = unittest.mock.Mock()
+ self.event_loop._selector.get_info.return_value = (
+ selectors.EVENT_WRITE, (None, writer))
+ h = self.event_loop.add_reader(1, lambda: True)
+
+ self.assertFalse(self.event_loop._selector.register.called)
+ self.assertTrue(self.event_loop._selector.modify.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)),
+ self.event_loop._selector.modify.call_args[0])
+
+ def test_remove_reader(self):
+ self.event_loop._selector.get_info.return_value = (
+ selectors.EVENT_READ, (None, None))
+ self.assertFalse(self.event_loop.remove_reader(1))
+
+ self.assertTrue(self.event_loop._selector.unregister.called)
+
+ def test_remove_reader_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.event_loop._selector.get_info.return_value = (
+ selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer))
+ self.assertTrue(
+ self.event_loop.remove_reader(1))
+
+ self.assertFalse(self.event_loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE, (None, writer)),
+ self.event_loop._selector.modify.call_args[0])
+
+ def test_remove_reader_unknown(self):
+ self.event_loop._selector.get_info.side_effect = KeyError
+ self.assertFalse(
+ self.event_loop.remove_reader(1))
+
+ def test_add_writer(self):
+ self.event_loop._selector.get_info.side_effect = KeyError
+ h = self.event_loop.add_writer(1, lambda: True)
+
+ self.assertTrue(self.event_loop._selector.register.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE, (None, h)),
+ self.event_loop._selector.register.call_args[0])
+
+ def test_add_writer_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.event_loop._selector.get_info.return_value = (
+ selectors.EVENT_READ, (reader, writer))
+ h = self.event_loop.add_writer(1, lambda: True)
+
+ self.assertTrue(writer.cancel.called)
+ self.assertFalse(self.event_loop._selector.register.called)
+ self.assertTrue(self.event_loop._selector.modify.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (reader, h)),
+ self.event_loop._selector.modify.call_args[0])
+
+ def test_remove_writer(self):
+ self.event_loop._selector.get_info.return_value = (
+ selectors.EVENT_WRITE, (None, None))
+ self.assertFalse(self.event_loop.remove_writer(1))
+
+ self.assertTrue(self.event_loop._selector.unregister.called)
+
+ def test_remove_writer_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.event_loop._selector.get_info.return_value = (
+ selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer))
+ self.assertTrue(
+ self.event_loop.remove_writer(1))
+
+ self.assertFalse(self.event_loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_READ, (reader, None)),
+ self.event_loop._selector.modify.call_args[0])
+
+ def test_remove_writer_unknown(self):
+ self.event_loop._selector.get_info.side_effect = KeyError
+ self.assertFalse(
+ self.event_loop.remove_writer(1))
+
+ def test_process_events_read(self):
+ reader = unittest.mock.Mock()
+ reader.cancelled = False
+
+ self.event_loop._add_callback = unittest.mock.Mock()
+ self.event_loop._process_events(
+ ((1, selectors.EVENT_READ, (reader, None)),))
+ self.assertTrue(self.event_loop._add_callback.called)
+ self.event_loop._add_callback.assert_called_with(reader)
+
+ def test_process_events_read_cancelled(self):
+ reader = unittest.mock.Mock()
+ reader.cancelled = True
+
+ self.event_loop.remove_reader = unittest.mock.Mock()
+ self.event_loop._process_events(
+ ((1, selectors.EVENT_READ, (reader, None)),))
+ self.event_loop.remove_reader.assert_called_with(1)
+
+ def test_process_events_write(self):
+ writer = unittest.mock.Mock()
+ writer.cancelled = False
+
+ self.event_loop._add_callback = unittest.mock.Mock()
+ self.event_loop._process_events(
+ ((1, selectors.EVENT_WRITE, (None, writer)),))
+ self.event_loop._add_callback.assert_called_with(writer)
+
+ def test_process_events_write_cancelled(self):
+ writer = unittest.mock.Mock()
+ writer.cancelled = True
+ self.event_loop.remove_writer = unittest.mock.Mock()
+
+ self.event_loop._process_events(
+ ((1, selectors.EVENT_WRITE, (None, writer)),))
+ self.event_loop.remove_writer.assert_called_with(1)
+
+
+class SelectorSocketTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+ self.protocol = unittest.mock.Mock(Protocol)
+
+ def test_ctor(self):
+ tr = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ self.event_loop.add_reader.assert_called_with(7, tr._read_ready)
+ self.event_loop.call_soon.assert_called_with(
+ self.protocol.connection_made, tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future()
+
+ _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol, fut)
+ self.assertEqual(2, self.event_loop.call_soon.call_count)
+ self.assertEqual(fut.set_result,
+ self.event_loop.call_soon.call_args[0][0])
+
+ def test_read_ready(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+
+ self.sock.recv.return_value = b'data'
+ transport._read_ready()
+
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_read_ready_eof(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+
+ self.sock.recv.return_value = b''
+ transport._read_ready()
+
+ self.assertTrue(self.event_loop.remove_reader.called)
+ self.protocol.eof_received.assert_called_with()
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_tryagain(self, m_exc):
+ self.sock.recv.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_err(self, m_exc):
+ err = self.sock.recv.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_abort(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._close = unittest.mock.Mock()
+
+ transport.abort()
+ transport._close.assert_called_with(None)
+
+ def test_write(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.write(data)
+ self.sock.send.assert_called_with(data)
+
+ def test_write_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append(b'data')
+ transport.write(b'')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual([b'data'], transport._buffer)
+
+ def test_write_buffer(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append(b'data1')
+ transport.write(b'data2')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual([b'data1', b'data2'], transport._buffer)
+
+ def test_write_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.assertTrue(self.event_loop.add_writer.called)
+ self.assertEqual(
+ transport._write_ready, self.event_loop.add_writer.call_args[0][1])
+
+ self.assertEqual([b'ta'], transport._buffer)
+
+ def test_write_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+ self.sock.fileno.return_value = 7
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.event_loop.add_writer.assert_called_with(
+ 7, transport._write_ready)
+ self.assertEqual([b'data'], transport._buffer)
+
+ def test_write_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.assertTrue(self.event_loop.add_writer.called)
+ self.assertEqual(
+ transport._write_ready, self.event_loop.add_writer.call_args[0][1])
+
+ self.assertEqual([b'data'], transport._buffer)
+
+ def test_write_exception(self):
+ err = self.sock.send.side_effect = OSError()
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.write(data)
+ transport._fatal_error.assert_called_with(err)
+
+ def test_write_str(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.close()
+ self.assertRaises(AssertionError, transport.write, b'data')
+
+ def test_write_ready(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ transport._write_ready()
+ self.assertTrue(self.sock.send.called)
+ self.assertEqual(self.sock.send.call_args[0], (data,))
+ self.assertTrue(self.event_loop.remove_writer.called)
+
+ def test_write_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append(data)
+ transport._write_ready()
+ self.sock.send.assert_called_with(data)
+ self.event_loop.remove_writer.assert_called_with(7)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport._write_ready)
+
+ def test_write_ready_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ transport._write_ready()
+ self.assertFalse(self.event_loop.remove_writer.called)
+ self.assertEqual([b'ta'], transport._buffer)
+
+ def test_write_ready_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ transport._write_ready()
+ self.assertFalse(self.event_loop.remove_writer.called)
+ self.assertEqual([b'data'], transport._buffer)
+
+ def test_write_ready_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer = [b'data1', b'data2']
+ transport._write_ready()
+
+ self.assertFalse(self.event_loop.remove_writer.called)
+ self.assertEqual([b'data1data2'], transport._buffer)
+
+ def test_write_ready_exception(self):
+ err = self.sock.send.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append(b'data')
+ transport._write_ready()
+ transport._fatal_error.assert_called_with(err)
+
+ def test_close(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.close()
+
+ self.assertTrue(transport._closing)
+ self.event_loop.remove_reader.assert_called_with(7)
+ self.protocol.connection_lost(None)
+
+ def test_close_write_buffer(self):
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ self.event_loop.reset_mock()
+ transport._buffer.append(b'data')
+ transport.close()
+
+ self.assertTrue(self.event_loop.remove_reader.called)
+ self.assertFalse(self.event_loop.call_soon.called)
+
+ @unittest.mock.patch('tulip.log.tulip_log.exception')
+ def test_fatal_error(self, m_exc):
+ exc = OSError()
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append(b'data')
+ transport._fatal_error(exc)
+
+ self.assertEqual([], transport._buffer)
+ self.event_loop.remove_reader.assert_called_with(7)
+ self.event_loop.remove_writer.assert_called_with(7)
+ self.protocol.connection_lost.assert_called_with(exc)
+ m_exc.assert_called_with('Fatal error for %s', transport)
+
+ def test_connection_lost(self):
+ exc = object()
+ transport = _SelectorSocketTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._call_connection_lost(exc)
+
+ self.protocol.connection_lost.assert_called_with(exc)
+ self.sock.close.assert_called_with()
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class SelectorSslTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+ self.protocol = unittest.mock.Mock(spec_set=Protocol)
+ self.sslsock = unittest.mock.Mock()
+ self.sslsock.fileno.return_value = 1
+ self.sslcontext = unittest.mock.Mock()
+ self.sslcontext.wrap_socket.return_value = self.sslsock
+ self.waiter = futures.Future()
+
+ self.transport = _SelectorSslTransport(
+ self.event_loop, self.sock,
+ self.protocol, self.sslcontext, self.waiter)
+ self.event_loop.reset_mock()
+ self.sock.reset_mock()
+ self.protocol.reset_mock()
+ self.sslcontext.reset_mock()
+
+ def test_on_handshake(self):
+ self.transport._on_handshake()
+ self.assertTrue(self.sslsock.do_handshake.called)
+ self.assertTrue(self.event_loop.remove_reader.called)
+ self.assertTrue(self.event_loop.remove_writer.called)
+ self.assertEqual(
+ (1, self.transport._on_ready,),
+ self.event_loop.add_reader.call_args[0])
+ self.assertEqual(
+ (1, self.transport._on_ready,),
+ self.event_loop.add_writer.call_args[0])
+
+ def test_on_handshake_reader_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
+ self.transport._on_handshake()
+ self.assertEqual(
+ (1, self.transport._on_handshake,),
+ self.event_loop.add_reader.call_args[0])
+
+ def test_on_handshake_writer_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError
+ self.transport._on_handshake()
+ self.assertEqual(
+ (1, self.transport._on_handshake,),
+ self.event_loop.add_writer.call_args[0])
+
+ def test_on_handshake_exc(self):
+ self.sslsock.do_handshake.side_effect = ValueError
+ self.transport._on_handshake()
+ self.assertTrue(self.sslsock.close.called)
+
+ def test_on_handshake_base_exc(self):
+ self.sslsock.do_handshake.side_effect = BaseException
+ self.assertRaises(BaseException, self.transport._on_handshake)
+ self.assertTrue(self.sslsock.close.called)
+
+ def test_write_no_data(self):
+ self.transport._buffer.append(b'data')
+ self.transport.write(b'')
+ self.assertEqual([b'data'], self.transport._buffer)
+
+ def test_write_str(self):
+ self.assertRaises(AssertionError, self.transport.write, 'str')
+
+ def test_write_closing(self):
+ self.transport.close()
+ self.assertRaises(AssertionError, self.transport.write, b'data')
+
+ def test_abort(self):
+ self.transport._close = unittest.mock.Mock()
+ self.transport.abort()
+ self.transport._close.assert_called_with(None)
+
+ @unittest.mock.patch('tulip.log.tulip_log.exception')
+ def test_fatal_error(self, m_exc):
+ exc = OSError()
+ self.transport._buffer.append(b'data')
+ self.transport._fatal_error(exc)
+
+ self.assertEqual([], self.transport._buffer)
+ self.assertTrue(self.event_loop.remove_writer.called)
+ self.assertTrue(self.event_loop.remove_reader.called)
+ self.protocol.connection_lost.assert_called_with(exc)
+ m_exc.assert_called_with('Fatal error for %s', self.transport)
+
+ def test_close(self):
+ self.transport.close()
+ self.assertTrue(self.transport._closing)
+ self.assertTrue(self.event_loop.remove_reader.called)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_close_write_buffer(self):
+ self.transport._buffer.append(b'data')
+ self.transport.close()
+
+ self.assertTrue(self.event_loop.remove_reader.called)
+ self.assertFalse(self.event_loop.call_soon.called)
+
+ def test_on_ready_closed(self):
+ self.sslsock.fileno.return_value = -1
+ self.transport._on_ready()
+ self.assertFalse(self.sslsock.recv.called)
+
+ def test_on_ready_recv(self):
+ self.sslsock.recv.return_value = b'data'
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertEqual((b'data',), self.protocol.data_received.call_args[0])
+
+ def test_on_ready_recv_eof(self):
+ self.sslsock.recv.return_value = b''
+ self.transport._on_ready()
+ self.assertTrue(self.event_loop.remove_reader.called)
+ self.assertTrue(self.event_loop.remove_writer.called)
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(self.protocol.connection_lost.called)
+
+ def test_on_ready_recv_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = ssl.SSLWantWriteError
+ self.transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = BlockingIOError
+ self.transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ def test_on_ready_recv_exc(self):
+ err = self.sslsock.recv.side_effect = OSError()
+ self.transport._fatal_error = unittest.mock.Mock()
+ self.transport._on_ready()
+ self.transport._fatal_error.assert_called_with(err)
+
+ def test_on_ready_send(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ self.transport._buffer = [b'data']
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual([], self.transport._buffer)
+
+ def test_on_ready_send_none(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 0
+ self.transport._buffer = [b'data1', b'data2']
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual([b'data1data2'], self.transport._buffer)
+
+ def test_on_ready_send_partial(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 2
+ self.transport._buffer = [b'data1', b'data2']
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual([b'ta1data2'], self.transport._buffer)
+
+ def test_on_ready_send_closing_partial(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 2
+ self.transport._buffer = [b'data1', b'data2']
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertFalse(self.sslsock.close.called)
+
+ def test_on_ready_send_closing(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ self.transport.close()
+ self.transport._buffer = [b'data']
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(self.event_loop.remove_writer.called)
+ self.assertTrue(self.protocol.connection_lost.called)
+
+ def test_on_ready_send_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+
+ self.transport._buffer = [b'data']
+
+ self.sslsock.send.side_effect = ssl.SSLWantReadError
+ self.transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual([b'data'], self.transport._buffer)
+
+ self.sslsock.send.side_effect = ssl.SSLWantWriteError
+ self.transport._on_ready()
+ self.assertEqual([b'data'], self.transport._buffer)
+
+ self.sslsock.send.side_effect = BlockingIOError()
+ self.transport._on_ready()
+ self.assertEqual([b'data'], self.transport._buffer)
+
+ def test_on_ready_send_exc(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ err = self.sslsock.send.side_effect = OSError()
+
+ self.transport._buffer = [b'data']
+ self.transport._fatal_error = unittest.mock.Mock()
+ self.transport._on_ready()
+ self.transport._fatal_error.assert_called_with(err)
+ self.assertEqual([], self.transport._buffer)
+
+
+class SelectorDatagramTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop)
+ self.sock = unittest.mock.Mock(spec_set=socket.socket)
+ self.sock.fileno.return_value = 7
+ self.protocol = unittest.mock.Mock(spec_set=DatagramProtocol)
+
+ def test_read_ready(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234))
+ transport._read_ready()
+
+ self.protocol.datagram_received.assert_called_with(
+ b'data', ('0.0.0.0', 1234))
+
+ def test_read_ready_tryagain(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.side_effect = BlockingIOError
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_read_ready_err(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+
+ err = self.sock.recvfrom.side_effect = OSError()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_abort(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._close = unittest.mock.Mock()
+
+ transport.abort()
+ transport._close.assert_called_with(None)
+
+ def test_sendto(self):
+ data = b'data'
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+ def test_sendto_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+ transport.sendto(b'', ())
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ def test_sendto_buffer(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport.sendto(b'data2', ('0.0.0.0', 12345))
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+
+ def test_sendto_tryagain(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 12345))
+
+ self.assertTrue(self.event_loop.add_writer.called)
+ self.assertEqual(
+ transport._sendto_ready,
+ self.event_loop.add_writer.call_args[0][1])
+
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ def test_sendto_exception(self):
+ data = b'data'
+ err = self.sock.sendto.side_effect = OSError()
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertTrue(transport._fatal_error.called)
+ transport._fatal_error.assert_called_with(err)
+
+ def test_sendto_connection_refused(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_connection_refused_connected(self):
+ data = b'data'
+
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data)
+
+ self.assertTrue(transport._fatal_error.called)
+
+ def test_sendto_str(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.sendto, 'str', ())
+
+ def test_sendto_connected_addr(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ self.assertRaises(
+ AssertionError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+ def test_sendto_closing(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.close()
+ self.assertRaises(AssertionError, transport.sendto, b'data', ())
+
+ def test_sendto_ready(self):
+ data = b'data'
+ self.sock.sendto.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append((data, ('0.0.0.0', 12345)))
+ transport._sendto_ready()
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345)))
+ self.assertTrue(self.event_loop.remove_writer.called)
+
+ def test_sendto_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append((data, ()))
+ transport._sendto_ready()
+ self.sock.sendto.assert_called_with(data, ())
+ self.event_loop.remove_writer.assert_called_with(7)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_sendto_ready_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._sendto_ready()
+ self.assertFalse(self.sock.sendto.called)
+ self.assertTrue(self.event_loop.remove_writer.called)
+
+ def test_sendto_ready_tryagain(self):
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.extend([(b'data1', ()), (b'data2', ())])
+ transport._sendto_ready()
+
+ self.assertFalse(self.event_loop.remove_writer.called)
+ self.assertEqual(
+ [(b'data1', ()), (b'data2', ())],
+ list(transport._buffer))
+
+ def test_sendto_ready_exception(self):
+ err = self.sock.sendto.side_effect = OSError()
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_sendto_ready_connection_refused(self):
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_ready_connection_refused_connection(self):
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertTrue(transport._fatal_error.called)
+
+ def test_close(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport.close()
+
+ self.assertTrue(transport._closing)
+ self.event_loop.remove_reader.assert_called_with(7)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_close_write_buffer(self):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._buffer.append((b'data', ()))
+ transport.close()
+
+ self.event_loop.remove_reader.assert_called_with(7)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ @unittest.mock.patch('tulip.log.tulip_log.exception')
+ def test_fatal_error(self, m_exc):
+ exc = OSError()
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ self.event_loop.reset_mock()
+ transport._buffer.append((b'data', ()))
+ transport._fatal_error(exc)
+
+ self.assertEqual([], list(transport._buffer))
+ self.event_loop.remove_writer.assert_called_with(7)
+ self.event_loop.remove_reader.assert_called_with(7)
+ self.protocol.connection_lost.assert_called_with(exc)
+ m_exc.assert_called_with('Fatal error for %s', transport)
+
+ @unittest.mock.patch('tulip.log.tulip_log.exception')
+ def test_fatal_error_connected(self, m_exc):
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ err = ConnectionRefusedError()
+ transport._fatal_error(err)
+ self.protocol.connection_refused.assert_called_with(err)
+ m_exc.assert_called_with('Fatal error for %s', transport)
+
+ def test_transport_closing(self):
+ exc = object()
+ transport = _SelectorDatagramTransport(
+ self.event_loop, self.sock, self.protocol)
+ transport._call_connection_lost(exc)
+
+ self.protocol.connection_lost.assert_called_with(exc)
+ self.sock.close.assert_called_with()
diff --git a/tests/selectors_test.py b/tests/selectors_test.py
new file mode 100644
index 0000000..996c013
--- /dev/null
+++ b/tests/selectors_test.py
@@ -0,0 +1,137 @@
+"""Tests for selectors.py."""
+
+import unittest
+import unittest.mock
+
+from tulip import selectors
+
+
+class BaseSelectorTests(unittest.TestCase):
+
+ def test_fileobj_to_fd(self):
+ self.assertEqual(10, selectors._fileobj_to_fd(10))
+
+ f = unittest.mock.Mock()
+ f.fileno.return_value = 10
+ self.assertEqual(10, selectors._fileobj_to_fd(f))
+
+ f.fileno.side_effect = TypeError
+ self.assertRaises(ValueError, selectors._fileobj_to_fd, f)
+
+ def test_selector_key_repr(self):
+ key = selectors.SelectorKey(10, selectors.EVENT_READ)
+ self.assertEqual(
+ "SelectorKey<fileobj=10, fd=10, events=0x1, data=None>", repr(key))
+
+ def test_register(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = selectors._BaseSelector()
+ key = s.register(fobj, selectors.EVENT_READ)
+ self.assertIsInstance(key, selectors.SelectorKey)
+ self.assertEqual(key.fd, 10)
+ self.assertIs(key, s._fd_to_key[10])
+
+ def test_register_unknown_event(self):
+ s = selectors._BaseSelector()
+ self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999)
+
+ def test_register_already_registered(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = selectors._BaseSelector()
+ s.register(fobj, selectors.EVENT_READ)
+ self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ)
+
+ def test_unregister(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = selectors._BaseSelector()
+ s.register(fobj, selectors.EVENT_READ)
+ s.unregister(fobj)
+ self.assertFalse(s._fd_to_key)
+ self.assertFalse(s._fileobj_to_key)
+
+ def test_unregister_unknown(self):
+ s = selectors._BaseSelector()
+ self.assertRaises(ValueError, s.unregister, unittest.mock.Mock())
+
+ def test_modify_unknown(self):
+ s = selectors._BaseSelector()
+ self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1)
+
+ def test_modify(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = selectors._BaseSelector()
+ key = s.register(fobj, selectors.EVENT_READ)
+ key2 = s.modify(fobj, selectors.EVENT_WRITE)
+ self.assertNotEqual(key.events, key2.events)
+ self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj))
+
+ def test_modify_data(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ d1 = object()
+ d2 = object()
+
+ s = selectors._BaseSelector()
+ key = s.register(fobj, selectors.EVENT_READ, d1)
+ key2 = s.modify(fobj, selectors.EVENT_READ, d2)
+ self.assertEqual(key.events, key2.events)
+ self.assertNotEqual(key.data, key2.data)
+ self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj))
+
+ def test_modify_same(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ data = object()
+
+ s = selectors._BaseSelector()
+ key = s.register(fobj, selectors.EVENT_READ, data)
+ key2 = s.modify(fobj, selectors.EVENT_READ, data)
+ self.assertIs(key, key2)
+
+ def test_select(self):
+ s = selectors._BaseSelector()
+ self.assertRaises(NotImplementedError, s.select)
+
+ def test_close(self):
+ s = selectors._BaseSelector()
+ s.register(1, selectors.EVENT_READ)
+
+ s.close()
+ self.assertFalse(s._fd_to_key)
+ self.assertFalse(s._fileobj_to_key)
+
+ def test_registered_count(self):
+ s = selectors._BaseSelector()
+ self.assertEqual(0, s.registered_count())
+
+ s.register(1, selectors.EVENT_READ)
+ self.assertEqual(1, s.registered_count())
+
+ s.unregister(1)
+ self.assertEqual(0, s.registered_count())
+
+ def test_context_manager(self):
+ s = selectors._BaseSelector()
+
+ with s as sel:
+ sel.register(1, selectors.EVENT_READ)
+
+ self.assertFalse(s._fd_to_key)
+ self.assertFalse(s._fileobj_to_key)
+
+ def test_key_from_fd(self):
+ s = selectors._BaseSelector()
+ key = s.register(1, selectors.EVENT_READ)
+
+ self.assertIs(key, s._key_from_fd(1))
+ self.assertIsNone(s._key_from_fd(10))
diff --git a/tests/streams_test.py b/tests/streams_test.py
new file mode 100644
index 0000000..dc6eeaf
--- /dev/null
+++ b/tests/streams_test.py
@@ -0,0 +1,299 @@
+"""Tests for streams.py."""
+
+import unittest
+
+from tulip import events
+from tulip import streams
+from tulip import tasks
+from tulip import test_utils
+
+
+class StreamReaderTests(test_utils.LogTrackingTestCase):
+
+ DATA = b'line1\nline2\nline3\n'
+
+ def setUp(self):
+ super().setUp()
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ super().tearDown()
+ self.event_loop.close()
+
+ def test_feed_empty_data(self):
+ stream = streams.StreamReader()
+
+ stream.feed_data(b'')
+ self.assertEqual(0, stream.byte_count)
+
+ def test_feed_data_byte_count(self):
+ stream = streams.StreamReader()
+
+ stream.feed_data(self.DATA)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_read_zero(self):
+ # Read zero bytes.
+ stream = streams.StreamReader()
+ stream.feed_data(self.DATA)
+
+ data = self.event_loop.run_until_complete(stream.read(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_read(self):
+ # Read bytes.
+ stream = streams.StreamReader()
+ read_task = tasks.Task(stream.read(30))
+
+ def cb():
+ stream.feed_data(self.DATA)
+ self.event_loop.call_soon(cb)
+
+ data = self.event_loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_line_breaks(self):
+ # Read bytes without line breaks.
+ stream = streams.StreamReader()
+ stream.feed_data(b'line1')
+ stream.feed_data(b'line2')
+
+ data = self.event_loop.run_until_complete(stream.read(5))
+
+ self.assertEqual(b'line1', data)
+ self.assertEqual(5, stream.byte_count)
+
+ def test_read_eof(self):
+ # Read bytes, stop at eof.
+ stream = streams.StreamReader()
+ read_task = tasks.Task(stream.read(1024))
+
+ def cb():
+ stream.feed_eof()
+ self.event_loop.call_soon(cb)
+
+ data = self.event_loop.run_until_complete(read_task)
+ self.assertEqual(b'', data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_until_eof(self):
+ # Read all bytes until eof.
+ stream = streams.StreamReader()
+ read_task = tasks.Task(stream.read(-1))
+
+ def cb():
+ stream.feed_data(b'chunk1\n')
+ stream.feed_data(b'chunk2')
+ stream.feed_eof()
+ self.event_loop.call_soon(cb)
+
+ data = self.event_loop.run_until_complete(read_task)
+
+ self.assertEqual(b'chunk1\nchunk2', data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_exception(self):
+ stream = streams.StreamReader()
+ stream.feed_data(b'line\n')
+
+ data = self.event_loop.run_until_complete(stream.read(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError,
+ self.event_loop.run_until_complete, stream.read(2))
+
+ def test_readline(self):
+ # Read one line.
+ stream = streams.StreamReader()
+ stream.feed_data(b'chunk1 ')
+ read_task = tasks.Task(stream.readline())
+
+ def cb():
+ stream.feed_data(b'chunk2 ')
+ stream.feed_data(b'chunk3 ')
+ stream.feed_data(b'\n chunk4')
+ self.event_loop.call_soon(cb)
+
+ line = self.event_loop.run_until_complete(read_task)
+ self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
+ self.assertEqual(len(b'\n chunk4')-1, stream.byte_count)
+
+ def test_readline_limit_with_existing_data(self):
+ self.suppress_log_errors()
+
+ stream = streams.StreamReader(3)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1\nline2\n')
+
+ self.assertRaises(
+ ValueError, self.event_loop.run_until_complete, stream.readline())
+ self.assertEqual([b'line2\n'], list(stream.buffer))
+
+ stream = streams.StreamReader(3)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1')
+ stream.feed_data(b'li')
+
+ self.assertRaises(
+ ValueError, self.event_loop.run_until_complete, stream.readline())
+ self.assertEqual([b'li'], list(stream.buffer))
+ self.assertEqual(2, stream.byte_count)
+
+ def test_readline_limit(self):
+ self.suppress_log_errors()
+
+ stream = streams.StreamReader(7)
+
+ def cb():
+ stream.feed_data(b'chunk1')
+ stream.feed_data(b'chunk2')
+ stream.feed_data(b'chunk3\n')
+ stream.feed_eof()
+ self.event_loop.call_soon(cb)
+
+ self.assertRaises(
+ ValueError, self.event_loop.run_until_complete, stream.readline())
+ self.assertEqual([b'chunk3\n'], list(stream.buffer))
+ self.assertEqual(7, stream.byte_count)
+
+ def test_readline_line_byte_count(self):
+ stream = streams.StreamReader()
+ stream.feed_data(self.DATA[:6])
+ stream.feed_data(self.DATA[6:])
+
+ line = self.event_loop.run_until_complete(stream.readline())
+
+ self.assertEqual(b'line1\n', line)
+ self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count)
+
+ def test_readline_eof(self):
+ stream = streams.StreamReader()
+ stream.feed_data(b'some data')
+ stream.feed_eof()
+
+ line = self.event_loop.run_until_complete(stream.readline())
+ self.assertEqual(b'some data', line)
+
+ def test_readline_empty_eof(self):
+ stream = streams.StreamReader()
+ stream.feed_eof()
+
+ line = self.event_loop.run_until_complete(stream.readline())
+ self.assertEqual(b'', line)
+
+ def test_readline_read_byte_count(self):
+ stream = streams.StreamReader()
+ stream.feed_data(self.DATA)
+
+ self.event_loop.run_until_complete(stream.readline())
+
+ data = self.event_loop.run_until_complete(stream.read(7))
+
+ self.assertEqual(b'line2\nl', data)
+ self.assertEqual(
+ len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
+ stream.byte_count)
+
+ def test_readline_exception(self):
+ stream = streams.StreamReader()
+ stream.feed_data(b'line\n')
+
+ data = self.event_loop.run_until_complete(stream.readline())
+ self.assertEqual(b'line\n', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError,
+ self.event_loop.run_until_complete, stream.readline())
+
+ def test_readexactly_zero_or_less(self):
+ # Read exact number of bytes (zero or less).
+ stream = streams.StreamReader()
+ stream.feed_data(self.DATA)
+
+ data = self.event_loop.run_until_complete(stream.readexactly(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ data = self.event_loop.run_until_complete(stream.readexactly(-1))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_readexactly(self):
+ # Read exact number of bytes.
+ stream = streams.StreamReader()
+
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n))
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ self.event_loop.call_soon(cb)
+
+ data = self.event_loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA + self.DATA, data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_readexactly_eof(self):
+ # Read exact number of bytes (eof).
+ stream = streams.StreamReader()
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n))
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_eof()
+ self.event_loop.call_soon(cb)
+
+ data = self.event_loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream.byte_count)
+
+ def test_readexactly_exception(self):
+ stream = streams.StreamReader()
+ stream.feed_data(b'line\n')
+
+ data = self.event_loop.run_until_complete(stream.readexactly(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError,
+ self.event_loop.run_until_complete,
+ stream.readexactly(2))
+
+ def test_exception(self):
+ stream = streams.StreamReader()
+ self.assertIsNone(stream.exception())
+
+ exc = ValueError()
+ stream.set_exception(exc)
+ self.assertIs(stream.exception(), exc)
+
+ def test_exception_waiter(self):
+ stream = streams.StreamReader()
+
+ def set_err():
+ yield from []
+ stream.set_exception(ValueError())
+
+ def readline():
+ yield from stream.readline()
+
+ t1 = tasks.Task(stream.readline())
+ t2 = tasks.Task(set_err())
+
+ self.event_loop.run_until_complete(tasks.wait([t1, t2]))
+
+ self.assertRaises(ValueError, t1.result)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py
new file mode 100644
index 0000000..09aaed5
--- /dev/null
+++ b/tests/subprocess_test.py
@@ -0,0 +1,54 @@
+"""Tests for subprocess_transport.py."""
+
+import logging
+import unittest
+
+from tulip import events
+from tulip import protocols
+from tulip import subprocess_transport
+
+
+class MyProto(protocols.Protocol):
+
+ def __init__(self):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write_eof()
+
+ def data_received(self, data):
+ logging.info('received: %r', data)
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+ self.transport.close()
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+
+
+class FutureTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+
+ def test_unix_subprocess(self):
+ p = MyProto()
+ subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR'])
+ self.event_loop.run()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/tasks_test.py b/tests/tasks_test.py
new file mode 100644
index 0000000..9ac15bb
--- /dev/null
+++ b/tests/tasks_test.py
@@ -0,0 +1,647 @@
+"""Tests for tasks.py."""
+
+import concurrent.futures
+import time
+import unittest
+import unittest.mock
+
+from tulip import events
+from tulip import futures
+from tulip import tasks
+from tulip import test_utils
+
+
+class Dummy:
+
+ def __repr__(self):
+ return 'Dummy()'
+
+ def __call__(self, *args):
+ pass
+
+
+class TaskTests(test_utils.LogTrackingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.event_loop = events.new_event_loop()
+ events.set_event_loop(self.event_loop)
+
+ def tearDown(self):
+ self.event_loop.close()
+ super().tearDown()
+
+ def test_task_class(self):
+ @tasks.coroutine
+ def notmuch():
+ yield from []
+ return 'ok'
+ t = tasks.Task(notmuch())
+ self.event_loop.run()
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._event_loop, self.event_loop)
+
+ event_loop = events.new_event_loop()
+ t = tasks.Task(notmuch(), event_loop=event_loop)
+ self.assertIs(t._event_loop, event_loop)
+
+ def test_task_decorator(self):
+ @tasks.task
+ def notmuch():
+ yield from []
+ return 'ko'
+ t = notmuch()
+ self.event_loop.run()
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ko')
+
+ def test_task_repr(self):
+ @tasks.task
+ def notmuch():
+ yield from []
+ return 'abc'
+ t = notmuch()
+ t.add_done_callback(Dummy())
+ self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>')
+ t.cancel() # Does not take immediate effect!
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>')
+ self.assertRaises(futures.CancelledError,
+ self.event_loop.run_until_complete, t)
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>')
+ t = notmuch()
+ self.event_loop.run_until_complete(t)
+ self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>")
+
+ def test_task_repr_custom(self):
+ def coro():
+ yield from []
+
+ class T(futures.Future):
+ def __repr__(self):
+ return 'T[]'
+
+ class MyTask(tasks.Task, T):
+ def __repr__(self):
+ return super().__repr__()
+
+ t = MyTask(coro())
+ self.assertEqual(repr(t), 'T[](<coro>)')
+
+ def test_task_basics(self):
+ @tasks.task
+ def outer():
+ a = yield from inner1()
+ b = yield from inner2()
+ return a+b
+
+ @tasks.task
+ def inner1():
+ yield from []
+ return 42
+
+ @tasks.task
+ def inner2():
+ yield from []
+ return 1000
+
+ t = outer()
+ self.assertEqual(self.event_loop.run_until_complete(t), 1042)
+
+ def test_cancel(self):
+ @tasks.task
+ def task():
+ yield from tasks.sleep(10.0)
+ return 12
+
+ t = task()
+ self.event_loop.call_soon(t.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.event_loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertFalse(t.cancel())
+
+ def test_future_timeout(self):
+ @tasks.coroutine
+ def coro():
+ yield from tasks.sleep(10.0)
+ return 12
+
+ t = tasks.Task(coro(), timeout=0.1)
+
+ self.assertRaises(
+ futures.CancelledError,
+ self.event_loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertFalse(t.cancel())
+
+ def test_future_timeout_catch(self):
+ @tasks.coroutine
+ def coro():
+ yield from tasks.sleep(10.0)
+ return 12
+
+ err = None
+
+ @tasks.coroutine
+ def coro2():
+ nonlocal err
+ try:
+ yield from tasks.Task(coro(), timeout=0.1)
+ except futures.CancelledError as exc:
+ err = exc
+
+ self.event_loop.run_until_complete(tasks.Task(coro2()))
+ self.assertIsInstance(err, futures.CancelledError)
+
+ def test_cancel_in_coro(self):
+ @tasks.coroutine
+ def task():
+ t.cancel()
+ yield from []
+ return 12
+
+ t = tasks.Task(task())
+ self.assertRaises(
+ futures.CancelledError,
+ self.event_loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertFalse(t.cancel())
+
+ def test_stop_while_run_in_complete(self):
+ x = 0
+
+ @tasks.coroutine
+ def task():
+ nonlocal x
+ while x < 10:
+ yield from tasks.sleep(0.1)
+ x += 1
+ if x == 2:
+ self.event_loop.stop()
+
+ t = tasks.Task(task())
+ t0 = time.monotonic()
+ self.assertRaises(
+ futures.InvalidStateError,
+ self.event_loop.run_until_complete, t)
+ t1 = time.monotonic()
+ self.assertFalse(t.done())
+ self.assertTrue(0.18 <= t1-t0 <= 0.22)
+ self.assertEqual(x, 2)
+
+ def test_timeout(self):
+ @tasks.task
+ def task():
+ yield from tasks.sleep(10.0)
+ return 42
+
+ t = task()
+ t0 = time.monotonic()
+ self.assertRaises(
+ futures.TimeoutError,
+ self.event_loop.run_until_complete, t, 0.1)
+ t1 = time.monotonic()
+ self.assertFalse(t.done())
+ self.assertTrue(0.08 <= t1-t0 <= 0.12)
+
+ def test_timeout_not(self):
+ @tasks.task
+ def task():
+ yield from tasks.sleep(0.1)
+ return 42
+
+ t = task()
+ t0 = time.monotonic()
+ r = self.event_loop.run_until_complete(t, 10.0)
+ t1 = time.monotonic()
+ self.assertTrue(t.done())
+ self.assertEqual(r, 42)
+ self.assertTrue(0.08 <= t1-t0 <= 0.12)
+
+ def test_wait(self):
+ a = tasks.sleep(0.1)
+ b = tasks.sleep(0.15)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a])
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ t0 = time.monotonic()
+ res = self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 >= 0.14)
+ self.assertEqual(res, 42)
+ # Doing it again should take no time and exercise a different path.
+ t0 = time.monotonic()
+ res = self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 <= 0.01)
+ # TODO: Test different return_when values.
+
+ def test_wait_first_completed(self):
+ a = tasks.sleep(10.0)
+ b = tasks.sleep(0.1)
+ task = tasks.Task(tasks.wait(
+ [b, a], return_when=tasks.FIRST_COMPLETED))
+
+ done, pending = self.event_loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+
+ def test_wait_really_done(self):
+ self.suppress_log_errors()
+ # there is possibility that some tasks in the pending list
+ # became done but their callbacks haven't all been called yet
+
+ @tasks.coroutine
+ def coro1():
+ yield from [None]
+
+ @tasks.coroutine
+ def coro2():
+ yield from [None, None]
+
+ a = tasks.Task(coro1())
+ b = tasks.Task(coro2())
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED))
+
+ done, pending = self.event_loop.run_until_complete(task)
+ self.assertEqual({a, b}, done)
+
+ def test_wait_first_exception(self):
+ self.suppress_log_errors()
+
+ a = tasks.sleep(10.0)
+
+ @tasks.coroutine
+ def exc():
+ yield from []
+ raise ZeroDivisionError('err')
+
+ b = tasks.Task(exc())
+ task = tasks.Task(tasks.wait(
+ [b, a], return_when=tasks.FIRST_EXCEPTION))
+
+ done, pending = self.event_loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+
+ def test_wait_with_exception(self):
+ self.suppress_log_errors()
+ a = tasks.sleep(0.1)
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(0.15)
+ raise ZeroDivisionError('really')
+
+ b = tasks.Task(sleeper())
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a])
+ self.assertEqual(len(done), 2)
+ self.assertEqual(pending, set())
+ errors = set(f for f in done if f.exception() is not None)
+ self.assertEqual(len(errors), 1)
+
+ t0 = time.monotonic()
+ self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 >= 0.14)
+ t0 = time.monotonic()
+ self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 <= 0.01)
+
+ def test_wait_with_timeout(self):
+ a = tasks.sleep(0.1)
+ b = tasks.sleep(0.15)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], timeout=0.11)
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+
+ t0 = time.monotonic()
+ self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 >= 0.1)
+ self.assertTrue(t1-t0 <= 0.13)
+
+ def test_as_completed(self):
+ @tasks.coroutine
+ def sleeper(dt, x):
+ yield from tasks.sleep(dt)
+ return x
+
+ a = sleeper(0.1, 'a')
+ b = sleeper(0.1, 'b')
+ c = sleeper(0.15, 'c')
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([b, c, a]):
+ values.append((yield from f))
+ return values
+
+ t0 = time.monotonic()
+ res = self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 >= 0.14)
+ self.assertTrue('a' in res[:2])
+ self.assertTrue('b' in res[:2])
+ self.assertEqual(res[2], 'c')
+ # Doing it again should take no time and exercise a different path.
+ t0 = time.monotonic()
+ res = self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 <= 0.01)
+
+ def test_as_completed_with_timeout(self):
+ self.suppress_log_errors()
+ a = tasks.sleep(0.1, 'a')
+ b = tasks.sleep(0.15, 'b')
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([a, b], timeout=0.12):
+ try:
+ v = yield from f
+ values.append((1, v))
+ except futures.TimeoutError as exc:
+ values.append((2, exc))
+ return values
+
+ t0 = time.monotonic()
+ res = self.event_loop.run_until_complete(tasks.Task(foo()))
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 >= 0.11)
+ self.assertEqual(len(res), 2, res)
+ self.assertEqual(res[0], (1, 'a'))
+ self.assertEqual(res[1][0], 2)
+ self.assertTrue(isinstance(res[1][1], futures.TimeoutError))
+
+ def test_sleep(self):
+ @tasks.coroutine
+ def sleeper(dt, arg):
+ yield from tasks.sleep(dt/2)
+ res = yield from tasks.sleep(dt/2, arg)
+ return res
+
+ t = tasks.Task(sleeper(0.1, 'yeah'))
+ t0 = time.monotonic()
+ self.event_loop.run()
+ t1 = time.monotonic()
+ self.assertTrue(t1-t0 >= 0.09)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'yeah')
+
+ def test_task_cancel_sleeping_task(self):
+ sleepfut = None
+
+ @tasks.task
+ def sleep(dt):
+ nonlocal sleepfut
+ sleepfut = tasks.sleep(dt)
+ try:
+ time.monotonic()
+ yield from sleepfut
+ finally:
+ time.monotonic()
+
+ @tasks.task
+ def doit():
+ sleeper = sleep(5000)
+ self.event_loop.call_later(0.1, sleeper.cancel)
+ try:
+ time.monotonic()
+ yield from sleeper
+ except futures.CancelledError:
+ time.monotonic()
+ return 'cancelled'
+ else:
+ return 'slept in'
+
+ t0 = time.monotonic()
+ doer = doit()
+ self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled')
+ t1 = time.monotonic()
+ self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer))
+
+ @unittest.mock.patch('tulip.tasks.tulip_log')
+ def test_step_in_completed_task(self, m_logging):
+ @tasks.coroutine
+ def notmuch():
+ yield from []
+ return 'ko'
+
+ task = tasks.Task(notmuch())
+ task.set_result('ok')
+
+ task._step()
+ self.assertTrue(m_logging.warn.called)
+ self.assertTrue(m_logging.warn.call_args[0][0].startswith(
+ '_step(): already done: '))
+
+ @unittest.mock.patch('tulip.tasks.tulip_log')
+ def test_step_result(self, m_logging):
+ @tasks.coroutine
+ def notmuch():
+ yield from [None, 1]
+ return 'ko'
+
+ task = tasks.Task(notmuch())
+ task._step()
+ self.assertFalse(m_logging.warn.called)
+
+ task._step()
+ self.assertTrue(m_logging.warn.called)
+ self.assertEqual(
+ '_step(): bad yield: %r',
+ m_logging.warn.call_args[0][0])
+ self.assertEqual(1, m_logging.warn.call_args[0][1])
+
+ def test_step_result_future(self):
+ # If coroutine returns future, task waits on this future.
+ self.suppress_log_warnings()
+
+ class Fut(futures.Future):
+ def __init__(self, *args):
+ self.cb_added = False
+ super().__init__(*args)
+
+ def add_done_callback(self, fn):
+ self.cb_added = True
+ super().add_done_callback(fn)
+
+ fut = Fut()
+ result = None
+
+ @tasks.task
+ def wait_for_future():
+ nonlocal result
+ result = yield from fut
+
+ wait_for_future()
+ self.event_loop.run_once()
+ self.assertTrue(fut.cb_added)
+
+ res = object()
+ fut.set_result(res)
+ self.event_loop.run_once()
+ self.assertIs(res, result)
+
+ def test_step_result_concurrent_future(self):
+ # Coroutine returns concurrent.futures.Future
+ self.suppress_log_warnings()
+
+ class Fut(concurrent.futures.Future):
+ def __init__(self):
+ self.cb_added = False
+ super().__init__()
+
+ def add_done_callback(self, fn):
+ self.cb_added = True
+ super().add_done_callback(fn)
+
+ c_fut = Fut()
+
+ @tasks.coroutine
+ def notmuch():
+ yield from [c_fut]
+ return (yield)
+
+ task = tasks.Task(notmuch())
+ task._step()
+ self.assertTrue(c_fut.cb_added)
+
+ res = object()
+ c_fut.set_result(res)
+ self.event_loop.run()
+ self.assertIs(res, task.result())
+
+ def test_step_with_baseexception(self):
+ self.suppress_log_errors()
+
+ @tasks.coroutine
+ def notmutch():
+ yield from []
+ raise BaseException()
+
+ task = tasks.Task(notmutch())
+ self.assertRaises(BaseException, task._step)
+
+ self.assertTrue(task.done())
+ self.assertIsInstance(task.exception(), BaseException)
+
+ def test_baseexception_during_cancel(self):
+ self.suppress_log_errors()
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(10)
+
+ @tasks.coroutine
+ def notmutch():
+ try:
+ yield from sleeper()
+ except futures.CancelledError:
+ raise BaseException()
+
+ task = tasks.Task(notmutch())
+ self.event_loop.run_once()
+
+ task.cancel()
+ self.assertFalse(task.done())
+
+ self.assertRaises(BaseException, self.event_loop.run_once)
+
+ self.assertTrue(task.done())
+ self.assertTrue(task.cancelled())
+
+ def test_iscoroutinefunction(self):
+ def fn():
+ pass
+
+ self.assertFalse(tasks.iscoroutinefunction(fn))
+
+ def fn1():
+ yield
+ self.assertFalse(tasks.iscoroutinefunction(fn1))
+
+ @tasks.coroutine
+ def fn2():
+ yield
+ self.assertTrue(tasks.iscoroutinefunction(fn2))
+
+ def test_yield_vs_yield_from(self):
+ fut = futures.Future()
+
+ @tasks.task
+ def wait_for_future():
+ yield fut
+
+ task = wait_for_future()
+ self.assertRaises(
+ RuntimeError,
+ self.event_loop.run_until_complete, task)
+
+ def test_yield_vs_yield_from_generator(self):
+ fut = futures.Future()
+
+ @tasks.coroutine
+ def coro():
+ yield from fut
+
+ @tasks.task
+ def wait_for_future():
+ yield coro()
+
+ task = wait_for_future()
+ self.assertRaises(
+ RuntimeError,
+ self.event_loop.run_until_complete, task)
+
+ def test_coroutine_non_gen_function(self):
+
+ @tasks.coroutine
+ def func():
+ return 'test'
+
+ self.assertTrue(tasks.iscoroutinefunction(func))
+
+ coro = func()
+ self.assertTrue(tasks.iscoroutine(coro))
+
+ res = self.event_loop.run_until_complete(coro)
+ self.assertEqual(res, 'test')
+
+ def test_coroutine_non_gen_function_return_future(self):
+ fut = futures.Future()
+
+ @tasks.coroutine
+ def func():
+ return fut
+
+ @tasks.coroutine
+ def coro():
+ fut.set_result('test')
+
+ t1 = tasks.Task(func())
+ tasks.Task(coro())
+ res = self.event_loop.run_until_complete(t1)
+ self.assertEqual(res, 'test')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/transports_test.py b/tests/transports_test.py
new file mode 100644
index 0000000..4b24b50
--- /dev/null
+++ b/tests/transports_test.py
@@ -0,0 +1,45 @@
+"""Tests for transports.py."""
+
+import unittest
+import unittest.mock
+
+from tulip import transports
+
+
+class TransportTests(unittest.TestCase):
+
+ def test_ctor_extra_is_none(self):
+ transport = transports.Transport()
+ self.assertEqual(transport._extra, {})
+
+ def test_get_extra_info(self):
+ transport = transports.Transport({'extra': 'info'})
+ self.assertEqual('info', transport.get_extra_info('extra'))
+ self.assertIsNone(transport.get_extra_info('unknown'))
+
+ default = object()
+ self.assertIs(default, transport.get_extra_info('unknown', default))
+
+ def test_writelines(self):
+ transport = transports.Transport()
+ transport.write = unittest.mock.Mock()
+
+ transport.writelines(['line1', 'line2', 'line3'])
+ self.assertEqual(3, transport.write.call_count)
+
+ def test_not_implemented(self):
+ transport = transports.Transport()
+
+ self.assertRaises(NotImplementedError, transport.write, 'data')
+ self.assertRaises(NotImplementedError, transport.write_eof)
+ self.assertRaises(NotImplementedError, transport.can_write_eof)
+ self.assertRaises(NotImplementedError, transport.pause)
+ self.assertRaises(NotImplementedError, transport.resume)
+ self.assertRaises(NotImplementedError, transport.close)
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_dgram_not_implemented(self):
+ transport = transports.DatagramTransport()
+
+ self.assertRaises(NotImplementedError, transport.sendto, 'data')
+ self.assertRaises(NotImplementedError, transport.abort)
diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py
new file mode 100644
index 0000000..d7af7ec
--- /dev/null
+++ b/tests/unix_events_test.py
@@ -0,0 +1,573 @@
+"""Tests for unix_events.py."""
+
+import errno
+import io
+import unittest
+import unittest.mock
+
+try:
+ import signal
+except ImportError:
+ signal = None
+
+from tulip import events
+from tulip import futures
+from tulip import protocols
+from tulip import unix_events
+
+
+@unittest.skipUnless(signal, 'Signals are not supported')
+class SelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = unix_events.SelectorEventLoop()
+
+ def test_check_signal(self):
+ self.assertRaises(
+ TypeError, self.event_loop._check_signal, '1')
+ self.assertRaises(
+ ValueError, self.event_loop._check_signal, signal.NSIG + 1)
+
+ unix_events.signal = None
+
+ def restore_signal():
+ unix_events.signal = signal
+ self.addCleanup(restore_signal)
+
+ self.assertRaises(
+ RuntimeError, self.event_loop._check_signal, signal.SIGINT)
+
+ def test_handle_signal_no_handler(self):
+ self.event_loop._handle_signal(signal.NSIG + 1, ())
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ def test_add_signal_handler_setup_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.assertRaises(
+ RuntimeError,
+ self.event_loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ def test_add_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True)
+ self.assertIsInstance(h, events.Handle)
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ def test_add_signal_handler_install_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ def set_wakeup_fd(fd):
+ if fd == -1:
+ raise ValueError()
+ m_signal.set_wakeup_fd = set_wakeup_fd
+
+ class Err(OSError):
+ errno = errno.EFAULT
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ Err,
+ self.event_loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ @unittest.mock.patch('tulip.unix_events.tulip_log')
+ def test_add_signal_handler_install_error2(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True
+ self.assertRaises(
+ RuntimeError,
+ self.event_loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ @unittest.mock.patch('tulip.unix_events.tulip_log')
+ def test_add_signal_handler_install_error3(self, m_logging, m_signal):
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+ m_signal.NSIG = signal.NSIG
+
+ self.assertRaises(
+ RuntimeError,
+ self.event_loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ def test_remove_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ self.assertTrue(
+ self.event_loop.remove_signal_handler(signal.SIGHUP))
+ self.assertTrue(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ def test_remove_signal_handler_2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.SIGINT = signal.SIGINT
+
+ self.event_loop.add_signal_handler(signal.SIGINT, lambda: True)
+ self.event_loop._signal_handlers[signal.SIGHUP] = object()
+ m_signal.set_wakeup_fd.reset_mock()
+
+ self.assertTrue(
+ self.event_loop.remove_signal_handler(signal.SIGINT))
+ self.assertFalse(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGINT, m_signal.default_int_handler),
+ m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ @unittest.mock.patch('tulip.unix_events.tulip_log')
+ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.event_loop.remove_signal_handler(signal.SIGHUP)
+ self.assertTrue(m_logging.info)
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ def test_remove_signal_handler_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.signal.side_effect = OSError
+
+ self.assertRaises(
+ OSError, self.event_loop.remove_signal_handler, signal.SIGHUP)
+
+ @unittest.mock.patch('tulip.unix_events.signal')
+ def test_remove_signal_handler_error2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP)
+
+
+class UnixReadPipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+ self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_ctor(self, m_fcntl):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ self.event_loop.add_reader.assert_called_with(5, tr._read_ready)
+ self.event_loop.call_soon.assert_called_with(
+ self.protocol.connection_made, tr)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_ctor_with_waiter(self, m_fcntl):
+ fut = futures.Future()
+ unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol, fut)
+ self.event_loop.call_soon.assert_called_with(fut.set_result, None)
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__read_ready(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ m_read.return_value = b'data'
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__read_ready_eof(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ m_read.return_value = b''
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.protocol.eof_received.assert_called_with()
+ self.event_loop.remove_reader.assert_called_with(5)
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__read_ready_blocked(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ self.event_loop.reset_mock()
+ m_read.side_effect = BlockingIOError
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.assertFalse(self.protocol.data_received.called)
+
+ @unittest.mock.patch('tulip.log.tulip_log.exception')
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__read_ready_error(self, m_fcntl, m_read, m_logexc):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ err = OSError()
+ m_read.side_effect = err
+ tr._close = unittest.mock.Mock()
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ tr._close.assert_called_with(err)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_pause(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr.pause()
+ self.event_loop.remove_reader.assert_called_with(5)
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_resume(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr.resume()
+ self.event_loop.add_reader.assert_called_with(5, tr._read_ready)
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_close(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ tr._close.assert_called_with(None)
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_close_already_closing(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._closing = True
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ self.assertFalse(tr._close.called)
+
+ @unittest.mock.patch('os.read')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__close(self, m_fcntl, m_read):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ err = object()
+ tr._close(err)
+ self.assertTrue(tr._closing)
+ self.event_loop.remove_reader.assert_called_with(5)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__call_connection_lost(self, m_fcntl):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__call_connection_lost_with_err(self, m_fcntl):
+ tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+
+class UnixWritePipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+ self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_ctor(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ self.event_loop.call_soon.assert_called_with(
+ self.protocol.connection_made, tr)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_ctor_with_waiter(self, m_fcntl):
+ fut = futures.Future()
+ unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol, fut)
+ self.event_loop.call_soon.assert_called_with(fut.set_result, None)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_can_write_eof(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ self.assertTrue(tr.can_write_eof())
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ m_write.return_value = 4
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.event_loop.add_writer.called)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write_no_data(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr.write(b'')
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.event_loop.add_writer.called)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write_partial(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ m_write.return_value = 2
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.event_loop.add_writer.assert_called_with(5, tr._write_ready)
+ self.assertEqual([b'ta'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write_buffer(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._buffer = [b'previous']
+ tr.write(b'data')
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.event_loop.add_writer.called)
+ self.assertEqual([b'previous', b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write_again(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ m_write.side_effect = BlockingIOError()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.event_loop.add_writer.assert_called_with(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write_err(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ err = OSError()
+ m_write.side_effect = err
+ tr._fatal_error = unittest.mock.Mock()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.event_loop.called)
+ self.assertEqual([], tr._buffer)
+ tr._fatal_error.assert_called_with(err)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__write_ready(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.event_loop.remove_writer.assert_called_with(5)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__write_ready_partial(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 3
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.event_loop.remove_writer.called)
+ self.assertEqual([b'a'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__write_ready_again(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = BlockingIOError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.event_loop.remove_writer.called)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__write_ready_empty(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 0
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.event_loop.remove_writer.called)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('tulip.log.tulip_log.exception')
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__write_ready_err(self, m_fcntl, m_write, m_logexc):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = err = OSError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.event_loop.remove_writer.assert_called_with(5)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ self.protocol.connection_lost.assert_called_with(err)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__write_ready_closing(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._closing = True
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.event_loop.remove_writer.assert_called_with(5)
+ self.assertEqual([], tr._buffer)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('os.write')
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_abort(self, m_fcntl, m_write):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr._buffer = [b'da', b'ta']
+ tr.abort()
+ self.assertFalse(m_write.called)
+ self.event_loop.remove_writer.assert_called_with(5)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__call_connection_lost(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test__call_connection_lost_with_err(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_close(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr.close()
+ tr.write_eof.assert_called_with()
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_close_closing(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr._closing = True
+ tr.close()
+ self.assertFalse(tr.write_eof.called)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write_eof(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('fcntl.fcntl')
+ def test_write_eof_pending(self, m_fcntl):
+ tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe,
+ self.protocol)
+ tr._buffer = [b'data']
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.protocol.connection_lost.called)
diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py
new file mode 100644
index 0000000..381fb22
--- /dev/null
+++ b/tests/winsocketpair_test.py
@@ -0,0 +1,26 @@
+"""Tests for winsocketpair.py"""
+
+import unittest
+import unittest.mock
+
+from tulip import winsocketpair
+
+
+class WinsocketpairTests(unittest.TestCase):
+
+ def test_winsocketpair(self):
+ ssock, csock = winsocketpair.socketpair()
+
+ csock.send(b'xxx')
+ self.assertEqual(b'xxx', ssock.recv(1024))
+
+ csock.close()
+ ssock.close()
+
+ @unittest.mock.patch('tulip.winsocketpair.socket')
+ def test_winsocketpair_exc(self, m_socket):
+ m_socket.socket.return_value.getsockname.return_value = ('', 12345)
+ m_socket.socket.return_value.accept.return_value = object(), object()
+ m_socket.socket.return_value.connect.side_effect = OSError()
+
+ self.assertRaises(OSError, winsocketpair.socketpair)
diff --git a/tulip/TODO b/tulip/TODO
new file mode 100644
index 0000000..acec5c2
--- /dev/null
+++ b/tulip/TODO
@@ -0,0 +1,28 @@
+TODO in tulip v2 (tulip/ package directory)
+-------------------------------------------
+
+- See also TBD and Open Issues in PEP 3156
+
+- Refactor unix_events.py (it's getting too long)
+
+- Docstrings
+
+- Unittests
+
+- better run_once() behavior? (Run ready list last.)
+
+- start_serving()
+
+- Make Handler() callable? Log the exception in there?
+
+- Add the repeat interval to the Handler class?
+
+- Recognize Handler passed to add_reader(), call_soon(), etc.?
+
+- SSL support
+
+- buffered stream implementation
+
+- Primitives like par() and wait_one()
+
+- Remove test dependency on xkcd.com, write our own test server
diff --git a/tulip/__init__.py b/tulip/__init__.py
new file mode 100644
index 0000000..faf307f
--- /dev/null
+++ b/tulip/__init__.py
@@ -0,0 +1,26 @@
+"""Tulip 2.0, tracking PEP 3156."""
+
+import sys
+
+# This relies on each of the submodules having an __all__ variable.
+from .futures import *
+from .events import *
+from .locks import *
+from .transports import *
+from .protocols import *
+from .streams import *
+from .tasks import *
+
+if sys.platform == 'win32': # pragma: no cover
+ from .windows_events import *
+else:
+ from .unix_events import * # pragma: no cover
+
+
+__all__ = (futures.__all__ +
+ events.__all__ +
+ locks.__all__ +
+ transports.__all__ +
+ protocols.__all__ +
+ streams.__all__ +
+ tasks.__all__)
diff --git a/tulip/base_events.py b/tulip/base_events.py
new file mode 100644
index 0000000..5ed257a
--- /dev/null
+++ b/tulip/base_events.py
@@ -0,0 +1,548 @@
+"""Base implementation of event loop.
+
+The event loop can be broken up into a multiplexer (the part
+responsible for notifying us of IO events) and the event loop proper,
+which wraps a multiplexer with functionality for scheduling callbacks,
+immediately or at a given time in the future.
+
+Whenever a public API takes a callback, subsequent positional
+arguments will be passed to the callback if/when it is called. This
+avoids the proliferation of trivial lambdas implementing closures.
+Keyword arguments for the callback are not supported; this is a
+conscious design decision, leaving the door open for keyword arguments
+to modify the meaning of the API call itself.
+"""
+
+
+import collections
+import concurrent.futures
+import heapq
+import logging
+import socket
+import time
+
+from . import events
+from . import futures
+from . import tasks
+from .log import tulip_log
+
+
+__all__ = ['BaseEventLoop']
+
+
+# Argument for default thread pool executor creation.
+_MAX_WORKERS = 5
+
+
+class _StopError(BaseException):
+ """Raised to stop the event loop."""
+
+
+def _raise_stop_error(*args):
+ raise _StopError
+
+
+class BaseEventLoop(events.AbstractEventLoop):
+
+ def __init__(self):
+ self._ready = collections.deque()
+ self._scheduled = []
+ self._default_executor = None
+ self._internal_fds = 0
+ self._running = False
+
+ def _make_socket_transport(self, sock, protocol, waiter=None, extra=None):
+ """Create socket transport."""
+ raise NotImplementedError
+
+ def _make_ssl_transport(self, rawsock, protocol,
+ sslcontext, waiter, extra=None):
+ """Create SSL transport."""
+ raise NotImplementedError
+
+ def _make_datagram_transport(self, sock, protocol,
+ address=None, extra=None):
+ """Create datagram transport."""
+ raise NotImplementedError
+
+ def _make_read_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ """Create read pipe transport."""
+ raise NotImplementedError
+
+ def _make_write_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ """Create write pipe transport."""
+ raise NotImplementedError
+
+ def _read_from_self(self):
+ """XXX"""
+ raise NotImplementedError
+
+ def _write_to_self(self):
+ """XXX"""
+ raise NotImplementedError
+
+ def _process_events(self, event_list):
+ """Process selector events."""
+ raise NotImplementedError
+
+ def is_running(self):
+ """Returns running status of event loop."""
+ return self._running
+
+ def run(self):
+ """Run the event loop until nothing left to do or stop() called.
+
+ This keeps going as long as there are either readable and
+ writable file descriptors, or scheduled callbacks (of either
+ variety).
+
+ TODO: Give this a timeout too?
+ """
+ if self._running:
+ raise RuntimeError('Event loop is running.')
+
+ self._running = True
+ try:
+ while (self._ready or
+ self._scheduled or
+ self._selector.registered_count() > 1):
+ try:
+ self._run_once()
+ except _StopError:
+ break
+ finally:
+ self._running = False
+
+ def run_forever(self):
+ """Run until stop() is called.
+
+ This only makes sense over run() if you have another thread
+ scheduling callbacks using call_soon_threadsafe().
+ """
+ handle = self.call_repeatedly(24*3600, lambda: None)
+ try:
+ self.run()
+ finally:
+ handle.cancel()
+
+ def run_once(self, timeout=0):
+ """Run through all callbacks and all I/O polls once.
+
+ Calling stop() will break out of this too.
+ """
+ if self._running:
+ raise RuntimeError('Event loop is running.')
+
+ self._running = True
+ try:
+ self._run_once(timeout)
+ except _StopError:
+ pass
+ finally:
+ self._running = False
+
+ def run_until_complete(self, future, timeout=None):
+ """Run until the Future is done, or until a timeout.
+
+ If the argument is a coroutine, it is wrapped in a Task.
+
+ XXX TBD: It would be disastrous to call run_until_complete()
+ with the same coroutine twice -- it would wrap it in two
+ different Tasks and that can't be good.
+
+ Return the Future's result, or raise its exception. If the
+ timeout is reached or stop() is called, raise TimeoutError.
+ """
+ if not isinstance(future, futures.Future):
+ if tasks.iscoroutine(future):
+ future = tasks.Task(future)
+ else:
+ assert False, 'A Future or coroutine is required'
+
+ handle_called = False
+
+ def stop_loop():
+ nonlocal handle_called
+ handle_called = True
+ raise _StopError
+
+ future.add_done_callback(_raise_stop_error)
+
+ if timeout is None:
+ self.run_forever()
+ else:
+ handle = self.call_later(timeout, stop_loop)
+ self.run_forever()
+ handle.cancel()
+
+ if handle_called:
+ raise futures.TimeoutError
+
+ return future.result()
+
+ def stop(self):
+ """Stop running the event loop.
+
+ Every callback scheduled before stop() is called will run.
+ Callback scheduled after stop() is called won't. However,
+ those callbacks will run if run() is called again later.
+ """
+ self.call_soon(_raise_stop_error)
+
+ def call_later(self, delay, callback, *args):
+ """Arrange for a callback to be called at a given time.
+
+ Return an object with a cancel() method that can be used to
+ cancel the call.
+
+ The delay can be an int or float, expressed in seconds. It is
+ always a relative time.
+
+ Each callback will be called exactly once. If two callbacks
+ are scheduled for exactly the same time, it undefined which
+ will be called first.
+
+ Callbacks scheduled in the past are passed on to call_soon(),
+ so these will be called in the order in which they were
+ registered rather than by time due. This is so you can't
+ cheat and insert yourself at the front of the ready queue by
+ using a negative time.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ if delay <= 0:
+ return self.call_soon(callback, *args)
+
+ handle = events.Timer(time.monotonic() + delay, callback, args)
+ heapq.heappush(self._scheduled, handle)
+ return handle
+
+ def call_repeatedly(self, interval, callback, *args):
+ """Call a callback every 'interval' seconds."""
+ assert interval > 0, 'Interval must be > 0: %r' % (interval,)
+ # TODO: What if callback is already a Handle?
+ def wrapper():
+ callback(*args) # If this fails, the chain is broken.
+ handle._when = time.monotonic() + interval
+ heapq.heappush(self._scheduled, handle)
+
+ handle = events.Timer(time.monotonic() + interval, wrapper, ())
+ heapq.heappush(self._scheduled, handle)
+ return handle
+
+ def call_soon(self, callback, *args):
+ """Arrange for a callback to be called as soon as possible.
+
+ This operates as a FIFO queue, callbacks are called in the
+ order in which they are registered. Each callback will be
+ called exactly once.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ handle = events.make_handle(callback, args)
+ self._ready.append(handle)
+ return handle
+
+ def call_soon_threadsafe(self, callback, *args):
+ """XXX"""
+ handle = self.call_soon(callback, *args)
+ self._write_to_self()
+ return handle
+
+ def run_in_executor(self, executor, callback, *args):
+ if isinstance(callback, events.Handle):
+ assert not args
+ assert not isinstance(callback, events.Timer)
+ if callback.cancelled:
+ f = futures.Future()
+ f.set_result(None)
+ return f
+ callback, args = callback.callback, callback.args
+ if executor is None:
+ executor = self._default_executor
+ if executor is None:
+ executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS)
+ self._default_executor = executor
+ return self.wrap_future(executor.submit(callback, *args))
+
+ def set_default_executor(self, executor):
+ self._default_executor = executor
+
+ def getaddrinfo(self, host, port, *,
+ family=0, type=0, proto=0, flags=0):
+ return self.run_in_executor(None, socket.getaddrinfo,
+ host, port, family, type, proto, flags)
+
+ def getnameinfo(self, sockaddr, flags=0):
+ return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags)
+
+ @tasks.coroutine
+ def create_connection(self, protocol_factory, host=None, port=None, *,
+ ssl=False, family=0, proto=0, flags=0, sock=None):
+ """XXX"""
+ if host is not None or port is not None:
+ if sock is not None:
+ raise ValueError(
+ "host, port and sock can not be specified at the same time")
+
+ infos = yield from self.getaddrinfo(
+ host, port, family=family,
+ type=socket.SOCK_STREAM, proto=proto, flags=flags)
+
+ if not infos:
+ raise socket.error('getaddrinfo() returned empty list')
+
+ exceptions = []
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ yield from self.sock_connect(sock, address)
+ except socket.error as exc:
+ if sock is not None:
+ sock.close()
+ exceptions.append(exc)
+ else:
+ break
+ else:
+ if len(exceptions) == 1:
+ raise exceptions[0]
+ else:
+ # If they all have the same str(), raise one.
+ model = str(exceptions[0])
+ if all(str(exc) == model for exc in exceptions):
+ raise exceptions[0]
+ # Raise a combined exception so the user can see all
+ # the various error messages.
+ raise socket.error('Multiple exceptions: {}'.format(
+ ', '.join(str(exc) for exc in exceptions)))
+
+ elif sock is None:
+ raise ValueError(
+ "host and port was not specified and no sock specified")
+
+ sock.setblocking(False)
+
+ protocol = protocol_factory()
+ waiter = futures.Future()
+ if ssl:
+ sslcontext = None if isinstance(ssl, bool) else ssl
+ transport = self._make_ssl_transport(
+ sock, protocol, sslcontext, waiter)
+ else:
+ transport = self._make_socket_transport(sock, protocol, waiter)
+
+ yield from waiter
+ return transport, protocol
+
+ @tasks.coroutine
+ def create_datagram_endpoint(self, protocol_factory,
+ local_addr=None, remote_addr=None, *,
+ family=0, proto=0, flags=0):
+ """Create datagram connection."""
+ if not (local_addr or remote_addr):
+ if family == 0:
+ raise ValueError('unexpected address family')
+ addr_pairs_info = (((family, proto), (None, None)),)
+ else:
+ # join addresss by (family, protocol)
+ addr_infos = collections.OrderedDict()
+ for idx, addr in ((0, local_addr), (1, remote_addr)):
+ if addr is not None:
+ assert isinstance(addr, tuple) and len(addr) == 2, (
+ '2-tuple is expected')
+
+ infos = yield from self.getaddrinfo(
+ *addr, family=family, type=socket.SOCK_DGRAM,
+ proto=proto, flags=flags)
+ if not infos:
+ raise socket.error('getaddrinfo() returned empty list')
+
+ for fam, _, pro, _, address in infos:
+ key = (fam, pro)
+ if key not in addr_infos:
+ addr_infos[key] = [None, None]
+ addr_infos[key][idx] = address
+
+ # each addr has to have info for each (family, proto) pair
+ addr_pairs_info = [
+ (key, addr_pair) for key, addr_pair in addr_infos.items()
+ if not ((local_addr and addr_pair[0] is None) or
+ (remote_addr and addr_pair[1] is None))]
+
+ if not addr_pairs_info:
+ raise ValueError('can not get address information')
+
+ exceptions = []
+
+ for (family, proto), (local_address, remote_address) in addr_pairs_info:
+ sock = None
+ l_addr = None
+ r_addr = None
+ try:
+ sock = socket.socket(
+ family=family, type=socket.SOCK_DGRAM, proto=proto)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.setblocking(False)
+
+ if local_addr:
+ sock.bind(local_address)
+ l_addr = sock.getsockname()
+ if remote_addr:
+ yield from self.sock_connect(sock, remote_address)
+ r_addr = remote_address
+ except socket.error as exc:
+ if sock is not None:
+ sock.close()
+ exceptions.append(exc)
+ else:
+ break
+ else:
+ raise exceptions[0]
+
+ protocol = protocol_factory()
+ transport = self._make_datagram_transport(
+ sock, protocol, r_addr, extra={'addr': l_addr})
+ return transport, protocol
+
+ # TODO: Or create_server()?
+ @tasks.task
+ def start_serving(self, protocol_factory, host=None, port=None, *,
+ family=0, proto=0, flags=0, backlog=100, sock=None):
+ """XXX"""
+ if host is not None or port is not None:
+ if sock is not None:
+ raise ValueError(
+ "host, port and sock can not be specified at the same time")
+
+ infos = yield from self.getaddrinfo(
+ host, port, family=family,
+ type=socket.SOCK_STREAM, proto=proto, flags=flags)
+
+ if not infos:
+ raise socket.error('getaddrinfo() returned empty list')
+
+ # TODO: Maybe we want to bind every address in the list
+ # instead of the first one that works?
+ exceptions = []
+ for family, type, proto, cname, address in infos:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ try:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(address)
+ except socket.error as exc:
+ sock.close()
+ exceptions.append(exc)
+ else:
+ break
+ else:
+ raise exceptions[0]
+
+ elif sock is None:
+ raise ValueError(
+ "host and port was not specified and no sock specified")
+
+ sock.listen(backlog)
+ sock.setblocking(False)
+ self._start_serving(protocol_factory, sock)
+ return sock
+
+ @tasks.coroutine
+ def connect_read_pipe(self, protocol_factory, pipe):
+ protocol = protocol_factory()
+ waiter = futures.Future()
+ transport = self._make_read_pipe_transport(pipe, protocol, waiter,
+ extra={})
+ yield from waiter
+ return transport, protocol
+
+ @tasks.coroutine
+ def connect_write_pipe(self, protocol_factory, pipe):
+ protocol = protocol_factory()
+ waiter = futures.Future()
+ transport = self._make_write_pipe_transport(pipe, protocol, waiter,
+ extra={})
+ yield from waiter
+ return transport, protocol
+
+ def _add_callback(self, handle):
+ """Add a Handle to ready or scheduled."""
+ if handle.cancelled:
+ return
+ if isinstance(handle, events.Timer):
+ heapq.heappush(self._scheduled, handle)
+ else:
+ self._ready.append(handle)
+
+ def wrap_future(self, future):
+ """XXX"""
+ if isinstance(future, futures.Future):
+ return future # Don't wrap our own type of Future.
+ new_future = futures.Future()
+ future.add_done_callback(
+ lambda future:
+ self.call_soon_threadsafe(new_future._copy_state, future))
+ return new_future
+
+ def _run_once(self, timeout=None):
+ """Run one full iteration of the event loop.
+
+ This calls all currently ready callbacks, polls for I/O,
+ schedules the resulting callbacks, and finally schedules
+ 'call_later' callbacks.
+ """
+ # TODO: Break each of these into smaller pieces.
+ # TODO: Refactor to separate the callbacks from the readers/writers.
+ # TODO: An alternative API would be to do the *minimal* amount
+ # of work, e.g. one callback or one I/O poll.
+
+ # Remove delayed calls that were cancelled from head of queue.
+ while self._scheduled and self._scheduled[0].cancelled:
+ heapq.heappop(self._scheduled)
+
+ if self._ready:
+ timeout = 0
+ elif self._scheduled:
+ # Compute the desired timeout.
+ when = self._scheduled[0].when
+ deadline = max(0, when - time.monotonic())
+ if timeout is None:
+ timeout = deadline
+ else:
+ timeout = min(timeout, deadline)
+
+ t0 = time.monotonic()
+ event_list = self._selector.select(timeout)
+ t1 = time.monotonic()
+ argstr = '' if timeout is None else ' %.3f' % timeout
+ if t1-t0 >= 1:
+ level = logging.INFO
+ else:
+ level = logging.DEBUG
+ tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0)
+ self._process_events(event_list)
+
+ # Handle 'later' callbacks that are ready.
+ now = time.monotonic()
+ while self._scheduled:
+ handle = self._scheduled[0]
+ if handle.when > now:
+ break
+ handle = heapq.heappop(self._scheduled)
+ self._ready.append(handle)
+
+ # This is the only place where callbacks are actually *called*.
+ # All other places just add them to ready.
+ # Note: We run all currently scheduled callbacks, but not any
+ # callbacks scheduled by callbacks run this time around --
+ # they will be run the next time (after another I/O poll).
+ # Use an idiom that is threadsafe without using locks.
+ ntodo = len(self._ready)
+ for i in range(ntodo):
+ handle = self._ready.popleft()
+ if not handle.cancelled:
+ handle.run()
diff --git a/tulip/events.py b/tulip/events.py
new file mode 100644
index 0000000..3a6ad40
--- /dev/null
+++ b/tulip/events.py
@@ -0,0 +1,356 @@
+"""Event loop and event loop policy.
+
+Beyond the PEP:
+- Only the main thread has a default event loop.
+"""
+
+__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy',
+ 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle',
+ 'get_event_loop_policy', 'set_event_loop_policy',
+ 'get_event_loop', 'set_event_loop', 'new_event_loop',
+ ]
+
+import sys
+import threading
+
+from .log import tulip_log
+
+
+class Handle:
+ """Object returned by callback registration methods."""
+
+ def __init__(self, callback, args):
+ self._callback = callback
+ self._args = args
+ self._cancelled = False
+
+ def __repr__(self):
+ res = 'Handle({}, {})'.format(self._callback, self._args)
+ if self._cancelled:
+ res += '<cancelled>'
+ return res
+
+ @property
+ def callback(self):
+ return self._callback
+
+ @property
+ def args(self):
+ return self._args
+
+ @property
+ def cancelled(self):
+ return self._cancelled
+
+ def cancel(self):
+ self._cancelled = True
+
+ def run(self):
+ try:
+ self._callback(*self._args)
+ except Exception:
+ tulip_log.exception('Exception in callback %s %r',
+ self._callback, self._args)
+
+
+def make_handle(callback, args):
+ if isinstance(callback, Handle):
+ assert not args
+ return callback
+ return Handle(callback, args)
+
+
+class Timer(Handle):
+ """Object returned by timed callback registration methods."""
+
+ def __init__(self, when, callback, args):
+ assert when is not None
+ super().__init__(callback, args)
+
+ self._when = when
+
+ def __repr__(self):
+ res = 'Timer({}, {}, {})'.format(self._when,
+ self._callback,
+ self._args)
+ if self._cancelled:
+ res += '<cancelled>'
+
+ return res
+
+ @property
+ def when(self):
+ return self._when
+
+ def __lt__(self, other):
+ return self._when < other._when
+
+ def __le__(self, other):
+ if self._when < other._when:
+ return True
+ return self.__eq__(other)
+
+ def __gt__(self, other):
+ return self._when > other._when
+
+ def __ge__(self, other):
+ if self._when > other._when:
+ return True
+ return self.__eq__(other)
+
+ def __eq__(self, other):
+ if isinstance(other, Timer):
+ return (self._when == other._when and
+ self._callback == other._callback and
+ self._args == other._args and
+ self._cancelled == other._cancelled)
+ return NotImplemented
+
+ def __ne__(self, other):
+ equal = self.__eq__(other)
+ return NotImplemented if equal is NotImplemented else not equal
+
+
+class AbstractEventLoop:
+ """Abstract event loop."""
+
+ # TODO: Rename run() -> run_until_idle(), run_forever() -> run().
+
+ def run(self):
+ """Run the event loop. Block until there is nothing left to do."""
+ raise NotImplementedError
+
+ def run_forever(self):
+ """Run the event loop. Block until stop() is called."""
+ raise NotImplementedError
+
+ def run_once(self, timeout=None): # NEW!
+ """Run one complete cycle of the event loop."""
+ raise NotImplementedError
+
+ def run_until_complete(self, future, timeout=None): # NEW!
+ """Run the event loop until a Future is done.
+
+ Return the Future's result, or raise its exception.
+
+ If timeout is not None, run it for at most that long;
+ if the Future is still not done, raise TimeoutError
+ (but don't cancel the Future).
+ """
+ raise NotImplementedError
+
+ def stop(self): # NEW!
+ """Stop the event loop as soon as reasonable.
+
+ Exactly how soon that is may depend on the implementation, but
+ no more I/O callbacks should be scheduled.
+ """
+ raise NotImplementedError
+
+ # Methods returning Handles for scheduling callbacks.
+
+ def call_later(self, delay, callback, *args):
+ raise NotImplementedError
+
+ def call_repeatedly(self, interval, callback, *args): # NEW!
+ raise NotImplementedError
+
+ def call_soon(self, callback, *args):
+ return self.call_later(0, callback, *args)
+
+ def call_soon_threadsafe(self, callback, *args):
+ raise NotImplementedError
+
+ # Methods returning Futures for interacting with threads.
+
+ def wrap_future(self, future):
+ raise NotImplementedError
+
+ def run_in_executor(self, executor, callback, *args):
+ raise NotImplementedError
+
+ # Network I/O methods returning Futures.
+
+ def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0):
+ raise NotImplementedError
+
+ def getnameinfo(self, sockaddr, flags=0):
+ raise NotImplementedError
+
+ def create_connection(self, protocol_factory, host=None, port=None, *,
+ family=0, proto=0, flags=0, sock=None):
+ raise NotImplementedError
+
+ def start_serving(self, protocol_factory, host=None, port=None, *,
+ family=0, proto=0, flags=0, sock=None):
+ raise NotImplementedError
+
+ def create_datagram_endpoint(self, protocol_factory,
+ local_addr=None, remote_addr=None, *,
+ family=0, proto=0, flags=0):
+ raise NotImplementedError
+
+ def connect_read_pipe(self, protocol_factory, pipe):
+ """Register read pipe in eventloop.
+
+ protocol_factory should instantiate object with Protocol interface.
+ pipe is file-like object already switched to nonblocking.
+ Return pair (transport, protocol), where transport support
+ ReadTransport ABC"""
+ # The reason to accept file-like object instead of just file descriptor
+ # is: we need to own pipe and close it at transport finishing
+ # Can got complicated errors if pass f.fileno(),
+ # close fd in pipe transport then close f and vise versa.
+ raise NotImplementedError
+
+ def connect_write_pipe(self, protocol_factory, pipe):
+ """Register write pipe in eventloop.
+
+ protocol_factory should instantiate object with BaseProtocol interface.
+ Pipe is file-like object already switched to nonblocking.
+ Return pair (transport, protocol), where transport support
+ WriteTransport ABC"""
+ # The reason to accept file-like object instead of just file descriptor
+ # is: we need to own pipe and close it at transport finishing
+ # Can got complicated errors if pass f.fileno(),
+ # close fd in pipe transport then close f and vise versa.
+ raise NotImplementedError
+
+ #def spawn_subprocess(self, protocol_factory, pipe):
+ # raise NotImplementedError
+
+ # Ready-based callback registration methods.
+ # The add_*() methods return a Handle.
+ # The remove_*() methods return True if something was removed,
+ # False if there was nothing to delete.
+
+ def add_reader(self, fd, callback, *args):
+ raise NotImplementedError
+
+ def remove_reader(self, fd):
+ raise NotImplementedError
+
+ def add_writer(self, fd, callback, *args):
+ raise NotImplementedError
+
+ def remove_writer(self, fd):
+ raise NotImplementedError
+
+ # Completion based I/O methods returning Futures.
+
+ def sock_recv(self, sock, nbytes):
+ raise NotImplementedError
+
+ def sock_sendall(self, sock, data):
+ raise NotImplementedError
+
+ def sock_connect(self, sock, address):
+ raise NotImplementedError
+
+ def sock_accept(self, sock):
+ raise NotImplementedError
+
+ # Signal handling.
+
+ def add_signal_handler(self, sig, callback, *args):
+ raise NotImplementedError
+
+ def remove_signal_handler(self, sig):
+ raise NotImplementedError
+
+
+class EventLoopPolicy:
+ """Abstract policy for accessing the event loop."""
+
+ def get_event_loop(self):
+ """XXX"""
+ raise NotImplementedError
+
+ def set_event_loop(self, event_loop):
+ """XXX"""
+ raise NotImplementedError
+
+ def new_event_loop(self):
+ """XXX"""
+ raise NotImplementedError
+
+
+class DefaultEventLoopPolicy(threading.local, EventLoopPolicy):
+ """Default policy implementation for accessing the event loop.
+
+ In this policy, each thread has its own event loop. However, we
+ only automatically create an event loop by default for the main
+ thread; other threads by default have no event loop.
+
+ Other policies may have different rules (e.g. a single global
+ event loop, or automatically creating an event loop per thread, or
+ using some other notion of context to which an event loop is
+ associated).
+ """
+
+ _event_loop = None
+
+ def get_event_loop(self):
+ """Get the event loop.
+
+ This may be None or an instance of EventLoop.
+ """
+ if (self._event_loop is None and
+ threading.current_thread().name == 'MainThread'):
+ self._event_loop = self.new_event_loop()
+ return self._event_loop
+
+ def set_event_loop(self, event_loop):
+ """Set the event loop."""
+ assert event_loop is None or isinstance(event_loop, AbstractEventLoop)
+ self._event_loop = event_loop
+
+ def new_event_loop(self):
+ """Create a new event loop.
+
+ You must call set_event_loop() to make this the current event
+ loop.
+ """
+ if sys.platform == 'win32': # pragma: no cover
+ from . import windows_events
+ return windows_events.SelectorEventLoop()
+ else: # pragma: no cover
+ from . import unix_events
+ return unix_events.SelectorEventLoop()
+
+
+# Event loop policy. The policy itself is always global, even if the
+# policy's rules say that there is an event loop per thread (or other
+# notion of context). The default policy is installed by the first
+# call to get_event_loop_policy().
+_event_loop_policy = None
+
+
+def get_event_loop_policy():
+ """XXX"""
+ global _event_loop_policy
+ if _event_loop_policy is None:
+ _event_loop_policy = DefaultEventLoopPolicy()
+ return _event_loop_policy
+
+
+def set_event_loop_policy(policy):
+ """XXX"""
+ global _event_loop_policy
+ assert policy is None or isinstance(policy, EventLoopPolicy)
+ _event_loop_policy = policy
+
+
+def get_event_loop():
+ """XXX"""
+ return get_event_loop_policy().get_event_loop()
+
+
+def set_event_loop(event_loop):
+ """XXX"""
+ get_event_loop_policy().set_event_loop(event_loop)
+
+
+def new_event_loop():
+ """XXX"""
+ return get_event_loop_policy().new_event_loop()
diff --git a/tulip/futures.py b/tulip/futures.py
new file mode 100644
index 0000000..39137aa
--- /dev/null
+++ b/tulip/futures.py
@@ -0,0 +1,255 @@
+"""A Future class similar to the one in PEP 3148."""
+
+__all__ = ['CancelledError', 'TimeoutError',
+ 'InvalidStateError', 'InvalidTimeoutError',
+ 'Future',
+ ]
+
+import concurrent.futures._base
+
+from . import events
+
+# States for Future.
+_PENDING = 'PENDING'
+_CANCELLED = 'CANCELLED'
+_FINISHED = 'FINISHED'
+
+# TODO: Do we really want to depend on concurrent.futures internals?
+Error = concurrent.futures._base.Error
+CancelledError = concurrent.futures.CancelledError
+TimeoutError = concurrent.futures.TimeoutError
+
+
+class InvalidStateError(Error):
+ """The operation is not allowed in this state."""
+ # TODO: Show the future, its state, the method, and the required state.
+
+
+class InvalidTimeoutError(Error):
+ """Called result() or exception() with timeout != 0."""
+ # TODO: Print a nice error message.
+
+
+class Future:
+ """This class is *almost* compatible with concurrent.futures.Future.
+
+ Differences:
+
+ - result() and exception() do not take a timeout argument and
+ raise an exception when the future isn't done yet.
+
+ - Callbacks registered with add_done_callback() are always called
+ via the event loop's call_soon_threadsafe().
+
+ - This class is not compatible with the wait() and as_completed()
+ methods in the concurrent.futures package.
+
+ (In Python 3.4 or later we may be able to unify the implementations.)
+ """
+
+ # Class variables serving as defaults for instance variables.
+ _state = _PENDING
+ _result = None
+ _exception = None
+ _timeout_handle = None
+
+ _blocking = False # proper use of future (yield vs yield from)
+
+ def __init__(self, *, event_loop=None, timeout=None):
+ """Initialize the future.
+
+ The optional event_loop argument allows to explicitly set the event
+ loop object used by the future. If it's not provided, the future uses
+ the default event loop.
+ """
+ if event_loop is None:
+ self._event_loop = events.get_event_loop()
+ else:
+ self._event_loop = event_loop
+ self._callbacks = []
+
+ if timeout is not None:
+ self._timeout_handle = self._event_loop.call_later(
+ timeout, self.cancel)
+
+ def __repr__(self):
+ res = self.__class__.__name__
+ if self._state == _FINISHED:
+ if self._exception is not None:
+ res += '<exception={!r}>'.format(self._exception)
+ else:
+ res += '<result={!r}>'.format(self._result)
+ elif self._callbacks:
+ size = len(self._callbacks)
+ if size > 2:
+ res += '<{}, [{}, <{} more>, {}]>'.format(
+ self._state, self._callbacks[0],
+ size-2, self._callbacks[-1])
+ else:
+ res += '<{}, {}>'.format(self._state, self._callbacks)
+ else:
+ res += '<{}>'.format(self._state)
+ return res
+
+ def cancel(self):
+ """Cancel the future and schedule callbacks.
+
+ If the future is already done or cancelled, return False. Otherwise,
+ change the future's state to cancelled, schedule the callbacks and
+ return True.
+ """
+ if self._state != _PENDING:
+ return False
+ self._state = _CANCELLED
+ self._schedule_callbacks()
+ return True
+
+ def _schedule_callbacks(self):
+ """Internal: Ask the event loop to call all callbacks.
+
+ The callbacks are scheduled to be called as soon as possible. Also
+ clears the callback list.
+ """
+ # Cancel timeout handle
+ if self._timeout_handle is not None:
+ self._timeout_handle.cancel()
+ self._timeout_handle = None
+
+ callbacks = self._callbacks[:]
+ if not callbacks:
+ return
+
+ self._callbacks[:] = []
+ for callback in callbacks:
+ self._event_loop.call_soon(callback, self)
+
+ def cancelled(self):
+ """Return True if the future was cancelled."""
+ return self._state == _CANCELLED
+
+ def running(self):
+ """Always return False.
+
+ This method is for compatibility with concurrent.futures; we don't
+ have a running state.
+ """
+ return False # We don't have a running state.
+
+ def done(self):
+ """Return True if the future is done.
+
+ Done means either that a result / exception are available, or that the
+ future was cancelled.
+ """
+ return self._state != _PENDING
+
+ def result(self, timeout=0):
+ """Return the result this future represents.
+
+ If the future has been cancelled, raises CancelledError. If the
+ future's result isn't yet available, raises InvalidStateError. If
+ the future is done and has an exception set, this exception is raised.
+ Timeout values other than 0 are not supported.
+ """
+ if timeout != 0:
+ raise InvalidTimeoutError
+ if self._state == _CANCELLED:
+ raise CancelledError
+ if self._state != _FINISHED:
+ raise InvalidStateError
+ if self._exception is not None:
+ raise self._exception
+ return self._result
+
+ def exception(self, timeout=0):
+ """Return the exception that was set on this future.
+
+ The exception (or None if no exception was set) is returned only if
+ the future is done. If the future has been cancelled, raises
+ CancelledError. If the future isn't done yet, raises
+ InvalidStateError. Timeout values other than 0 are not supported.
+ """
+ if timeout != 0:
+ raise InvalidTimeoutError
+ if self._state == _CANCELLED:
+ raise CancelledError
+ if self._state != _FINISHED:
+ raise InvalidStateError
+ return self._exception
+
+ def add_done_callback(self, fn):
+ """Add a callback to be run when the future becomes done.
+
+ The callback is called with a single argument - the future object. If
+ the future is already done when this is called, the callback is
+ scheduled with call_soon.
+ """
+ if self._state != _PENDING:
+ self._event_loop.call_soon(fn, self)
+ else:
+ self._callbacks.append(fn)
+
+ # New method not in PEP 3148.
+
+ def remove_done_callback(self, fn):
+ """Remove all instances of a callback from the "call when done" list.
+
+ Returns the number of callbacks removed.
+ """
+ filtered_callbacks = [f for f in self._callbacks if f != fn]
+ removed_count = len(self._callbacks) - len(filtered_callbacks)
+ if removed_count:
+ self._callbacks[:] = filtered_callbacks
+ return removed_count
+
+ # So-called internal methods (note: no set_running_or_notify_cancel()).
+
+ def set_result(self, result):
+ """Mark the future done and set its result.
+
+ If the future is already done when this method is called, raises
+ InvalidStateError.
+ """
+ if self._state != _PENDING:
+ raise InvalidStateError
+ self._result = result
+ self._state = _FINISHED
+ self._schedule_callbacks()
+
+ def set_exception(self, exception):
+ """ Mark the future done and set an exception.
+
+ If the future is already done when this method is called, raises
+ InvalidStateError.
+ """
+ if self._state != _PENDING:
+ raise InvalidStateError
+ self._exception = exception
+ self._state = _FINISHED
+ self._schedule_callbacks()
+
+ # Truly internal methods.
+
+ def _copy_state(self, other):
+ """Internal helper to copy state from another Future.
+
+ The other Future may be a concurrent.futures.Future.
+ """
+ assert other.done()
+ assert not self.done()
+ if other.cancelled():
+ self.cancel()
+ else:
+ exception = other.exception()
+ if exception is not None:
+ self.set_exception(exception)
+ else:
+ result = other.result()
+ self.set_result(result)
+
+ def __iter__(self):
+ if not self.done():
+ self._blocking = True
+ yield self # This tells Task to wait for completion.
+ assert self.done(), "yield from wasn't used with future"
+ return self.result() # May raise too.
diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py
new file mode 100644
index 0000000..582f080
--- /dev/null
+++ b/tulip/http/__init__.py
@@ -0,0 +1,12 @@
+# This relies on each of the submodules having an __all__ variable.
+
+from .client import *
+from .errors import *
+from .protocol import *
+from .server import *
+
+
+__all__ = (client.__all__ +
+ errors.__all__ +
+ protocol.__all__ +
+ server.__all__)
diff --git a/tulip/http/client.py b/tulip/http/client.py
new file mode 100644
index 0000000..b65b90a
--- /dev/null
+++ b/tulip/http/client.py
@@ -0,0 +1,145 @@
+"""HTTP Client for Tulip.
+
+Most basic usage:
+
+ sts, headers, response = yield from http_client.fetch(url,
+ method='GET', headers={}, request=b'')
+ assert isinstance(sts, int)
+ assert isinstance(headers, dict)
+ # sort of; case insensitive (what about multiple values for same header?)
+ headers['status'] == '200 Ok' # or some such
+ assert isinstance(response, bytes)
+
+TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage).
+TODO: How do we do connection keep alive? Pooling?
+"""
+
+__all__ = ['HttpClientProtocol']
+
+
+import email.message
+import email.parser
+
+import tulip
+
+from . import protocol
+
+
+class HttpClientProtocol:
+ """This Protocol class is also used to initiate the connection.
+
+ Usage:
+ p = HttpClientProtocol(url, ...)
+ sts, headers, stream = yield from p.connect()
+
+ """
+
+ def __init__(self, host, port=None, *,
+ path='/', method='GET', headers=None, ssl=None,
+ make_body=None, encoding='utf-8', version=(1, 1),
+ chunked=False):
+ host = self.validate(host, 'host')
+ if ':' in host:
+ assert port is None
+ host, port_s = host.split(':', 1)
+ port = int(port_s)
+ self.host = host
+ if port is None:
+ if ssl:
+ port = 443
+ else:
+ port = 80
+ assert isinstance(port, int)
+ self.port = port
+ self.path = self.validate(path, 'path')
+ self.method = self.validate(method, 'method')
+ self.headers = email.message.Message()
+ self.headers['Accept-Encoding'] = 'gzip, deflate'
+ if headers:
+ for key, value in headers.items():
+ self.validate(key, 'header key')
+ self.validate(value, 'header value', True)
+ self.headers[key] = value
+ self.encoding = self.validate(encoding, 'encoding')
+ self.version = version
+ self.make_body = make_body
+ self.chunked = chunked
+ self.ssl = ssl
+ if 'content-length' not in self.headers:
+ if self.make_body is None:
+ self.headers['Content-Length'] = '0'
+ else:
+ self.chunked = True
+ if self.chunked:
+ if 'Transfer-Encoding' not in self.headers:
+ self.headers['Transfer-Encoding'] = 'chunked'
+ else:
+ assert self.headers['Transfer-Encoding'].lower() == 'chunked'
+ if 'host' not in self.headers:
+ self.headers['Host'] = self.host
+ self.event_loop = tulip.get_event_loop()
+ self.transport = None
+
+ def validate(self, value, name, embedded_spaces_okay=False):
+ # Must be a string. If embedded_spaces_okay is False, no
+ # whitespace is allowed; otherwise, internal single spaces are
+ # allowed (but no other whitespace).
+ assert isinstance(value, str), \
+ '{} should be str, not {}'.format(name, type(value))
+ parts = value.split()
+ assert parts, '{} should not be empty'.format(name)
+ if embedded_spaces_okay:
+ assert ' '.join(parts) == value, \
+ '{} can only contain embedded single spaces ({!r})'.format(
+ name, value)
+ else:
+ assert parts == [value], \
+ '{} cannot contain whitespace ({!r})'.format(name, value)
+ return value
+
+ @tulip.coroutine
+ def connect(self):
+ yield from self.event_loop.create_connection(
+ lambda: self, self.host, self.port, ssl=self.ssl)
+
+ # read response status
+ version, status, reason = yield from self.stream.read_response_status()
+
+ message = yield from self.stream.read_message(version)
+
+ # headers
+ headers = email.message.Message()
+ for hdr, val in message.headers:
+ headers.add_header(hdr, val)
+
+ sts = '{} {}'.format(status, reason)
+ return (sts, headers, message.payload)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ self.stream = protocol.HttpStreamReader()
+
+ self.request = protocol.Request(
+ transport, self.method, self.path, self.version)
+
+ self.request.add_headers(*self.headers.items())
+ self.request.send_headers()
+
+ if self.make_body is not None:
+ if self.chunked:
+ self.make_body(
+ self.request.write, self.request.eof)
+ else:
+ self.make_body(
+ self.request.write, self.request.eof)
+ else:
+ self.request.write_eof()
+
+ def data_received(self, data):
+ self.stream.feed_data(data)
+
+ def eof_received(self):
+ self.stream.feed_eof()
+
+ def connection_lost(self, exc):
+ pass
diff --git a/tulip/http/errors.py b/tulip/http/errors.py
new file mode 100644
index 0000000..41344de
--- /dev/null
+++ b/tulip/http/errors.py
@@ -0,0 +1,44 @@
+"""http related errors."""
+
+__all__ = ['HttpException', 'HttpStatusException',
+ 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader']
+
+import http.client
+
+
+class HttpException(http.client.HTTPException):
+
+ code = None
+ headers = ()
+
+
+class HttpStatusException(HttpException):
+
+ def __init__(self, code, headers=None, message=''):
+ self.code = code
+ self.headers = headers
+ self.message = message
+
+
+class BadRequestException(HttpException):
+
+ code = 400
+
+
+class IncompleteRead(BadRequestException, http.client.IncompleteRead):
+ pass
+
+
+class BadStatusLine(BadRequestException, http.client.BadStatusLine):
+ pass
+
+
+class LineTooLong(BadRequestException, http.client.LineTooLong):
+ pass
+
+
+class InvalidHeader(BadRequestException):
+
+ def __init__(self, hdr):
+ super().__init__('Invalid HTTP Header: %s' % hdr)
+ self.hdr = hdr
diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py
new file mode 100644
index 0000000..6a0e127
--- /dev/null
+++ b/tulip/http/protocol.py
@@ -0,0 +1,877 @@
+"""Http related helper utils."""
+
+__all__ = ['HttpStreamReader',
+ 'HttpMessage', 'Request', 'Response',
+ 'RawHttpMessage', 'RequestLine', 'ResponseStatus']
+
+import collections
+import email.utils
+import functools
+import http.server
+import itertools
+import re
+import sys
+import zlib
+
+import tulip
+from . import errors
+
+METHRE = re.compile('[A-Z0-9$-_.]+')
+VERSRE = re.compile('HTTP/(\d+).(\d+)')
+HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]")
+CONTINUATION = (b' ', b'\t')
+RESPONSES = http.server.BaseHTTPRequestHandler.responses
+
+RequestLine = collections.namedtuple(
+ 'RequestLine', ['method', 'uri', 'version'])
+
+
+ResponseStatus = collections.namedtuple(
+ 'ResponseStatus', ['version', 'code', 'reason'])
+
+
+RawHttpMessage = collections.namedtuple(
+ 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression'])
+
+
+class HttpStreamReader(tulip.StreamReader):
+
+ MAX_HEADERS = 32768
+ MAX_HEADERFIELD_SIZE = 8190
+
+ # if _parser is set, feed_data and feed_eof sends data into
+ # _parser instead of self. is it being used as stream redirection for
+ # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload
+ _parser = None
+
+ def feed_data(self, data):
+ """_parser is a generator, if _parser is set, feed_data sends
+ incoming data into the generator untile generator stops."""
+ if self._parser:
+ try:
+ self._parser.send(data)
+ except StopIteration as exc:
+ self._parser = None
+ if exc.value:
+ self.feed_data(exc.value)
+ else:
+ super().feed_data(data)
+
+ def feed_eof(self):
+ """_parser is a generator, if _parser is set feed_eof throws
+ StreamEofException into this generator."""
+ if self._parser:
+ try:
+ self._parser.throw(StreamEofException())
+ except StopIteration:
+ self._parser = None
+
+ super().feed_eof()
+
+ @tulip.coroutine
+ def read_request_line(self):
+ """Read request status line. Exception errors.BadStatusLine
+ could be raised in case of any errors in status line.
+ Returns three values (method, uri, version)
+
+ Example:
+
+ GET /path HTTP/1.1
+
+ >> yield from reader.read_request_line()
+ ('GET', '/path', (1, 1))
+
+ """
+ bline = yield from self.readline()
+ try:
+ line = bline.decode('ascii').rstrip()
+ except UnicodeDecodeError:
+ raise errors.BadStatusLine(bline) from None
+
+ try:
+ method, uri, version = line.split(None, 2)
+ except ValueError:
+ raise errors.BadStatusLine(line) from None
+
+ # method
+ method = method.upper()
+ if not METHRE.match(method):
+ raise errors.BadStatusLine(method)
+
+ # version
+ match = VERSRE.match(version)
+ if match is None:
+ raise errors.BadStatusLine(version)
+ version = (int(match.group(1)), int(match.group(2)))
+
+ return RequestLine(method, uri, version)
+
+ @tulip.coroutine
+ def read_response_status(self):
+ """Read response status line. Exception errors.BadStatusLine
+ could be raised in case of any errors in status line.
+ Returns three values (version, status_code, reason)
+
+ Example:
+
+ HTTP/1.1 200 Ok
+
+ >> yield from reader.read_response_status()
+ ((1, 1), 200, 'Ok')
+
+ """
+ bline = yield from self.readline()
+ if not bline:
+ # Presumably, the server closed the connection before
+ # sending a valid response.
+ raise errors.BadStatusLine(bline)
+
+ try:
+ line = bline.decode('ascii').rstrip()
+ except UnicodeDecodeError:
+ raise errors.BadStatusLine(bline) from None
+
+ try:
+ version, status = line.split(None, 1)
+ except ValueError:
+ raise errors.BadStatusLine(line) from None
+ else:
+ try:
+ status, reason = status.split(None, 1)
+ except ValueError:
+ reason = ''
+
+ # version
+ match = VERSRE.match(version)
+ if match is None:
+ raise errors.BadStatusLine(line)
+ version = (int(match.group(1)), int(match.group(2)))
+
+ # The status code is a three-digit number
+ try:
+ status = int(status)
+ except ValueError:
+ raise errors.BadStatusLine(line) from None
+
+ if status < 100 or status > 999:
+ raise errors.BadStatusLine(line)
+
+ return ResponseStatus(version, status, reason.strip())
+
+ @tulip.coroutine
+ def read_headers(self):
+ """Read and parses RFC2822 headers from a stream.
+
+ Line continuations are supported. Returns list of header name
+ and value pairs. Header name is in upper case.
+ """
+ size = 0
+ headers = []
+
+ line = yield from self.readline()
+
+ while line not in (b'\r\n', b'\n'):
+ header_length = len(line)
+
+ # Parse initial header name : value pair.
+ sep_pos = line.find(b':')
+ if sep_pos < 0:
+ raise ValueError('Invalid header %s' % line.strip())
+
+ name, value = line[:sep_pos], line[sep_pos+1:]
+ name = name.rstrip(b' \t').upper()
+ if HDRRE.search(name):
+ raise ValueError('Invalid header name %s' % name)
+
+ name = name.strip().decode('ascii', 'surrogateescape')
+ value = [value.lstrip()]
+
+ # next line
+ line = yield from self.readline()
+
+ # consume continuation lines
+ continuation = line.startswith(CONTINUATION)
+
+ if continuation:
+ while continuation:
+ header_length += len(line)
+ if header_length > self.MAX_HEADERFIELD_SIZE:
+ raise errors.LineTooLong(
+ 'limit request headers fields size')
+ value.append(line)
+
+ line = yield from self.readline()
+ continuation = line.startswith(CONTINUATION)
+ else:
+ if header_length > self.MAX_HEADERFIELD_SIZE:
+ raise errors.LineTooLong(
+ 'limit request headers fields size')
+
+ # total headers size
+ size += header_length
+ if size >= self.MAX_HEADERS:
+ raise errors.LineTooLong('limit request headers fields')
+
+ headers.append(
+ (name,
+ b''.join(value).rstrip().decode('ascii', 'surrogateescape')))
+
+ return headers
+
+ def _parse_chunked_payload(self):
+ """Chunked transfer encoding parser."""
+ stream = yield
+
+ try:
+ data = bytearray()
+
+ while True:
+ # read line
+ if b'\n' not in data:
+ data.extend((yield))
+ continue
+
+ line, data = data.split(b'\n', 1)
+
+ # Read the next chunk size from the file
+ i = line.find(b';')
+ if i >= 0:
+ line = line[:i] # strip chunk-extensions
+ try:
+ size = int(line, 16)
+ except ValueError:
+ raise errors.IncompleteRead(b'') from None
+
+ if size == 0:
+ break
+
+ # read chunk
+ while len(data) < size:
+ data.extend((yield))
+
+ # feed stream
+ stream.feed_data(data[:size])
+
+ data = data[size:]
+
+ # toss the CRLF at the end of the chunk
+ while len(data) < 2:
+ data.extend((yield))
+
+ data = data[2:]
+
+ # read and discard trailer up to the CRLF terminator
+ while True:
+ if b'\n' in data:
+ line, data = data.split(b'\n', 1)
+ if line in (b'\r', b''):
+ break
+ else:
+ data.extend((yield))
+
+ # stream eof
+ stream.feed_eof()
+ return data
+
+ except StreamEofException:
+ stream.set_exception(errors.IncompleteRead(b''))
+ except errors.IncompleteRead as exc:
+ stream.set_exception(exc)
+
+ def _parse_length_payload(self, length):
+ """Read specified amount of bytes."""
+ stream = yield
+
+ try:
+ data = bytearray()
+ while length:
+ data.extend((yield))
+
+ data_len = len(data)
+ if data_len <= length:
+ stream.feed_data(data)
+ data = bytearray()
+ length -= data_len
+ else:
+ stream.feed_data(data[:length])
+ data = data[length:]
+ length = 0
+
+ stream.feed_eof()
+ return data
+ except StreamEofException:
+ stream.set_exception(errors.IncompleteRead(b''))
+
+ def _parse_eof_payload(self):
+ """Read all bytes untile eof."""
+ stream = yield
+
+ try:
+ while True:
+ stream.feed_data((yield))
+ except StreamEofException:
+ stream.feed_eof()
+
+ @tulip.coroutine
+ def read_message(self, version=(1, 1),
+ length=None, compression=True, readall=False):
+ """Read RFC2822 headers and message payload from a stream.
+
+ read_message() automatically decompress gzip and deflate content
+ encoding. To prevent decompression pass compression=False.
+
+ Returns tuple of headers, payload stream, should close flag,
+ compression type.
+ """
+ # load headers
+ headers = yield from self.read_headers()
+
+ # payload params
+ chunked = False
+ encoding = None
+ close_conn = None
+
+ for name, value in headers:
+ if name == 'CONTENT-LENGTH':
+ length = value
+ elif name == 'TRANSFER-ENCODING':
+ chunked = value.lower() == 'chunked'
+ elif name == 'SEC-WEBSOCKET-KEY1':
+ length = 8
+ elif name == 'CONNECTION':
+ v = value.lower()
+ if v == 'close':
+ close_conn = True
+ elif v == 'keep-alive':
+ close_conn = False
+ elif compression and name == 'CONTENT-ENCODING':
+ enc = value.lower()
+ if enc in ('gzip', 'deflate'):
+ encoding = enc
+
+ if close_conn is None:
+ close_conn = version <= (1, 0)
+
+ # payload parser
+ if chunked:
+ parser = self._parse_chunked_payload()
+
+ elif length is not None:
+ try:
+ length = int(length)
+ except ValueError:
+ raise errors.InvalidHeader('CONTENT-LENGTH') from None
+
+ if length < 0:
+ raise errors.InvalidHeader('CONTENT-LENGTH')
+
+ parser = self._parse_length_payload(length)
+ else:
+ if readall:
+ parser = self._parse_eof_payload()
+ else:
+ parser = self._parse_length_payload(0)
+
+ next(parser)
+
+ payload = stream = tulip.StreamReader()
+
+ # payload decompression wrapper
+ if encoding is not None:
+ stream = DeflateStream(stream, encoding)
+
+ try:
+ # initialize payload parser with stream, stream is being
+ # used by parser as destination stream
+ parser.send(stream)
+ except StopIteration:
+ pass
+ else:
+ # feed existing buffer to payload parser
+ self.byte_count = 0
+ while self.buffer:
+ try:
+ parser.send(self.buffer.popleft())
+ except StopIteration as exc:
+ parser = None
+
+ # parser is done
+ buf = b''.join(self.buffer)
+ self.buffer.clear()
+
+ # re-add remaining data back to buffer
+ if exc.value:
+ self.feed_data(exc.value)
+
+ if buf:
+ self.feed_data(buf)
+
+ break
+
+ # parser still require more data
+ if parser is not None:
+ if self.eof:
+ try:
+ parser.throw(StreamEofException())
+ except StopIteration as exc:
+ pass
+ else:
+ self._parser = parser
+
+ return RawHttpMessage(headers, payload, close_conn, encoding)
+
+
+class StreamEofException(Exception):
+ """Internal exception: eof received."""
+
+
+class DeflateStream:
+ """DeflateStream decomress stream and feed data into specified stream."""
+
+ def __init__(self, stream, encoding):
+ self.stream = stream
+ zlib_mode = (16 + zlib.MAX_WBITS
+ if encoding == 'gzip' else -zlib.MAX_WBITS)
+
+ self.zlib = zlib.decompressobj(wbits=zlib_mode)
+
+ def set_exception(self, exc):
+ self.stream.set_exception(exc)
+
+ def feed_data(self, chunk):
+ try:
+ chunk = self.zlib.decompress(chunk)
+ except:
+ self.stream.set_exception(errors.IncompleteRead(b''))
+
+ if chunk:
+ self.stream.feed_data(chunk)
+
+ def feed_eof(self):
+ self.stream.feed_data(self.zlib.flush())
+ if not self.zlib.eof:
+ self.stream.set_exception(errors.IncompleteRead(b''))
+
+ self.stream.feed_eof()
+
+
+EOF_MARKER = object()
+EOL_MARKER = object()
+
+
+def wrap_payload_filter(func):
+ """Wraps payload filter and piped filters.
+
+ Filter is a generatator that accepts arbitrary chunks of data,
+ modify data and emit new stream of data.
+
+ For example we have stream of chunks: ['1', '2', '3', '4', '5'],
+ we can apply chunking filter to this stream:
+
+ ['1', '2', '3', '4', '5']
+ |
+ response.add_chunking_filter(2)
+ |
+ ['12', '34', '5']
+
+ It is possible to use different filters at the same time.
+
+ For a example to compress incoming stream with 'deflate' encoding
+ and then split data and emit chunks of 8196 bytes size chunks:
+
+ >> response.add_compression_filter('deflate')
+ >> response.add_chunking_filter(8196)
+
+ Filters do not alter transfer encoding.
+
+ Filter can receive types types of data, bytes object or EOF_MARKER.
+
+ 1. If filter receives bytes object, it should process data
+ and yield processed data then yield EOL_MARKER object.
+ 2. If Filter recevied EOF_MARKER, it should yield remaining
+ data (buffered) and then yield EOF_MARKER.
+ """
+ @functools.wraps(func)
+ def wrapper(self, *args, **kw):
+ new_filter = func(self, *args, **kw)
+
+ filter = self.filter
+ if filter is not None:
+ next(new_filter)
+ self.filter = filter_pipe(filter, new_filter)
+ else:
+ self.filter = new_filter
+
+ next(self.filter)
+
+ return wrapper
+
+
+def filter_pipe(filter, filter2):
+ """Creates pipe between two filters.
+
+ filter_pipe() feeds first filter with incoming data and then
+ send yielded from first filter data into filter2, results of
+ filter2 are being emitted.
+
+ 1. If filter_pipe receives bytes object, it sends it to the first filter.
+ 2. Reads yielded values from the first filter until it receives
+ EOF_MARKER or EOL_MARKER.
+ 3. Each of this values is being send to second filter.
+ 4. Reads yielded values from second filter until it recives EOF_MARKER or
+ EOL_MARKER. Each of this values yields to writer.
+ """
+ chunk = yield
+
+ while True:
+ eof = chunk is EOF_MARKER
+ chunk = filter.send(chunk)
+
+ while chunk is not EOL_MARKER:
+ chunk = filter2.send(chunk)
+
+ while chunk not in (EOF_MARKER, EOL_MARKER):
+ yield chunk
+ chunk = next(filter2)
+
+ if chunk is not EOF_MARKER:
+ if eof:
+ chunk = EOF_MARKER
+ else:
+ chunk = next(filter)
+ else:
+ break
+
+ chunk = yield EOL_MARKER
+
+
+class HttpMessage:
+ """HttpMessage allows to write headers and payload to a stream.
+
+ For example, lets say we want to read file then compress it with deflate
+ compression and then send it with chunked transfer encoding, code may look
+ like this:
+
+ >> response = tulip.http.Response(transport, 200)
+
+ We have to use deflate compression first:
+
+ >> response.add_compression_filter('deflate')
+
+ Then we want to split output stream into chunks of 1024 bytes size:
+
+ >> response.add_chunking_filter(1024)
+
+ We can add headers to response with add_headers() method. add_headers()
+ does not send data to transport, send_headers() sends request/response
+ line and then sends headers:
+
+ >> response.add_headers(
+ .. ('Content-Disposition', 'attachment; filename="..."'))
+ >> response.send_headers()
+
+ Now we can use chunked writer to write stream to a network stream.
+ First call to write() method sends response status line and headers,
+ add_header() and add_headers() method unavailble at this stage:
+
+ >> with open('...', 'rb') as f:
+ .. chunk = fp.read(8196)
+ .. while chunk:
+ .. response.write(chunk)
+ .. chunk = fp.read(8196)
+
+ >> response.write_eof()
+ """
+
+ writer = None
+
+ # 'filter' is being used for altering write() bahaviour,
+ # add_chunking_filter adds deflate/gzip compression and
+ # add_compression_filter splits incoming data into a chunks.
+ filter = None
+
+ HOP_HEADERS = None # Must be set by subclass.
+
+ SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info)
+
+ status = None
+ status_line = b''
+
+ # subclass can enable auto sending headers with write() call,
+ # this is useful for wsgi's start_response implementation.
+ _send_headers = False
+
+ def __init__(self, transport, version, close):
+ self.transport = transport
+ self.version = version
+ self.closing = close
+ self.keepalive = False
+
+ self.chunked = False
+ self.length = None
+ self.upgrade = False
+ self.headers = []
+ self.headers_sent = False
+
+ def force_close(self):
+ self.closing = True
+
+ def force_chunked(self):
+ self.chunked = True
+
+ def keep_alive(self):
+ return self.keepalive and not self.closing
+
+ def is_headers_sent(self):
+ return self.headers_sent
+
+ def add_header(self, name, value):
+ """Analyze headers. Calculate content length,
+ removes hop headers, etc."""
+ assert not self.headers_sent, 'headers have been sent already'
+ assert isinstance(name, str), '%r is not a string' % name
+
+ name = name.strip().upper()
+
+ if name == 'CONTENT-LENGTH':
+ self.length = int(value)
+
+ if name == 'CONNECTION':
+ val = value.lower().strip()
+ # handle websocket
+ if val == 'upgrade':
+ self.upgrade = True
+ # connection keep-alive
+ elif val == 'close':
+ self.keepalive = False
+ elif val == 'keep-alive':
+ self.keepalive = True
+
+ elif name == 'UPGRADE':
+ if 'websocket' in value.lower():
+ self.headers.append((name, value))
+
+ elif name == 'TRANSFER-ENCODING' and not self.chunked:
+ self.chunked = value.lower().strip() == 'chunked'
+
+ elif name not in self.HOP_HEADERS:
+ # ignore hopbyhop headers
+ self.headers.append((name, value))
+
+ def add_headers(self, *headers):
+ """Adds headers to a http message."""
+ for name, value in headers:
+ self.add_header(name, value)
+
+ def send_headers(self):
+ """Writes headers to a stream. Constructs payload writer."""
+ # Chunked response is only for HTTP/1.1 clients or newer
+ # and there is no Content-Length header is set.
+ # Do not use chunked responses when the response is guaranteed to
+ # not have a response body (304, 204).
+ assert not self.headers_sent, 'headers have been sent already'
+ self.headers_sent = True
+
+ if (self.chunked is True) or (
+ self.length is None and
+ self.version >= (1, 1) and
+ self.status not in (304, 204)):
+ self.chunked = True
+ self.writer = self._write_chunked_payload()
+
+ elif self.length is not None:
+ self.writer = self._write_length_payload(self.length)
+
+ else:
+ self.writer = self._write_eof_payload()
+
+ next(self.writer)
+
+ # status line
+ self.transport.write(self.status_line.encode('ascii'))
+
+ # send headers
+ self.transport.write(
+ ('%s\r\n\r\n' % '\r\n'.join(
+ ('%s: %s' % (k, v) for k, v in
+ itertools.chain(self._default_headers(), self.headers)))
+ ).encode('ascii'))
+
+ def _default_headers(self):
+ # set the connection header
+ if self.upgrade:
+ connection = 'upgrade'
+ elif self.keep_alive():
+ connection = 'keep-alive'
+ else:
+ connection = 'close'
+
+ headers = [('CONNECTION', connection)]
+
+ if self.chunked:
+ headers.append(('TRANSFER-ENCODING', 'chunked'))
+
+ return headers
+
+ def write(self, chunk):
+ """write() writes chunk of data to a steram by using different writers.
+ writer uses filter to modify chunk of data. write_eof() indicates
+ end of stream. writer can't be used after write_eof() method
+ being called."""
+ assert (isinstance(chunk, (bytes, bytearray)) or
+ chunk is EOF_MARKER), chunk
+
+ if self._send_headers and not self.headers_sent:
+ self.send_headers()
+
+ assert self.writer is not None, 'send_headers() is not called.'
+
+ if self.filter:
+ chunk = self.filter.send(chunk)
+ while chunk not in (EOF_MARKER, EOL_MARKER):
+ self.writer.send(chunk)
+ chunk = next(self.filter)
+ else:
+ if chunk is not EOF_MARKER:
+ self.writer.send(chunk)
+
+ def write_eof(self):
+ self.write(EOF_MARKER)
+ try:
+ self.writer.throw(StreamEofException())
+ except StopIteration:
+ pass
+
+ def _write_chunked_payload(self):
+ """Write data in chunked transfer encoding."""
+ while True:
+ try:
+ chunk = yield
+ except StreamEofException:
+ self.transport.write(b'0\r\n\r\n')
+ break
+
+ self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii'))
+ self.transport.write(chunk)
+ self.transport.write(b'\r\n')
+
+ def _write_length_payload(self, length):
+ """Write specified number of bytes to a stream."""
+ while True:
+ try:
+ chunk = yield
+ except StreamEofException:
+ break
+
+ if length:
+ l = len(chunk)
+ if length >= l:
+ self.transport.write(chunk)
+ else:
+ self.transport.write(chunk[:length])
+
+ length = max(0, length-l)
+
+ def _write_eof_payload(self):
+ while True:
+ try:
+ chunk = yield
+ except StreamEofException:
+ break
+
+ self.transport.write(chunk)
+
+ @wrap_payload_filter
+ def add_chunking_filter(self, chunk_size=16*1024):
+ """Split incoming stream into chunks."""
+ buf = bytearray()
+ chunk = yield
+
+ while True:
+ if chunk is EOF_MARKER:
+ if buf:
+ yield buf
+
+ yield EOF_MARKER
+
+ else:
+ buf.extend(chunk)
+
+ while len(buf) >= chunk_size:
+ chunk, buf = buf[:chunk_size], buf[chunk_size:]
+ yield chunk
+
+ chunk = yield EOL_MARKER
+
+ @wrap_payload_filter
+ def add_compression_filter(self, encoding='deflate'):
+ """Compress incoming stream with deflate or gzip encoding."""
+ zlib_mode = (16 + zlib.MAX_WBITS
+ if encoding == 'gzip' else -zlib.MAX_WBITS)
+ zcomp = zlib.compressobj(wbits=zlib_mode)
+
+ chunk = yield
+ while True:
+ if chunk is EOF_MARKER:
+ yield zcomp.flush()
+ chunk = yield EOF_MARKER
+
+ else:
+ yield zcomp.compress(chunk)
+ chunk = yield EOL_MARKER
+
+
+class Response(HttpMessage):
+ """Create http response message.
+
+ Transport is a socket stream transport. status is a response status code,
+ status has to be integer value. http_version is a tuple that represents
+ http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1
+ """
+
+ HOP_HEADERS = {
+ 'CONNECTION',
+ 'KEEP-ALIVE',
+ 'PROXY-AUTHENTICATE',
+ 'PROXY-AUTHORIZATION',
+ 'TE',
+ 'TRAILERS',
+ 'TRANSFER-ENCODING',
+ 'UPGRADE',
+ 'SERVER',
+ 'DATE',
+ }
+
+ def __init__(self, transport, status, http_version=(1, 1), close=False):
+ super().__init__(transport, http_version, close)
+
+ self.status = status
+ self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format(
+ http_version, status, RESPONSES[status][0])
+
+ def _default_headers(self):
+ headers = super()._default_headers()
+ headers.extend((('DATE', email.utils.formatdate()),
+ ('SERVER', self.SERVER_SOFTWARE)))
+
+ return headers
+
+
+class Request(HttpMessage):
+
+ HOP_HEADERS = ()
+
+ def __init__(self, transport, method, uri,
+ http_version=(1, 1), close=False):
+ super().__init__(transport, http_version, close)
+
+ self.method = method
+ self.uri = uri
+ self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
+ method, uri, http_version)
+
+ def _default_headers(self):
+ headers = super()._default_headers()
+ headers.append(('USER-AGENT', self.SERVER_SOFTWARE))
+
+ return headers
diff --git a/tulip/http/server.py b/tulip/http/server.py
new file mode 100644
index 0000000..7590e47
--- /dev/null
+++ b/tulip/http/server.py
@@ -0,0 +1,176 @@
+"""simple http server."""
+
+__all__ = ['ServerHttpProtocol']
+
+import http.server
+import inspect
+import logging
+import traceback
+
+import tulip
+import tulip.http
+
+from . import errors
+
+RESPONSES = http.server.BaseHTTPRequestHandler.responses
+DEFAULT_ERROR_MESSAGE = """
+<html>
+ <head>
+ <title>%(status)s %(reason)s</title>
+ </head>
+ <body>
+ <h1>%(status)s %(reason)s</h1>
+ %(message)s
+ </body>
+</html>"""
+
+
+class ServerHttpProtocol(tulip.Protocol):
+ """Simple http protocol implementation.
+
+ ServerHttpProtocol handles incoming http request. It reads request line,
+ request headers and request payload and calls handler_request() method.
+ By default it always returns with 404 respose.
+
+ ServerHttpProtocol handles errors in incoming request, like bad
+ status line, bad headers or incomplete payload. If any error occurs,
+ connection gets closed.
+ """
+ closing = False
+ request_count = 0
+ _request_handle = None
+
+ def __init__(self, log=logging, debug=False):
+ self.log = log
+ self.debug = debug
+
+ def connection_made(self, transport):
+ self.transport = transport
+ self.stream = tulip.http.HttpStreamReader()
+ self._request_handle = self.start()
+
+ def data_received(self, data):
+ self.stream.feed_data(data)
+
+ def connection_lost(self, exc):
+ if self._request_handle is not None:
+ self._request_handle.cancel()
+ self._request_handle = None
+
+ def eof_received(self):
+ self.stream.feed_eof()
+
+ def close(self):
+ self.closing = True
+
+ def log_access(self, status, info, message, *args, **kw):
+ pass
+
+ def log_debug(self, *args, **kw):
+ if self.debug:
+ self.log.debug(*args, **kw)
+
+ def log_exception(self, *args, **kw):
+ self.log.exception(*args, **kw)
+
+ @tulip.task
+ def start(self):
+ """Start processing of incoming requests.
+ It reads request line, request headers and request payload, then
+ calls handle_request() method. Subclass has to override
+ handle_request(). start() handles various excetions in request
+ or response handling. In case of any error connection is being closed.
+ """
+
+ while True:
+ info = None
+ message = None
+ self.request_count += 1
+
+ try:
+ info = yield from self.stream.read_request_line()
+ message = yield from self.stream.read_message(info.version)
+
+ handler = self.handle_request(info, message)
+ if (inspect.isgenerator(handler) or
+ isinstance(handler, tulip.Future)):
+ yield from handler
+
+ except tulip.CancelledError:
+ self.log_debug('Ignored premature client disconnection.')
+ break
+ except errors.HttpException as exc:
+ self.handle_error(exc.code, info, message, exc, exc.headers)
+ except Exception as exc:
+ self.handle_error(500, info, message, exc)
+ finally:
+ if self.closing:
+ self.transport.close()
+ break
+
+ self._request_handle = None
+
+ def handle_error(self, status=500, info=None,
+ message=None, exc=None, headers=None):
+ """Handle errors.
+
+ Returns http response with specific status code. Logs additional
+ information. It always closes current connection."""
+
+ if status == 500:
+ self.log_exception("Error handling request")
+
+ try:
+ reason, msg = RESPONSES[status]
+ except KeyError:
+ status = 500
+ reason, msg = '???', ''
+
+ if self.debug and exc is not None:
+ try:
+ tb = traceback.format_exc()
+ msg += '<br><h2>Traceback:</h2>\n<pre>%s</pre>' % tb
+ except:
+ pass
+
+ self.log_access(status, info, message)
+
+ html = DEFAULT_ERROR_MESSAGE % {
+ 'status': status, 'reason': reason, 'message': msg}
+
+ response = tulip.http.Response(self.transport, status, close=True)
+ response.add_headers(
+ ('Content-Type', 'text/html'),
+ ('Content-Length', str(len(html))))
+ if headers is not None:
+ response.add_headers(*headers)
+ response.send_headers()
+
+ response.write(html.encode('ascii'))
+ response.write_eof()
+
+ self.close()
+
+ def handle_request(self, info, message):
+ """Handle a single http request.
+
+ Subclass should override this method. By default it always
+ returns 404 response.
+
+ info: tulip.http.RequestLine instance
+ message: tulip.http.RawHttpMessage instance
+ """
+ response = tulip.http.Response(
+ self.transport, 404, http_version=info.version, close=True)
+
+ body = b'Page Not Found!'
+
+ response.add_headers(
+ ('Content-Type', 'text/plain'),
+ ('Content-Length', str(len(body))))
+ response.send_headers()
+ response.write(body)
+ response.write_eof()
+
+ self.close()
+ self.log_access(404, info, message)
diff --git a/tulip/locks.py b/tulip/locks.py
new file mode 100644
index 0000000..4024796
--- /dev/null
+++ b/tulip/locks.py
@@ -0,0 +1,433 @@
+"""Synchronization primitives"""
+
+__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore']
+
+import collections
+import time
+
+from . import events
+from . import futures
+from . import tasks
+
+
+class Lock:
+ """The class implementing primitive lock objects.
+
+ A primitive lock is a synchronization primitive that is not owned by
+ a particular coroutine when locked. A primitive lock is in one of two
+ states, "locked" or "unlocked".
+ It is created in the unlocked state. It has two basic methods,
+ acquire() and release(). When the state is unlocked, acquire() changes
+ the state to locked and returns immediately. When the state is locked,
+ acquire() blocks until a call to release() in another coroutine changes
+ it to unlocked, then the acquire() call resets it to locked and returns.
+ The release() method should only be called in the locked state; it changes
+ the state to unlocked and returns immediately. If an attempt is made
+ to release an unlocked lock, a RuntimeError will be raised.
+
+ When more than one coroutine is blocked in acquire() waiting for the state
+ to turn to unlocked, only one coroutine proceeds when a release() call
+ resets the state to unlocked; first coroutine which is blocked in acquire()
+ is being processed.
+
+ acquire() method is a coroutine and should be called with "yield from"
+
+ Locks also support the context manager protocol. (yield from lock) should
+ be used as context manager expression.
+
+ Usage:
+
+ lock = Lock()
+ ...
+ yield from lock
+ try:
+ ...
+ finally:
+ lock.release()
+
+ Context manager usage:
+
+ lock = Lock()
+ ...
+ with (yield from lock):
+ ...
+
+ Lock object could be tested for locking state:
+
+ if not lock.locked():
+ yield from lock
+ else:
+ # lock is acquired
+ ...
+
+ """
+
+ def __init__(self):
+ self._waiters = collections.deque()
+ self._locked = False
+ self._event_loop = events.get_event_loop()
+
+ def __repr__(self):
+ res = super().__repr__()
+ return '<%s [%s]>' % (
+ res[1:-1], 'locked' if self._locked else 'unlocked')
+
+ def locked(self):
+ """Return true if lock is acquired."""
+ return self._locked
+
+ @tasks.coroutine
+ def acquire(self, timeout=None):
+ """Acquire a lock.
+
+ Acquire method blocks until the lock is unlocked, then set it to
+ locked and return True.
+
+ When invoked with the floating-point timeout argument set, blocks for
+ at most the number of seconds specified by timeout and as long as
+ the lock cannot be acquired.
+
+ The return value is True if the lock is acquired successfully,
+ False if not (for example if the timeout expired).
+ """
+ if not self._waiters and not self._locked:
+ self._locked = True
+ return True
+
+ fut = futures.Future(event_loop=self._event_loop, timeout=timeout)
+
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ except futures.CancelledError:
+ self._waiters.remove(fut)
+ return False
+ else:
+ f = self._waiters.popleft()
+ assert f is fut
+
+ self._locked = True
+ return True
+
+ def release(self):
+ """Release a lock.
+
+ When the lock is locked, reset it to unlocked, and return.
+ If any other coroutines are blocked waiting for the lock to become
+ unlocked, allow exactly one of them to proceed.
+
+ When invoked on an unlocked lock, a RuntimeError is raised.
+
+ There is no return value.
+ """
+ if self._locked:
+ self._locked = False
+ if self._waiters:
+ self._waiters[0].set_result(True)
+ else:
+ raise RuntimeError('Lock is not acquired.')
+
+ def __enter__(self):
+ if not self._locked:
+ raise RuntimeError(
+ '"yield from" should be used as context manager expression')
+ return True
+
+ def __exit__(self, *args):
+ self.release()
+
+ def __iter__(self):
+ yield from self.acquire()
+ return self
+
+
+class EventWaiter:
+ """A EventWaiter implementation, our equivalent to threading.Event
+
+ Class implementing event objects. An event manages a flag that can be set
+ to true with the set() method and reset to false with the clear() method.
+ The wait() method blocks until the flag is true. The flag is initially
+ false.
+ """
+
+ def __init__(self):
+ self._waiters = collections.deque()
+ self._value = False
+ self._event_loop = events.get_event_loop()
+
+ def __repr__(self):
+ res = super().__repr__()
+ return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset')
+
+ def is_set(self):
+ """Return true if and only if the internal flag is true."""
+ return self._value
+
+ def set(self):
+ """Set the internal flag to true. All coroutines waiting for it to
+ become true are awakened. Coroutine that call wait() once the flag is
+ true will not block at all.
+ """
+ if not self._value:
+ self._value = True
+
+ for fut in self._waiters:
+ if not fut.done():
+ fut.set_result(True)
+
+ def clear(self):
+ """Reset the internal flag to false. Subsequently, coroutines calling
+ wait() will block until set() is called to set the internal flag
+ to true again."""
+ self._value = False
+
+ @tasks.coroutine
+ def wait(self, timeout=None):
+ """Block until the internal flag is true. If the internal flag
+ is true on entry, return immediately. Otherwise, block until another
+ coroutine calls set() to set the flag to true, or until the optional
+ timeout occurs.
+
+ When the timeout argument is present and not None, it should be
+ a floating point number specifying a timeout for the operation in
+ seconds (or fractions thereof).
+
+ This method returns true if and only if the internal flag has been
+ set to true, either before the wait call or after the wait starts,
+ so it will always return True except if a timeout is given and
+ the operation times out.
+
+ wait() method is a coroutine.
+ """
+ if self._value:
+ return True
+
+ fut = futures.Future(event_loop=self._event_loop, timeout=timeout)
+
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ except futures.CancelledError:
+ self._waiters.remove(fut)
+ return False
+ else:
+ f = self._waiters.popleft()
+ assert f is fut
+
+ return True
+
+
+class Condition(Lock):
+ """A Condition implementation.
+
+ This class implements condition variable objects. A condition variable
+ allows one or more coroutines to wait until they are notified by another
+ coroutine.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ self._condition_waiters = collections.deque()
+
+ @tasks.coroutine
+ def wait(self, timeout=None):
+ """Wait until notified or until a timeout occurs. If the calling
+ coroutine has not acquired the lock when this method is called,
+ a RuntimeError is raised.
+
+ This method releases the underlying lock, and then blocks until it is
+ awakened by a notify() or notify_all() call for the same condition
+ variable in another coroutine, or until the optional timeout occurs.
+ Once awakened or timed out, it re-acquires the lock and returns.
+
+ When the timeout argument is present and not None, it should be
+ a floating point number specifying a timeout for the operation
+ in seconds (or fractions thereof).
+
+ The return value is True unless a given timeout expired, in which
+ case it is False.
+ """
+ if not self._locked:
+ raise RuntimeError('cannot wait on un-acquired lock')
+
+ self.release()
+
+ fut = futures.Future(event_loop=self._event_loop, timeout=timeout)
+
+ self._condition_waiters.append(fut)
+ try:
+ yield from fut
+ except futures.CancelledError:
+ self._condition_waiters.remove(fut)
+ return False
+ else:
+ f = self._condition_waiters.popleft()
+ assert fut is f
+ finally:
+ yield from self.acquire()
+
+ return True
+
+ @tasks.coroutine
+ def wait_for(self, predicate, timeout=None):
+ """Wait until a condition evaluates to True. predicate should be a
+ callable which result will be interpreted as a boolean value. A timeout
+ may be provided giving the maximum time to wait.
+ """
+ endtime = None
+ waittime = timeout
+ result = predicate()
+
+ while not result:
+ if waittime is not None:
+ if endtime is None:
+ endtime = time.monotonic() + waittime
+ else:
+ waittime = endtime - time.monotonic()
+ if waittime <= 0:
+ break
+
+ yield from self.wait(waittime)
+ result = predicate()
+
+ return result
+
+ def notify(self, n=1):
+ """By default, wake up one coroutine waiting on this condition, if any.
+ If the calling coroutine has not acquired the lock when this method
+ is called, a RuntimeError is raised.
+
+ This method wakes up at most n of the coroutines waiting for the
+ condition variable; it is a no-op if no coroutines are waiting.
+
+ Note: an awakened coroutine does not actually return from its
+ wait() call until it can reacquire the lock. Since notify() does
+ not release the lock, its caller should.
+ """
+ if not self._locked:
+ raise RuntimeError('cannot notify on un-acquired lock')
+
+ idx = 0
+ for fut in self._condition_waiters:
+ if idx >= n:
+ break
+
+ if not fut.done():
+ idx += 1
+ fut.set_result(False)
+
+ def notify_all(self):
+ """Wake up all threads waiting on this condition. This method acts
+ like notify(), but wakes up all waiting threads instead of one. If the
+ calling thread has not acquired the lock when this method is called,
+ a RuntimeError is raised.
+ """
+ self.notify(len(self._condition_waiters))
+
+
+class Semaphore:
+ """A Semaphore implementation.
+
+ A semaphore manages an internal counter which is decremented by each
+ acquire() call and incremented by each release() call. The counter
+ can never go below zero; when acquire() finds that it is zero, it blocks,
+ waiting until some other thread calls release().
+
+ Semaphores also support the context manager protocol.
+
+ The first optional argument gives the initial value for the internal
+ counter; it defaults to 1. If the value given is less than 0,
+ ValueError is raised.
+
+ The second optional argument determins can semophore be released more than
+ initial internal counter value; it defaults to False. If the value given
+ is True and number of release() is more than number of successfull
+ acquire() calls ValueError is raised.
+ """
+
+ def __init__(self, value=1, bound=False):
+ if value < 0:
+ raise ValueError("Semaphore initial value must be > 0")
+ self._value = value
+ self._bound = bound
+ self._bound_value = value
+ self._waiters = collections.deque()
+ self._locked = False
+ self._event_loop = events.get_event_loop()
+
+ def __repr__(self):
+ res = super().__repr__()
+ return '<%s [%s]>' % (
+ res[1:-1],
+ 'locked' if self._locked else 'unlocked,value:%s' % self._value)
+
+ def locked(self):
+ """Returns True if semaphore can not be acquired immediately."""
+ return self._locked
+
+ @tasks.coroutine
+ def acquire(self, timeout=None):
+ """Acquire a semaphore. acquire() method is a coroutine.
+
+ When invoked without arguments: if the internal counter is larger
+ than zero on entry, decrement it by one and return immediately.
+ If it is zero on entry, block, waiting until some other coroutine has
+ called release() to make it larger than zero.
+
+ When invoked with a timeout other than None, it will block for at
+ most timeout seconds. If acquire does not complete successfully in
+ that interval, return false. Return true otherwise.
+ """
+ if not self._waiters and self._value > 0:
+ self._value -= 1
+ if self._value == 0:
+ self._locked = True
+ return True
+
+ fut = futures.Future(event_loop=self._event_loop, timeout=timeout)
+
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ except futures.CancelledError:
+ self._waiters.remove(fut)
+ return False
+ else:
+ f = self._waiters.popleft()
+ assert f is fut
+
+ self._value -= 1
+ if self._value == 0:
+ self._locked = True
+ return True
+
+ def release(self):
+ """Release a semaphore, incrementing the internal counter by one.
+ When it was zero on entry and another coroutine is waiting for it to
+ become larger than zero again, wake up that coroutine.
+
+ If Semaphore is create with "bound" paramter equals true, then
+ release() method checks to make sure its current value doesn't exceed
+ its initial value. If it does, ValueError is raised.
+ """
+ if self._bound and self._value >= self._bound_value:
+ raise ValueError('Semaphore released too many times')
+
+ self._value += 1
+ self._locked = False
+
+ for waiter in self._waiters:
+ if not waiter.done():
+ waiter.set_result(True)
+ break
+
+ def __enter__(self):
+ return True
+
+ def __exit__(self, *args):
+ self.release()
+
+ def __iter__(self):
+ yield from self.acquire()
+ return self
diff --git a/tulip/log.py b/tulip/log.py
new file mode 100644
index 0000000..b918fe5
--- /dev/null
+++ b/tulip/log.py
@@ -0,0 +1,6 @@
+"""Tulip logging configuration"""
+
+import logging
+
+
+tulip_log = logging.getLogger("tulip")
diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py
new file mode 100644
index 0000000..46ffa58
--- /dev/null
+++ b/tulip/proactor_events.py
@@ -0,0 +1,189 @@
+"""Event loop using a proactor and related classes.
+
+A proactor is a "notify-on-completion" multiplexer. Currently a
+proactor is only implemented on Windows with IOCP.
+"""
+
+from . import base_events
+from . import transports
+from .log import tulip_log
+
+
+class _ProactorSocketTransport(transports.Transport):
+
+ def __init__(self, event_loop, sock, protocol, waiter=None, extra=None):
+ super().__init__(extra)
+ self._extra['socket'] = sock
+ self._event_loop = event_loop
+ self._sock = sock
+ self._protocol = protocol
+ self._buffer = []
+ self._read_fut = None
+ self._write_fut = None
+ self._closing = False # Set when close() called.
+ self._event_loop.call_soon(self._protocol.connection_made, self)
+ self._event_loop.call_soon(self._loop_reading)
+ if waiter is not None:
+ self._event_loop.call_soon(waiter.set_result, None)
+
+ def _loop_reading(self, fut=None):
+ data = None
+
+ try:
+ if fut is not None:
+ assert fut is self._read_fut
+
+ data = fut.result() # deliver data later in "finally" clause
+ if not data:
+ self._read_fut = None
+ return
+
+ self._read_fut = self._event_loop._proactor.recv(self._sock, 4096)
+ except ConnectionAbortedError as exc:
+ if not self._closing:
+ self._fatal_error(exc)
+ except OSError as exc:
+ self._fatal_error(exc)
+ else:
+ self._read_fut.add_done_callback(self._loop_reading)
+ finally:
+ if data:
+ self._protocol.data_received(data)
+ elif data is not None:
+ self._protocol.eof_received()
+
+ def write(self, data):
+ assert isinstance(data, bytes), repr(data)
+ assert not self._closing
+ if not data:
+ return
+ self._buffer.append(data)
+ if not self._write_fut:
+ self._loop_writing()
+
+ def _loop_writing(self, f=None):
+ try:
+ assert f is self._write_fut
+ if f:
+ f.result()
+ data = b''.join(self._buffer)
+ self._buffer = []
+ if not data:
+ self._write_fut = None
+ return
+ self._write_fut = self._event_loop._proactor.send(self._sock, data)
+ except OSError as exc:
+ self._fatal_error(exc)
+ else:
+ self._write_fut.add_done_callback(self._loop_writing)
+
+ # TODO: write_eof(), can_write_eof().
+
+ def abort(self):
+ self._fatal_error(None)
+
+ def close(self):
+ self._closing = True
+ if self._write_fut:
+ self._write_fut.cancel()
+ if not self._buffer:
+ self._event_loop.call_soon(self._call_connection_lost, None)
+
+ def _fatal_error(self, exc):
+ tulip_log.exception('Fatal error for %s', self)
+ if self._write_fut:
+ self._write_fut.cancel()
+ if self._read_fut: # XXX
+ self._read_fut.cancel()
+ self._write_fut = self._read_fut = None
+ self._buffer = []
+ self._event_loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._sock.close()
+
+
+class BaseProactorEventLoop(base_events.BaseEventLoop):
+
+ def __init__(self, proactor):
+ super().__init__()
+ tulip_log.debug('Using proactor: %s', proactor.__class__.__name__)
+ self._proactor = proactor
+ self._selector = proactor # convenient alias
+ self._make_self_pipe()
+
+ def _make_socket_transport(self, sock, protocol, waiter=None, extra=None):
+ return _ProactorSocketTransport(self, sock, protocol, waiter, extra)
+
+ def close(self):
+ if self._proactor is not None:
+ self._close_self_pipe()
+ self._proactor.close()
+ self._proactor = None
+ self._selector = None
+
+ def sock_recv(self, sock, n):
+ return self._proactor.recv(sock, n)
+
+ def sock_sendall(self, sock, data):
+ return self._proactor.send(sock, data)
+
+ def sock_connect(self, sock, address):
+ return self._proactor.connect(sock, address)
+
+ def sock_accept(self, sock):
+ return self._proactor.accept(sock)
+
+ def _socketpair(self):
+ raise NotImplementedError
+
+ def _close_self_pipe(self):
+ self._ssock.close()
+ self._ssock = None
+ self._csock.close()
+ self._csock = None
+ self._internal_fds -= 1
+
+ def _make_self_pipe(self):
+ # A self-socket, really. :-)
+ self._ssock, self._csock = self._socketpair()
+ self._ssock.setblocking(False)
+ self._csock.setblocking(False)
+ self._internal_fds += 1
+
+ def loop(f=None):
+ try:
+ if f:
+ f.result() # may raise
+ f = self._proactor.recv(self._ssock, 4096)
+ except:
+ self.close()
+ raise
+ else:
+ f.add_done_callback(loop)
+ self.call_soon(loop)
+
+ def _write_to_self(self):
+ self._csock.send(b'x')
+
+ def _start_serving(self, protocol_factory, sock):
+ def loop(f=None):
+ try:
+ if f:
+ conn, addr = f.result()
+ protocol = protocol_factory()
+ self._make_socket_transport(
+ conn, protocol, extra={'addr': addr})
+ f = self._proactor.accept(sock)
+ except OSError:
+ sock.close()
+ tulip_log.exception('Accept failed')
+ else:
+ f.add_done_callback(loop)
+ self.call_soon(loop)
+
+ def _process_events(self, event_list):
+ pass # XXX hard work currently done in poll
diff --git a/tulip/protocols.py b/tulip/protocols.py
new file mode 100644
index 0000000..593ee74
--- /dev/null
+++ b/tulip/protocols.py
@@ -0,0 +1,78 @@
+"""Abstract Protocol class."""
+
+__all__ = ['Protocol', 'DatagramProtocol']
+
+
+class BaseProtocol:
+ """ABC for base protocol class.
+
+ Usually user implements protocols that derived from BaseProtocol
+ like Protocol or ProcessProtocol.
+
+ The only case when BaseProtocol should be implemented directly is
+ write-only transport like write pipe
+ """
+
+ def connection_made(self, transport):
+ """Called when a connection is made.
+
+ The argument is the transport representing the pipe connection.
+ To receive data, wait for data_received() calls.
+ When the connection is closed, connection_lost() is called.
+ """
+
+ def connection_lost(self, exc):
+ """Called when the connection is lost or closed.
+
+ The argument is an exception object or None (the latter
+ meaning a regular EOF is received or the connection was
+ aborted or closed).
+ """
+
+
+class Protocol(BaseProtocol):
+ """ABC representing a protocol.
+
+ The user should implement this interface. They can inherit from
+ this class but don't need to. The implementations here do
+ nothing (they don't raise exceptions).
+
+ When the user wants to requests a transport, they pass a protocol
+ factory to a utility function (e.g., EventLoop.create_connection()).
+
+ When the connection is made successfully, connection_made() is
+ called with a suitable transport object. Then data_received()
+ will be called 0 or more times with data (bytes) received from the
+ transport; finally, connection_lost() will be called exactly once
+ with either an exception object or None as an argument.
+
+ State machine of calls:
+
+ start -> CM [-> DR*] [-> ER?] -> CL -> end
+ """
+
+ def data_received(self, data):
+ """Called when some data is received.
+
+ The argument is a bytes object.
+ """
+
+ def eof_received(self):
+ """Called when the other end calls write_eof() or equivalent.
+
+ The default implementation does nothing.
+
+ TODO: By default close the transport. But we don't have the
+ transport as an instance variable (connection_made() may not
+ set it).
+ """
+
+
+class DatagramProtocol(BaseProtocol):
+ """ABC representing a datagram protocol."""
+
+ def datagram_received(self, data, addr):
+ """Called when some datagram is received."""
+
+ def connection_refused(self, exc):
+ """Connection is refused."""
diff --git a/tulip/queues.py b/tulip/queues.py
new file mode 100644
index 0000000..ee349e1
--- /dev/null
+++ b/tulip/queues.py
@@ -0,0 +1,291 @@
+"""Queues"""
+
+__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue']
+
+import collections
+import concurrent.futures
+import heapq
+import queue
+
+from . import events
+from . import futures
+from . import locks
+from .tasks import coroutine
+
+
+class Queue:
+ """A queue, useful for coordinating producer and consumer coroutines.
+
+ If maxsize is less than or equal to zero, the queue size is infinite. If it
+ is an integer greater than 0, then "yield from put()" will block when the
+ queue reaches maxsize, until an item is removed by get().
+
+ Unlike the standard library Queue, you can reliably know this Queue's size
+ with qsize(), since your single-threaded Tulip application won't be
+ interrupted between calling qsize() and doing an operation on the Queue.
+ """
+
+ def __init__(self, maxsize=0):
+ self._event_loop = events.get_event_loop()
+ self._maxsize = maxsize
+
+ # Futures.
+ self._getters = collections.deque()
+ # Pairs of (item, Future).
+ self._putters = collections.deque()
+ self._init(maxsize)
+
+ def _init(self, maxsize):
+ self._queue = collections.deque()
+
+ def _get(self):
+ return self._queue.popleft()
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ def __repr__(self):
+ return '<%s at %s %s>' % (
+ type(self).__name__, hex(id(self)), self._format())
+
+ def __str__(self):
+ return '<%s %s>' % (type(self).__name__, self._format())
+
+ def _format(self):
+ result = 'maxsize=%r' % (self._maxsize, )
+ if getattr(self, '_queue', None):
+ result += ' _queue=%r' % list(self._queue)
+ if self._getters:
+ result += ' _getters[%s]' % len(self._getters)
+ if self._putters:
+ result += ' _putters[%s]' % len(self._putters)
+ return result
+
+ def _consume_done_getters(self, waiters):
+ # Delete waiters at the head of the get() queue who've timed out.
+ while waiters and waiters[0].done():
+ waiters.popleft()
+
+ def _consume_done_putters(self):
+ # Delete waiters at the head of the put() queue who've timed out.
+ while self._putters and self._putters[0][1].done():
+ self._putters.popleft()
+
+ def qsize(self):
+ """Number of items in the queue."""
+ return len(self._queue)
+
+ @property
+ def maxsize(self):
+ """Number of items allowed in the queue."""
+ return self._maxsize
+
+ def empty(self):
+ """Return True if the queue is empty, False otherwise."""
+ return not self._queue
+
+ def full(self):
+ """Return True if there are maxsize items in the queue.
+
+ Note: if the Queue was initialized with maxsize=0 (the default),
+ then full() is never True.
+ """
+ if self._maxsize <= 0:
+ return False
+ else:
+ return self.qsize() == self._maxsize
+
+ @coroutine
+ def put(self, item, timeout=None):
+ """Put an item into the queue.
+
+ If you yield from put() and timeout is None (the default), wait until a
+ free slot is available before adding item.
+
+ If a timeout is provided, raise queue.Full if no free slot becomes
+ available before the timeout.
+ """
+ self._consume_done_getters(self._getters)
+ if self._getters:
+ assert not self._queue, (
+ 'queue non-empty, why are getters waiting?')
+
+ getter = self._getters.popleft()
+
+ # Use _put and _get instead of passing item straight to getter, in
+ # case a subclass has logic that must run (e.g. JoinableQueue).
+ self._put(item)
+ getter.set_result(self._get())
+
+ elif self._maxsize > 0 and self._maxsize == self.qsize():
+ waiter = futures.Future(
+ event_loop=self._event_loop, timeout=timeout)
+
+ self._putters.append((item, waiter))
+ try:
+ yield from waiter
+ except concurrent.futures.CancelledError:
+ raise queue.Full
+
+ else:
+ self._put(item)
+
+ def put_nowait(self, item):
+ """Put an item into the queue without blocking.
+
+ If no free slot is immediately available, raise queue.Full.
+ """
+ self._consume_done_getters(self._getters)
+ if self._getters:
+ assert not self._queue, (
+ 'queue non-empty, why are getters waiting?')
+
+ getter = self._getters.popleft()
+
+ # Use _put and _get instead of passing item straight to getter, in
+ # case a subclass has logic that must run (e.g. JoinableQueue).
+ self._put(item)
+ getter.set_result(self._get())
+
+ elif self._maxsize > 0 and self._maxsize == self.qsize():
+ raise queue.Full
+ else:
+ self._put(item)
+
+ @coroutine
+ def get(self, timeout=None):
+ """Remove and return an item from the queue.
+
+ If you yield from get() and timeout is None (the default), wait until a
+ item is available.
+
+ If a timeout is provided, raise queue.Empty if no item is available
+ before the timeout.
+ """
+ self._consume_done_putters()
+ if self._putters:
+ assert self.full(), 'queue not full, why are putters waiting?'
+ item, putter = self._putters.popleft()
+ self._put(item)
+
+ # When a getter runs and frees up a slot so this putter can
+ # run, we need to defer the put for a tick to ensure that
+ # getters and putters alternate perfectly. See
+ # ChannelTest.test_wait.
+ self._event_loop.call_soon(putter.set_result, None)
+
+ return self._get()
+
+ elif self.qsize():
+ return self._get()
+ else:
+ waiter = futures.Future(
+ event_loop=self._event_loop, timeout=timeout)
+
+ self._getters.append(waiter)
+ try:
+ return (yield from waiter)
+ except concurrent.futures.CancelledError:
+ raise queue.Empty
+
+ def get_nowait(self):
+ """Remove and return an item from the queue.
+
+ Return an item if one is immediately available, else raise queue.Full.
+ """
+ self._consume_done_putters()
+ if self._putters:
+ assert self.full(), 'queue not full, why are putters waiting?'
+ item, putter = self._putters.popleft()
+ self._put(item)
+ # Wake putter on next tick.
+ putter.set_result(None)
+
+ return self._get()
+
+ elif self.qsize():
+ return self._get()
+ else:
+ raise queue.Empty
+
+
+class PriorityQueue(Queue):
+ """A subclass of Queue; retrieves entries in priority order (lowest first).
+
+ Entries are typically tuples of the form: (priority number, data).
+ """
+
+ def _init(self, maxsize):
+ self._queue = []
+
+ def _put(self, item, heappush=heapq.heappush):
+ heappush(self._queue, item)
+
+ def _get(self, heappop=heapq.heappop):
+ return heappop(self._queue)
+
+
+class LifoQueue(Queue):
+ """A subclass of Queue that retrieves most recently added entries first."""
+
+ def _init(self, maxsize):
+ self._queue = []
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ def _get(self):
+ return self._queue.pop()
+
+
+class JoinableQueue(Queue):
+ """A subclass of Queue with task_done() and join() methods."""
+
+ def __init__(self, maxsize=0):
+ self._unfinished_tasks = 0
+ self._finished = locks.EventWaiter()
+ self._finished.set()
+ super().__init__(maxsize=maxsize)
+
+ def _format(self):
+ result = Queue._format(self)
+ if self._unfinished_tasks:
+ result += ' tasks=%s' % self._unfinished_tasks
+ return result
+
+ def _put(self, item):
+ super()._put(item)
+ self._unfinished_tasks += 1
+ self._finished.clear()
+
+ def task_done(self):
+ """Indicate that a formerly enqueued task is complete.
+
+ Used by queue consumers. For each get() used to fetch a task,
+ a subsequent call to task_done() tells the queue that the processing
+ on the task is complete.
+
+ If a join() is currently blocking, it will resume when all items have
+ been processed (meaning that a task_done() call was received for every
+ item that had been put() into the queue).
+
+ Raises ValueError if called more times than there were items placed in
+ the queue.
+ """
+ if self._unfinished_tasks <= 0:
+ raise ValueError('task_done() called too many times')
+ self._unfinished_tasks -= 1
+ if self._unfinished_tasks == 0:
+ self._finished.set()
+
+ @coroutine
+ def join(self, timeout=None):
+ """Block until all items in the queue have been gotten and processed.
+
+ The count of unfinished tasks goes up whenever an item is added to the
+ queue. The count goes down whenever a consumer thread calls task_done()
+ to indicate that the item was retrieved and all work on it is complete.
+ When the count of unfinished tasks drops to zero, join() unblocks.
+ """
+ if self._unfinished_tasks > 0:
+ yield from self._finished.wait(timeout=timeout)
diff --git a/tulip/selector_events.py b/tulip/selector_events.py
new file mode 100644
index 0000000..20e5db0
--- /dev/null
+++ b/tulip/selector_events.py
@@ -0,0 +1,655 @@
+"""Event loop using a selector and related classes.
+
+A selector is a "notify-when-ready" multiplexer. For a subclass which
+also includes support for signal handling, see the unix_events sub-module.
+"""
+
+import collections
+import errno
+import socket
+try:
+ import ssl
+except ImportError: # pragma: no cover
+ ssl = None
+
+from . import base_events
+from . import events
+from . import futures
+from . import selectors
+from . import transports
+from .log import tulip_log
+
+
+# Errno values indicating the connection was disconnected.
+_DISCONNECTED = frozenset((errno.ECONNRESET,
+ errno.ENOTCONN,
+ errno.ESHUTDOWN,
+ errno.ECONNABORTED,
+ errno.EPIPE,
+ errno.EBADF,
+ ))
+
+
+class BaseSelectorEventLoop(base_events.BaseEventLoop):
+ """Selector event loop.
+
+ See events.EventLoop for API specification.
+ """
+
+ def __init__(self, selector=None):
+ super().__init__()
+
+ if selector is None:
+ selector = selectors.Selector()
+ tulip_log.debug('Using selector: %s', selector.__class__.__name__)
+ self._selector = selector
+ self._make_self_pipe()
+
+ def _make_socket_transport(self, sock, protocol, waiter=None, extra=None):
+ return _SelectorSocketTransport(self, sock, protocol, waiter, extra)
+
+ def _make_ssl_transport(self, rawsock, protocol,
+ sslcontext, waiter, extra=None):
+ return _SelectorSslTransport(
+ self, rawsock, protocol, sslcontext, waiter, extra)
+
+ def _make_datagram_transport(self, sock, protocol,
+ address=None, extra=None):
+ return _SelectorDatagramTransport(self, sock, protocol, address, extra)
+
+ def close(self):
+ if self._selector is not None:
+ self._close_self_pipe()
+ self._selector.close()
+ self._selector = None
+
+ def _socketpair(self):
+ raise NotImplementedError
+
+ def _close_self_pipe(self):
+ self.remove_reader(self._ssock.fileno())
+ self._ssock.close()
+ self._ssock = None
+ self._csock.close()
+ self._csock = None
+ self._internal_fds -= 1
+
+ def _make_self_pipe(self):
+ # A self-socket, really. :-)
+ self._ssock, self._csock = self._socketpair()
+ self._ssock.setblocking(False)
+ self._csock.setblocking(False)
+ self._internal_fds += 1
+ self.add_reader(self._ssock.fileno(), self._read_from_self)
+
+ def _read_from_self(self):
+ try:
+ self._ssock.recv(1)
+ except (BlockingIOError, InterruptedError):
+ pass
+
+ def _write_to_self(self):
+ try:
+ self._csock.send(b'x')
+ except (BlockingIOError, InterruptedError):
+ pass
+
+ def _start_serving(self, protocol_factory, sock):
+ self.add_reader(sock.fileno(), self._accept_connection,
+ protocol_factory, sock)
+
+ def _accept_connection(self, protocol_factory, sock):
+ try:
+ conn, addr = sock.accept()
+ except (BlockingIOError, InterruptedError):
+ pass # False alarm.
+ except:
+ # Bad error. Stop serving.
+ self.remove_reader(sock.fileno())
+ sock.close()
+ # There's nowhere to send the error, so just log it.
+ # TODO: Someone will want an error handler for this.
+ tulip_log.exception('Accept failed')
+ else:
+ self._make_socket_transport(
+ conn, protocol_factory(), extra={'addr': addr})
+ # It's now up to the protocol to handle the connection.
+
+ def add_reader(self, fd, callback, *args):
+ """Add a reader callback. Return a Handle instance."""
+ handle = events.make_handle(callback, args)
+ try:
+ mask, (reader, writer) = self._selector.get_info(fd)
+ except KeyError:
+ self._selector.register(fd, selectors.EVENT_READ,
+ (handle, None))
+ else:
+ self._selector.modify(fd, mask | selectors.EVENT_READ,
+ (handle, writer))
+ if reader is not None:
+ reader.cancel()
+
+ return handle
+
+ def remove_reader(self, fd):
+ """Remove a reader callback."""
+ try:
+ mask, (reader, writer) = self._selector.get_info(fd)
+ except KeyError:
+ return False
+ else:
+ mask &= ~selectors.EVENT_READ
+ if not mask:
+ self._selector.unregister(fd)
+ else:
+ self._selector.modify(fd, mask, (None, writer))
+
+ if reader is not None:
+ reader.cancel()
+ return True
+ else:
+ return False
+
+ def add_writer(self, fd, callback, *args):
+ """Add a writer callback. Return a Handle instance."""
+ handle = events.make_handle(callback, args)
+ try:
+ mask, (reader, writer) = self._selector.get_info(fd)
+ except KeyError:
+ self._selector.register(fd, selectors.EVENT_WRITE,
+ (None, handle))
+ else:
+ self._selector.modify(fd, mask | selectors.EVENT_WRITE,
+ (reader, handle))
+ if writer is not None:
+ writer.cancel()
+
+ return handle
+
+ def remove_writer(self, fd):
+ """Remove a writer callback."""
+ try:
+ mask, (reader, writer) = self._selector.get_info(fd)
+ except KeyError:
+ return False
+ else:
+ # Remove both writer and connector.
+ mask &= ~selectors.EVENT_WRITE
+ if not mask:
+ self._selector.unregister(fd)
+ else:
+ self._selector.modify(fd, mask, (reader, None))
+
+ if writer is not None:
+ writer.cancel()
+ return True
+ else:
+ return False
+
+ def sock_recv(self, sock, n):
+ """XXX"""
+ fut = futures.Future()
+ self._sock_recv(fut, False, sock, n)
+ return fut
+
+ def _sock_recv(self, fut, registered, sock, n):
+ fd = sock.fileno()
+ if registered:
+ # Remove the callback early. It should be rare that the
+ # selector says the fd is ready but the call still returns
+ # EAGAIN, and I am willing to take a hit in that case in
+ # order to simplify the common case.
+ self.remove_reader(fd)
+ if fut.cancelled():
+ return
+ try:
+ data = sock.recv(n)
+ fut.set_result(data)
+ except (BlockingIOError, InterruptedError):
+ self.add_reader(fd, self._sock_recv, fut, True, sock, n)
+ except Exception as exc:
+ fut.set_exception(exc)
+
+ def sock_sendall(self, sock, data):
+ """XXX"""
+ fut = futures.Future()
+ if data:
+ self._sock_sendall(fut, False, sock, data)
+ else:
+ fut.set_result(None)
+ return fut
+
+ def _sock_sendall(self, fut, registered, sock, data):
+ fd = sock.fileno()
+
+ if registered:
+ self.remove_writer(fd)
+ if fut.cancelled():
+ return
+
+ try:
+ n = sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except Exception as exc:
+ fut.set_exception(exc)
+ return
+
+ if n == len(data):
+ fut.set_result(None)
+ else:
+ if n:
+ data = data[n:]
+ self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
+
+ def sock_connect(self, sock, address):
+ """XXX"""
+ # That address better not require a lookup! We're not calling
+ # self.getaddrinfo() for you here. But verifying this is
+ # complicated; the socket module doesn't have a pattern for
+ # IPv6 addresses (there are too many forms, apparently).
+ fut = futures.Future()
+ self._sock_connect(fut, False, sock, address)
+ return fut
+
+ def _sock_connect(self, fut, registered, sock, address):
+ fd = sock.fileno()
+ if registered:
+ self.remove_writer(fd)
+ if fut.cancelled():
+ return
+ try:
+ if not registered:
+ # First time around.
+ sock.connect(address)
+ else:
+ err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ # Jump to the except clause below.
+ raise socket.error(err, 'Connect call failed')
+ fut.set_result(None)
+ except (BlockingIOError, InterruptedError):
+ self.add_writer(fd, self._sock_connect, fut, True, sock, address)
+ except Exception as exc:
+ fut.set_exception(exc)
+
+ def sock_accept(self, sock):
+ """XXX"""
+ fut = futures.Future()
+ self._sock_accept(fut, False, sock)
+ return fut
+
+ def _sock_accept(self, fut, registered, sock):
+ fd = sock.fileno()
+ if registered:
+ self.remove_reader(fd)
+ if fut.cancelled():
+ return
+ try:
+ conn, address = sock.accept()
+ conn.setblocking(False)
+ fut.set_result((conn, address))
+ except (BlockingIOError, InterruptedError):
+ self.add_reader(fd, self._sock_accept, fut, True, sock)
+ except Exception as exc:
+ fut.set_exception(exc)
+
+ def _process_events(self, event_list):
+ for fileobj, mask, (reader, writer) in event_list:
+ if mask & selectors.EVENT_READ and reader is not None:
+ if reader.cancelled:
+ self.remove_reader(fileobj)
+ else:
+ self._add_callback(reader)
+ if mask & selectors.EVENT_WRITE and writer is not None:
+ if writer.cancelled:
+ self.remove_writer(fileobj)
+ else:
+ self._add_callback(writer)
+
+
+class _SelectorSocketTransport(transports.Transport):
+
+ def __init__(self, event_loop, sock, protocol, waiter=None, extra=None):
+ super().__init__(extra)
+ self._extra['socket'] = sock
+ self._event_loop = event_loop
+ self._sock = sock
+ self._protocol = protocol
+ self._buffer = []
+ self._closing = False # Set when close() called.
+ self._event_loop.add_reader(self._sock.fileno(), self._read_ready)
+ self._event_loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ self._event_loop.call_soon(waiter.set_result, None)
+
+ def _read_ready(self):
+ try:
+ data = self._sock.recv(16*1024)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ self._event_loop.remove_reader(self._sock.fileno())
+ self._protocol.eof_received()
+
+ def write(self, data):
+ assert isinstance(data, (bytes, bytearray)), repr(data)
+ assert not self._closing
+ if not data:
+ return
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ n = self._sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except socket.error as exc:
+ self._fatal_error(exc)
+ return
+
+ if n == len(data):
+ return
+ elif n:
+ data = data[n:]
+ self._event_loop.add_writer(self._sock.fileno(), self._write_ready)
+
+ self._buffer.append(data)
+
+ def _write_ready(self):
+ data = b''.join(self._buffer)
+ assert data, "Data should not be empty"
+
+ self._buffer.clear()
+ try:
+ n = self._sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ self._buffer.append(data)
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ if n == len(data):
+ self._event_loop.remove_writer(self._sock.fileno())
+ if self._closing:
+ self._call_connection_lost(None)
+ return
+ elif n:
+ data = data[n:]
+
+ self._buffer.append(data) # Try again later.
+
+ # TODO: write_eof(), can_write_eof().
+
+ def abort(self):
+ self._close(None)
+
+ def close(self):
+ self._closing = True
+ self._event_loop.remove_reader(self._sock.fileno())
+ if not self._buffer:
+ self._call_connection_lost(None)
+
+ def _fatal_error(self, exc):
+ # should be called from exception handler only
+ tulip_log.exception('Fatal error for %s', self)
+ self._close(exc)
+
+ def _close(self, exc):
+ self._event_loop.remove_writer(self._sock.fileno())
+ self._event_loop.remove_reader(self._sock.fileno())
+ self._buffer.clear()
+ self._call_connection_lost(exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._sock.close()
+
+
+class _SelectorSslTransport(transports.Transport):
+
+ def __init__(self, event_loop, rawsock,
+ protocol, sslcontext, waiter, extra=None):
+ super().__init__(extra)
+
+ self._event_loop = event_loop
+ self._rawsock = rawsock
+ self._protocol = protocol
+ sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ self._sslcontext = sslcontext
+ self._waiter = waiter
+ sslsock = sslcontext.wrap_socket(rawsock,
+ do_handshake_on_connect=False)
+ self._sslsock = sslsock
+ self._buffer = []
+ self._closing = False # Set when close() called.
+ self._extra['socket'] = sslsock
+
+ self._on_handshake()
+
+ def _on_handshake(self):
+ fd = self._sslsock.fileno()
+ try:
+ self._sslsock.do_handshake()
+ except ssl.SSLWantReadError:
+ self._event_loop.add_reader(fd, self._on_handshake)
+ return
+ except ssl.SSLWantWriteError:
+ self._event_loop.add_writer(fd, self._on_handshake)
+ return
+ except Exception as exc:
+ self._sslsock.close()
+ self._waiter.set_exception(exc)
+ return
+ except BaseException as exc:
+ self._sslsock.close()
+ self._waiter.set_exception(exc)
+ raise
+ self._event_loop.remove_reader(fd)
+ self._event_loop.remove_writer(fd)
+ self._event_loop.add_reader(fd, self._on_ready)
+ self._event_loop.add_writer(fd, self._on_ready)
+ self._event_loop.call_soon(self._protocol.connection_made, self)
+ self._event_loop.call_soon(self._waiter.set_result, None)
+
+ def _on_ready(self):
+ # Because of renegotiations (?), there's no difference between
+ # readable and writable. We just try both. XXX This may be
+ # incorrect; we probably need to keep state about what we
+ # should do next.
+
+ # Maybe we're already closed...
+ fd = self._sslsock.fileno()
+ if fd < 0:
+ return
+
+ # First try reading.
+ try:
+ data = self._sslsock.recv(8192)
+ except ssl.SSLWantReadError:
+ pass
+ except ssl.SSLWantWriteError:
+ pass
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ # TODO: Don't close when self._buffer is non-empty.
+ assert not self._buffer
+ self._event_loop.remove_reader(fd)
+ self._event_loop.remove_writer(fd)
+ self._sslsock.close()
+ self._protocol.connection_lost(None)
+ return
+
+ # Now try writing, if there's anything to write.
+ if not self._buffer:
+ return
+
+ data = b''.join(self._buffer)
+ self._buffer = []
+ try:
+ n = self._sslsock.send(data)
+ except ssl.SSLWantReadError:
+ n = 0
+ except ssl.SSLWantWriteError:
+ n = 0
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except Exception as exc:
+ self._fatal_error(exc)
+ return
+
+ if n < len(data):
+ self._buffer.append(data[n:])
+ elif self._closing:
+ self._event_loop.remove_writer(self._sslsock.fileno())
+ self._sslsock.close()
+ self._protocol.connection_lost(None)
+
+ def write(self, data):
+ assert isinstance(data, bytes), repr(data)
+ assert not self._closing
+ if not data:
+ return
+ self._buffer.append(data)
+ # We could optimize, but the callback can do this for now.
+
+ # TODO: write_eof(), can_write_eof().
+
+ def abort(self):
+ self._close(None)
+
+ def close(self):
+ self._closing = True
+ self._event_loop.remove_reader(self._sslsock.fileno())
+ if not self._buffer:
+ self._protocol.connection_lost(None)
+
+ def _fatal_error(self, exc):
+ tulip_log.exception('Fatal error for %s', self)
+ self._close(exc)
+
+ def _close(self, exc):
+ self._event_loop.remove_writer(self._sslsock.fileno())
+ self._event_loop.remove_reader(self._sslsock.fileno())
+ self._buffer = []
+ self._protocol.connection_lost(exc)
+
+
+class _SelectorDatagramTransport(transports.DatagramTransport):
+
+ max_size = 256 * 1024 # max bytes we read in one eventloop iteration
+
+ def __init__(self, event_loop, sock, protocol, address=None, extra=None):
+ super().__init__(extra)
+ self._extra['socket'] = sock
+ self._event_loop = event_loop
+ self._sock = sock
+ self._fileno = sock.fileno()
+ self._protocol = protocol
+ self._address = address
+ self._buffer = collections.deque()
+ self._closing = False # Set when close() called.
+ self._event_loop.add_reader(self._fileno, self._read_ready)
+ self._event_loop.call_soon(self._protocol.connection_made, self)
+
+ def _read_ready(self):
+ try:
+ data, addr = self._sock.recvfrom(self.max_size)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ self._protocol.datagram_received(data, addr)
+
+ def sendto(self, data, addr=None):
+ assert isinstance(data, bytes)
+ assert not self._closing
+ if not data:
+ return
+
+ if self._address:
+ assert addr in (None, self._address)
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ if self._address:
+ self._sock.send(data)
+ else:
+ self._sock.sendto(data, addr)
+ return
+ except ConnectionRefusedError as exc:
+ if self._address:
+ self._fatal_error(exc)
+ return
+ except (BlockingIOError, InterruptedError):
+ self._event_loop.add_writer(self._fileno, self._sendto_ready)
+ except Exception as exc:
+ self._fatal_error(exc)
+ return
+
+ self._buffer.append((data, addr))
+
+ def _sendto_ready(self):
+ while self._buffer:
+ data, addr = self._buffer.popleft()
+ try:
+ if self._address:
+ self._sock.send(data)
+ else:
+ self._sock.sendto(data, addr)
+ except ConnectionRefusedError as exc:
+ if self._address:
+ self._fatal_error(exc)
+ return
+ except (BlockingIOError, InterruptedError):
+ self._buffer.appendleft((data, addr)) # Try again later.
+ break
+ except Exception as exc:
+ self._fatal_error(exc)
+ return
+
+ if not self._buffer:
+ self._event_loop.remove_writer(self._fileno)
+ if self._closing:
+ self._call_connection_lost(None)
+
+ def abort(self):
+ self._close(None)
+
+ def close(self):
+ self._closing = True
+ self._event_loop.remove_reader(self._fileno)
+ if not self._buffer:
+ self._call_connection_lost(None)
+
+ def _fatal_error(self, exc):
+ tulip_log.exception('Fatal error for %s', self)
+ self._close(exc)
+
+ def _close(self, exc):
+ self._buffer.clear()
+ self._event_loop.remove_writer(self._fileno)
+ self._event_loop.remove_reader(self._fileno)
+ if self._address and isinstance(exc, ConnectionRefusedError):
+ self._protocol.connection_refused(exc)
+ self._call_connection_lost(exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._sock.close()
diff --git a/tulip/selectors.py b/tulip/selectors.py
new file mode 100644
index 0000000..57be7ab
--- /dev/null
+++ b/tulip/selectors.py
@@ -0,0 +1,418 @@
+"""Select module.
+
+This module supports asynchronous I/O on multiple file descriptors.
+"""
+
+import sys
+from select import *
+
+from .log import tulip_log
+
+
+# generic events, that must be mapped to implementation-specific ones
+# read event
+EVENT_READ = (1 << 0)
+# write event
+EVENT_WRITE = (1 << 1)
+
+
+def _fileobj_to_fd(fileobj):
+ """Return a file descriptor from a file object.
+
+ Parameters:
+ fileobj -- file descriptor, or any object with a `fileno()` method
+
+ Returns:
+ corresponding file descriptor
+ """
+ if isinstance(fileobj, int):
+ fd = fileobj
+ else:
+ try:
+ fd = int(fileobj.fileno())
+ except (ValueError, TypeError):
+ raise ValueError("Invalid file object: {!r}".format(fileobj))
+ return fd
+
+
+class SelectorKey:
+ """Object used internally to associate a file object to its backing file
+ descriptor, selected event mask and attached data."""
+
+ def __init__(self, fileobj, events, data=None):
+ self.fileobj = fileobj
+ self.fd = _fileobj_to_fd(fileobj)
+ self.events = events
+ self.data = data
+
+ def __repr__(self):
+ return '{}<fileobj={}, fd={}, events={:#x}, data={}>'.format(
+ self.__class__.__name__,
+ self.fileobj, self.fd, self.events, self.data)
+
+
+class _BaseSelector:
+ """Base selector class.
+
+ A selector supports registering file objects to be monitored for specific
+ I/O events.
+
+ A file object is a file descriptor or any object with a `fileno()` method.
+ An arbitrary object can be attached to the file object, which can be used
+ for example to store context information, a callback, etc.
+
+ A selector can use various implementations (select(), poll(), epoll()...)
+ depending on the platform. The default `Selector` class uses the most
+ performant implementation on the current platform.
+ """
+
+ def __init__(self):
+ # this maps file descriptors to keys
+ self._fd_to_key = {}
+ # this maps file objects to keys - for fast (un)registering
+ self._fileobj_to_key = {}
+
+ def register(self, fileobj, events, data=None):
+ """Register a file object.
+
+ Parameters:
+ fileobj -- file object
+ events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
+ data -- attached data
+
+ Returns:
+ SelectorKey instance
+ """
+ if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)):
+ raise ValueError("Invalid events: {}".format(events))
+
+ if fileobj in self._fileobj_to_key:
+ raise ValueError("{!r} is already registered".format(fileobj))
+
+ key = SelectorKey(fileobj, events, data)
+ self._fd_to_key[key.fd] = key
+ self._fileobj_to_key[fileobj] = key
+ return key
+
+ def unregister(self, fileobj):
+ """Unregister a file object.
+
+ Parameters:
+ fileobj -- file object
+
+ Returns:
+ SelectorKey instance
+ """
+ try:
+ key = self._fileobj_to_key[fileobj]
+ del self._fd_to_key[key.fd]
+ del self._fileobj_to_key[fileobj]
+ except KeyError:
+ raise ValueError("{!r} is not registered".format(fileobj))
+ return key
+
+ def modify(self, fileobj, events, data=None):
+ """Change a registered file object monitored events or attached data.
+
+ Parameters:
+ fileobj -- file object
+ events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
+ data -- attached data
+ """
+ # TODO: Subclasses can probably optimize this even further.
+ try:
+ key = self._fileobj_to_key[fileobj]
+ except KeyError:
+ raise ValueError("{!r} is not registered".format(fileobj))
+ if events != key.events or data != key.data:
+ self.unregister(fileobj)
+ return self.register(fileobj, events, data)
+ else:
+ return key
+
+ def select(self, timeout=None):
+ """Perform the actual selection, until some monitored file objects are
+ ready or a timeout expires.
+
+ Parameters:
+ timeout -- if timeout > 0, this specifies the maximum wait time, in
+ seconds
+ if timeout == 0, the select() call won't block, and will
+ report the currently ready file objects
+ if timeout is None, select() will block until a monitored
+ file object becomes ready
+
+ Returns:
+ list of (fileobj, events, attached data) for ready file objects
+ `events` is a bitwise mask of EVENT_READ|EVENT_WRITE
+ """
+ raise NotImplementedError()
+
+ def close(self):
+ """Close the selector.
+
+ This must be called to make sure that any underlying resource is freed.
+ """
+ self._fd_to_key.clear()
+ self._fileobj_to_key.clear()
+
+ def get_info(self, fileobj):
+ """Return information about a registered file object.
+
+ Returns:
+ (events, data) associated to this file object
+
+ Raises KeyError if the file object is not registered.
+ """
+ try:
+ key = self._fileobj_to_key[fileobj]
+ except KeyError:
+ raise KeyError("{} is not registered".format(fileobj))
+ return key.events, key.data
+
+ def registered_count(self):
+ """Return the number of registered file objects.
+
+ Returns:
+ number of currently registered file objects
+ """
+ return len(self._fd_to_key)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.close()
+
+ def _key_from_fd(self, fd):
+ """Return the key associated to a given file descriptor.
+
+ Parameters:
+ fd -- file descriptor
+
+ Returns:
+ corresponding key
+ """
+ try:
+ return self._fd_to_key[fd]
+ except KeyError:
+ tulip_log.warn('No key found for fd %r', fd)
+ return None
+
+
+class SelectSelector(_BaseSelector):
+ """Select-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._readers = set()
+ self._writers = set()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ if events & EVENT_READ:
+ self._readers.add(key.fd)
+ if events & EVENT_WRITE:
+ self._writers.add(key.fd)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ self._readers.discard(key.fd)
+ self._writers.discard(key.fd)
+ return key
+
+ def select(self, timeout=None):
+ try:
+ r, w, _ = self._select(self._readers, self._writers, [], timeout)
+ except InterruptedError:
+ # A signal arrived. Don't die, just return no events.
+ return []
+ r = set(r)
+ w = set(w)
+ ready = []
+ for fd in r | w:
+ events = 0
+ if fd in r:
+ events |= EVENT_READ
+ if fd in w:
+ events |= EVENT_WRITE
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key.fileobj, events & key.events, key.data))
+ return ready
+
+ if sys.platform == 'win32':
+ def _select(self, r, w, _, timeout=None):
+ r, w, x = select(r, w, w, timeout)
+ return r, w + x, []
+ else:
+ from select import select as _select
+
+
+if 'poll' in globals():
+
+ # TODO: Implement poll() for Windows with workaround for
+ # brokenness in WSAPoll() (Richard Oudkerk, see
+ # http://bugs.python.org/issue16507).
+
+ class PollSelector(_BaseSelector):
+ """Poll-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._poll = poll()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ poll_events = 0
+ if events & EVENT_READ:
+ poll_events |= POLLIN
+ if events & EVENT_WRITE:
+ poll_events |= POLLOUT
+ self._poll.register(key.fd, poll_events)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ self._poll.unregister(key.fd)
+ return key
+
+ def select(self, timeout=None):
+ timeout = None if timeout is None else int(1000 * timeout)
+ ready = []
+ try:
+ fd_event_list = self._poll.poll(timeout)
+ except InterruptedError:
+ # A signal arrived. Don't die, just return no events.
+ return []
+ for fd, event in fd_event_list:
+ events = 0
+ if event & ~POLLIN:
+ events |= EVENT_WRITE
+ if event & ~POLLOUT:
+ events |= EVENT_READ
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key.fileobj, events & key.events, key.data))
+ return ready
+
+
+if 'epoll' in globals():
+
+ class EpollSelector(_BaseSelector):
+ """Epoll-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._epoll = epoll()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ epoll_events = 0
+ if events & EVENT_READ:
+ epoll_events |= EPOLLIN
+ if events & EVENT_WRITE:
+ epoll_events |= EPOLLOUT
+ self._epoll.register(key.fd, epoll_events)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ self._epoll.unregister(key.fd)
+ return key
+
+ def select(self, timeout=None):
+ timeout = -1 if timeout is None else timeout
+ max_ev = self.registered_count()
+ ready = []
+ try:
+ fd_event_list = self._epoll.poll(timeout, max_ev)
+ except InterruptedError:
+ # A signal arrived. Don't die, just return no events.
+ return []
+ for fd, event in fd_event_list:
+ events = 0
+ if event & ~EPOLLIN:
+ events |= EVENT_WRITE
+ if event & ~EPOLLOUT:
+ events |= EVENT_READ
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key.fileobj, events & key.events, key.data))
+ return ready
+
+ def close(self):
+ super().close()
+ self._epoll.close()
+
+
+if 'kqueue' in globals():
+
+ class KqueueSelector(_BaseSelector):
+ """Kqueue-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._kqueue = kqueue()
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ if key.events & EVENT_READ:
+ kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE)
+ self._kqueue.control([kev], 0, 0)
+ if key.events & EVENT_WRITE:
+ kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE)
+ self._kqueue.control([kev], 0, 0)
+ return key
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ if events & EVENT_READ:
+ kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD)
+ self._kqueue.control([kev], 0, 0)
+ if events & EVENT_WRITE:
+ kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD)
+ self._kqueue.control([kev], 0, 0)
+ return key
+
+ def select(self, timeout=None):
+ max_ev = self.registered_count()
+ ready = []
+ try:
+ kev_list = self._kqueue.control(None, max_ev, timeout)
+ except InterruptedError:
+ # A signal arrived. Don't die, just return no events.
+ return []
+ for kev in kev_list:
+ fd = kev.ident
+ flag = kev.filter
+ events = 0
+ if flag == KQ_FILTER_READ:
+ events |= EVENT_READ
+ if flag == KQ_FILTER_WRITE:
+ events |= EVENT_WRITE
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key.fileobj, events & key.events, key.data))
+ return ready
+
+ def close(self):
+ super().close()
+ self._kqueue.close()
+
+
+# Choose the best implementation: roughly, epoll|kqueue > poll > select.
+# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
+if 'KqueueSelector' in globals():
+ Selector = KqueueSelector
+elif 'EpollSelector' in globals():
+ Selector = EpollSelector
+elif 'PollSelector' in globals():
+ Selector = PollSelector
+else:
+ Selector = SelectSelector
diff --git a/tulip/streams.py b/tulip/streams.py
new file mode 100644
index 0000000..8d7f623
--- /dev/null
+++ b/tulip/streams.py
@@ -0,0 +1,145 @@
+"""Stream-related things."""
+
+__all__ = ['StreamReader']
+
+import collections
+
+from . import futures
+from . import tasks
+
+
+class StreamReader:
+
+ def __init__(self, limit=2**16):
+ self.limit = limit # Max line length. (Security feature.)
+ self.buffer = collections.deque() # Deque of bytes objects.
+ self.byte_count = 0 # Bytes in buffer.
+ self.eof = False # Whether we're done.
+ self.waiter = None # A future.
+ self._exception = None
+
+ def exception(self):
+ return self._exception
+
+ def set_exception(self, exc):
+ self._exception = exc
+
+ if self.waiter is not None:
+ self.waiter.set_exception(exc)
+
+ def feed_eof(self):
+ self.eof = True
+ waiter = self.waiter
+ if waiter is not None:
+ self.waiter = None
+ waiter.set_result(True)
+
+ def feed_data(self, data):
+ if not data:
+ return
+
+ self.buffer.append(data)
+ self.byte_count += len(data)
+
+ waiter = self.waiter
+ if waiter is not None:
+ self.waiter = None
+ waiter.set_result(False)
+
+ @tasks.coroutine
+ def readline(self):
+ if self._exception is not None:
+ raise self._exception
+
+ parts = []
+ parts_size = 0
+ not_enough = True
+
+ while not_enough:
+ while self.buffer and not_enough:
+ data = self.buffer.popleft()
+ ichar = data.find(b'\n')
+ if ichar < 0:
+ parts.append(data)
+ parts_size += len(data)
+ else:
+ ichar += 1
+ head, tail = data[:ichar], data[ichar:]
+ if tail:
+ self.buffer.appendleft(tail)
+ not_enough = False
+ parts.append(head)
+ parts_size += len(head)
+
+ if parts_size > self.limit:
+ self.byte_count -= parts_size
+ raise ValueError('Line is too long')
+
+ if self.eof:
+ break
+
+ if not_enough:
+ assert self.waiter is None
+ self.waiter = futures.Future()
+ yield from self.waiter
+
+ line = b''.join(parts)
+ self.byte_count -= parts_size
+
+ return line
+
+ @tasks.coroutine
+ def read(self, n=-1):
+ if self._exception is not None:
+ raise self._exception
+
+ if not n:
+ return b''
+
+ if n < 0:
+ while not self.eof:
+ assert not self.waiter
+ self.waiter = futures.Future()
+ yield from self.waiter
+ else:
+ if not self.byte_count and not self.eof:
+ assert not self.waiter
+ self.waiter = futures.Future()
+ yield from self.waiter
+
+ if n < 0 or self.byte_count <= n:
+ data = b''.join(self.buffer)
+ self.buffer.clear()
+ self.byte_count = 0
+ return data
+
+ parts = []
+ parts_bytes = 0
+ while self.buffer and parts_bytes < n:
+ data = self.buffer.popleft()
+ data_bytes = len(data)
+ if n < parts_bytes + data_bytes:
+ data_bytes = n - parts_bytes
+ data, rest = data[:data_bytes], data[data_bytes:]
+ self.buffer.appendleft(rest)
+
+ parts.append(data)
+ parts_bytes += data_bytes
+ self.byte_count -= data_bytes
+
+ return b''.join(parts)
+
+ @tasks.coroutine
+ def readexactly(self, n):
+ if self._exception is not None:
+ raise self._exception
+
+ if n <= 0:
+ return b''
+
+ while self.byte_count < n and not self.eof:
+ assert not self.waiter
+ self.waiter = futures.Future()
+ yield from self.waiter
+
+ return (yield from self.read(n))
diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py
new file mode 100644
index 0000000..734a5fa
--- /dev/null
+++ b/tulip/subprocess_transport.py
@@ -0,0 +1,139 @@
+import fcntl
+import os
+import traceback
+
+from . import transports
+from . import events
+from .log import tulip_log
+
+
+class UnixSubprocessTransport(transports.Transport):
+ """Transport class managing a subprocess.
+
+ TODO: Separate this into something that just handles pipe I/O,
+ and something else that handles pipe setup, fork, and exec.
+ """
+
+ def __init__(self, protocol, args):
+ self._protocol = protocol # Not a factory! :-)
+ self._args = args # args[0] must be full path of binary.
+ self._event_loop = events.get_event_loop()
+ self._buffer = []
+ self._eof = False
+ rstdin, self._wstdin = os.pipe()
+ self._rstdout, wstdout = os.pipe()
+
+ # TODO: This is incredibly naive. Should look at
+ # subprocess.py for all the precautions around fork/exec.
+ pid = os.fork()
+ if not pid:
+ # Child.
+ try:
+ os.dup2(rstdin, 0)
+ os.dup2(wstdout, 1)
+ # TODO: What to do with stderr?
+ os.execv(args[0], args)
+ except:
+ try:
+ traceback.print_traceback()
+ finally:
+ os._exit(127)
+
+ # Parent.
+ os.close(rstdin)
+ os.close(wstdout)
+ _setnonblocking(self._wstdin)
+ _setnonblocking(self._rstdout)
+ self._event_loop.call_soon(self._protocol.connection_made, self)
+ self._event_loop.add_reader(self._rstdout, self._stdout_callback)
+
+ def write(self, data):
+ assert not self._eof
+ assert isinstance(data, bytes), repr(data)
+ if not data:
+ return
+
+ if not self._buffer:
+ # Attempt to write it right away first.
+ try:
+ n = os.write(self._wstdin, data)
+ except BlockingIOError:
+ pass
+ except Exception as exc:
+ self._fatal_error(exc)
+ return
+ else:
+ if n == len(data):
+ return
+ elif n:
+ data = data[n:]
+ self._event_loop.add_writer(self._wstdin, self._stdin_callback)
+ self._buffer.append(data)
+
+ def write_eof(self):
+ assert not self._eof
+ assert self._wstdin >= 0
+ self._eof = True
+ if not self._buffer:
+ self._event_loop.remove_writer(self._wstdin)
+ os.close(self._wstdin)
+ self._wstdin = -1
+
+ def close(self):
+ if not self._eof:
+ self.write_eof()
+ # XXX What else?
+
+ def _fatal_error(self, exc):
+ tulip_log.error('Fatal error: %r', exc)
+ if self._rstdout >= 0:
+ os.close(self._rstdout)
+ self._rstdout = -1
+ if self._wstdin >= 0:
+ os.close(self._wstdin)
+ self._wstdin = -1
+ self._eof = True
+ self._buffer = None
+
+ def _stdin_callback(self):
+ data = b''.join(self._buffer)
+ assert data, "Data shold not be empty"
+
+ self._buffer = []
+ try:
+ n = os.write(self._wstdin, data)
+ except BlockingIOError:
+ self._buffer.append(data)
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ if n >= len(data):
+ self._event_loop.remove_writer(self._wstdin)
+ if self._eof:
+ os.close(self._wstdin)
+ self._wstdin = -1
+ return
+
+ elif n > 0:
+ data = data[n:]
+
+ self._buffer.append(data) # Try again later.
+
+ def _stdout_callback(self):
+ try:
+ data = os.read(self._rstdout, 1024)
+ except BlockingIOError:
+ pass
+ else:
+ if data:
+ self._event_loop.call_soon(self._protocol.data_received, data)
+ else:
+ self._event_loop.remove_reader(self._rstdout)
+ os.close(self._rstdout)
+ self._rstdout = -1
+ self._event_loop.call_soon(self._protocol.eof_received)
+
+
+def _setnonblocking(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
diff --git a/tulip/tasks.py b/tulip/tasks.py
new file mode 100644
index 0000000..81359a2
--- /dev/null
+++ b/tulip/tasks.py
@@ -0,0 +1,320 @@
+"""Support for tasks, coroutines and the scheduler."""
+
+__all__ = ['coroutine', 'task', 'Task',
+ 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED',
+ 'wait', 'as_completed', 'sleep',
+ ]
+
+import concurrent.futures
+import functools
+import inspect
+import time
+
+from . import futures
+from .log import tulip_log
+
+
+def coroutine(func):
+ """Decorator to mark coroutines.
+
+ Decorator wraps non generator functions and returns generator wrapper.
+ If non generator function returns generator of Future it yield-from it.
+
+ TODO: This is a feel-good API only. It is not enforced.
+ """
+ if inspect.isgeneratorfunction(func):
+ coro = func
+ else:
+ tulip_log.warning(
+ 'Coroutine function %s is not a generator.', func.__name__)
+
+ @functools.wraps(func)
+ def coro(*args, **kw):
+ res = func(*args, **kw)
+
+ if isinstance(res, futures.Future) or inspect.isgenerator(res):
+ res = yield from res
+
+ return res
+
+ coro._is_coroutine = True # Not sure who can use this.
+ return coro
+
+
+# TODO: Do we need this?
+def iscoroutinefunction(func):
+ """Return True if func is a decorated coroutine function."""
+ return (inspect.isgeneratorfunction(func) and
+ getattr(func, '_is_coroutine', False))
+
+
+# TODO: Do we need this?
+def iscoroutine(obj):
+ """Return True if obj is a coroutine object."""
+ return inspect.isgenerator(obj) # TODO: And what?
+
+
+def task(func):
+ """Decorator for a coroutine to be wrapped in a Task."""
+ def task_wrapper(*args, **kwds):
+ coro = func(*args, **kwds)
+ return Task(coro)
+ return task_wrapper
+
+
+class Task(futures.Future):
+ """A coroutine wrapped in a Future."""
+
+ def __init__(self, coro, event_loop=None, timeout=None):
+ assert inspect.isgenerator(coro) # Must be a coroutine *object*.
+ super().__init__(event_loop=event_loop, timeout=timeout)
+ self._coro = coro
+ self._must_cancel = False
+ self._event_loop.call_soon(self._step)
+
+ def __repr__(self):
+ res = super().__repr__()
+ if (self._must_cancel and
+ self._state == futures._PENDING and
+ '<PENDING' in res):
+ res = res.replace('<PENDING', '<CANCELLING', 1)
+ i = res.find('<')
+ if i < 0:
+ i = len(res)
+ res = res[:i] + '(<{}>)'.format(self._coro.__name__) + res[i:]
+ return res
+
+ def cancel(self):
+ if self.done():
+ return False
+ self._must_cancel = True
+ # _step() will call super().cancel() to call the callbacks.
+ self._event_loop.call_soon(self._step_maybe)
+ return True
+
+ def cancelled(self):
+ return self._must_cancel or super().cancelled()
+
+ def _step_maybe(self):
+ # Helper for cancel().
+ if not self.done():
+ return self._step()
+
+ def _step(self, value=None, exc=None):
+ if self.done():
+ tulip_log.warn('_step(): already done: %r, %r, %r', self, value, exc)
+ return
+ # We'll call either coro.throw(exc) or coro.send(value).
+ if self._must_cancel:
+ exc = futures.CancelledError
+ coro = self._coro
+ try:
+ if exc is not None:
+ result = coro.throw(exc)
+ elif value is not None:
+ result = coro.send(value)
+ else:
+ result = next(coro)
+ except StopIteration as exc:
+ if self._must_cancel:
+ super().cancel()
+ else:
+ self.set_result(exc.value)
+ except Exception as exc:
+ if self._must_cancel:
+ super().cancel()
+ else:
+ self.set_exception(exc)
+ tulip_log.exception('Exception in task')
+ except BaseException as exc:
+ if self._must_cancel:
+ super().cancel()
+ else:
+ self.set_exception(exc)
+ tulip_log.exception('BaseException in task')
+ raise
+ else:
+ # XXX No check for self._must_cancel here?
+ if isinstance(result, futures.Future):
+ if not result._blocking:
+ self._event_loop.call_soon(
+ self._step,
+ None, RuntimeError(
+ 'yield was used instead of yield from in task %r '
+ 'with %r' % (self, result)))
+ else:
+ result._blocking = False
+ result.add_done_callback(self._wakeup)
+
+ elif isinstance(result, concurrent.futures.Future):
+ # This ought to be more efficient than wrap_future(),
+ # because we don't create an extra Future.
+ result.add_done_callback(
+ lambda future:
+ self._event_loop.call_soon_threadsafe(
+ self._wakeup, future))
+ else:
+ if inspect.isgenerator(result):
+ self._event_loop.call_soon(
+ self._step,
+ None, RuntimeError(
+ 'yield was used instead of yield from for '
+ 'generator in task %r with %s' % (self, result)))
+ else:
+ if result is not None:
+ tulip_log.warn('_step(): bad yield: %r', result)
+
+ self._event_loop.call_soon(self._step)
+
+ def _wakeup(self, future):
+ try:
+ value = future.result()
+ except Exception as exc:
+ self._step(None, exc)
+ else:
+ self._step(value, None)
+
+
+# wait() and as_completed() similar to those in PEP 3148.
+
+FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED
+FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION
+ALL_COMPLETED = concurrent.futures.ALL_COMPLETED
+
+
+# Even though this *is* a @coroutine, we don't mark it as such!
+def wait(fs, timeout=None, return_when=ALL_COMPLETED):
+ """Wait for the Futures and and coroutines given by fs to complete.
+
+ Coroutines will be wrapped in Tasks.
+
+ Returns two sets of Future: (done, pending).
+
+ Usage:
+
+ done, pending = yield from tulip.wait(fs)
+
+ Note: This does not raise TimeoutError! Futures that aren't done
+ when the timeout occurs are returned in the second set.
+ """
+ fs = _wrap_coroutines(fs)
+ return _wait(fs, timeout, return_when)
+
+
+@coroutine
+def _wait(fs, timeout=None, return_when=ALL_COMPLETED):
+ """Internal helper: Like wait() but does not wrap coroutines."""
+ done, pending = set(), set()
+
+ errors = 0
+ for f in fs:
+ if f.done():
+ done.add(f)
+ if not f.cancelled() and f.exception() is not None:
+ errors += 1
+ else:
+ pending.add(f)
+
+ if (not pending or
+ timeout is not None and timeout <= 0 or
+ return_when == FIRST_COMPLETED and done or
+ return_when == FIRST_EXCEPTION and errors):
+ return done, pending
+
+ # Will always be cancelled eventually.
+ bail = futures.Future(timeout=timeout)
+
+ def _on_completion(f):
+ pending.remove(f)
+ done.add(f)
+ if (not pending or
+ return_when == FIRST_COMPLETED or
+ (return_when == FIRST_EXCEPTION and
+ not f.cancelled() and
+ f.exception() is not None)):
+ bail.cancel()
+
+ try:
+ for f in pending:
+ f.add_done_callback(_on_completion)
+ try:
+ yield from bail
+ except futures.CancelledError:
+ pass
+ finally:
+ for f in pending:
+ f.remove_done_callback(_on_completion)
+
+ really_done = set(f for f in pending if f.done())
+ if really_done:
+ done.update(really_done)
+ pending.difference_update(really_done)
+
+ return done, pending
+
+
+# This is *not* a @coroutine! It is just an iterator (yielding Futures).
+def as_completed(fs, timeout=None):
+ """Return an iterator whose values, when waited for, are Futures.
+
+ This differs from PEP 3148; the proper way to use this is:
+
+ for f in as_completed(fs):
+ result = yield from f # The 'yield from' may raise.
+ # Use result.
+
+ Raises TimeoutError if the timeout occurs before all Futures are
+ done.
+
+ Note: The futures 'f' are not necessarily members of fs.
+ """
+ deadline = None
+ if timeout is not None:
+ deadline = time.monotonic() + timeout
+
+ done = None # Make nonlocal happy.
+ fs = _wrap_coroutines(fs)
+
+ while fs:
+ if deadline is not None:
+ timeout = deadline - time.monotonic()
+
+ @coroutine
+ def _wait_for_some():
+ nonlocal done, fs
+ done, fs = yield from _wait(fs, timeout=timeout,
+ return_when=FIRST_COMPLETED)
+ if not done:
+ fs = set()
+ raise futures.TimeoutError()
+ return done.pop().result() # May raise.
+
+ yield Task(_wait_for_some())
+ for f in done:
+ yield f
+
+
+def _wrap_coroutines(fs):
+ """Internal helper to process an iterator of Futures and coroutines.
+
+ Returns a set of Futures.
+ """
+ wrapped = set()
+ for f in fs:
+ if not isinstance(f, futures.Future):
+ assert iscoroutine(f)
+ f = Task(f)
+ wrapped.add(f)
+ return wrapped
+
+
+def sleep(when, result=None):
+ """Return a Future that completes after a given time (in seconds).
+
+ It's okay to cancel the Future.
+
+ Undocumented feature: sleep(when, x) sets the Future's result to x.
+ """
+ future = futures.Future()
+ future._event_loop.call_later(when, future.set_result, result)
+ return future
diff --git a/tulip/test_utils.py b/tulip/test_utils.py
new file mode 100644
index 0000000..9b87db2
--- /dev/null
+++ b/tulip/test_utils.py
@@ -0,0 +1,30 @@
+"""Utilities shared by tests."""
+
+import logging
+import socket
+import sys
+import unittest
+
+
+if sys.platform == 'win32': # pragma: no cover
+ from .winsocketpair import socketpair
+else:
+ from socket import socketpair # pragma: no cover
+
+
+class LogTrackingTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self._logger = logging.getLogger()
+ self._log_level = self._logger.getEffectiveLevel()
+
+ def tearDown(self):
+ self._logger.setLevel(self._log_level)
+
+ def suppress_log_errors(self): # pragma: no cover
+ if self._log_level >= logging.WARNING:
+ self._logger.setLevel(logging.CRITICAL)
+
+ def suppress_log_warnings(self): # pragma: no cover
+ if self._log_level >= logging.WARNING:
+ self._logger.setLevel(logging.ERROR)
diff --git a/tulip/transports.py b/tulip/transports.py
new file mode 100644
index 0000000..a9ec07a
--- /dev/null
+++ b/tulip/transports.py
@@ -0,0 +1,134 @@
+"""Abstract Transport class."""
+
+__all__ = ['ReadTransport', 'WriteTransport', 'Transport']
+
+
+class BaseTransport:
+ """Base ABC for transports."""
+
+ def __init__(self, extra=None):
+ if extra is None:
+ extra = {}
+ self._extra = extra
+
+ def get_extra_info(self, name, default=None):
+ """Get optional transport information."""
+ return self._extra.get(name, default)
+
+ def close(self):
+ """Closes the transport.
+
+ Buffered data will be flushed asynchronously. No more data
+ will be received. After all buffered data is flushed, the
+ protocol's connection_lost() method will (eventually) called
+ with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class ReadTransport(BaseTransport):
+ """ABC for read-only transports."""
+
+ def pause(self):
+ """Pause the receiving end.
+
+ No data will be passed to the protocol's data_received()
+ method until resume() is called.
+ """
+ raise NotImplementedError
+
+ def resume(self):
+ """Resume the receiving end.
+
+ Data received will once again be passed to the protocol's
+ data_received() method.
+ """
+ raise NotImplementedError
+
+
+class WriteTransport(BaseTransport):
+ """ABC for write-only transports."""
+
+ def write(self, data):
+ """Write some data bytes to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ """
+ raise NotImplementedError
+
+ def writelines(self, list_of_data):
+ """Write a list (or any iterable) of data bytes to the transport.
+
+ The default implementation just calls write() for each item in
+ the list/iterable.
+ """
+ for data in list_of_data:
+ self.write(data)
+
+ def write_eof(self):
+ """Closes the write end after flushing buffered data.
+
+ (This is like typing ^D into a UNIX program reading from stdin.)
+
+ Data may still be received.
+ """
+ raise NotImplementedError
+
+ def can_write_eof(self):
+ """Return True if this protocol supports write_eof(), False if not."""
+ raise NotImplementedError
+
+ def abort(self):
+ """Closes the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method will (eventually) be
+ called with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class Transport(ReadTransport, WriteTransport):
+ """ABC representing a bidirectional transport.
+
+ There may be several implementations, but typically, the user does
+ not implement new transports; rather, the platform provides some
+ useful transports that are implemented using the platform's best
+ practices.
+
+ The user never instantiates a transport directly; they call a
+ utility function, passing it a protocol factory and other
+ information necessary to create the transport and protocol. (E.g.
+ EventLoop.create_connection() or EventLoop.start_serving().)
+
+ The utility function will asynchronously create a transport and a
+ protocol and hook them up by calling the protocol's
+ connection_made() method, passing it the transport.
+
+ The implementation here raises NotImplemented for every method
+ except writelines(), which calls write() in a loop.
+ """
+
+
+class DatagramTransport(BaseTransport):
+ """ABC for datagram (UDP) transports."""
+
+ def sendto(self, data, addr=None):
+ """Send data to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ addr is target socket address.
+ If addr is None use target address pointed on transport creation.
+ """
+ raise NotImplementedError
+
+ def abort(self):
+ """Closes the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method will (eventually) be
+ called with None as its argument.
+ """
+ raise NotImplementedError
diff --git a/tulip/unix_events.py b/tulip/unix_events.py
new file mode 100644
index 0000000..3073ab6
--- /dev/null
+++ b/tulip/unix_events.py
@@ -0,0 +1,301 @@
+"""Selector eventloop for Unix with signal handling."""
+
+import errno
+import fcntl
+import os
+import socket
+import sys
+
+try:
+ import signal
+except ImportError: # pragma: no cover
+ signal = None
+
+from . import events
+from . import selector_events
+from . import transports
+from .log import tulip_log
+
+
+__all__ = ['SelectorEventLoop']
+
+
+if sys.platform == 'win32': # pragma: no cover
+ raise ImportError('Signals are not really supported on Windows')
+
+
+class SelectorEventLoop(selector_events.BaseSelectorEventLoop):
+ """Unix event loop
+
+ Adds signal handling to SelectorEventLoop
+ """
+
+ def __init__(self, selector=None):
+ super().__init__(selector)
+ self._signal_handlers = {}
+
+ def _socketpair(self):
+ return socket.socketpair()
+
+ def add_signal_handler(self, sig, callback, *args):
+ """Add a handler for a signal. UNIX only.
+
+ Raise ValueError if the signal number is invalid or uncatchable.
+ Raise RuntimeError if there is a problem setting up the handler.
+ """
+ self._check_signal(sig)
+ try:
+ # set_wakeup_fd() raises ValueError if this is not the
+ # main thread. By calling it early we ensure that an
+ # event loop running in another thread cannot add a signal
+ # handler.
+ signal.set_wakeup_fd(self._csock.fileno())
+ except ValueError as exc:
+ raise RuntimeError(str(exc))
+
+ handle = events.make_handle(callback, args)
+ self._signal_handlers[sig] = handle
+
+ try:
+ signal.signal(sig, self._handle_signal)
+ except OSError as exc:
+ del self._signal_handlers[sig]
+ if not self._signal_handlers:
+ try:
+ signal.set_wakeup_fd(-1)
+ except ValueError as nexc:
+ tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc)
+
+ if exc.errno == errno.EINVAL:
+ raise RuntimeError('sig {} cannot be caught'.format(sig))
+ else:
+ raise
+
+ return handle
+
+ def _handle_signal(self, sig, arg):
+ """Internal helper that is the actual signal handler."""
+ handle = self._signal_handlers.get(sig)
+ if handle is None:
+ return # Assume it's some race condition.
+ if handle.cancelled:
+ self.remove_signal_handler(sig) # Remove it properly.
+ else:
+ self.call_soon_threadsafe(handle)
+
+ def remove_signal_handler(self, sig):
+ """Remove a handler for a signal. UNIX only.
+
+ Return True if a signal handler was removed, False if not."""
+ self._check_signal(sig)
+ try:
+ del self._signal_handlers[sig]
+ except KeyError:
+ return False
+
+ if sig == signal.SIGINT:
+ handler = signal.default_int_handler
+ else:
+ handler = signal.SIG_DFL
+
+ try:
+ signal.signal(sig, handler)
+ except OSError as exc:
+ if exc.errno == errno.EINVAL:
+ raise RuntimeError('sig {} cannot be caught'.format(sig))
+ else:
+ raise
+
+ if not self._signal_handlers:
+ try:
+ signal.set_wakeup_fd(-1)
+ except ValueError as exc:
+ tulip_log.info('set_wakeup_fd(-1) failed: %s', exc)
+
+ return True
+
+ def _check_signal(self, sig):
+ """Internal helper to validate a signal.
+
+ Raise ValueError if the signal number is invalid or uncatchable.
+ Raise RuntimeError if there is a problem setting up the handler.
+ """
+ if not isinstance(sig, int):
+ raise TypeError('sig must be an int, not {!r}'.format(sig))
+
+ if signal is None:
+ raise RuntimeError('Signals are not supported')
+
+ if not (1 <= sig < signal.NSIG):
+ raise ValueError(
+ 'sig {} out of range(1, {})'.format(sig, signal.NSIG))
+
+ def _make_read_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra)
+
+ def _make_write_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra)
+
+
+def _set_nonblocking(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+
+
+class _UnixReadPipeTransport(transports.ReadTransport):
+
+ max_size = 256 * 1024 # max bytes we read in one eventloop iteration
+
+ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None):
+ super().__init__(extra)
+ self._extra['pipe'] = pipe
+ self._event_loop = event_loop
+ self._pipe = pipe
+ self._fileno = pipe.fileno()
+ _set_nonblocking(self._fileno)
+ self._protocol = protocol
+ self._closing = False
+ self._event_loop.add_reader(self._fileno, self._read_ready)
+ self._event_loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ self._event_loop.call_soon(waiter.set_result, None)
+
+ def _read_ready(self):
+ try:
+ data = os.read(self._fileno, self.max_size)
+ except BlockingIOError:
+ pass
+ except OSError as exc:
+ self._fatal_error(exc)
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ self._event_loop.remove_reader(self._fileno)
+ self._protocol.eof_received()
+
+ def pause(self):
+ self._event_loop.remove_reader(self._fileno)
+
+ def resume(self):
+ self._event_loop.add_reader(self._fileno, self._read_ready)
+
+ def close(self):
+ if not self._closing:
+ self._close(None)
+
+ def _fatal_error(self, exc):
+ # should be called by exception handler only
+ tulip_log.exception('Fatal error for %s', self)
+ self._close(exc)
+
+ def _close(self, exc):
+ self._closing = True
+ self._event_loop.remove_reader(self._fileno)
+ self._call_connection_lost(exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._pipe.close()
+
+
+class _UnixWritePipeTransport(transports.WriteTransport):
+
+ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None):
+ super().__init__(extra)
+ self._extra['pipe'] = pipe
+ self._event_loop = event_loop
+ self._pipe = pipe
+ self._fileno = pipe.fileno()
+ _set_nonblocking(self._fileno)
+ self._protocol = protocol
+ self._buffer = []
+ self._closing = False # Set when close() or write_eof() called.
+ self._event_loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ self._event_loop.call_soon(waiter.set_result, None)
+
+ def write(self, data):
+ assert isinstance(data, bytes), repr(data)
+ assert not self._closing
+ if not data:
+ return
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ n = os.write(self._fileno, data)
+ except BlockingIOError:
+ n = 0
+ except Exception as exc:
+ self._fatal_error(exc)
+ return
+ if n == len(data):
+ return
+ elif n > 0:
+ data = data[n:]
+ self._event_loop.add_writer(self._fileno, self._write_ready)
+
+ self._buffer.append(data)
+
+ def _write_ready(self):
+ data = b''.join(self._buffer)
+ assert data, "Data should not be empty"
+
+ self._buffer.clear()
+ try:
+ n = os.write(self._fileno, data)
+ except BlockingIOError:
+ self._buffer.append(data)
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ if n == len(data):
+ self._event_loop.remove_writer(self._fileno)
+ if self._closing:
+ self._call_connection_lost(None)
+ return
+ elif n > 0:
+ data = data[n:]
+
+ self._buffer.append(data) # Try again later.
+
+ def can_write_eof(self):
+ return True
+
+ def write_eof(self):
+ assert not self._closing
+ assert self._pipe
+ self._closing = True
+ if not self._buffer:
+ self._call_connection_lost(None)
+
+ def close(self):
+ if not self._closing:
+ # write_eof is all what we needed to close the write pipe
+ self.write_eof()
+
+ def abort(self):
+ self._close(None)
+
+ def _fatal_error(self, exc):
+ # should be called by exception handler only
+ tulip_log.exception('Fatal error for %s', self)
+ self._close(exc)
+
+ def _close(self, exc=None):
+ self._closing = True
+ self._buffer.clear()
+ self._event_loop.remove_writer(self._fileno)
+ self._call_connection_lost(exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._pipe.close()
diff --git a/tulip/windows_events.py b/tulip/windows_events.py
new file mode 100644
index 0000000..2ec8561
--- /dev/null
+++ b/tulip/windows_events.py
@@ -0,0 +1,157 @@
+"""Selector and proactor eventloops for Windows."""
+
+import socket
+import weakref
+import struct
+import _winapi
+
+from . import futures
+from . import proactor_events
+from . import selector_events
+from . import winsocketpair
+from . import _overlapped
+from .log import tulip_log
+
+
+__all__ = ['SelectorEventLoop', 'ProactorEventLoop']
+
+
+NULL = 0
+INFINITE = 0xffffffff
+ERROR_CONNECTION_REFUSED = 1225
+ERROR_CONNECTION_ABORTED = 1236
+
+
+class SelectorEventLoop(selector_events.BaseSelectorEventLoop):
+ def _socketpair(self):
+ return winsocketpair.socketpair()
+
+
+class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
+ def __init__(self, proactor=None):
+ if proactor is None:
+ proactor = IocpProactor()
+ super().__init__(proactor)
+
+ def _socketpair(self):
+ return winsocketpair.socketpair()
+
+
+class IocpProactor:
+
+ def __init__(self, concurrency=0xffffffff):
+ self._results = []
+ self._iocp = _overlapped.CreateIoCompletionPort(
+ _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency)
+ self._cache = {}
+ self._registered = weakref.WeakSet()
+
+ def registered_count(self):
+ return len(self._cache)
+
+ def select(self, timeout=None):
+ if not self._results:
+ self._poll(timeout)
+ tmp = self._results
+ self._results = []
+ return tmp
+
+ def recv(self, conn, nbytes, flags=0):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+ ov.WSARecv(conn.fileno(), nbytes, flags)
+ return self._register(ov, conn, ov.getresult)
+
+ def send(self, conn, buf, flags=0):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+ ov.WSASend(conn.fileno(), buf, flags)
+ return self._register(ov, conn, ov.getresult)
+
+ def accept(self, listener):
+ self._register_with_iocp(listener)
+ conn = self._get_accept_socket()
+ ov = _overlapped.Overlapped(NULL)
+ ov.AcceptEx(listener.fileno(), conn.fileno())
+
+ def finish_accept():
+ addr = ov.getresult()
+ buf = struct.pack('@P', listener.fileno())
+ conn.setsockopt(socket.SOL_SOCKET,
+ _overlapped.SO_UPDATE_ACCEPT_CONTEXT,
+ buf)
+ conn.settimeout(listener.gettimeout())
+ return conn, conn.getpeername()
+
+ return self._register(ov, listener, finish_accept)
+
+ def connect(self, conn, address):
+ self._register_with_iocp(conn)
+ _overlapped.BindLocal(conn.fileno(), len(address))
+ ov = _overlapped.Overlapped(NULL)
+ ov.ConnectEx(conn.fileno(), address)
+
+ def finish_connect():
+ ov.getresult()
+ conn.setsockopt(socket.SOL_SOCKET,
+ _overlapped.SO_UPDATE_CONNECT_CONTEXT,
+ 0)
+ return conn
+
+ return self._register(ov, conn, finish_connect)
+
+ def _register_with_iocp(self, obj):
+ if obj not in self._registered:
+ self._registered.add(obj)
+ _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0)
+
+ def _register(self, ov, obj, callback):
+ f = futures.Future()
+ self._cache[ov.address] = (f, ov, obj, callback)
+ return f
+
+ def _get_accept_socket(self):
+ s = socket.socket()
+ s.settimeout(0)
+ return s
+
+ def _poll(self, timeout=None):
+ if timeout is None:
+ ms = INFINITE
+ elif timeout < 0:
+ raise ValueError("negative timeout")
+ else:
+ ms = int(timeout * 1000 + 0.5)
+ if ms >= INFINITE:
+ raise ValueError("timeout too big")
+ while True:
+ status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms)
+ if status is None:
+ return
+ address = status[3]
+ f, ov, obj, callback = self._cache.pop(address)
+ try:
+ value = callback()
+ except OSError as e:
+ f.set_exception(e)
+ self._results.append(f)
+ else:
+ f.set_result(value)
+ self._results.append(f)
+ ms = 0
+
+ def close(self):
+ for (f, ov, obj, callback) in self._cache.values():
+ try:
+ ov.cancel()
+ except OSError:
+ pass
+
+ while self._cache:
+ if not self._poll(1):
+ tulip_log.debug('taking long time to close proactor')
+
+ self._results = []
+ if self._iocp is not None:
+ _winapi.CloseHandle(self._iocp)
+ self._iocp = None
diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py
new file mode 100644
index 0000000..bd1e092
--- /dev/null
+++ b/tulip/winsocketpair.py
@@ -0,0 +1,34 @@
+"""A socket pair usable as a self-pipe, for Windows.
+
+Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
+"""
+
+import socket
+import sys
+
+if sys.platform != 'win32': # pragma: no cover
+ raise ImportError('winsocketpair is win32 only')
+
+
+def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
+ """Emulate the Unix socketpair() function on Windows."""
+ # We create a connected TCP socket. Note the trick with setblocking(0)
+ # that prevents us from having to create a thread.
+ lsock = socket.socket(family, type, proto)
+ lsock.bind(('localhost', 0))
+ lsock.listen(1)
+ addr, port = lsock.getsockname()
+ csock = socket.socket(family, type, proto)
+ csock.setblocking(False)
+ try:
+ csock.connect((addr, port))
+ except (BlockingIOError, InterruptedError):
+ pass
+ except:
+ lsock.close()
+ csock.close()
+ raise
+ ssock, _ = lsock.accept()
+ csock.setblocking(True)
+ lsock.close()
+ return (ssock, csock)