diff options
Diffstat (limited to 'gear')
-rw-r--r-- | gear/__init__.py | 78 | ||||
-rw-r--r-- | gear/tests/__init__.py | 2 | ||||
-rw-r--r-- | gear/tests/test_functional.py | 37 |
3 files changed, 90 insertions, 27 deletions
diff --git a/gear/__init__.py b/gear/__init__.py index d436274..38b037a 100644 --- a/gear/__init__.py +++ b/gear/__init__.py @@ -1487,6 +1487,7 @@ class Client(BaseClient): self.sendPacket(packet, conn) except Exception: # Error handling is all done by sendPacket + self.log.info("Sending packet failed") continue complete = task.wait(timeout) if not complete: @@ -2704,6 +2705,38 @@ class ServerConnection(NonBlockingConnection): id(self), self.client_id, self.host, self.port) +class Poller(object): + """A poller using Epoll if available and Poll on non-linux systems. + + :arg bool use_epoll: If epoll should be used (needs to be also supported + by the OS) + """ + + def __init__(self, use_epoll=True): + self.use_epoll = use_epoll and hasattr(select, 'epoll') + + self.POLL_EDGE = select.EPOLLET if self.use_epoll else 0 + self.POLL_IN = select.EPOLLIN if self.use_epoll else select.POLLIN + self.POLL_OUT = select.EPOLLOUT if self.use_epoll else select.POLLOUT + self.POLL_HUP = select.EPOLLHUP if self.use_epoll else select.POLLHUP + self.POLL_PRI = 0 if self.use_epoll else select.POLLPRI + self.POLL_ERR = select.EPOLLERR if self.use_epoll else select.POLLERR + + if self.use_epoll: + self._poll = select.epoll() + else: + self._poll = select.poll() + + def register(self, fd, event_mask): + self._poll.register(fd, event_mask) + + def unregister(self, fd): + self._poll.unregister(fd) + + def poll(self): + return self._poll.poll(0) + + class Server(BaseClientServer): """A simple gearman server implementation for testing (not for production use). @@ -2727,17 +2760,15 @@ class Server(BaseClientServer): :arg int tcp_keepidle: Idle time after which to start keepalives sending :arg int tcp_keepintvl: Interval in seconds between TCP keepalives :arg int tcp_keepcnt: Count of TCP keepalives to send before disconnect + :arg bool use_epoll: If epoll should be used (needs to be also supported + by the OS) """ - edge_bitmask = select.EPOLLET - error_bitmask = (select.EPOLLERR | select.EPOLLHUP | edge_bitmask) - read_bitmask = (select.EPOLLIN | error_bitmask) - readwrite_bitmask = (select.EPOLLOUT | read_bitmask) - def __init__(self, port=4730, ssl_key=None, ssl_cert=None, ssl_ca=None, statsd_host=None, statsd_port=8125, statsd_prefix=None, server_id=None, acl=None, host=None, keepalive=False, - tcp_keepidle=7200, tcp_keepintvl=75, tcp_keepcnt=9): + tcp_keepidle=7200, tcp_keepintvl=75, tcp_keepcnt=9, + use_epoll=True): self.port = port self.ssl_key = ssl_key self.ssl_cert = ssl_cert @@ -2753,10 +2784,15 @@ class Server(BaseClientServer): self.max_handle = 0 self.acl = acl self.connect_wake_read, self.connect_wake_write = os.pipe() - self.poll = select.epoll() - # Reverse mapping of fd -> connection + self.poller = Poller(use_epoll) self.connection_map = {} + self.edge_bitmask = self.poller.POLL_EDGE + self.error_bitmask = (self.poller.POLL_ERR | self.poller.POLL_HUP + | self.edge_bitmask) + self.read_bitmask = (self.poller.POLL_IN | self.error_bitmask) + self.readwrite_bitmask = (self.poller.POLL_OUT | self.read_bitmask) + self.use_ssl = False if all([self.ssl_key, self.ssl_cert, self.ssl_ca]): self.use_ssl = True @@ -2805,7 +2841,7 @@ class Server(BaseClientServer): # Register the wake pipe so that we can break if we need to # reconfigure connections - self.poll.register(self.wake_read, self.read_bitmask) + self.poller.register(self.wake_read, self.read_bitmask) if server_id: self.log = logging.getLogger("gear.Server.%s" % (self.client_id,)) @@ -2835,8 +2871,7 @@ class Server(BaseClientServer): poll = select.poll() bitmask = (select.POLLIN | select.POLLERR | select.POLLHUP | select.POLLNVAL) - # Register the wake pipe so that we can break if we need to - # shutdown. + # Register the wake pipe so that we can break if we need to shutdown. poll.register(self.connect_wake_read, bitmask) poll.register(self.socket.fileno(), bitmask) while self.running: @@ -2897,11 +2932,11 @@ class Server(BaseClientServer): # The exception handlers here can raise exceptions and if they # do, it's okay, the poll loop will be restarted. try: - if event & (select.EPOLLERR | select.EPOLLHUP): - self.log.debug("Received error event on %s: %s" % ( - conn, event)) + if event & (self.poller.POLL_ERR | self.poller.POLL_HUP): + self.log.debug("Received error event on %s: %s" % + (conn, event)) raise DisconnectError() - if event & (select.POLLIN | select.POLLOUT): + if event & (self.poller.POLL_IN | self.poller.POLL_OUT): self.readFromConnection(conn) self.writeToConnection(conn) except socket.error as e: @@ -2933,15 +2968,16 @@ class Server(BaseClientServer): # loop and therefore the list in guaranteed never to shrink. connections = self.active_connections[:] for conn in connections: - self._processPollEvent(conn, select.POLLIN | select.POLLOUT) + self._processPollEvent(conn, + self.poller.POLL_IN | self.poller.POLL_OUT) def _doPollLoop(self): # Outer run method of poll thread. while self.running: try: self._pollLoop() - except Exception: - self.log.exception("Exception in poll loop:") + except Exception as e: + self.log.exception("Exception in poll loop: %s" % str(e)) def _pollLoop(self): # Inner method of poll loop. @@ -2951,7 +2987,7 @@ class Server(BaseClientServer): while self.running: self.log.debug("Polling %s connections" % len(self.active_connections)) - ret = self.poll.poll() + ret = self.poller.poll() # Since we're using edge-triggering, we need to make sure # that every file descriptor in 'ret' is processed. for fd, event in ret: @@ -2983,7 +3019,7 @@ class Server(BaseClientServer): # Call while holding the connection condition self.log.debug("Registering %s" % conn) self.connection_map[conn.conn.fileno()] = conn - self.poll.register(conn.conn.fileno(), self.readwrite_bitmask) + self.poller.register(conn.conn.fileno(), self.readwrite_bitmask) def _unregisterConnection(self, conn): # Unregister the connection with the poll object @@ -2993,7 +3029,7 @@ class Server(BaseClientServer): if fd not in self.connection_map: return try: - self.poll.unregister(fd) + self.poller.unregister(fd) except KeyError: pass try: diff --git a/gear/tests/__init__.py b/gear/tests/__init__.py index 6d5edb4..6161890 100644 --- a/gear/tests/__init__.py +++ b/gear/tests/__init__.py @@ -51,7 +51,7 @@ class BaseTestCase(testtools.TestCase, testresources.ResourcedTestCase): self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr)) self.useFixture(fixtures.FakeLogger( - level=logging.DEBUG, + level=logging.INFO, format='%(asctime)s %(name)-32s ' '%(levelname)-8s %(message)s')) self.useFixture(fixtures.NestedTempfile()) diff --git a/gear/tests/test_functional.py b/gear/tests/test_functional.py index 3bca907..f322e17 100644 --- a/gear/tests/test_functional.py +++ b/gear/tests/test_functional.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import select import threading import time import uuid @@ -37,14 +38,25 @@ def iterate_timeout(max_seconds, purpose): raise Exception("Timeout waiting for %s" % purpose) +def _wait_for_connection(server, timeout=10): + time.sleep(1) + for _ in iterate_timeout(10, "available connections"): + if server.active_connections: + break + + class TestFunctional(tests.BaseTestCase): scenarios = [ - ('no_ssl', dict(ssl=False)), - ('ssl', dict(ssl=True)), + ('no_ssl_with_epoll', dict(ssl=False, use_epoll=True)), + ('ssl_with_epoll', dict(ssl=True, use_epoll=True)), + ('no_ssl_without_epoll', dict(ssl=False, use_epoll=False)), + ('ssl_without_epoll', dict(ssl=True, use_epoll=False)), ] def setUp(self): super(TestFunctional, self).setUp() + if self.use_epoll and not hasattr(select, 'epoll'): + self.skipTest("Epoll not available.") if self.ssl: self.tmp_root = self.useFixture(fixtures.TempDir()).path root_subject, root_key = self.create_cert('root') @@ -55,7 +67,8 @@ class TestFunctional(tests.BaseTestCase): 0, os.path.join(self.tmp_root, 'server.key'), os.path.join(self.tmp_root, 'server.crt'), - os.path.join(self.tmp_root, 'root.crt')) + os.path.join(self.tmp_root, 'root.crt'), + use_epoll=self.use_epoll) self.client = gear.Client('client') self.worker = gear.Worker('worker') self.client.addServer('127.0.0.1', self.server.port, @@ -67,7 +80,7 @@ class TestFunctional(tests.BaseTestCase): os.path.join(self.tmp_root, 'worker.crt'), os.path.join(self.tmp_root, 'root.crt')) else: - self.server = gear.Server(0) + self.server = gear.Server(0, use_epoll=self.use_epoll) self.client = gear.Client('client') self.worker = gear.Worker('worker') self.client.addServer('127.0.0.1', self.server.port) @@ -113,6 +126,7 @@ class TestFunctional(tests.BaseTestCase): for jobcount in range(2): job = gear.Job(b'test', b'testdata') + _wait_for_connection(self.server) self.client.submitJob(job) self.assertNotEqual(job.handle, None) @@ -132,6 +146,7 @@ class TestFunctional(tests.BaseTestCase): self.worker.registerFunction('test') job = gear.Job(b'test', b'testdata') + _wait_for_connection(self.server) self.client.submitJob(job, background=True) self.assertNotEqual(job.handle, None) self.client.shutdown() @@ -158,6 +173,7 @@ class TestFunctional(tests.BaseTestCase): for jobcount in range(2): job = gear.Job('test', b'testdata') + _wait_for_connection(self.server) self.client.submitJob(job) self.assertNotEqual(job.handle, None) @@ -166,9 +182,16 @@ class TestFunctional(tests.BaseTestCase): class TestFunctionalText(tests.BaseTestCase): + scenarios = [ + ('with_epoll', dict(use_epoll=True)), + ('without_epoll', dict(use_epoll=False)), + ] + def setUp(self): super(TestFunctionalText, self).setUp() - self.server = gear.Server(0) + if self.use_epoll and not hasattr(select, 'epoll'): + self.skipTest("Epoll not available.") + self.server = gear.Server(0, use_epoll=self.use_epoll) self.client = gear.Client('client') self.worker = gear.TextWorker('worker') self.client.addServer('127.0.0.1', self.server.port) @@ -181,6 +204,7 @@ class TestFunctionalText(tests.BaseTestCase): for jobcount in range(2): job = gear.TextJob('test', 'testdata') + _wait_for_connection(self.server) self.client.submitJob(job) self.assertNotEqual(job.handle, None) @@ -202,6 +226,7 @@ class TestFunctionalText(tests.BaseTestCase): for jobcount in range(2): jobunique = uuid.uuid4().hex job = gear.TextJob('test', 'testdata', unique=jobunique) + _wait_for_connection(self.server) self.client.submitJob(job) self.assertNotEqual(job.handle, None) @@ -224,6 +249,7 @@ class TestFunctionalText(tests.BaseTestCase): for jobcount in range(2): job = gear.TextJob('test', 'testdata') + _wait_for_connection(self.server) self.client.submitJob(job) self.assertNotEqual(job.handle, None) @@ -241,6 +267,7 @@ class TestFunctionalText(tests.BaseTestCase): def test_grab_job_after_register(self): jobunique = uuid.uuid4().hex job = gear.TextJob('test', 'testdata', unique=jobunique) + _wait_for_connection(self.server) self.client.submitJob(job) self.assertNotEqual(job.handle, None) |