diff options
author | Nobuaki Sukegawa <nsuke@apache.org> | 2017-02-12 21:11:36 +0900 |
---|---|---|
committer | Nobuaki Sukegawa <nsuke@apache.org> | 2017-02-12 21:11:36 +0900 |
commit | 4626fd889da53462023d42d99d1d82e13a6e890f (patch) | |
tree | bce5eda5b1e48eab0f097ee90aa25c91ab5e3d23 /lib/py | |
parent | bff044667caf8a8c2b0dd30ed11b328ff2902cf5 (diff) | |
download | thrift-4626fd889da53462023d42d99d1d82e13a6e890f.tar.gz |
THRIFT-3938 Python TNonblockingServer does not work with SSL
This closes #1100
Diffstat (limited to 'lib/py')
-rw-r--r-- | lib/py/src/server/TNonblockingServer.py | 131 | ||||
-rw-r--r-- | lib/py/src/transport/TTransport.py | 4 |
2 files changed, 78 insertions, 57 deletions
diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py index 87031c137..67ee04ed5 100644 --- a/lib/py/src/server/TNonblockingServer.py +++ b/lib/py/src/server/TNonblockingServer.py @@ -31,6 +31,7 @@ import socket import struct import threading +from collections import deque from six.moves import queue from thrift.transport import TTransport @@ -58,7 +59,7 @@ class Worker(threading.Thread): processor.process(iprot, oprot) callback(True, otrans.getvalue()) except Exception: - logger.exception("Exception while processing request") + logger.exception("Exception while processing request", exc_info=True) callback(False, b'') WAIT_LEN = 0 @@ -85,10 +86,23 @@ def socket_exception(func): try: return func(self, *args, **kwargs) except socket.error: + logger.debug('ignoring socket exception', exc_info=True) self.close() return read +class Message(object): + def __init__(self, offset, len_, header): + self.offset = offset + self.len = len_ + self.buffer = None + self.is_header = header + + @property + def end(self): + return self.offset + self.len + + class Connection(object): """Basic class is represented connection. @@ -106,68 +120,60 @@ class Connection(object): self.socket.setblocking(False) self.status = WAIT_LEN self.len = 0 - self.message = b'' + self.received = deque() + self._reading = Message(0, 4, True) + self._rbuf = b'' + self._wbuf = b'' self.lock = threading.Lock() self.wake_up = wake_up - - def _read_len(self): - """Reads length of request. - - It's a safer alternative to self.socket.recv(4) - """ - read = self.socket.recv(4 - len(self.message)) - if len(read) == 0: - # if we read 0 bytes and self.message is empty, then - # the client closed the connection - if len(self.message) != 0: - logger.error("can't read frame size from socket") - self.close() - return - self.message += read - if len(self.message) == 4: - self.len, = struct.unpack('!i', self.message) - if self.len < 0: - logger.error("negative frame size, it seems client " - "doesn't use FramedTransport") - self.close() - elif self.len == 0: - logger.error("empty frame, it's really strange") - self.close() - else: - self.message = b'' - self.status = WAIT_MESSAGE + self.remaining = False @socket_exception def read(self): """Reads data from stream and switch state.""" assert self.status in (WAIT_LEN, WAIT_MESSAGE) - if self.status == WAIT_LEN: - self._read_len() - # go back to the main loop here for simplicity instead of - # falling through, even though there is a good chance that - # the message is already available - elif self.status == WAIT_MESSAGE: - read = self.socket.recv(self.len - len(self.message)) - if len(read) == 0: - logger.error("can't read frame from socket (get %d of " - "%d bytes)" % (len(self.message), self.len)) + assert not self.received + buf_size = 8192 + first = True + done = False + while not done: + read = self.socket.recv(buf_size) + rlen = len(read) + done = rlen < buf_size + self._rbuf += read + if first and rlen == 0: + if self.status != WAIT_LEN or self._rbuf: + logger.error('could not read frame from socket') + else: + logger.debug('read zero length. client might have disconnected') self.close() - return - self.message += read - if len(self.message) == self.len: + while len(self._rbuf) >= self._reading.end: + if self._reading.is_header: + mlen, = struct.unpack('!i', self._rbuf[:4]) + self._reading = Message(self._reading.end, mlen, False) + self.status = WAIT_MESSAGE + else: + self._reading.buffer = self._rbuf + self.received.append(self._reading) + self._rbuf = self._rbuf[self._reading.end:] + self._reading = Message(0, 4, True) + first = False + if self.received: self.status = WAIT_PROCESS + break + self.remaining = not done @socket_exception def write(self): """Writes data from socket and switch state.""" assert self.status == SEND_ANSWER - sent = self.socket.send(self.message) - if sent == len(self.message): + sent = self.socket.send(self._wbuf) + if sent == len(self._wbuf): self.status = WAIT_LEN - self.message = b'' + self._wbuf = b'' self.len = 0 else: - self.message = self.message[sent:] + self._wbuf = self.message[sent:] @locked def ready(self, all_ok, message): @@ -190,10 +196,10 @@ class Connection(object): self.len = 0 if len(message) == 0: # it was a oneway request, do not write answer - self.message = b'' + self._wbuf = b'' self.status = WAIT_LEN else: - self.message = struct.pack('!i', len(message)) + message + self._wbuf = struct.pack('!i', len(message)) + message self.status = SEND_ANSWER self.wake_up() @@ -292,14 +298,20 @@ class TNonblockingServer(object): """Does select on open connections.""" readable = [self.socket.handle.fileno(), self._read.fileno()] writable = [] + remaining = [] for i, connection in list(self.clients.items()): if connection.is_readable(): readable.append(connection.fileno()) + if connection.remaining or connection.received: + remaining.append(connection.fileno()) if connection.is_writeable(): writable.append(connection.fileno()) if connection.is_closed(): del self.clients[i] - return select.select(readable, writable, readable) + if remaining: + return remaining, [], [], False + else: + return select.select(readable, writable, readable) + (True,) def handle(self): """Handle requests. @@ -307,20 +319,27 @@ class TNonblockingServer(object): WARNING! You must call prepare() BEFORE calling handle() """ assert self.prepared, "You have to call prepare before handle" - rset, wset, xset = self._select() + rset, wset, xset, selected = self._select() for readable in rset: if readable == self._read.fileno(): # don't care i just need to clean readable flag self._read.recv(1024) elif readable == self.socket.handle.fileno(): - client = self.socket.accept().handle - self.clients[client.fileno()] = Connection(client, - self.wake_up) + try: + client = self.socket.accept() + if client: + self.clients[client.handle.fileno()] = Connection(client.handle, + self.wake_up) + except socket.error: + logger.debug('error while accepting', exc_info=True) else: connection = self.clients[readable] - connection.read() - if connection.status == WAIT_PROCESS: - itransport = TTransport.TMemoryBuffer(connection.message) + if selected: + connection.read() + if connection.received: + connection.status = WAIT_PROCESS + msg = connection.received.popleft() + itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset) otransport = TTransport.TMemoryBuffer() iprot = self.in_protocol.getProtocol(itransport) oprot = self.out_protocol.getProtocol(otransport) diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py index a3e97253f..ae01ac675 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/transport/TTransport.py @@ -206,7 +206,7 @@ class TMemoryBuffer(TTransportBase, CReadableTransport): TODO(dreiss): Make this work like the C++ version. """ - def __init__(self, value=None): + def __init__(self, value=None, offset=0): """value -- a value to read from for stringio If value is set, this will be a transport for reading, @@ -215,6 +215,8 @@ class TMemoryBuffer(TTransportBase, CReadableTransport): self._buffer = BufferIO(value) else: self._buffer = BufferIO() + if offset: + self._buffer.seek(offset) def isOpen(self): return not self._buffer.closed |