diff options
author | Jenkins <jenkins@review.openstack.org> | 2014-12-15 23:27:37 +0000 |
---|---|---|
committer | Gerrit Code Review <review@openstack.org> | 2014-12-15 23:27:37 +0000 |
commit | bf8bac544929f90cef5d3608d5880597f2e2fd70 (patch) | |
tree | b47368f15599b467323de73810bcf0974e4a87c2 | |
parent | bab054cc6f9c94f12603de4fdd8995b544d77180 (diff) | |
parent | 0f24beb72b340628a564ce8d755341eca70a4c40 (diff) | |
download | gear-bf8bac544929f90cef5d3608d5880597f2e2fd70.tar.gz |
Merge "Use non-blocking IO in server"
-rw-r--r-- | gear/__init__.py | 302 |
1 files changed, 277 insertions, 25 deletions
diff --git a/gear/__init__.py b/gear/__init__.py index fd8ad09..1a182de 100644 --- a/gear/__init__.py +++ b/gear/__init__.py @@ -73,6 +73,10 @@ class GearmanError(Exception): pass +class DisconnectError(Exception): + pass + + def convert_to_bytes(data): try: data = data.encode('utf8') @@ -130,6 +134,7 @@ class Connection(object): if all([self.ssl_key, self.ssl_cert, self.ssl_ca]): self.use_ssl = True + self.input_buffer = b'' self.echo_lock = threading.Lock() self._init() @@ -270,39 +275,56 @@ class Connection(object): return buff + def _putAdminRequest(self, req): + self.admin_requests.insert(0, req) + def readPacket(self): """Read one packet or administrative response from the server. - Blocks until the complete packet or response is read. - :returns: The :py:class:`Packet` or :py:class:`AdminRequest` read. :rtype: :py:class:`Packet` or :py:class:`AdminRequest` """ - packet = b'' + # This handles non-blocking or blocking IO. datalen = 0 code = None ptype = None admin = None admin_request = None - while True: - c = self._readRawBytes(1) - if not c: - return None - if admin is None: - if c == b'\x00': - admin = False + packet = self.input_buffer + try: + while True: + try: + c = self._readRawBytes(1) + except socket.error as e: + if e.errno == errno.EAGAIN: + if admin_request: + self._putAdminRequest(admin_request) + raise + if not c: + packet = b'' + return None + packet += c + if admin is None: + if packet[0] == b'\x00': + admin = False + else: + admin = True + admin_request = self._getAdminRequest() + if admin: + if admin_request.isComplete(packet): + packet = b'' + return admin_request else: - admin = True - admin_request = self._getAdminRequest() - packet += c - if admin: - if admin_request.isComplete(packet): - return admin_request - else: - if len(packet) == 12: - code, ptype, datalen = struct.unpack('!4sii', packet) - if len(packet) == datalen + 12: - return Packet(code, ptype, packet[12:], connection=self) + if len(packet) == 12: + code, ptype, datalen = struct.unpack('!4sii', + packet) + if len(packet) == datalen + 12: + p = Packet(code, ptype, packet[12:], + connection=self) + packet = b'' + return p + finally: + self.input_buffer = packet def sendAdminRequest(self, request, timeout=90): """Send an administrative request to the server. @@ -2187,7 +2209,71 @@ class ServerAdminRequest(AdminRequest): return False -class ServerConnection(Connection): +class NonBlockingConnection(Connection): + """A Non-blocking connection to a Gearman Client.""" + + def __init__(self, host, port, ssl_key=None, ssl_cert=None, + ssl_ca=None, client_id='unknown'): + super(NonBlockingConnection, self).__init__( + host, port, ssl_key, + ssl_cert, ssl_ca, client_id) + self.send_queue = [] + + def connect(self): + super(NonBlockingConnection, self).connect() + if self.connected and self.conn: + self.conn.setblocking(0) + self.input_buffer = b'' + + def sendPacket(self, packet): + """Append a packet to this connection's send queue. The Client or + Server must manage actually sending the data. + + :arg :py:class:`Packet` packet The packet to send + + """ + self.log.debug("Queuing packet to %s: %s" % (self, packet)) + self.send_queue.append(packet.toBinary()) + self.sendQueuedData() + + def sendRaw(self, data): + """Append raw data to this connection's send queue. The Client or + Server must manage actually sending the data. + + :arg bytes data The raw data to send + + """ + self.log.debug("Queuing data to %s: %s" % (self, data)) + self.send_queue.append(data) + self.sendQueuedData() + + def sendQueuedData(self): + """Send previously queued data to the socket.""" + while len(self.send_queue): + data = self.send_queue.pop(0) + r = 0 + try: + r = self.conn.send(data) + except ssl.SSLError as e: + if e.errno == ssl.SSL_ERROR_WANT_READ: + pass + elif e.errno == ssl.SSL_ERROR_WANT_WRITE: + pass + else: + raise + except socket.error as e: + if e.errno == errno.EAGAIN: + self.log.debug("Write operation on %s would block" + % self) + return + raise + finally: + data = data[r:] + if data: + self.send_queue.insert(0, data) + + +class ServerConnection(NonBlockingConnection): """A Connection to a Gearman Client.""" def __init__(self, addr, conn, use_ssl, client_id): @@ -2196,9 +2282,13 @@ class ServerConnection(Connection): (client_id,)) else: self.log = logging.getLogger("gear.ServerConnection") + self.send_queue = [] + self.admin_requests = [] self.host = addr[0] self.port = addr[1] self.conn = conn + self.conn.setblocking(0) + self.input_buffer = b'' self.use_ssl = use_ssl self.client_id = None self.functions = set() @@ -2214,6 +2304,12 @@ class ServerConnection(Connection): def _getAdminRequest(self): return ServerAdminRequest(self) + def _putAdminRequest(self, req): + # The server does not need to keep track of admin requests + # that have been partially received; it will simply create a + # new instance the next time it tries to read. + pass + def __repr__(self): return '<gear.ServerConnection 0x%x name: %s host: %s port: %s>' % ( id(self), self.client_id, self.host, self.port) @@ -2238,6 +2334,11 @@ class Server(BaseClientServer): access control rules to its connections. """ + edge_bitmask = select.EPOLLET + error_bitmask = (select.EPOLLERR | select.EPOLLHUP | edge_bitmask) + read_bitmask = (select.EPOLLIN | error_bitmask) + readwrite_bitmask = (select.EPOLLOUT | read_bitmask) + def __init__(self, port=4730, ssl_key=None, ssl_cert=None, ssl_ca=None, statsd_host=None, statsd_port=8125, statsd_prefix=None, server_id=None, acl=None): @@ -2253,6 +2354,9 @@ class Server(BaseClientServer): self.max_handle = 0 self.acl = acl self.connect_wake_read, self.connect_wake_write = os.pipe() + self.poll = select.epoll() + # Reverse mapping of fd -> connection + self.connection_map = {} self.use_ssl = False if all([self.ssl_key, self.ssl_cert, self.ssl_ca]): @@ -2285,6 +2389,11 @@ class Server(BaseClientServer): self.port = self.socket.getsockname()[1] super(Server, self).__init__(server_id) + + # Register the wake pipe so that we can break if we need to + # reconfigure connections + self.poll.register(self.wake_read, self.read_bitmask) + if server_id: self.log = logging.getLogger("gear.Server.%s" % (self.client_id,)) else: @@ -2342,11 +2451,114 @@ class Server(BaseClientServer): self.connections_condition.acquire() try: self.active_connections.append(conn) + self._registerConnection(conn) self.connections_condition.notifyAll() - os.write(self.wake_write, b'1\n') finally: self.connections_condition.release() + def readFromConnection(self, conn): + while True: + self.log.debug("Processing input on %s" % conn) + try: + p = conn.readPacket() + except socket.error as e: + if e.errno == errno.EAGAIN: + # Read operation would block, we're done until + # epoll flags this connection again + return + raise + if p: + if isinstance(p, Packet): + self.handlePacket(p) + else: + self.handleAdminRequest(p) + else: + self.log.debug("Received no data on %s" % conn) + raise DisconnectError() + + def writeToConnection(self, conn): + self.log.debug("Processing output on %s" % conn) + conn.sendQueuedData() + + def _processPollEvent(self, conn, event): + # This should do whatever is necessary to process a connection + # that has triggered a poll event. It should generally not + # raise exceptions so as to avoid restarting the poll loop. + # The exception handlers here can raise exceptions and if they + # do, it's okay, the poll loop will be restarted. + try: + if event & (select.EPOLLERR | select.EPOLLHUP): + self.log.debug("Received error event on %s: %s" % ( + conn, event)) + raise DisconnectError() + if event & select.POLLIN: + self.readFromConnection(conn) + if event & select.POLLOUT: + self.writeToConnection(conn) + except socket.error as e: + if e.errno == errno.ECONNRESET: + self.log.debug("Connection reset by peer: %s" % (conn,)) + self._lostConnection(conn) + return + raise + except DisconnectError: + # Our inner method says we should quietly drop + # this connection + self._lostConnection(conn) + return + except Exception: + self.log.exception("Exception reading or writing " + "from %s:" % (conn,)) + self._lostConnection(conn) + return + + def _flushAllConnections(self): + # If we need to restart the poll loop, we need to make sure + # there are no pending data on any connection. Simulate poll + # in+out events on every connection. + # + # If this method raises an exception, the poll loop wil + # restart again. + # + # No need to get the lock since this is called within the poll + # loop and therefore the list in guaranteed never to shrink. + connections = self.active_connections[:] + for conn in connections: + self._processPollEvent(conn, select.POLLIN | select.POLLOUT) + + def _doPollLoop(self): + # Outer run method of poll thread. + while self.running: + try: + self._pollLoop() + except Exception: + self.log.exception("Exception in poll loop:") + + def _pollLoop(self): + # Inner method of poll loop. + self.log.debug("Preparing to poll") + # Ensure there are no pending data. + self._flushAllConnections() + while self.running: + self.log.debug("Polling %s connections" % + len(self.active_connections)) + ret = self.poll.poll() + # Since we're using edge-triggering, we need to make sure + # that every file descriptor in 'ret' is processed. + for fd, event in ret: + if fd == self.wake_read: + # This means we're exiting, so we can ignore the + # rest of 'ret'. + self.log.debug("Woken by pipe") + while True: + if os.read(self.wake_read, 1) == b'\n': + break + return + # In the unlikely event this raises an exception, the + # loop will be restarted. + conn = self.connection_map[fd] + self._processPollEvent(conn, event) + def _shutdown(self): super(Server, self)._shutdown() os.write(self.connect_wake_write, b'1\n') @@ -2357,10 +2569,34 @@ class Server(BaseClientServer): os.close(self.connect_wake_read) os.close(self.connect_wake_write) + def _registerConnection(self, conn): + # Register the connection with the poll object + # Call while holding the connection condition + self.log.debug("Registering %s" % conn) + self.connection_map[conn.conn.fileno()] = conn + self.poll.register(conn.conn.fileno(), self.readwrite_bitmask) + + def _unregisterConnection(self, conn): + # Unregister the connection with the poll object + # Call while holding the connection condition + self.log.debug("Unregistering %s" % conn) + fd = conn.conn.fileno() + if fd not in self.connection_map: + return + try: + self.poll.unregister(fd) + except KeyError: + pass + try: + del self.connection_map[fd] + except KeyError: + pass + def _lostConnection(self, conn): # Called as soon as a connection is detected as faulty. self.log.info("Marking %s as disconnected" % conn) self.connections_condition.acquire() + self._unregisterConnection(conn) try: jobs = conn.related_jobs.values() if conn in self.active_connections: @@ -2378,6 +2614,20 @@ class Server(BaseClientServer): self.log.exception("Sending WORK_FAIL to client after " "worker disconnect failed:") self._removeJob(job) + try: + conn.conn.shutdown(socket.SHUT_RDWR) + except socket.error as e: + if e.errno != errno.ENOTCONN: + self.log.exception("Unable to shutdown socket " + "for connection %s" % (conn,)) + except Exception: + self.log.exception("Unable to shutdown socket " + "for connection %s" % (conn,)) + try: + conn.conn.close() + except Exception: + self.log.exception("Unable to close socket " + "for connection %s" % (conn,)) self._updateStats() def _removeJob(self, job, dequeue=True): @@ -2766,7 +3016,8 @@ class Server(BaseClientServer): handle = packet.getArgument(0) job = self.jobs.get(handle) if not job: - raise UnknownJobError() + self.log.info("Received packet %s for unknown job" % (packet,)) + return job.numerator = packet.getArgument(1) job.denominator = packet.getArgument(2) self.handlePassthrough(packet) @@ -2775,7 +3026,8 @@ class Server(BaseClientServer): handle = packet.getArgument(0) job = self.jobs.get(handle) if not job: - raise UnknownJobError() + self.log.info("Received packet %s for unknown job" % (packet,)) + return packet.code = constants.RES job.client_connection.sendPacket(packet) if finished: |