summaryrefslogtreecommitdiff
path: root/gear
diff options
context:
space:
mode:
Diffstat (limited to 'gear')
-rw-r--r--gear/__init__.py78
-rw-r--r--gear/tests/__init__.py2
-rw-r--r--gear/tests/test_functional.py37
3 files changed, 90 insertions, 27 deletions
diff --git a/gear/__init__.py b/gear/__init__.py
index d436274..38b037a 100644
--- a/gear/__init__.py
+++ b/gear/__init__.py
@@ -1487,6 +1487,7 @@ class Client(BaseClient):
self.sendPacket(packet, conn)
except Exception:
# Error handling is all done by sendPacket
+ self.log.info("Sending packet failed")
continue
complete = task.wait(timeout)
if not complete:
@@ -2704,6 +2705,38 @@ class ServerConnection(NonBlockingConnection):
id(self), self.client_id, self.host, self.port)
+class Poller(object):
+ """A poller using Epoll if available and Poll on non-linux systems.
+
+ :arg bool use_epoll: If epoll should be used (needs to be also supported
+ by the OS)
+ """
+
+ def __init__(self, use_epoll=True):
+ self.use_epoll = use_epoll and hasattr(select, 'epoll')
+
+ self.POLL_EDGE = select.EPOLLET if self.use_epoll else 0
+ self.POLL_IN = select.EPOLLIN if self.use_epoll else select.POLLIN
+ self.POLL_OUT = select.EPOLLOUT if self.use_epoll else select.POLLOUT
+ self.POLL_HUP = select.EPOLLHUP if self.use_epoll else select.POLLHUP
+ self.POLL_PRI = 0 if self.use_epoll else select.POLLPRI
+ self.POLL_ERR = select.EPOLLERR if self.use_epoll else select.POLLERR
+
+ if self.use_epoll:
+ self._poll = select.epoll()
+ else:
+ self._poll = select.poll()
+
+ def register(self, fd, event_mask):
+ self._poll.register(fd, event_mask)
+
+ def unregister(self, fd):
+ self._poll.unregister(fd)
+
+ def poll(self):
+ return self._poll.poll(0)
+
+
class Server(BaseClientServer):
"""A simple gearman server implementation for testing
(not for production use).
@@ -2727,17 +2760,15 @@ class Server(BaseClientServer):
:arg int tcp_keepidle: Idle time after which to start keepalives sending
:arg int tcp_keepintvl: Interval in seconds between TCP keepalives
:arg int tcp_keepcnt: Count of TCP keepalives to send before disconnect
+ :arg bool use_epoll: If epoll should be used (needs to be also supported
+ by the OS)
"""
- edge_bitmask = select.EPOLLET
- error_bitmask = (select.EPOLLERR | select.EPOLLHUP | edge_bitmask)
- read_bitmask = (select.EPOLLIN | error_bitmask)
- readwrite_bitmask = (select.EPOLLOUT | read_bitmask)
-
def __init__(self, port=4730, ssl_key=None, ssl_cert=None, ssl_ca=None,
statsd_host=None, statsd_port=8125, statsd_prefix=None,
server_id=None, acl=None, host=None, keepalive=False,
- tcp_keepidle=7200, tcp_keepintvl=75, tcp_keepcnt=9):
+ tcp_keepidle=7200, tcp_keepintvl=75, tcp_keepcnt=9,
+ use_epoll=True):
self.port = port
self.ssl_key = ssl_key
self.ssl_cert = ssl_cert
@@ -2753,10 +2784,15 @@ class Server(BaseClientServer):
self.max_handle = 0
self.acl = acl
self.connect_wake_read, self.connect_wake_write = os.pipe()
- self.poll = select.epoll()
- # Reverse mapping of fd -> connection
+ self.poller = Poller(use_epoll)
self.connection_map = {}
+ self.edge_bitmask = self.poller.POLL_EDGE
+ self.error_bitmask = (self.poller.POLL_ERR | self.poller.POLL_HUP
+ | self.edge_bitmask)
+ self.read_bitmask = (self.poller.POLL_IN | self.error_bitmask)
+ self.readwrite_bitmask = (self.poller.POLL_OUT | self.read_bitmask)
+
self.use_ssl = False
if all([self.ssl_key, self.ssl_cert, self.ssl_ca]):
self.use_ssl = True
@@ -2805,7 +2841,7 @@ class Server(BaseClientServer):
# Register the wake pipe so that we can break if we need to
# reconfigure connections
- self.poll.register(self.wake_read, self.read_bitmask)
+ self.poller.register(self.wake_read, self.read_bitmask)
if server_id:
self.log = logging.getLogger("gear.Server.%s" % (self.client_id,))
@@ -2835,8 +2871,7 @@ class Server(BaseClientServer):
poll = select.poll()
bitmask = (select.POLLIN | select.POLLERR |
select.POLLHUP | select.POLLNVAL)
- # Register the wake pipe so that we can break if we need to
- # shutdown.
+ # Register the wake pipe so that we can break if we need to shutdown.
poll.register(self.connect_wake_read, bitmask)
poll.register(self.socket.fileno(), bitmask)
while self.running:
@@ -2897,11 +2932,11 @@ class Server(BaseClientServer):
# The exception handlers here can raise exceptions and if they
# do, it's okay, the poll loop will be restarted.
try:
- if event & (select.EPOLLERR | select.EPOLLHUP):
- self.log.debug("Received error event on %s: %s" % (
- conn, event))
+ if event & (self.poller.POLL_ERR | self.poller.POLL_HUP):
+ self.log.debug("Received error event on %s: %s" %
+ (conn, event))
raise DisconnectError()
- if event & (select.POLLIN | select.POLLOUT):
+ if event & (self.poller.POLL_IN | self.poller.POLL_OUT):
self.readFromConnection(conn)
self.writeToConnection(conn)
except socket.error as e:
@@ -2933,15 +2968,16 @@ class Server(BaseClientServer):
# loop and therefore the list in guaranteed never to shrink.
connections = self.active_connections[:]
for conn in connections:
- self._processPollEvent(conn, select.POLLIN | select.POLLOUT)
+ self._processPollEvent(conn,
+ self.poller.POLL_IN | self.poller.POLL_OUT)
def _doPollLoop(self):
# Outer run method of poll thread.
while self.running:
try:
self._pollLoop()
- except Exception:
- self.log.exception("Exception in poll loop:")
+ except Exception as e:
+ self.log.exception("Exception in poll loop: %s" % str(e))
def _pollLoop(self):
# Inner method of poll loop.
@@ -2951,7 +2987,7 @@ class Server(BaseClientServer):
while self.running:
self.log.debug("Polling %s connections" %
len(self.active_connections))
- ret = self.poll.poll()
+ ret = self.poller.poll()
# Since we're using edge-triggering, we need to make sure
# that every file descriptor in 'ret' is processed.
for fd, event in ret:
@@ -2983,7 +3019,7 @@ class Server(BaseClientServer):
# Call while holding the connection condition
self.log.debug("Registering %s" % conn)
self.connection_map[conn.conn.fileno()] = conn
- self.poll.register(conn.conn.fileno(), self.readwrite_bitmask)
+ self.poller.register(conn.conn.fileno(), self.readwrite_bitmask)
def _unregisterConnection(self, conn):
# Unregister the connection with the poll object
@@ -2993,7 +3029,7 @@ class Server(BaseClientServer):
if fd not in self.connection_map:
return
try:
- self.poll.unregister(fd)
+ self.poller.unregister(fd)
except KeyError:
pass
try:
diff --git a/gear/tests/__init__.py b/gear/tests/__init__.py
index 6d5edb4..6161890 100644
--- a/gear/tests/__init__.py
+++ b/gear/tests/__init__.py
@@ -51,7 +51,7 @@ class BaseTestCase(testtools.TestCase, testresources.ResourcedTestCase):
self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr))
self.useFixture(fixtures.FakeLogger(
- level=logging.DEBUG,
+ level=logging.INFO,
format='%(asctime)s %(name)-32s '
'%(levelname)-8s %(message)s'))
self.useFixture(fixtures.NestedTempfile())
diff --git a/gear/tests/test_functional.py b/gear/tests/test_functional.py
index 3bca907..f322e17 100644
--- a/gear/tests/test_functional.py
+++ b/gear/tests/test_functional.py
@@ -14,6 +14,7 @@
# limitations under the License.
import os
+import select
import threading
import time
import uuid
@@ -37,14 +38,25 @@ def iterate_timeout(max_seconds, purpose):
raise Exception("Timeout waiting for %s" % purpose)
+def _wait_for_connection(server, timeout=10):
+ time.sleep(1)
+ for _ in iterate_timeout(10, "available connections"):
+ if server.active_connections:
+ break
+
+
class TestFunctional(tests.BaseTestCase):
scenarios = [
- ('no_ssl', dict(ssl=False)),
- ('ssl', dict(ssl=True)),
+ ('no_ssl_with_epoll', dict(ssl=False, use_epoll=True)),
+ ('ssl_with_epoll', dict(ssl=True, use_epoll=True)),
+ ('no_ssl_without_epoll', dict(ssl=False, use_epoll=False)),
+ ('ssl_without_epoll', dict(ssl=True, use_epoll=False)),
]
def setUp(self):
super(TestFunctional, self).setUp()
+ if self.use_epoll and not hasattr(select, 'epoll'):
+ self.skipTest("Epoll not available.")
if self.ssl:
self.tmp_root = self.useFixture(fixtures.TempDir()).path
root_subject, root_key = self.create_cert('root')
@@ -55,7 +67,8 @@ class TestFunctional(tests.BaseTestCase):
0,
os.path.join(self.tmp_root, 'server.key'),
os.path.join(self.tmp_root, 'server.crt'),
- os.path.join(self.tmp_root, 'root.crt'))
+ os.path.join(self.tmp_root, 'root.crt'),
+ use_epoll=self.use_epoll)
self.client = gear.Client('client')
self.worker = gear.Worker('worker')
self.client.addServer('127.0.0.1', self.server.port,
@@ -67,7 +80,7 @@ class TestFunctional(tests.BaseTestCase):
os.path.join(self.tmp_root, 'worker.crt'),
os.path.join(self.tmp_root, 'root.crt'))
else:
- self.server = gear.Server(0)
+ self.server = gear.Server(0, use_epoll=self.use_epoll)
self.client = gear.Client('client')
self.worker = gear.Worker('worker')
self.client.addServer('127.0.0.1', self.server.port)
@@ -113,6 +126,7 @@ class TestFunctional(tests.BaseTestCase):
for jobcount in range(2):
job = gear.Job(b'test', b'testdata')
+ _wait_for_connection(self.server)
self.client.submitJob(job)
self.assertNotEqual(job.handle, None)
@@ -132,6 +146,7 @@ class TestFunctional(tests.BaseTestCase):
self.worker.registerFunction('test')
job = gear.Job(b'test', b'testdata')
+ _wait_for_connection(self.server)
self.client.submitJob(job, background=True)
self.assertNotEqual(job.handle, None)
self.client.shutdown()
@@ -158,6 +173,7 @@ class TestFunctional(tests.BaseTestCase):
for jobcount in range(2):
job = gear.Job('test', b'testdata')
+ _wait_for_connection(self.server)
self.client.submitJob(job)
self.assertNotEqual(job.handle, None)
@@ -166,9 +182,16 @@ class TestFunctional(tests.BaseTestCase):
class TestFunctionalText(tests.BaseTestCase):
+ scenarios = [
+ ('with_epoll', dict(use_epoll=True)),
+ ('without_epoll', dict(use_epoll=False)),
+ ]
+
def setUp(self):
super(TestFunctionalText, self).setUp()
- self.server = gear.Server(0)
+ if self.use_epoll and not hasattr(select, 'epoll'):
+ self.skipTest("Epoll not available.")
+ self.server = gear.Server(0, use_epoll=self.use_epoll)
self.client = gear.Client('client')
self.worker = gear.TextWorker('worker')
self.client.addServer('127.0.0.1', self.server.port)
@@ -181,6 +204,7 @@ class TestFunctionalText(tests.BaseTestCase):
for jobcount in range(2):
job = gear.TextJob('test', 'testdata')
+ _wait_for_connection(self.server)
self.client.submitJob(job)
self.assertNotEqual(job.handle, None)
@@ -202,6 +226,7 @@ class TestFunctionalText(tests.BaseTestCase):
for jobcount in range(2):
jobunique = uuid.uuid4().hex
job = gear.TextJob('test', 'testdata', unique=jobunique)
+ _wait_for_connection(self.server)
self.client.submitJob(job)
self.assertNotEqual(job.handle, None)
@@ -224,6 +249,7 @@ class TestFunctionalText(tests.BaseTestCase):
for jobcount in range(2):
job = gear.TextJob('test', 'testdata')
+ _wait_for_connection(self.server)
self.client.submitJob(job)
self.assertNotEqual(job.handle, None)
@@ -241,6 +267,7 @@ class TestFunctionalText(tests.BaseTestCase):
def test_grab_job_after_register(self):
jobunique = uuid.uuid4().hex
job = gear.TextJob('test', 'testdata', unique=jobunique)
+ _wait_for_connection(self.server)
self.client.submitJob(job)
self.assertNotEqual(job.handle, None)