summaryrefslogtreecommitdiff
path: root/lib/py
diff options
context:
space:
mode:
authorNobuaki Sukegawa <nsuke@apache.org>2017-02-12 21:11:36 +0900
committerNobuaki Sukegawa <nsuke@apache.org>2017-02-12 21:11:36 +0900
commit4626fd889da53462023d42d99d1d82e13a6e890f (patch)
treebce5eda5b1e48eab0f097ee90aa25c91ab5e3d23 /lib/py
parentbff044667caf8a8c2b0dd30ed11b328ff2902cf5 (diff)
downloadthrift-4626fd889da53462023d42d99d1d82e13a6e890f.tar.gz
THRIFT-3938 Python TNonblockingServer does not work with SSL
This closes #1100
Diffstat (limited to 'lib/py')
-rw-r--r--lib/py/src/server/TNonblockingServer.py131
-rw-r--r--lib/py/src/transport/TTransport.py4
2 files changed, 78 insertions, 57 deletions
diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py
index 87031c137..67ee04ed5 100644
--- a/lib/py/src/server/TNonblockingServer.py
+++ b/lib/py/src/server/TNonblockingServer.py
@@ -31,6 +31,7 @@ import socket
import struct
import threading
+from collections import deque
from six.moves import queue
from thrift.transport import TTransport
@@ -58,7 +59,7 @@ class Worker(threading.Thread):
processor.process(iprot, oprot)
callback(True, otrans.getvalue())
except Exception:
- logger.exception("Exception while processing request")
+ logger.exception("Exception while processing request", exc_info=True)
callback(False, b'')
WAIT_LEN = 0
@@ -85,10 +86,23 @@ def socket_exception(func):
try:
return func(self, *args, **kwargs)
except socket.error:
+ logger.debug('ignoring socket exception', exc_info=True)
self.close()
return read
+class Message(object):
+ def __init__(self, offset, len_, header):
+ self.offset = offset
+ self.len = len_
+ self.buffer = None
+ self.is_header = header
+
+ @property
+ def end(self):
+ return self.offset + self.len
+
+
class Connection(object):
"""Basic class is represented connection.
@@ -106,68 +120,60 @@ class Connection(object):
self.socket.setblocking(False)
self.status = WAIT_LEN
self.len = 0
- self.message = b''
+ self.received = deque()
+ self._reading = Message(0, 4, True)
+ self._rbuf = b''
+ self._wbuf = b''
self.lock = threading.Lock()
self.wake_up = wake_up
-
- def _read_len(self):
- """Reads length of request.
-
- It's a safer alternative to self.socket.recv(4)
- """
- read = self.socket.recv(4 - len(self.message))
- if len(read) == 0:
- # if we read 0 bytes and self.message is empty, then
- # the client closed the connection
- if len(self.message) != 0:
- logger.error("can't read frame size from socket")
- self.close()
- return
- self.message += read
- if len(self.message) == 4:
- self.len, = struct.unpack('!i', self.message)
- if self.len < 0:
- logger.error("negative frame size, it seems client "
- "doesn't use FramedTransport")
- self.close()
- elif self.len == 0:
- logger.error("empty frame, it's really strange")
- self.close()
- else:
- self.message = b''
- self.status = WAIT_MESSAGE
+ self.remaining = False
@socket_exception
def read(self):
"""Reads data from stream and switch state."""
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
- if self.status == WAIT_LEN:
- self._read_len()
- # go back to the main loop here for simplicity instead of
- # falling through, even though there is a good chance that
- # the message is already available
- elif self.status == WAIT_MESSAGE:
- read = self.socket.recv(self.len - len(self.message))
- if len(read) == 0:
- logger.error("can't read frame from socket (get %d of "
- "%d bytes)" % (len(self.message), self.len))
+ assert not self.received
+ buf_size = 8192
+ first = True
+ done = False
+ while not done:
+ read = self.socket.recv(buf_size)
+ rlen = len(read)
+ done = rlen < buf_size
+ self._rbuf += read
+ if first and rlen == 0:
+ if self.status != WAIT_LEN or self._rbuf:
+ logger.error('could not read frame from socket')
+ else:
+ logger.debug('read zero length. client might have disconnected')
self.close()
- return
- self.message += read
- if len(self.message) == self.len:
+ while len(self._rbuf) >= self._reading.end:
+ if self._reading.is_header:
+ mlen, = struct.unpack('!i', self._rbuf[:4])
+ self._reading = Message(self._reading.end, mlen, False)
+ self.status = WAIT_MESSAGE
+ else:
+ self._reading.buffer = self._rbuf
+ self.received.append(self._reading)
+ self._rbuf = self._rbuf[self._reading.end:]
+ self._reading = Message(0, 4, True)
+ first = False
+ if self.received:
self.status = WAIT_PROCESS
+ break
+ self.remaining = not done
@socket_exception
def write(self):
"""Writes data from socket and switch state."""
assert self.status == SEND_ANSWER
- sent = self.socket.send(self.message)
- if sent == len(self.message):
+ sent = self.socket.send(self._wbuf)
+ if sent == len(self._wbuf):
self.status = WAIT_LEN
- self.message = b''
+ self._wbuf = b''
self.len = 0
else:
- self.message = self.message[sent:]
+ self._wbuf = self.message[sent:]
@locked
def ready(self, all_ok, message):
@@ -190,10 +196,10 @@ class Connection(object):
self.len = 0
if len(message) == 0:
# it was a oneway request, do not write answer
- self.message = b''
+ self._wbuf = b''
self.status = WAIT_LEN
else:
- self.message = struct.pack('!i', len(message)) + message
+ self._wbuf = struct.pack('!i', len(message)) + message
self.status = SEND_ANSWER
self.wake_up()
@@ -292,14 +298,20 @@ class TNonblockingServer(object):
"""Does select on open connections."""
readable = [self.socket.handle.fileno(), self._read.fileno()]
writable = []
+ remaining = []
for i, connection in list(self.clients.items()):
if connection.is_readable():
readable.append(connection.fileno())
+ if connection.remaining or connection.received:
+ remaining.append(connection.fileno())
if connection.is_writeable():
writable.append(connection.fileno())
if connection.is_closed():
del self.clients[i]
- return select.select(readable, writable, readable)
+ if remaining:
+ return remaining, [], [], False
+ else:
+ return select.select(readable, writable, readable) + (True,)
def handle(self):
"""Handle requests.
@@ -307,20 +319,27 @@ class TNonblockingServer(object):
WARNING! You must call prepare() BEFORE calling handle()
"""
assert self.prepared, "You have to call prepare before handle"
- rset, wset, xset = self._select()
+ rset, wset, xset, selected = self._select()
for readable in rset:
if readable == self._read.fileno():
# don't care i just need to clean readable flag
self._read.recv(1024)
elif readable == self.socket.handle.fileno():
- client = self.socket.accept().handle
- self.clients[client.fileno()] = Connection(client,
- self.wake_up)
+ try:
+ client = self.socket.accept()
+ if client:
+ self.clients[client.handle.fileno()] = Connection(client.handle,
+ self.wake_up)
+ except socket.error:
+ logger.debug('error while accepting', exc_info=True)
else:
connection = self.clients[readable]
- connection.read()
- if connection.status == WAIT_PROCESS:
- itransport = TTransport.TMemoryBuffer(connection.message)
+ if selected:
+ connection.read()
+ if connection.received:
+ connection.status = WAIT_PROCESS
+ msg = connection.received.popleft()
+ itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset)
otransport = TTransport.TMemoryBuffer()
iprot = self.in_protocol.getProtocol(itransport)
oprot = self.out_protocol.getProtocol(otransport)
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index a3e97253f..ae01ac675 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -206,7 +206,7 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
TODO(dreiss): Make this work like the C++ version.
"""
- def __init__(self, value=None):
+ def __init__(self, value=None, offset=0):
"""value -- a value to read from for stringio
If value is set, this will be a transport for reading,
@@ -215,6 +215,8 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
self._buffer = BufferIO(value)
else:
self._buffer = BufferIO()
+ if offset:
+ self._buffer.seek(offset)
def isOpen(self):
return not self._buffer.closed