summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2014-12-15 23:27:37 +0000
committerGerrit Code Review <review@openstack.org>2014-12-15 23:27:37 +0000
commitbf8bac544929f90cef5d3608d5880597f2e2fd70 (patch)
treeb47368f15599b467323de73810bcf0974e4a87c2
parentbab054cc6f9c94f12603de4fdd8995b544d77180 (diff)
parent0f24beb72b340628a564ce8d755341eca70a4c40 (diff)
downloadgear-bf8bac544929f90cef5d3608d5880597f2e2fd70.tar.gz
Merge "Use non-blocking IO in server"
-rw-r--r--gear/__init__.py302
1 files changed, 277 insertions, 25 deletions
diff --git a/gear/__init__.py b/gear/__init__.py
index fd8ad09..1a182de 100644
--- a/gear/__init__.py
+++ b/gear/__init__.py
@@ -73,6 +73,10 @@ class GearmanError(Exception):
pass
+class DisconnectError(Exception):
+ pass
+
+
def convert_to_bytes(data):
try:
data = data.encode('utf8')
@@ -130,6 +134,7 @@ class Connection(object):
if all([self.ssl_key, self.ssl_cert, self.ssl_ca]):
self.use_ssl = True
+ self.input_buffer = b''
self.echo_lock = threading.Lock()
self._init()
@@ -270,39 +275,56 @@ class Connection(object):
return buff
+ def _putAdminRequest(self, req):
+ self.admin_requests.insert(0, req)
+
def readPacket(self):
"""Read one packet or administrative response from the server.
- Blocks until the complete packet or response is read.
-
:returns: The :py:class:`Packet` or :py:class:`AdminRequest` read.
:rtype: :py:class:`Packet` or :py:class:`AdminRequest`
"""
- packet = b''
+ # This handles non-blocking or blocking IO.
datalen = 0
code = None
ptype = None
admin = None
admin_request = None
- while True:
- c = self._readRawBytes(1)
- if not c:
- return None
- if admin is None:
- if c == b'\x00':
- admin = False
+ packet = self.input_buffer
+ try:
+ while True:
+ try:
+ c = self._readRawBytes(1)
+ except socket.error as e:
+ if e.errno == errno.EAGAIN:
+ if admin_request:
+ self._putAdminRequest(admin_request)
+ raise
+ if not c:
+ packet = b''
+ return None
+ packet += c
+ if admin is None:
+ if packet[0] == b'\x00':
+ admin = False
+ else:
+ admin = True
+ admin_request = self._getAdminRequest()
+ if admin:
+ if admin_request.isComplete(packet):
+ packet = b''
+ return admin_request
else:
- admin = True
- admin_request = self._getAdminRequest()
- packet += c
- if admin:
- if admin_request.isComplete(packet):
- return admin_request
- else:
- if len(packet) == 12:
- code, ptype, datalen = struct.unpack('!4sii', packet)
- if len(packet) == datalen + 12:
- return Packet(code, ptype, packet[12:], connection=self)
+ if len(packet) == 12:
+ code, ptype, datalen = struct.unpack('!4sii',
+ packet)
+ if len(packet) == datalen + 12:
+ p = Packet(code, ptype, packet[12:],
+ connection=self)
+ packet = b''
+ return p
+ finally:
+ self.input_buffer = packet
def sendAdminRequest(self, request, timeout=90):
"""Send an administrative request to the server.
@@ -2187,7 +2209,71 @@ class ServerAdminRequest(AdminRequest):
return False
-class ServerConnection(Connection):
+class NonBlockingConnection(Connection):
+ """A Non-blocking connection to a Gearman Client."""
+
+ def __init__(self, host, port, ssl_key=None, ssl_cert=None,
+ ssl_ca=None, client_id='unknown'):
+ super(NonBlockingConnection, self).__init__(
+ host, port, ssl_key,
+ ssl_cert, ssl_ca, client_id)
+ self.send_queue = []
+
+ def connect(self):
+ super(NonBlockingConnection, self).connect()
+ if self.connected and self.conn:
+ self.conn.setblocking(0)
+ self.input_buffer = b''
+
+ def sendPacket(self, packet):
+ """Append a packet to this connection's send queue. The Client or
+ Server must manage actually sending the data.
+
+ :arg :py:class:`Packet` packet The packet to send
+
+ """
+ self.log.debug("Queuing packet to %s: %s" % (self, packet))
+ self.send_queue.append(packet.toBinary())
+ self.sendQueuedData()
+
+ def sendRaw(self, data):
+ """Append raw data to this connection's send queue. The Client or
+ Server must manage actually sending the data.
+
+ :arg bytes data The raw data to send
+
+ """
+ self.log.debug("Queuing data to %s: %s" % (self, data))
+ self.send_queue.append(data)
+ self.sendQueuedData()
+
+ def sendQueuedData(self):
+ """Send previously queued data to the socket."""
+ while len(self.send_queue):
+ data = self.send_queue.pop(0)
+ r = 0
+ try:
+ r = self.conn.send(data)
+ except ssl.SSLError as e:
+ if e.errno == ssl.SSL_ERROR_WANT_READ:
+ pass
+ elif e.errno == ssl.SSL_ERROR_WANT_WRITE:
+ pass
+ else:
+ raise
+ except socket.error as e:
+ if e.errno == errno.EAGAIN:
+ self.log.debug("Write operation on %s would block"
+ % self)
+ return
+ raise
+ finally:
+ data = data[r:]
+ if data:
+ self.send_queue.insert(0, data)
+
+
+class ServerConnection(NonBlockingConnection):
"""A Connection to a Gearman Client."""
def __init__(self, addr, conn, use_ssl, client_id):
@@ -2196,9 +2282,13 @@ class ServerConnection(Connection):
(client_id,))
else:
self.log = logging.getLogger("gear.ServerConnection")
+ self.send_queue = []
+ self.admin_requests = []
self.host = addr[0]
self.port = addr[1]
self.conn = conn
+ self.conn.setblocking(0)
+ self.input_buffer = b''
self.use_ssl = use_ssl
self.client_id = None
self.functions = set()
@@ -2214,6 +2304,12 @@ class ServerConnection(Connection):
def _getAdminRequest(self):
return ServerAdminRequest(self)
+ def _putAdminRequest(self, req):
+ # The server does not need to keep track of admin requests
+ # that have been partially received; it will simply create a
+ # new instance the next time it tries to read.
+ pass
+
def __repr__(self):
return '<gear.ServerConnection 0x%x name: %s host: %s port: %s>' % (
id(self), self.client_id, self.host, self.port)
@@ -2238,6 +2334,11 @@ class Server(BaseClientServer):
access control rules to its connections.
"""
+ edge_bitmask = select.EPOLLET
+ error_bitmask = (select.EPOLLERR | select.EPOLLHUP | edge_bitmask)
+ read_bitmask = (select.EPOLLIN | error_bitmask)
+ readwrite_bitmask = (select.EPOLLOUT | read_bitmask)
+
def __init__(self, port=4730, ssl_key=None, ssl_cert=None, ssl_ca=None,
statsd_host=None, statsd_port=8125, statsd_prefix=None,
server_id=None, acl=None):
@@ -2253,6 +2354,9 @@ class Server(BaseClientServer):
self.max_handle = 0
self.acl = acl
self.connect_wake_read, self.connect_wake_write = os.pipe()
+ self.poll = select.epoll()
+ # Reverse mapping of fd -> connection
+ self.connection_map = {}
self.use_ssl = False
if all([self.ssl_key, self.ssl_cert, self.ssl_ca]):
@@ -2285,6 +2389,11 @@ class Server(BaseClientServer):
self.port = self.socket.getsockname()[1]
super(Server, self).__init__(server_id)
+
+ # Register the wake pipe so that we can break if we need to
+ # reconfigure connections
+ self.poll.register(self.wake_read, self.read_bitmask)
+
if server_id:
self.log = logging.getLogger("gear.Server.%s" % (self.client_id,))
else:
@@ -2342,11 +2451,114 @@ class Server(BaseClientServer):
self.connections_condition.acquire()
try:
self.active_connections.append(conn)
+ self._registerConnection(conn)
self.connections_condition.notifyAll()
- os.write(self.wake_write, b'1\n')
finally:
self.connections_condition.release()
+ def readFromConnection(self, conn):
+ while True:
+ self.log.debug("Processing input on %s" % conn)
+ try:
+ p = conn.readPacket()
+ except socket.error as e:
+ if e.errno == errno.EAGAIN:
+ # Read operation would block, we're done until
+ # epoll flags this connection again
+ return
+ raise
+ if p:
+ if isinstance(p, Packet):
+ self.handlePacket(p)
+ else:
+ self.handleAdminRequest(p)
+ else:
+ self.log.debug("Received no data on %s" % conn)
+ raise DisconnectError()
+
+ def writeToConnection(self, conn):
+ self.log.debug("Processing output on %s" % conn)
+ conn.sendQueuedData()
+
+ def _processPollEvent(self, conn, event):
+ # This should do whatever is necessary to process a connection
+ # that has triggered a poll event. It should generally not
+ # raise exceptions so as to avoid restarting the poll loop.
+ # The exception handlers here can raise exceptions and if they
+ # do, it's okay, the poll loop will be restarted.
+ try:
+ if event & (select.EPOLLERR | select.EPOLLHUP):
+ self.log.debug("Received error event on %s: %s" % (
+ conn, event))
+ raise DisconnectError()
+ if event & select.POLLIN:
+ self.readFromConnection(conn)
+ if event & select.POLLOUT:
+ self.writeToConnection(conn)
+ except socket.error as e:
+ if e.errno == errno.ECONNRESET:
+ self.log.debug("Connection reset by peer: %s" % (conn,))
+ self._lostConnection(conn)
+ return
+ raise
+ except DisconnectError:
+ # Our inner method says we should quietly drop
+ # this connection
+ self._lostConnection(conn)
+ return
+ except Exception:
+ self.log.exception("Exception reading or writing "
+ "from %s:" % (conn,))
+ self._lostConnection(conn)
+ return
+
+ def _flushAllConnections(self):
+ # If we need to restart the poll loop, we need to make sure
+ # there are no pending data on any connection. Simulate poll
+ # in+out events on every connection.
+ #
+ # If this method raises an exception, the poll loop wil
+ # restart again.
+ #
+ # No need to get the lock since this is called within the poll
+ # loop and therefore the list in guaranteed never to shrink.
+ connections = self.active_connections[:]
+ for conn in connections:
+ self._processPollEvent(conn, select.POLLIN | select.POLLOUT)
+
+ def _doPollLoop(self):
+ # Outer run method of poll thread.
+ while self.running:
+ try:
+ self._pollLoop()
+ except Exception:
+ self.log.exception("Exception in poll loop:")
+
+ def _pollLoop(self):
+ # Inner method of poll loop.
+ self.log.debug("Preparing to poll")
+ # Ensure there are no pending data.
+ self._flushAllConnections()
+ while self.running:
+ self.log.debug("Polling %s connections" %
+ len(self.active_connections))
+ ret = self.poll.poll()
+ # Since we're using edge-triggering, we need to make sure
+ # that every file descriptor in 'ret' is processed.
+ for fd, event in ret:
+ if fd == self.wake_read:
+ # This means we're exiting, so we can ignore the
+ # rest of 'ret'.
+ self.log.debug("Woken by pipe")
+ while True:
+ if os.read(self.wake_read, 1) == b'\n':
+ break
+ return
+ # In the unlikely event this raises an exception, the
+ # loop will be restarted.
+ conn = self.connection_map[fd]
+ self._processPollEvent(conn, event)
+
def _shutdown(self):
super(Server, self)._shutdown()
os.write(self.connect_wake_write, b'1\n')
@@ -2357,10 +2569,34 @@ class Server(BaseClientServer):
os.close(self.connect_wake_read)
os.close(self.connect_wake_write)
+ def _registerConnection(self, conn):
+ # Register the connection with the poll object
+ # Call while holding the connection condition
+ self.log.debug("Registering %s" % conn)
+ self.connection_map[conn.conn.fileno()] = conn
+ self.poll.register(conn.conn.fileno(), self.readwrite_bitmask)
+
+ def _unregisterConnection(self, conn):
+ # Unregister the connection with the poll object
+ # Call while holding the connection condition
+ self.log.debug("Unregistering %s" % conn)
+ fd = conn.conn.fileno()
+ if fd not in self.connection_map:
+ return
+ try:
+ self.poll.unregister(fd)
+ except KeyError:
+ pass
+ try:
+ del self.connection_map[fd]
+ except KeyError:
+ pass
+
def _lostConnection(self, conn):
# Called as soon as a connection is detected as faulty.
self.log.info("Marking %s as disconnected" % conn)
self.connections_condition.acquire()
+ self._unregisterConnection(conn)
try:
jobs = conn.related_jobs.values()
if conn in self.active_connections:
@@ -2378,6 +2614,20 @@ class Server(BaseClientServer):
self.log.exception("Sending WORK_FAIL to client after "
"worker disconnect failed:")
self._removeJob(job)
+ try:
+ conn.conn.shutdown(socket.SHUT_RDWR)
+ except socket.error as e:
+ if e.errno != errno.ENOTCONN:
+ self.log.exception("Unable to shutdown socket "
+ "for connection %s" % (conn,))
+ except Exception:
+ self.log.exception("Unable to shutdown socket "
+ "for connection %s" % (conn,))
+ try:
+ conn.conn.close()
+ except Exception:
+ self.log.exception("Unable to close socket "
+ "for connection %s" % (conn,))
self._updateStats()
def _removeJob(self, job, dequeue=True):
@@ -2766,7 +3016,8 @@ class Server(BaseClientServer):
handle = packet.getArgument(0)
job = self.jobs.get(handle)
if not job:
- raise UnknownJobError()
+ self.log.info("Received packet %s for unknown job" % (packet,))
+ return
job.numerator = packet.getArgument(1)
job.denominator = packet.getArgument(2)
self.handlePassthrough(packet)
@@ -2775,7 +3026,8 @@ class Server(BaseClientServer):
handle = packet.getArgument(0)
job = self.jobs.get(handle)
if not job:
- raise UnknownJobError()
+ self.log.info("Received packet %s for unknown job" % (packet,))
+ return
packet.code = constants.RES
job.client_connection.sendPacket(packet)
if finished: