summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcontrib/async-test/test-leaf.py15
-rw-r--r--contrib/fb303/py/fb303/FacebookBase.py77
-rw-r--r--contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py39
-rw-r--r--contrib/fb303/py/setup.py43
-rwxr-xr-xcontrib/parse_profiling.py12
-rw-r--r--contrib/zeromq/TZmqClient.py81
-rw-r--r--contrib/zeromq/TZmqServer.py87
-rwxr-xr-xcontrib/zeromq/test-client.py40
-rwxr-xr-xcontrib/zeromq/test-server.py32
-rw-r--r--lib/py/setup.py75
-rw-r--r--lib/py/src/TMultiplexedProcessor.py49
-rw-r--r--lib/py/src/TSCons.py17
-rw-r--r--lib/py/src/TTornado.py5
-rw-r--r--lib/py/src/Thrift.py288
-rw-r--r--lib/py/src/compat.py51
-rw-r--r--lib/py/src/protocol/TBase.py131
-rw-r--r--lib/py/src/protocol/TBinaryProtocol.py492
-rw-r--r--lib/py/src/protocol/TCompactProtocol.py717
-rw-r--r--lib/py/src/protocol/TJSONProtocol.py939
-rw-r--r--lib/py/src/protocol/TMultiplexedProtocol.py27
-rw-r--r--lib/py/src/protocol/TProtocol.py638
-rw-r--r--lib/py/src/protocol/TProtocolDecorator.py44
-rw-r--r--lib/py/src/protocol/__init__.py3
-rw-r--r--lib/py/src/server/THttpServer.py100
-rw-r--r--lib/py/src/server/TNonblockingServer.py15
-rw-r--r--lib/py/src/server/TProcessPoolServer.py9
-rw-r--r--lib/py/src/server/TServer.py465
-rw-r--r--lib/py/src/transport/THttpClient.py246
-rw-r--r--lib/py/src/transport/TSSLSocket.py646
-rw-r--r--lib/py/src/transport/TSocket.py290
-rw-r--r--lib/py/src/transport/TTransport.py691
-rw-r--r--lib/py/src/transport/TTwisted.py8
-rw-r--r--lib/py/src/transport/TZlibTransport.py420
-rw-r--r--lib/py/test/_import_local_thrift.py8
-rw-r--r--lib/py/test/test_sslsocket.py432
-rw-r--r--lib/py/test/thrift_json.py22
-rw-r--r--test/crossrunner/collect.py188
-rw-r--r--test/crossrunner/compat.py26
-rw-r--r--test/crossrunner/report.py740
-rw-r--r--test/crossrunner/run.py570
-rw-r--r--test/crossrunner/test.py212
-rw-r--r--test/crossrunner/util.py16
-rw-r--r--test/features/container_limit.py106
-rw-r--r--test/features/local_thrift/__init__.py12
-rw-r--r--test/features/string_limit.py88
-rw-r--r--test/features/theader_binary.py84
-rw-r--r--test/features/util.py42
-rwxr-xr-xtest/py.tornado/test_suite.py2
-rwxr-xr-xtest/py.twisted/test_suite.py19
-rwxr-xr-xtest/py/FastbinaryTest.py194
-rwxr-xr-xtest/py/RunClientServer.py482
-rwxr-xr-xtest/py/SerializationTest.py617
-rw-r--r--test/py/TSimpleJSONProtocolTest.py150
-rwxr-xr-xtest/py/TestClient.py522
-rwxr-xr-xtest/py/TestEof.py174
-rwxr-xr-xtest/py/TestFrozen.py150
-rwxr-xr-xtest/py/TestServer.py556
-rwxr-xr-xtest/py/TestSocket.py10
-rwxr-xr-xtest/test.py220
-rwxr-xr-xtutorial/php/runserver.py3
60 files changed, 6244 insertions, 6193 deletions
diff --git a/contrib/async-test/test-leaf.py b/contrib/async-test/test-leaf.py
index 8b7c3e3f5..4ea4a9b8c 100755
--- a/contrib/async-test/test-leaf.py
+++ b/contrib/async-test/test-leaf.py
@@ -7,16 +7,17 @@ from thrift.protocol import TBinaryProtocol
from thrift.server import THttpServer
from aggr import Aggr
+
class AggrHandler(Aggr.Iface):
- def __init__(self):
- self.values = []
+ def __init__(self):
+ self.values = []
- def addValue(self, value):
- self.values.append(value)
+ def addValue(self, value):
+ self.values.append(value)
- def getValues(self, ):
- time.sleep(1)
- return self.values
+ def getValues(self, ):
+ time.sleep(1)
+ return self.values
processor = Aggr.Processor(AggrHandler())
pfactory = TBinaryProtocol.TBinaryProtocolFactory()
diff --git a/contrib/fb303/py/fb303/FacebookBase.py b/contrib/fb303/py/fb303/FacebookBase.py
index 685ff20f3..07db10cd3 100644
--- a/contrib/fb303/py/fb303/FacebookBase.py
+++ b/contrib/fb303/py/fb303/FacebookBase.py
@@ -24,59 +24,60 @@ import FacebookService
import thrift.reflection.limited
from ttypes import fb_status
+
class FacebookBase(FacebookService.Iface):
- def __init__(self, name):
- self.name = name
- self.alive = int(time.time())
- self.counters = {}
+ def __init__(self, name):
+ self.name = name
+ self.alive = int(time.time())
+ self.counters = {}
- def getName(self, ):
- return self.name
+ def getName(self, ):
+ return self.name
- def getVersion(self, ):
- return ''
+ def getVersion(self, ):
+ return ''
- def getStatus(self, ):
- return fb_status.ALIVE
+ def getStatus(self, ):
+ return fb_status.ALIVE
- def getCounters(self):
- return self.counters
+ def getCounters(self):
+ return self.counters
- def resetCounter(self, key):
- self.counters[key] = 0
+ def resetCounter(self, key):
+ self.counters[key] = 0
- def getCounter(self, key):
- if self.counters.has_key(key):
- return self.counters[key]
- return 0
+ def getCounter(self, key):
+ if self.counters.has_key(key):
+ return self.counters[key]
+ return 0
- def incrementCounter(self, key):
- self.counters[key] = self.getCounter(key) + 1
+ def incrementCounter(self, key):
+ self.counters[key] = self.getCounter(key) + 1
- def setOption(self, key, value):
- pass
+ def setOption(self, key, value):
+ pass
- def getOption(self, key):
- return ""
+ def getOption(self, key):
+ return ""
- def getOptions(self):
- return {}
+ def getOptions(self):
+ return {}
- def getOptions(self):
- return {}
+ def getOptions(self):
+ return {}
- def aliveSince(self):
- return self.alive
+ def aliveSince(self):
+ return self.alive
- def getCpuProfile(self, duration):
- return ""
+ def getCpuProfile(self, duration):
+ return ""
- def getLimitedReflection(self):
- return thrift.reflection.limited.Service()
+ def getLimitedReflection(self):
+ return thrift.reflection.limited.Service()
- def reinitialize(self):
- pass
+ def reinitialize(self):
+ pass
- def shutdown(self):
- pass
+ def shutdown(self):
+ pass
diff --git a/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py
index 4f8ce9933..4b1c25728 100644
--- a/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py
+++ b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py
@@ -19,7 +19,8 @@
# under the License.
#
-import sys, os
+import sys
+import os
from optparse import OptionParser
from thrift.Thrift import *
@@ -31,11 +32,12 @@ from thrift.protocol import TBinaryProtocol
from fb303 import *
from fb303.ttypes import *
+
def service_ctrl(
- command,
- port,
- trans_factory = None,
- prot_factory = None):
+ command,
+ port,
+ trans_factory=None,
+ prot_factory=None):
"""
service_ctrl is a generic function to execute standard fb303 functions
@@ -66,19 +68,19 @@ def service_ctrl(
return 3
# scalar commands
- if command in ["version","alive","name"]:
+ if command in ["version", "alive", "name"]:
try:
- result = fb303_wrapper(command, port, trans_factory, prot_factory)
+ result = fb303_wrapper(command, port, trans_factory, prot_factory)
print result
return 0
except:
- print "failed to get ",command
+ print "failed to get ", command
return 3
# counters
if command in ["counters"]:
try:
- counters = fb303_wrapper('counters', port, trans_factory, prot_factory)
+ counters = fb303_wrapper('counters', port, trans_factory, prot_factory)
for counter in counters:
print "%s: %d" % (counter, counters[counter])
return 0
@@ -86,11 +88,10 @@ def service_ctrl(
print "failed to get counters"
return 3
-
# Only root should be able to run the following commands
if os.getuid() == 0:
# async commands
- if command in ["stop","reload"] :
+ if command in ["stop", "reload"]:
try:
fb303_wrapper(command, port, trans_factory, prot_factory)
return 0
@@ -98,23 +99,21 @@ def service_ctrl(
print "failed to tell the service to ", command
return 3
else:
- if command in ["stop","reload"]:
+ if command in ["stop", "reload"]:
print "root privileges are required to stop or reload the service."
return 4
print "The following commands are available:"
- for command in ["counters","name","version","alive","status"]:
+ for command in ["counters", "name", "version", "alive", "status"]:
print "\t%s" % command
print "The following commands are available for users with root privileges:"
- for command in ["stop","reload"]:
+ for command in ["stop", "reload"]:
print "\t%s" % command
+ return 0
- return 0;
-
-
-def fb303_wrapper(command, port, trans_factory = None, prot_factory = None):
+def fb303_wrapper(command, port, trans_factory=None, prot_factory=None):
sock = TSocket.TSocket('localhost', port)
# use input transport factory if provided
@@ -179,11 +178,11 @@ def main():
# parse command line options
parser = OptionParser()
- commands=["stop","counters","status","reload","version","name","alive"]
+ commands = ["stop", "counters", "status", "reload", "version", "name", "alive"]
parser.add_option("-c", "--command", dest="command", help="execute this API",
choices=commands, default="status")
- parser.add_option("-p","--port",dest="port",help="the service's port",
+ parser.add_option("-p", "--port", dest="port", help="the service's port",
default=9082)
(options, args) = parser.parse_args()
diff --git a/contrib/fb303/py/setup.py b/contrib/fb303/py/setup.py
index 6710c8f61..4321ce258 100644
--- a/contrib/fb303/py/setup.py
+++ b/contrib/fb303/py/setup.py
@@ -24,26 +24,25 @@ try:
from setuptools import setup, Extension
except:
from distutils.core import setup, Extension, Command
-
-setup(name = 'thrift_fb303',
- version = '1.0.0-dev',
- description = 'Python bindings for the Apache Thrift FB303',
- author = ['Thrift Developers'],
- author_email = ['dev@thrift.apache.org'],
- url = 'http://thrift.apache.org',
- license = 'Apache License 2.0',
- packages = [
- 'fb303',
- 'fb303_scripts',
- ],
- classifiers = [
- 'Development Status :: 5 - Production/Stable',
- 'Environment :: Console',
- 'Intended Audience :: Developers',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Topic :: Software Development :: Libraries',
- 'Topic :: System :: Networking'
- ],
-)
+setup(name='thrift_fb303',
+ version='1.0.0-dev',
+ description='Python bindings for the Apache Thrift FB303',
+ author=['Thrift Developers'],
+ author_email=['dev@thrift.apache.org'],
+ url='http://thrift.apache.org',
+ license='Apache License 2.0',
+ packages=[
+ 'fb303',
+ 'fb303_scripts',
+ ],
+ classifiers=[
+ 'Development Status :: 5 - Production/Stable',
+ 'Environment :: Console',
+ 'Intended Audience :: Developers',
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 2',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: System :: Networking'
+ ],
+ )
diff --git a/contrib/parse_profiling.py b/contrib/parse_profiling.py
index 3d46fb832..0be5f29ed 100755
--- a/contrib/parse_profiling.py
+++ b/contrib/parse_profiling.py
@@ -46,6 +46,8 @@ class AddressInfo(object):
g_addrs_by_filename = {}
+
+
def get_address(filename, address):
"""
Retrieve an AddressInfo object for the specified object file and address.
@@ -103,12 +105,12 @@ def translate_file_addresses(filename, addresses, options):
idx = file_and_line.rfind(':')
if idx < 0:
msg = 'expected file and line number from addr2line; got %r' % \
- (file_and_line,)
+ (file_and_line,)
msg += '\nfile=%r, address=%r' % (filename, address.address)
raise Exception(msg)
address.sourceFile = file_and_line[:idx]
- address.sourceLine = file_and_line[idx+1:]
+ address.sourceLine = file_and_line[idx + 1:]
(remaining_out, cmd_err) = proc.communicate()
retcode = proc.wait()
@@ -180,7 +182,7 @@ def process_file(in_file, out_file, options):
virt_call_regex = re.compile(r'^\s*T_VIRTUAL_CALL: (\d+) calls on (.*):$')
gen_prot_regex = re.compile(
- r'^\s*T_GENERIC_PROTOCOL: (\d+) calls to (.*) with a (.*):$')
+ r'^\s*T_GENERIC_PROTOCOL: (\d+) calls to (.*) with a (.*):$')
bt_regex = re.compile(r'^\s*#(\d+)\s*(.*) \[(0x[0-9A-Za-z]+)\]$')
# Parse all of the input, and store it as Entry objects
@@ -209,7 +211,7 @@ def process_file(in_file, out_file, options):
# "_Z" to the type name to make it look like an external name.
type_name = '_Z' + type_name
header = 'T_VIRTUAL_CALL: %d calls on "%s"' % \
- (num_calls, type_name)
+ (num_calls, type_name)
if current_entry is not None:
entries.append(current_entry)
current_entry = Entry(header)
@@ -224,7 +226,7 @@ def process_file(in_file, out_file, options):
type_name1 = '_Z' + type_name1
type_name2 = '_Z' + type_name2
header = 'T_GENERIC_PROTOCOL: %d calls to "%s" with a "%s"' % \
- (num_calls, type_name1, type_name2)
+ (num_calls, type_name1, type_name2)
if current_entry is not None:
entries.append(current_entry)
current_entry = Entry(header)
diff --git a/contrib/zeromq/TZmqClient.py b/contrib/zeromq/TZmqClient.py
index d56069733..1bd60a1e5 100644
--- a/contrib/zeromq/TZmqClient.py
+++ b/contrib/zeromq/TZmqClient.py
@@ -20,44 +20,45 @@ import zmq
from cStringIO import StringIO
from thrift.transport.TTransport import TTransportBase, CReadableTransport
+
class TZmqClient(TTransportBase, CReadableTransport):
- def __init__(self, ctx, endpoint, sock_type):
- self._sock = ctx.socket(sock_type)
- self._endpoint = endpoint
- self._wbuf = StringIO()
- self._rbuf = StringIO()
-
- def open(self):
- self._sock.connect(self._endpoint)
-
- def read(self, size):
- ret = self._rbuf.read(size)
- if len(ret) != 0:
- return ret
- self._read_message()
- return self._rbuf.read(size)
-
- def _read_message(self):
- msg = self._sock.recv()
- self._rbuf = StringIO(msg)
-
- def write(self, buf):
- self._wbuf.write(buf)
-
- def flush(self):
- msg = self._wbuf.getvalue()
- self._wbuf = StringIO()
- self._sock.send(msg)
-
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self._rbuf
-
- # NOTE: This will probably not actually work.
- def cstringio_refill(self, prefix, reqlen):
- while len(prefix) < reqlen:
- self.read_message()
- prefix += self._rbuf.getvalue()
- self._rbuf = StringIO(prefix)
- return self._rbuf
+ def __init__(self, ctx, endpoint, sock_type):
+ self._sock = ctx.socket(sock_type)
+ self._endpoint = endpoint
+ self._wbuf = StringIO()
+ self._rbuf = StringIO()
+
+ def open(self):
+ self._sock.connect(self._endpoint)
+
+ def read(self, size):
+ ret = self._rbuf.read(size)
+ if len(ret) != 0:
+ return ret
+ self._read_message()
+ return self._rbuf.read(size)
+
+ def _read_message(self):
+ msg = self._sock.recv()
+ self._rbuf = StringIO(msg)
+
+ def write(self, buf):
+ self._wbuf.write(buf)
+
+ def flush(self):
+ msg = self._wbuf.getvalue()
+ self._wbuf = StringIO()
+ self._sock.send(msg)
+
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self._rbuf
+
+ # NOTE: This will probably not actually work.
+ def cstringio_refill(self, prefix, reqlen):
+ while len(prefix) < reqlen:
+ self.read_message()
+ prefix += self._rbuf.getvalue()
+ self._rbuf = StringIO(prefix)
+ return self._rbuf
diff --git a/contrib/zeromq/TZmqServer.py b/contrib/zeromq/TZmqServer.py
index c83cc8d5d..15c1543ac 100644
--- a/contrib/zeromq/TZmqServer.py
+++ b/contrib/zeromq/TZmqServer.py
@@ -21,58 +21,59 @@ import zmq
import thrift.server.TServer
import thrift.transport.TTransport
+
class TZmqServer(thrift.server.TServer.TServer):
- def __init__(self, processor, ctx, endpoint, sock_type):
- thrift.server.TServer.TServer.__init__(self, processor, None)
- self.zmq_type = sock_type
- self.socket = ctx.socket(sock_type)
- self.socket.bind(endpoint)
+ def __init__(self, processor, ctx, endpoint, sock_type):
+ thrift.server.TServer.TServer.__init__(self, processor, None)
+ self.zmq_type = sock_type
+ self.socket = ctx.socket(sock_type)
+ self.socket.bind(endpoint)
- def serveOne(self):
- msg = self.socket.recv()
- itrans = thrift.transport.TTransport.TMemoryBuffer(msg)
- otrans = thrift.transport.TTransport.TMemoryBuffer()
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+ def serveOne(self):
+ msg = self.socket.recv()
+ itrans = thrift.transport.TTransport.TMemoryBuffer(msg)
+ otrans = thrift.transport.TTransport.TMemoryBuffer()
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
- try:
- self.processor.process(iprot, oprot)
- except Exception:
- logging.exception("Exception while processing request")
- # Fall through and send back a response, even if empty or incomplete.
+ try:
+ self.processor.process(iprot, oprot)
+ except Exception:
+ logging.exception("Exception while processing request")
+ # Fall through and send back a response, even if empty or incomplete.
- if self.zmq_type == zmq.REP:
- msg = otrans.getvalue()
- self.socket.send(msg)
+ if self.zmq_type == zmq.REP:
+ msg = otrans.getvalue()
+ self.socket.send(msg)
- def serve(self):
- while True:
- self.serveOne()
+ def serve(self):
+ while True:
+ self.serveOne()
class TZmqMultiServer(object):
- def __init__(self):
- self.servers = []
+ def __init__(self):
+ self.servers = []
- def serveOne(self, timeout = -1):
- self._serveActive(self._setupPoll(), timeout)
+ def serveOne(self, timeout=-1):
+ self._serveActive(self._setupPoll(), timeout)
- def serveForever(self):
- poll_info = self._setupPoll()
- while True:
- self._serveActive(poll_info, -1)
+ def serveForever(self):
+ poll_info = self._setupPoll()
+ while True:
+ self._serveActive(poll_info, -1)
- def _setupPoll(self):
- server_map = {}
- poller = zmq.Poller()
- for server in self.servers:
- server_map[server.socket] = server
- poller.register(server.socket, zmq.POLLIN)
- return (server_map, poller)
+ def _setupPoll(self):
+ server_map = {}
+ poller = zmq.Poller()
+ for server in self.servers:
+ server_map[server.socket] = server
+ poller.register(server.socket, zmq.POLLIN)
+ return (server_map, poller)
- def _serveActive(self, poll_info, timeout):
- (server_map, poller) = poll_info
- ready = dict(poller.poll())
- for sock, state in ready.items():
- assert (state & zmq.POLLIN) != 0
- server_map[sock].serveOne()
+ def _serveActive(self, poll_info, timeout):
+ (server_map, poller) = poll_info
+ ready = dict(poller.poll())
+ for sock, state in ready.items():
+ assert (state & zmq.POLLIN) != 0
+ server_map[sock].serveOne()
diff --git a/contrib/zeromq/test-client.py b/contrib/zeromq/test-client.py
index 1886d9cab..753b132d8 100755
--- a/contrib/zeromq/test-client.py
+++ b/contrib/zeromq/test-client.py
@@ -9,28 +9,28 @@ import storage.Storage
def main(args):
- endpoint = "tcp://127.0.0.1:9090"
- socktype = zmq.REQ
- incr = 0
- if len(args) > 1:
- incr = int(args[1])
- if incr:
- socktype = zmq.DOWNSTREAM
- endpoint = "tcp://127.0.0.1:9091"
+ endpoint = "tcp://127.0.0.1:9090"
+ socktype = zmq.REQ
+ incr = 0
+ if len(args) > 1:
+ incr = int(args[1])
+ if incr:
+ socktype = zmq.DOWNSTREAM
+ endpoint = "tcp://127.0.0.1:9091"
- ctx = zmq.Context()
- transport = TZmqClient.TZmqClient(ctx, endpoint, socktype)
- protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocolAccelerated(transport)
- client = storage.Storage.Client(protocol)
- transport.open()
+ ctx = zmq.Context()
+ transport = TZmqClient.TZmqClient(ctx, endpoint, socktype)
+ protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocolAccelerated(transport)
+ client = storage.Storage.Client(protocol)
+ transport.open()
- if incr:
- client.incr(incr)
- time.sleep(0.05)
- else:
- value = client.get()
- print value
+ if incr:
+ client.incr(incr)
+ time.sleep(0.05)
+ else:
+ value = client.get()
+ print value
if __name__ == "__main__":
- main(sys.argv)
+ main(sys.argv)
diff --git a/contrib/zeromq/test-server.py b/contrib/zeromq/test-server.py
index 5767b71fe..c7804d317 100755
--- a/contrib/zeromq/test-server.py
+++ b/contrib/zeromq/test-server.py
@@ -6,28 +6,28 @@ import storage.Storage
class StorageHandler(storage.Storage.Iface):
- def __init__(self):
- self.value = 0
+ def __init__(self):
+ self.value = 0
- def incr(self, amount):
- self.value += amount
+ def incr(self, amount):
+ self.value += amount
- def get(self):
- return self.value
+ def get(self):
+ return self.value
def main():
- handler = StorageHandler()
- processor = storage.Storage.Processor(handler)
+ handler = StorageHandler()
+ processor = storage.Storage.Processor(handler)
- ctx = zmq.Context()
- reqrep_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9090", zmq.REP)
- oneway_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9091", zmq.UPSTREAM)
- multiserver = TZmqServer.TZmqMultiServer()
- multiserver.servers.append(reqrep_server)
- multiserver.servers.append(oneway_server)
- multiserver.serveForever()
+ ctx = zmq.Context()
+ reqrep_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9090", zmq.REP)
+ oneway_server = TZmqServer.TZmqServer(processor, ctx, "tcp://0.0.0.0:9091", zmq.UPSTREAM)
+ multiserver = TZmqServer.TZmqMultiServer()
+ multiserver.servers.append(reqrep_server)
+ multiserver.servers.append(oneway_server)
+ multiserver.serveForever()
if __name__ == "__main__":
- main()
+ main()
diff --git a/lib/py/setup.py b/lib/py/setup.py
index 090544ce9..f57c1a131 100644
--- a/lib/py/setup.py
+++ b/lib/py/setup.py
@@ -9,7 +9,7 @@
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
@@ -24,7 +24,7 @@ import sys
try:
from setuptools import setup, Extension
except:
- from distutils.core import setup, Extension, Command
+ from distutils.core import setup, Extension
from distutils.command.build_ext import build_ext
from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError
@@ -41,63 +41,66 @@ if sys.platform == 'win32':
else:
ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError)
+
class BuildFailed(Exception):
pass
+
class ve_build_ext(build_ext):
def run(self):
try:
build_ext.run(self)
- except DistutilsPlatformError as x:
+ except DistutilsPlatformError:
raise BuildFailed()
def build_extension(self, ext):
try:
build_ext.build_extension(self, ext)
- except ext_errors as x:
+ except ext_errors:
raise BuildFailed()
+
def run_setup(with_binary):
if with_binary:
extensions = dict(
- ext_modules = [
- Extension('thrift.protocol.fastbinary',
- sources = ['src/protocol/fastbinary.c'],
- include_dirs = include_dirs,
- )
+ ext_modules=[
+ Extension('thrift.protocol.fastbinary',
+ sources=['src/protocol/fastbinary.c'],
+ include_dirs=include_dirs,
+ )
],
cmdclass=dict(build_ext=ve_build_ext)
)
else:
extensions = dict()
- setup(name = 'thrift',
- version = '1.0.0-dev',
- description = 'Python bindings for the Apache Thrift RPC system',
- author = 'Thrift Developers',
- author_email = 'dev@thrift.apache.org',
- url = 'http://thrift.apache.org',
- license = 'Apache License 2.0',
- install_requires=['six>=1.7.2'],
- packages = [
- 'thrift',
- 'thrift.protocol',
- 'thrift.transport',
- 'thrift.server',
- ],
- package_dir = {'thrift' : 'src'},
- classifiers = [
- 'Development Status :: 5 - Production/Stable',
- 'Environment :: Console',
- 'Intended Audience :: Developers',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 3',
- 'Topic :: Software Development :: Libraries',
- 'Topic :: System :: Networking'
- ],
- **extensions
- )
+ setup(name='thrift',
+ version='1.0.0-dev',
+ description='Python bindings for the Apache Thrift RPC system',
+ author='Thrift Developers',
+ author_email='dev@thrift.apache.org',
+ url='http://thrift.apache.org',
+ license='Apache License 2.0',
+ install_requires=['six>=1.7.2'],
+ packages=[
+ 'thrift',
+ 'thrift.protocol',
+ 'thrift.transport',
+ 'thrift.server',
+ ],
+ package_dir={'thrift': 'src'},
+ classifiers=[
+ 'Development Status :: 5 - Production/Stable',
+ 'Environment :: Console',
+ 'Intended Audience :: Developers',
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 2',
+ 'Programming Language :: Python :: 3',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: System :: Networking'
+ ],
+ **extensions
+ )
try:
with_binary = False
diff --git a/lib/py/src/TMultiplexedProcessor.py b/lib/py/src/TMultiplexedProcessor.py
index a8d5565c3..581214b31 100644
--- a/lib/py/src/TMultiplexedProcessor.py
+++ b/lib/py/src/TMultiplexedProcessor.py
@@ -20,39 +20,36 @@
from thrift.Thrift import TProcessor, TMessageType, TException
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
+
class TMultiplexedProcessor(TProcessor):
- def __init__(self):
- self.services = {}
+ def __init__(self):
+ self.services = {}
- def registerProcessor(self, serviceName, processor):
- self.services[serviceName] = processor
+ def registerProcessor(self, serviceName, processor):
+ self.services[serviceName] = processor
- def process(self, iprot, oprot):
- (name, type, seqid) = iprot.readMessageBegin();
- if type != TMessageType.CALL & type != TMessageType.ONEWAY:
- raise TException("TMultiplex protocol only supports CALL & ONEWAY")
+ def process(self, iprot, oprot):
+ (name, type, seqid) = iprot.readMessageBegin()
+ if type != TMessageType.CALL & type != TMessageType.ONEWAY:
+ raise TException("TMultiplex protocol only supports CALL & ONEWAY")
- index = name.find(TMultiplexedProtocol.SEPARATOR)
- if index < 0:
- raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
+ index = name.find(TMultiplexedProtocol.SEPARATOR)
+ if index < 0:
+ raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
- serviceName = name[0:index]
- call = name[index+len(TMultiplexedProtocol.SEPARATOR):]
- if not serviceName in self.services:
- raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
+ serviceName = name[0:index]
+ call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
+ if serviceName not in self.services:
+ raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
- standardMessage = (
- call,
- type,
- seqid
- )
- return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
+ standardMessage = (call, type, seqid)
+ return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
- def __init__(self, protocol, messageBegin):
- TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
- self.messageBegin = messageBegin
+ def __init__(self, protocol, messageBegin):
+ TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
+ self.messageBegin = messageBegin
- def readMessageBegin(self):
- return self.messageBegin
+ def readMessageBegin(self):
+ return self.messageBegin
diff --git a/lib/py/src/TSCons.py b/lib/py/src/TSCons.py
index ed2601a7d..bc67d7069 100644
--- a/lib/py/src/TSCons.py
+++ b/lib/py/src/TSCons.py
@@ -20,18 +20,17 @@
from os import path
from SCons.Builder import Builder
from six.moves import map
-from six.moves import zip
def scons_env(env, add=''):
- opath = path.dirname(path.abspath('$TARGET'))
- lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
- cppbuild = Builder(action=lstr)
- env.Append(BUILDERS={'ThriftCpp': cppbuild})
+ opath = path.dirname(path.abspath('$TARGET'))
+ lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
+ cppbuild = Builder(action=lstr)
+ env.Append(BUILDERS={'ThriftCpp': cppbuild})
def gen_cpp(env, dir, file):
- scons_env(env)
- suffixes = ['_types.h', '_types.cpp']
- targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
- return env.ThriftCpp(targets, dir + file + '.thrift')
+ scons_env(env)
+ suffixes = ['_types.h', '_types.cpp']
+ targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
+ return env.ThriftCpp(targets, dir + file + '.thrift')
diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py
index e3b4df7b2..e01a49f25 100644
--- a/lib/py/src/TTornado.py
+++ b/lib/py/src/TTornado.py
@@ -18,10 +18,9 @@
#
from __future__ import absolute_import
+import logging
import socket
import struct
-import logging
-logger = logging.getLogger(__name__)
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
@@ -32,6 +31,8 @@ from tornado import gen, iostream, ioloop, tcpserver, concurrent
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
+logger = logging.getLogger(__name__)
+
class _Lock(object):
def __init__(self):
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 11ee79625..c4dabdca0 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -21,170 +21,172 @@ import sys
class TType(object):
- STOP = 0
- VOID = 1
- BOOL = 2
- BYTE = 3
- I08 = 3
- DOUBLE = 4
- I16 = 6
- I32 = 8
- I64 = 10
- STRING = 11
- UTF7 = 11
- STRUCT = 12
- MAP = 13
- SET = 14
- LIST = 15
- UTF8 = 16
- UTF16 = 17
-
- _VALUES_TO_NAMES = ('STOP',
- 'VOID',
- 'BOOL',
- 'BYTE',
- 'DOUBLE',
- None,
- 'I16',
- None,
- 'I32',
- None,
- 'I64',
- 'STRING',
- 'STRUCT',
- 'MAP',
- 'SET',
- 'LIST',
- 'UTF8',
- 'UTF16')
+ STOP = 0
+ VOID = 1
+ BOOL = 2
+ BYTE = 3
+ I08 = 3
+ DOUBLE = 4
+ I16 = 6
+ I32 = 8
+ I64 = 10
+ STRING = 11
+ UTF7 = 11
+ STRUCT = 12
+ MAP = 13
+ SET = 14
+ LIST = 15
+ UTF8 = 16
+ UTF16 = 17
+
+ _VALUES_TO_NAMES = (
+ 'STOP',
+ 'VOID',
+ 'BOOL',
+ 'BYTE',
+ 'DOUBLE',
+ None,
+ 'I16',
+ None,
+ 'I32',
+ None,
+ 'I64',
+ 'STRING',
+ 'STRUCT',
+ 'MAP',
+ 'SET',
+ 'LIST',
+ 'UTF8',
+ 'UTF16',
+ )
class TMessageType(object):
- CALL = 1
- REPLY = 2
- EXCEPTION = 3
- ONEWAY = 4
+ CALL = 1
+ REPLY = 2
+ EXCEPTION = 3
+ ONEWAY = 4
class TProcessor(object):
- """Base class for procsessor, which works on two streams."""
+ """Base class for procsessor, which works on two streams."""
- def process(iprot, oprot):
- pass
+ def process(iprot, oprot):
+ pass
class TException(Exception):
- """Base class for all thrift exceptions."""
+ """Base class for all thrift exceptions."""
- # BaseException.message is deprecated in Python v[2.6,3.0)
- if (2, 6, 0) <= sys.version_info < (3, 0):
- def _get_message(self):
- return self._message
+ # BaseException.message is deprecated in Python v[2.6,3.0)
+ if (2, 6, 0) <= sys.version_info < (3, 0):
+ def _get_message(self):
+ return self._message
- def _set_message(self, message):
- self._message = message
- message = property(_get_message, _set_message)
+ def _set_message(self, message):
+ self._message = message
+ message = property(_get_message, _set_message)
- def __init__(self, message=None):
- Exception.__init__(self, message)
- self.message = message
+ def __init__(self, message=None):
+ Exception.__init__(self, message)
+ self.message = message
class TApplicationException(TException):
- """Application level thrift exceptions."""
-
- UNKNOWN = 0
- UNKNOWN_METHOD = 1
- INVALID_MESSAGE_TYPE = 2
- WRONG_METHOD_NAME = 3
- BAD_SEQUENCE_ID = 4
- MISSING_RESULT = 5
- INTERNAL_ERROR = 6
- PROTOCOL_ERROR = 7
- INVALID_TRANSFORM = 8
- INVALID_PROTOCOL = 9
- UNSUPPORTED_CLIENT_TYPE = 10
-
- def __init__(self, type=UNKNOWN, message=None):
- TException.__init__(self, message)
- self.type = type
-
- def __str__(self):
- if self.message:
- return self.message
- elif self.type == self.UNKNOWN_METHOD:
- return 'Unknown method'
- elif self.type == self.INVALID_MESSAGE_TYPE:
- return 'Invalid message type'
- elif self.type == self.WRONG_METHOD_NAME:
- return 'Wrong method name'
- elif self.type == self.BAD_SEQUENCE_ID:
- return 'Bad sequence ID'
- elif self.type == self.MISSING_RESULT:
- return 'Missing result'
- elif self.type == self.INTERNAL_ERROR:
- return 'Internal error'
- elif self.type == self.PROTOCOL_ERROR:
- return 'Protocol error'
- elif self.type == self.INVALID_TRANSFORM:
- return 'Invalid transform'
- elif self.type == self.INVALID_PROTOCOL:
- return 'Invalid protocol'
- elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
- return 'Unsupported client type'
- else:
- return 'Default (unknown) TApplicationException'
-
- def read(self, iprot):
- iprot.readStructBegin()
- while True:
- (fname, ftype, fid) = iprot.readFieldBegin()
- if ftype == TType.STOP:
- break
- if fid == 1:
- if ftype == TType.STRING:
- self.message = iprot.readString()
+ """Application level thrift exceptions."""
+
+ UNKNOWN = 0
+ UNKNOWN_METHOD = 1
+ INVALID_MESSAGE_TYPE = 2
+ WRONG_METHOD_NAME = 3
+ BAD_SEQUENCE_ID = 4
+ MISSING_RESULT = 5
+ INTERNAL_ERROR = 6
+ PROTOCOL_ERROR = 7
+ INVALID_TRANSFORM = 8
+ INVALID_PROTOCOL = 9
+ UNSUPPORTED_CLIENT_TYPE = 10
+
+ def __init__(self, type=UNKNOWN, message=None):
+ TException.__init__(self, message)
+ self.type = type
+
+ def __str__(self):
+ if self.message:
+ return self.message
+ elif self.type == self.UNKNOWN_METHOD:
+ return 'Unknown method'
+ elif self.type == self.INVALID_MESSAGE_TYPE:
+ return 'Invalid message type'
+ elif self.type == self.WRONG_METHOD_NAME:
+ return 'Wrong method name'
+ elif self.type == self.BAD_SEQUENCE_ID:
+ return 'Bad sequence ID'
+ elif self.type == self.MISSING_RESULT:
+ return 'Missing result'
+ elif self.type == self.INTERNAL_ERROR:
+ return 'Internal error'
+ elif self.type == self.PROTOCOL_ERROR:
+ return 'Protocol error'
+ elif self.type == self.INVALID_TRANSFORM:
+ return 'Invalid transform'
+ elif self.type == self.INVALID_PROTOCOL:
+ return 'Invalid protocol'
+ elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
+ return 'Unsupported client type'
else:
- iprot.skip(ftype)
- elif fid == 2:
- if ftype == TType.I32:
- self.type = iprot.readI32()
- else:
- iprot.skip(ftype)
- else:
- iprot.skip(ftype)
- iprot.readFieldEnd()
- iprot.readStructEnd()
-
- def write(self, oprot):
- oprot.writeStructBegin('TApplicationException')
- if self.message is not None:
- oprot.writeFieldBegin('message', TType.STRING, 1)
- oprot.writeString(self.message)
- oprot.writeFieldEnd()
- if self.type is not None:
- oprot.writeFieldBegin('type', TType.I32, 2)
- oprot.writeI32(self.type)
- oprot.writeFieldEnd()
- oprot.writeFieldStop()
- oprot.writeStructEnd()
+ return 'Default (unknown) TApplicationException'
+
+ def read(self, iprot):
+ iprot.readStructBegin()
+ while True:
+ (fname, ftype, fid) = iprot.readFieldBegin()
+ if ftype == TType.STOP:
+ break
+ if fid == 1:
+ if ftype == TType.STRING:
+ self.message = iprot.readString()
+ else:
+ iprot.skip(ftype)
+ elif fid == 2:
+ if ftype == TType.I32:
+ self.type = iprot.readI32()
+ else:
+ iprot.skip(ftype)
+ else:
+ iprot.skip(ftype)
+ iprot.readFieldEnd()
+ iprot.readStructEnd()
+
+ def write(self, oprot):
+ oprot.writeStructBegin('TApplicationException')
+ if self.message is not None:
+ oprot.writeFieldBegin('message', TType.STRING, 1)
+ oprot.writeString(self.message)
+ oprot.writeFieldEnd()
+ if self.type is not None:
+ oprot.writeFieldBegin('type', TType.I32, 2)
+ oprot.writeI32(self.type)
+ oprot.writeFieldEnd()
+ oprot.writeFieldStop()
+ oprot.writeStructEnd()
class TFrozenDict(dict):
- """A dictionary that is "frozen" like a frozenset"""
+ """A dictionary that is "frozen" like a frozenset"""
- def __init__(self, *args, **kwargs):
- super(TFrozenDict, self).__init__(*args, **kwargs)
- # Sort the items so they will be in a consistent order.
- # XOR in the hash of the class so we don't collide with
- # the hash of a list of tuples.
- self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
+ def __init__(self, *args, **kwargs):
+ super(TFrozenDict, self).__init__(*args, **kwargs)
+ # Sort the items so they will be in a consistent order.
+ # XOR in the hash of the class so we don't collide with
+ # the hash of a list of tuples.
+ self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
- def __setitem__(self, *args):
- raise TypeError("Can't modify frozen TFreezableDict")
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
- def __delitem__(self, *args):
- raise TypeError("Can't modify frozen TFreezableDict")
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
- def __hash__(self):
- return self.__hashval
+ def __hash__(self):
+ return self.__hashval
diff --git a/lib/py/src/compat.py b/lib/py/src/compat.py
index 06f672ae6..42403eae8 100644
--- a/lib/py/src/compat.py
+++ b/lib/py/src/compat.py
@@ -1,27 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
import sys
if sys.version_info[0] == 2:
- from cStringIO import StringIO as BufferIO
+ from cStringIO import StringIO as BufferIO
- def binary_to_str(bin_val):
- return bin_val
+ def binary_to_str(bin_val):
+ return bin_val
- def str_to_binary(str_val):
- return str_val
+ def str_to_binary(str_val):
+ return str_val
else:
- from io import BytesIO as BufferIO
+ from io import BytesIO as BufferIO
- def binary_to_str(bin_val):
- try:
- return bin_val.decode('utf8')
- except:
- return bin_val
+ def binary_to_str(bin_val):
+ try:
+ return bin_val.decode('utf8')
+ except:
+ return bin_val
- def str_to_binary(str_val):
- try:
- return bytes(str_val, 'utf8')
- except:
- return str_val
+ def str_to_binary(str_val):
+ try:
+ return bytes(str_val, 'utf8')
+ except:
+ return str_val
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index d106f4e03..87caf0d16 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -21,78 +21,79 @@ from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport
try:
- from thrift.protocol import fastbinary
+ from thrift.protocol import fastbinary
except:
- fastbinary = None
+ fastbinary = None
class TBase(object):
- __slots__ = ()
-
- def __repr__(self):
- L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
- return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
-
- def __eq__(self, other):
- if not isinstance(other, self.__class__):
- return False
- for attr in self.__slots__:
- my_val = getattr(self, attr)
- other_val = getattr(other, attr)
- if my_val != other_val:
- return False
- return True
-
- def __ne__(self, other):
- return not (self == other)
-
- def read(self, iprot):
- if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- isinstance(iprot.trans, TTransport.CReadableTransport) and
- self.thrift_spec is not None and
- fastbinary is not None):
- fastbinary.decode_binary(self,
- iprot.trans,
- (self.__class__, self.thrift_spec),
- iprot.string_length_limit,
- iprot.container_length_limit)
- return
- iprot.readStruct(self, self.thrift_spec)
-
- def write(self, oprot):
- if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- self.thrift_spec is not None and
- fastbinary is not None):
- oprot.trans.write(
- fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
- return
- oprot.writeStruct(self, self.thrift_spec)
+ __slots__ = ()
+
+ def __repr__(self):
+ L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
+ return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return False
+ for attr in self.__slots__:
+ my_val = getattr(self, attr)
+ other_val = getattr(other, attr)
+ if my_val != other_val:
+ return False
+ return True
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ def read(self, iprot):
+ if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ self.thrift_spec is not None and
+ fastbinary is not None):
+ fastbinary.decode_binary(self,
+ iprot.trans,
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
+ return
+ iprot.readStruct(self, self.thrift_spec)
+
+ def write(self, oprot):
+ if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ self.thrift_spec is not None and
+ fastbinary is not None):
+ oprot.trans.write(
+ fastbinary.encode_binary(
+ self, (self.__class__, self.thrift_spec)))
+ return
+ oprot.writeStruct(self, self.thrift_spec)
class TExceptionBase(TBase, Exception):
- pass
+ pass
class TFrozenBase(TBase):
- def __setitem__(self, *args):
- raise TypeError("Can't modify frozen struct")
-
- def __delitem__(self, *args):
- raise TypeError("Can't modify frozen struct")
-
- def __hash__(self, *args):
- return hash(self.__class__) ^ hash(self.__slots__)
-
- @classmethod
- def read(cls, iprot):
- if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- isinstance(iprot.trans, TTransport.CReadableTransport) and
- cls.thrift_spec is not None and
- fastbinary is not None):
- self = cls()
- return fastbinary.decode_binary(None,
- iprot.trans,
- (self.__class__, self.thrift_spec),
- iprot.string_length_limit,
- iprot.container_length_limit)
- return iprot.readStruct(cls, cls.thrift_spec, True)
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
+
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
+
+ def __hash__(self, *args):
+ return hash(self.__class__) ^ hash(self.__slots__)
+
+ @classmethod
+ def read(cls, iprot):
+ if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ cls.thrift_spec is not None and
+ fastbinary is not None):
+ self = cls()
+ return fastbinary.decode_binary(None,
+ iprot.trans,
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
+ return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index db4ea3182..7fce12f07 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -22,264 +22,264 @@ from struct import pack, unpack
class TBinaryProtocol(TProtocolBase):
- """Binary implementation of the Thrift protocol driver."""
+ """Binary implementation of the Thrift protocol driver."""
- # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
- # positive, converting this into a long. If we hardcode the int value
- # instead it'll stay in 32 bit-land.
+ # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
+ # positive, converting this into a long. If we hardcode the int value
+ # instead it'll stay in 32 bit-land.
- # VERSION_MASK = 0xffff0000
- VERSION_MASK = -65536
+ # VERSION_MASK = 0xffff0000
+ VERSION_MASK = -65536
- # VERSION_1 = 0x80010000
- VERSION_1 = -2147418112
+ # VERSION_1 = 0x80010000
+ VERSION_1 = -2147418112
- TYPE_MASK = 0x000000ff
+ TYPE_MASK = 0x000000ff
- def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
- TProtocolBase.__init__(self, trans)
- self.strictRead = strictRead
- self.strictWrite = strictWrite
- self.string_length_limit = kwargs.get('string_length_limit', None)
- self.container_length_limit = kwargs.get('container_length_limit', None)
-
- def _check_string_length(self, length):
- self._check_length(self.string_length_limit, length)
-
- def _check_container_length(self, length):
- self._check_length(self.container_length_limit, length)
-
- def writeMessageBegin(self, name, type, seqid):
- if self.strictWrite:
- self.writeI32(TBinaryProtocol.VERSION_1 | type)
- self.writeString(name)
- self.writeI32(seqid)
- else:
- self.writeString(name)
- self.writeByte(type)
- self.writeI32(seqid)
-
- def writeMessageEnd(self):
- pass
-
- def writeStructBegin(self, name):
- pass
-
- def writeStructEnd(self):
- pass
-
- def writeFieldBegin(self, name, type, id):
- self.writeByte(type)
- self.writeI16(id)
-
- def writeFieldEnd(self):
- pass
-
- def writeFieldStop(self):
- self.writeByte(TType.STOP)
-
- def writeMapBegin(self, ktype, vtype, size):
- self.writeByte(ktype)
- self.writeByte(vtype)
- self.writeI32(size)
-
- def writeMapEnd(self):
- pass
-
- def writeListBegin(self, etype, size):
- self.writeByte(etype)
- self.writeI32(size)
-
- def writeListEnd(self):
- pass
-
- def writeSetBegin(self, etype, size):
- self.writeByte(etype)
- self.writeI32(size)
-
- def writeSetEnd(self):
- pass
-
- def writeBool(self, bool):
- if bool:
- self.writeByte(1)
- else:
- self.writeByte(0)
-
- def writeByte(self, byte):
- buff = pack("!b", byte)
- self.trans.write(buff)
-
- def writeI16(self, i16):
- buff = pack("!h", i16)
- self.trans.write(buff)
-
- def writeI32(self, i32):
- buff = pack("!i", i32)
- self.trans.write(buff)
-
- def writeI64(self, i64):
- buff = pack("!q", i64)
- self.trans.write(buff)
-
- def writeDouble(self, dub):
- buff = pack("!d", dub)
- self.trans.write(buff)
-
- def writeBinary(self, str):
- self.writeI32(len(str))
- self.trans.write(str)
-
- def readMessageBegin(self):
- sz = self.readI32()
- if sz < 0:
- version = sz & TBinaryProtocol.VERSION_MASK
- if version != TBinaryProtocol.VERSION_1:
- raise TProtocolException(
- type=TProtocolException.BAD_VERSION,
- message='Bad version in readMessageBegin: %d' % (sz))
- type = sz & TBinaryProtocol.TYPE_MASK
- name = self.readString()
- seqid = self.readI32()
- else:
- if self.strictRead:
- raise TProtocolException(type=TProtocolException.BAD_VERSION,
- message='No protocol version header')
- name = self.trans.readAll(sz)
- type = self.readByte()
- seqid = self.readI32()
- return (name, type, seqid)
-
- def readMessageEnd(self):
- pass
-
- def readStructBegin(self):
- pass
-
- def readStructEnd(self):
- pass
-
- def readFieldBegin(self):
- type = self.readByte()
- if type == TType.STOP:
- return (None, type, 0)
- id = self.readI16()
- return (None, type, id)
-
- def readFieldEnd(self):
- pass
-
- def readMapBegin(self):
- ktype = self.readByte()
- vtype = self.readByte()
- size = self.readI32()
- self._check_container_length(size)
- return (ktype, vtype, size)
-
- def readMapEnd(self):
- pass
-
- def readListBegin(self):
- etype = self.readByte()
- size = self.readI32()
- self._check_container_length(size)
- return (etype, size)
-
- def readListEnd(self):
- pass
-
- def readSetBegin(self):
- etype = self.readByte()
- size = self.readI32()
- self._check_container_length(size)
- return (etype, size)
-
- def readSetEnd(self):
- pass
-
- def readBool(self):
- byte = self.readByte()
- if byte == 0:
- return False
- return True
-
- def readByte(self):
- buff = self.trans.readAll(1)
- val, = unpack('!b', buff)
- return val
-
- def readI16(self):
- buff = self.trans.readAll(2)
- val, = unpack('!h', buff)
- return val
-
- def readI32(self):
- buff = self.trans.readAll(4)
- val, = unpack('!i', buff)
- return val
-
- def readI64(self):
- buff = self.trans.readAll(8)
- val, = unpack('!q', buff)
- return val
-
- def readDouble(self):
- buff = self.trans.readAll(8)
- val, = unpack('!d', buff)
- return val
-
- def readBinary(self):
- size = self.readI32()
- self._check_string_length(size)
- s = self.trans.readAll(size)
- return s
+ def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
+ TProtocolBase.__init__(self, trans)
+ self.strictRead = strictRead
+ self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
+
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
+
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
+
+ def writeMessageBegin(self, name, type, seqid):
+ if self.strictWrite:
+ self.writeI32(TBinaryProtocol.VERSION_1 | type)
+ self.writeString(name)
+ self.writeI32(seqid)
+ else:
+ self.writeString(name)
+ self.writeByte(type)
+ self.writeI32(seqid)
+
+ def writeMessageEnd(self):
+ pass
+
+ def writeStructBegin(self, name):
+ pass
+
+ def writeStructEnd(self):
+ pass
+
+ def writeFieldBegin(self, name, type, id):
+ self.writeByte(type)
+ self.writeI16(id)
+
+ def writeFieldEnd(self):
+ pass
+
+ def writeFieldStop(self):
+ self.writeByte(TType.STOP)
+
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeByte(ktype)
+ self.writeByte(vtype)
+ self.writeI32(size)
+
+ def writeMapEnd(self):
+ pass
+
+ def writeListBegin(self, etype, size):
+ self.writeByte(etype)
+ self.writeI32(size)
+
+ def writeListEnd(self):
+ pass
+
+ def writeSetBegin(self, etype, size):
+ self.writeByte(etype)
+ self.writeI32(size)
+
+ def writeSetEnd(self):
+ pass
+
+ def writeBool(self, bool):
+ if bool:
+ self.writeByte(1)
+ else:
+ self.writeByte(0)
+
+ def writeByte(self, byte):
+ buff = pack("!b", byte)
+ self.trans.write(buff)
+
+ def writeI16(self, i16):
+ buff = pack("!h", i16)
+ self.trans.write(buff)
+
+ def writeI32(self, i32):
+ buff = pack("!i", i32)
+ self.trans.write(buff)
+
+ def writeI64(self, i64):
+ buff = pack("!q", i64)
+ self.trans.write(buff)
+
+ def writeDouble(self, dub):
+ buff = pack("!d", dub)
+ self.trans.write(buff)
+
+ def writeBinary(self, str):
+ self.writeI32(len(str))
+ self.trans.write(str)
+
+ def readMessageBegin(self):
+ sz = self.readI32()
+ if sz < 0:
+ version = sz & TBinaryProtocol.VERSION_MASK
+ if version != TBinaryProtocol.VERSION_1:
+ raise TProtocolException(
+ type=TProtocolException.BAD_VERSION,
+ message='Bad version in readMessageBegin: %d' % (sz))
+ type = sz & TBinaryProtocol.TYPE_MASK
+ name = self.readString()
+ seqid = self.readI32()
+ else:
+ if self.strictRead:
+ raise TProtocolException(type=TProtocolException.BAD_VERSION,
+ message='No protocol version header')
+ name = self.trans.readAll(sz)
+ type = self.readByte()
+ seqid = self.readI32()
+ return (name, type, seqid)
+
+ def readMessageEnd(self):
+ pass
+
+ def readStructBegin(self):
+ pass
+
+ def readStructEnd(self):
+ pass
+
+ def readFieldBegin(self):
+ type = self.readByte()
+ if type == TType.STOP:
+ return (None, type, 0)
+ id = self.readI16()
+ return (None, type, id)
+
+ def readFieldEnd(self):
+ pass
+
+ def readMapBegin(self):
+ ktype = self.readByte()
+ vtype = self.readByte()
+ size = self.readI32()
+ self._check_container_length(size)
+ return (ktype, vtype, size)
+
+ def readMapEnd(self):
+ pass
+
+ def readListBegin(self):
+ etype = self.readByte()
+ size = self.readI32()
+ self._check_container_length(size)
+ return (etype, size)
+
+ def readListEnd(self):
+ pass
+
+ def readSetBegin(self):
+ etype = self.readByte()
+ size = self.readI32()
+ self._check_container_length(size)
+ return (etype, size)
+
+ def readSetEnd(self):
+ pass
+
+ def readBool(self):
+ byte = self.readByte()
+ if byte == 0:
+ return False
+ return True
+
+ def readByte(self):
+ buff = self.trans.readAll(1)
+ val, = unpack('!b', buff)
+ return val
+
+ def readI16(self):
+ buff = self.trans.readAll(2)
+ val, = unpack('!h', buff)
+ return val
+
+ def readI32(self):
+ buff = self.trans.readAll(4)
+ val, = unpack('!i', buff)
+ return val
+
+ def readI64(self):
+ buff = self.trans.readAll(8)
+ val, = unpack('!q', buff)
+ return val
+
+ def readDouble(self):
+ buff = self.trans.readAll(8)
+ val, = unpack('!d', buff)
+ return val
+
+ def readBinary(self):
+ size = self.readI32()
+ self._check_string_length(size)
+ s = self.trans.readAll(size)
+ return s
class TBinaryProtocolFactory(object):
- def __init__(self, strictRead=False, strictWrite=True, **kwargs):
- self.strictRead = strictRead
- self.strictWrite = strictWrite
- self.string_length_limit = kwargs.get('string_length_limit', None)
- self.container_length_limit = kwargs.get('container_length_limit', None)
+ def __init__(self, strictRead=False, strictWrite=True, **kwargs):
+ self.strictRead = strictRead
+ self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
- def getProtocol(self, trans):
- prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
- string_length_limit=self.string_length_limit,
- container_length_limit=self.container_length_limit)
- return prot
+ def getProtocol(self, trans):
+ prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
+ return prot
class TBinaryProtocolAccelerated(TBinaryProtocol):
- """C-Accelerated version of TBinaryProtocol.
-
- This class does not override any of TBinaryProtocol's methods,
- but the generated code recognizes it directly and will call into
- our C module to do the encoding, bypassing this object entirely.
- We inherit from TBinaryProtocol so that the normal TBinaryProtocol
- encoding can happen if the fastbinary module doesn't work for some
- reason. (TODO(dreiss): Make this happen sanely in more cases.)
-
- In order to take advantage of the C module, just use
- TBinaryProtocolAccelerated instead of TBinaryProtocol.
-
- NOTE: This code was contributed by an external developer.
- The internal Thrift team has reviewed and tested it,
- but we cannot guarantee that it is production-ready.
- Please feel free to report bugs and/or success stories
- to the public mailing list.
- """
- pass
+ """C-Accelerated version of TBinaryProtocol.
+
+ This class does not override any of TBinaryProtocol's methods,
+ but the generated code recognizes it directly and will call into
+ our C module to do the encoding, bypassing this object entirely.
+ We inherit from TBinaryProtocol so that the normal TBinaryProtocol
+ encoding can happen if the fastbinary module doesn't work for some
+ reason. (TODO(dreiss): Make this happen sanely in more cases.)
+
+ In order to take advantage of the C module, just use
+ TBinaryProtocolAccelerated instead of TBinaryProtocol.
+
+ NOTE: This code was contributed by an external developer.
+ The internal Thrift team has reviewed and tested it,
+ but we cannot guarantee that it is production-ready.
+ Please feel free to report bugs and/or success stories
+ to the public mailing list.
+ """
+ pass
class TBinaryProtocolAcceleratedFactory(object):
- def __init__(self,
- string_length_limit=None,
- container_length_limit=None):
- self.string_length_limit = string_length_limit
- self.container_length_limit = container_length_limit
-
- def getProtocol(self, trans):
- return TBinaryProtocolAccelerated(
- trans,
- string_length_limit=self.string_length_limit,
- container_length_limit=self.container_length_limit)
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
+
+ def getProtocol(self, trans):
+ return TBinaryProtocolAccelerated(
+ trans,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 3d9c0e6e3..8d3db1a9d 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -36,390 +36,391 @@ BOOL_READ = 8
def make_helper(v_from, container):
- def helper(func):
- def nested(self, *args, **kwargs):
- assert self.state in (v_from, container), (self.state, v_from, container)
- return func(self, *args, **kwargs)
- return nested
- return helper
+ def helper(func):
+ def nested(self, *args, **kwargs):
+ assert self.state in (v_from, container), (self.state, v_from, container)
+ return func(self, *args, **kwargs)
+ return nested
+ return helper
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ)
def makeZigZag(n, bits):
- checkIntegerLimits(n, bits)
- return (n << 1) ^ (n >> (bits - 1))
+ checkIntegerLimits(n, bits)
+ return (n << 1) ^ (n >> (bits - 1))
def fromZigZag(n):
- return (n >> 1) ^ -(n & 1)
+ return (n >> 1) ^ -(n & 1)
def writeVarint(trans, n):
- out = bytearray()
- while True:
- if n & ~0x7f == 0:
- out.append(n)
- break
- else:
- out.append((n & 0xff) | 0x80)
- n = n >> 7
- trans.write(bytes(out))
+ out = bytearray()
+ while True:
+ if n & ~0x7f == 0:
+ out.append(n)
+ break
+ else:
+ out.append((n & 0xff) | 0x80)
+ n = n >> 7
+ trans.write(bytes(out))
def readVarint(trans):
- result = 0
- shift = 0
- while True:
- x = trans.readAll(1)
- byte = ord(x)
- result |= (byte & 0x7f) << shift
- if byte >> 7 == 0:
- return result
- shift += 7
+ result = 0
+ shift = 0
+ while True:
+ x = trans.readAll(1)
+ byte = ord(x)
+ result |= (byte & 0x7f) << shift
+ if byte >> 7 == 0:
+ return result
+ shift += 7
class CompactType(object):
- STOP = 0x00
- TRUE = 0x01
- FALSE = 0x02
- BYTE = 0x03
- I16 = 0x04
- I32 = 0x05
- I64 = 0x06
- DOUBLE = 0x07
- BINARY = 0x08
- LIST = 0x09
- SET = 0x0A
- MAP = 0x0B
- STRUCT = 0x0C
-
-CTYPES = {TType.STOP: CompactType.STOP,
- TType.BOOL: CompactType.TRUE, # used for collection
- TType.BYTE: CompactType.BYTE,
- TType.I16: CompactType.I16,
- TType.I32: CompactType.I32,
- TType.I64: CompactType.I64,
- TType.DOUBLE: CompactType.DOUBLE,
- TType.STRING: CompactType.BINARY,
- TType.STRUCT: CompactType.STRUCT,
- TType.LIST: CompactType.LIST,
- TType.SET: CompactType.SET,
- TType.MAP: CompactType.MAP
- }
+ STOP = 0x00
+ TRUE = 0x01
+ FALSE = 0x02
+ BYTE = 0x03
+ I16 = 0x04
+ I32 = 0x05
+ I64 = 0x06
+ DOUBLE = 0x07
+ BINARY = 0x08
+ LIST = 0x09
+ SET = 0x0A
+ MAP = 0x0B
+ STRUCT = 0x0C
+
+CTYPES = {
+ TType.STOP: CompactType.STOP,
+ TType.BOOL: CompactType.TRUE, # used for collection
+ TType.BYTE: CompactType.BYTE,
+ TType.I16: CompactType.I16,
+ TType.I32: CompactType.I32,
+ TType.I64: CompactType.I64,
+ TType.DOUBLE: CompactType.DOUBLE,
+ TType.STRING: CompactType.BINARY,
+ TType.STRUCT: CompactType.STRUCT,
+ TType.LIST: CompactType.LIST,
+ TType.SET: CompactType.SET,
+ TType.MAP: CompactType.MAP,
+}
TTYPES = {}
for k, v in CTYPES.items():
- TTYPES[v] = k
+ TTYPES[v] = k
TTYPES[CompactType.FALSE] = TType.BOOL
del k
del v
class TCompactProtocol(TProtocolBase):
- """Compact implementation of the Thrift protocol driver."""
-
- PROTOCOL_ID = 0x82
- VERSION = 1
- VERSION_MASK = 0x1f
- TYPE_MASK = 0xe0
- TYPE_BITS = 0x07
- TYPE_SHIFT_AMOUNT = 5
-
- def __init__(self, trans,
- string_length_limit=None,
- container_length_limit=None):
- TProtocolBase.__init__(self, trans)
- self.state = CLEAR
- self.__last_fid = 0
- self.__bool_fid = None
- self.__bool_value = None
- self.__structs = []
- self.__containers = []
- self.string_length_limit = string_length_limit
- self.container_length_limit = container_length_limit
-
- def _check_string_length(self, length):
- self._check_length(self.string_length_limit, length)
-
- def _check_container_length(self, length):
- self._check_length(self.container_length_limit, length)
-
- def __writeVarint(self, n):
- writeVarint(self.trans, n)
-
- def writeMessageBegin(self, name, type, seqid):
- assert self.state == CLEAR
- self.__writeUByte(self.PROTOCOL_ID)
- self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
- self.__writeVarint(seqid)
- self.__writeBinary(str_to_binary(name))
- self.state = VALUE_WRITE
-
- def writeMessageEnd(self):
- assert self.state == VALUE_WRITE
- self.state = CLEAR
-
- def writeStructBegin(self, name):
- assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
- self.__structs.append((self.state, self.__last_fid))
- self.state = FIELD_WRITE
- self.__last_fid = 0
-
- def writeStructEnd(self):
- assert self.state == FIELD_WRITE
- self.state, self.__last_fid = self.__structs.pop()
-
- def writeFieldStop(self):
- self.__writeByte(0)
-
- def __writeFieldHeader(self, type, fid):
- delta = fid - self.__last_fid
- if 0 < delta <= 15:
- self.__writeUByte(delta << 4 | type)
- else:
- self.__writeByte(type)
- self.__writeI16(fid)
- self.__last_fid = fid
-
- def writeFieldBegin(self, name, type, fid):
- assert self.state == FIELD_WRITE, self.state
- if type == TType.BOOL:
- self.state = BOOL_WRITE
- self.__bool_fid = fid
- else:
- self.state = VALUE_WRITE
- self.__writeFieldHeader(CTYPES[type], fid)
-
- def writeFieldEnd(self):
- assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
- self.state = FIELD_WRITE
-
- def __writeUByte(self, byte):
- self.trans.write(pack('!B', byte))
-
- def __writeByte(self, byte):
- self.trans.write(pack('!b', byte))
-
- def __writeI16(self, i16):
- self.__writeVarint(makeZigZag(i16, 16))
-
- def __writeSize(self, i32):
- self.__writeVarint(i32)
-
- def writeCollectionBegin(self, etype, size):
- assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
- if size <= 14:
- self.__writeUByte(size << 4 | CTYPES[etype])
- else:
- self.__writeUByte(0xf0 | CTYPES[etype])
- self.__writeSize(size)
- self.__containers.append(self.state)
- self.state = CONTAINER_WRITE
- writeSetBegin = writeCollectionBegin
- writeListBegin = writeCollectionBegin
-
- def writeMapBegin(self, ktype, vtype, size):
- assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
- if size == 0:
- self.__writeByte(0)
- else:
- self.__writeSize(size)
- self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
- self.__containers.append(self.state)
- self.state = CONTAINER_WRITE
-
- def writeCollectionEnd(self):
- assert self.state == CONTAINER_WRITE, self.state
- self.state = self.__containers.pop()
- writeMapEnd = writeCollectionEnd
- writeSetEnd = writeCollectionEnd
- writeListEnd = writeCollectionEnd
-
- def writeBool(self, bool):
- if self.state == BOOL_WRITE:
- if bool:
- ctype = CompactType.TRUE
- else:
- ctype = CompactType.FALSE
- self.__writeFieldHeader(ctype, self.__bool_fid)
- elif self.state == CONTAINER_WRITE:
- if bool:
- self.__writeByte(CompactType.TRUE)
- else:
- self.__writeByte(CompactType.FALSE)
- else:
- raise AssertionError("Invalid state in compact protocol")
-
- writeByte = writer(__writeByte)
- writeI16 = writer(__writeI16)
-
- @writer
- def writeI32(self, i32):
- self.__writeVarint(makeZigZag(i32, 32))
-
- @writer
- def writeI64(self, i64):
- self.__writeVarint(makeZigZag(i64, 64))
-
- @writer
- def writeDouble(self, dub):
- self.trans.write(pack('<d', dub))
-
- def __writeBinary(self, s):
- self.__writeSize(len(s))
- self.trans.write(s)
- writeBinary = writer(__writeBinary)
-
- def readFieldBegin(self):
- assert self.state == FIELD_READ, self.state
- type = self.__readUByte()
- if type & 0x0f == TType.STOP:
- return (None, 0, 0)
- delta = type >> 4
- if delta == 0:
- fid = self.__readI16()
- else:
- fid = self.__last_fid + delta
- self.__last_fid = fid
- type = type & 0x0f
- if type == CompactType.TRUE:
- self.state = BOOL_READ
- self.__bool_value = True
- elif type == CompactType.FALSE:
- self.state = BOOL_READ
- self.__bool_value = False
- else:
- self.state = VALUE_READ
- return (None, self.__getTType(type), fid)
-
- def readFieldEnd(self):
- assert self.state in (VALUE_READ, BOOL_READ), self.state
- self.state = FIELD_READ
-
- def __readUByte(self):
- result, = unpack('!B', self.trans.readAll(1))
- return result
-
- def __readByte(self):
- result, = unpack('!b', self.trans.readAll(1))
- return result
-
- def __readVarint(self):
- return readVarint(self.trans)
-
- def __readZigZag(self):
- return fromZigZag(self.__readVarint())
-
- def __readSize(self):
- result = self.__readVarint()
- if result < 0:
- raise TProtocolException("Length < 0")
- return result
-
- def readMessageBegin(self):
- assert self.state == CLEAR
- proto_id = self.__readUByte()
- if proto_id != self.PROTOCOL_ID:
- raise TProtocolException(TProtocolException.BAD_VERSION,
- 'Bad protocol id in the message: %d' % proto_id)
- ver_type = self.__readUByte()
- type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
- version = ver_type & self.VERSION_MASK
- if version != self.VERSION:
- raise TProtocolException(TProtocolException.BAD_VERSION,
- 'Bad version: %d (expect %d)' % (version, self.VERSION))
- seqid = self.__readVarint()
- name = binary_to_str(self.__readBinary())
- return (name, type, seqid)
-
- def readMessageEnd(self):
- assert self.state == CLEAR
- assert len(self.__structs) == 0
-
- def readStructBegin(self):
- assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
- self.__structs.append((self.state, self.__last_fid))
- self.state = FIELD_READ
- self.__last_fid = 0
-
- def readStructEnd(self):
- assert self.state == FIELD_READ
- self.state, self.__last_fid = self.__structs.pop()
-
- def readCollectionBegin(self):
- assert self.state in (VALUE_READ, CONTAINER_READ), self.state
- size_type = self.__readUByte()
- size = size_type >> 4
- type = self.__getTType(size_type)
- if size == 15:
- size = self.__readSize()
- self._check_container_length(size)
- self.__containers.append(self.state)
- self.state = CONTAINER_READ
- return type, size
- readSetBegin = readCollectionBegin
- readListBegin = readCollectionBegin
-
- def readMapBegin(self):
- assert self.state in (VALUE_READ, CONTAINER_READ), self.state
- size = self.__readSize()
- self._check_container_length(size)
- types = 0
- if size > 0:
- types = self.__readUByte()
- vtype = self.__getTType(types)
- ktype = self.__getTType(types >> 4)
- self.__containers.append(self.state)
- self.state = CONTAINER_READ
- return (ktype, vtype, size)
-
- def readCollectionEnd(self):
- assert self.state == CONTAINER_READ, self.state
- self.state = self.__containers.pop()
- readSetEnd = readCollectionEnd
- readListEnd = readCollectionEnd
- readMapEnd = readCollectionEnd
-
- def readBool(self):
- if self.state == BOOL_READ:
- return self.__bool_value == CompactType.TRUE
- elif self.state == CONTAINER_READ:
- return self.__readByte() == CompactType.TRUE
- else:
- raise AssertionError("Invalid state in compact protocol: %d" %
- self.state)
-
- readByte = reader(__readByte)
- __readI16 = __readZigZag
- readI16 = reader(__readZigZag)
- readI32 = reader(__readZigZag)
- readI64 = reader(__readZigZag)
-
- @reader
- def readDouble(self):
- buff = self.trans.readAll(8)
- val, = unpack('<d', buff)
- return val
-
- def __readBinary(self):
- size = self.__readSize()
- self._check_string_length(size)
- return self.trans.readAll(size)
- readBinary = reader(__readBinary)
-
- def __getTType(self, byte):
- return TTYPES[byte & 0x0f]
+ """Compact implementation of the Thrift protocol driver."""
+
+ PROTOCOL_ID = 0x82
+ VERSION = 1
+ VERSION_MASK = 0x1f
+ TYPE_MASK = 0xe0
+ TYPE_BITS = 0x07
+ TYPE_SHIFT_AMOUNT = 5
+
+ def __init__(self, trans,
+ string_length_limit=None,
+ container_length_limit=None):
+ TProtocolBase.__init__(self, trans)
+ self.state = CLEAR
+ self.__last_fid = 0
+ self.__bool_fid = None
+ self.__bool_value = None
+ self.__structs = []
+ self.__containers = []
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
+
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
+
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
+
+ def __writeVarint(self, n):
+ writeVarint(self.trans, n)
+
+ def writeMessageBegin(self, name, type, seqid):
+ assert self.state == CLEAR
+ self.__writeUByte(self.PROTOCOL_ID)
+ self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
+ self.__writeVarint(seqid)
+ self.__writeBinary(str_to_binary(name))
+ self.state = VALUE_WRITE
+
+ def writeMessageEnd(self):
+ assert self.state == VALUE_WRITE
+ self.state = CLEAR
+
+ def writeStructBegin(self, name):
+ assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
+ self.__structs.append((self.state, self.__last_fid))
+ self.state = FIELD_WRITE
+ self.__last_fid = 0
+
+ def writeStructEnd(self):
+ assert self.state == FIELD_WRITE
+ self.state, self.__last_fid = self.__structs.pop()
+
+ def writeFieldStop(self):
+ self.__writeByte(0)
+
+ def __writeFieldHeader(self, type, fid):
+ delta = fid - self.__last_fid
+ if 0 < delta <= 15:
+ self.__writeUByte(delta << 4 | type)
+ else:
+ self.__writeByte(type)
+ self.__writeI16(fid)
+ self.__last_fid = fid
+
+ def writeFieldBegin(self, name, type, fid):
+ assert self.state == FIELD_WRITE, self.state
+ if type == TType.BOOL:
+ self.state = BOOL_WRITE
+ self.__bool_fid = fid
+ else:
+ self.state = VALUE_WRITE
+ self.__writeFieldHeader(CTYPES[type], fid)
+
+ def writeFieldEnd(self):
+ assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
+ self.state = FIELD_WRITE
+
+ def __writeUByte(self, byte):
+ self.trans.write(pack('!B', byte))
+
+ def __writeByte(self, byte):
+ self.trans.write(pack('!b', byte))
+
+ def __writeI16(self, i16):
+ self.__writeVarint(makeZigZag(i16, 16))
+
+ def __writeSize(self, i32):
+ self.__writeVarint(i32)
+
+ def writeCollectionBegin(self, etype, size):
+ assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
+ if size <= 14:
+ self.__writeUByte(size << 4 | CTYPES[etype])
+ else:
+ self.__writeUByte(0xf0 | CTYPES[etype])
+ self.__writeSize(size)
+ self.__containers.append(self.state)
+ self.state = CONTAINER_WRITE
+ writeSetBegin = writeCollectionBegin
+ writeListBegin = writeCollectionBegin
+
+ def writeMapBegin(self, ktype, vtype, size):
+ assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
+ if size == 0:
+ self.__writeByte(0)
+ else:
+ self.__writeSize(size)
+ self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
+ self.__containers.append(self.state)
+ self.state = CONTAINER_WRITE
+
+ def writeCollectionEnd(self):
+ assert self.state == CONTAINER_WRITE, self.state
+ self.state = self.__containers.pop()
+ writeMapEnd = writeCollectionEnd
+ writeSetEnd = writeCollectionEnd
+ writeListEnd = writeCollectionEnd
+
+ def writeBool(self, bool):
+ if self.state == BOOL_WRITE:
+ if bool:
+ ctype = CompactType.TRUE
+ else:
+ ctype = CompactType.FALSE
+ self.__writeFieldHeader(ctype, self.__bool_fid)
+ elif self.state == CONTAINER_WRITE:
+ if bool:
+ self.__writeByte(CompactType.TRUE)
+ else:
+ self.__writeByte(CompactType.FALSE)
+ else:
+ raise AssertionError("Invalid state in compact protocol")
+
+ writeByte = writer(__writeByte)
+ writeI16 = writer(__writeI16)
+
+ @writer
+ def writeI32(self, i32):
+ self.__writeVarint(makeZigZag(i32, 32))
+
+ @writer
+ def writeI64(self, i64):
+ self.__writeVarint(makeZigZag(i64, 64))
+
+ @writer
+ def writeDouble(self, dub):
+ self.trans.write(pack('<d', dub))
+
+ def __writeBinary(self, s):
+ self.__writeSize(len(s))
+ self.trans.write(s)
+ writeBinary = writer(__writeBinary)
+
+ def readFieldBegin(self):
+ assert self.state == FIELD_READ, self.state
+ type = self.__readUByte()
+ if type & 0x0f == TType.STOP:
+ return (None, 0, 0)
+ delta = type >> 4
+ if delta == 0:
+ fid = self.__readI16()
+ else:
+ fid = self.__last_fid + delta
+ self.__last_fid = fid
+ type = type & 0x0f
+ if type == CompactType.TRUE:
+ self.state = BOOL_READ
+ self.__bool_value = True
+ elif type == CompactType.FALSE:
+ self.state = BOOL_READ
+ self.__bool_value = False
+ else:
+ self.state = VALUE_READ
+ return (None, self.__getTType(type), fid)
+
+ def readFieldEnd(self):
+ assert self.state in (VALUE_READ, BOOL_READ), self.state
+ self.state = FIELD_READ
+
+ def __readUByte(self):
+ result, = unpack('!B', self.trans.readAll(1))
+ return result
+
+ def __readByte(self):
+ result, = unpack('!b', self.trans.readAll(1))
+ return result
+
+ def __readVarint(self):
+ return readVarint(self.trans)
+
+ def __readZigZag(self):
+ return fromZigZag(self.__readVarint())
+
+ def __readSize(self):
+ result = self.__readVarint()
+ if result < 0:
+ raise TProtocolException("Length < 0")
+ return result
+
+ def readMessageBegin(self):
+ assert self.state == CLEAR
+ proto_id = self.__readUByte()
+ if proto_id != self.PROTOCOL_ID:
+ raise TProtocolException(TProtocolException.BAD_VERSION,
+ 'Bad protocol id in the message: %d' % proto_id)
+ ver_type = self.__readUByte()
+ type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
+ version = ver_type & self.VERSION_MASK
+ if version != self.VERSION:
+ raise TProtocolException(TProtocolException.BAD_VERSION,
+ 'Bad version: %d (expect %d)' % (version, self.VERSION))
+ seqid = self.__readVarint()
+ name = binary_to_str(self.__readBinary())
+ return (name, type, seqid)
+
+ def readMessageEnd(self):
+ assert self.state == CLEAR
+ assert len(self.__structs) == 0
+
+ def readStructBegin(self):
+ assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
+ self.__structs.append((self.state, self.__last_fid))
+ self.state = FIELD_READ
+ self.__last_fid = 0
+
+ def readStructEnd(self):
+ assert self.state == FIELD_READ
+ self.state, self.__last_fid = self.__structs.pop()
+
+ def readCollectionBegin(self):
+ assert self.state in (VALUE_READ, CONTAINER_READ), self.state
+ size_type = self.__readUByte()
+ size = size_type >> 4
+ type = self.__getTType(size_type)
+ if size == 15:
+ size = self.__readSize()
+ self._check_container_length(size)
+ self.__containers.append(self.state)
+ self.state = CONTAINER_READ
+ return type, size
+ readSetBegin = readCollectionBegin
+ readListBegin = readCollectionBegin
+
+ def readMapBegin(self):
+ assert self.state in (VALUE_READ, CONTAINER_READ), self.state
+ size = self.__readSize()
+ self._check_container_length(size)
+ types = 0
+ if size > 0:
+ types = self.__readUByte()
+ vtype = self.__getTType(types)
+ ktype = self.__getTType(types >> 4)
+ self.__containers.append(self.state)
+ self.state = CONTAINER_READ
+ return (ktype, vtype, size)
+
+ def readCollectionEnd(self):
+ assert self.state == CONTAINER_READ, self.state
+ self.state = self.__containers.pop()
+ readSetEnd = readCollectionEnd
+ readListEnd = readCollectionEnd
+ readMapEnd = readCollectionEnd
+
+ def readBool(self):
+ if self.state == BOOL_READ:
+ return self.__bool_value == CompactType.TRUE
+ elif self.state == CONTAINER_READ:
+ return self.__readByte() == CompactType.TRUE
+ else:
+ raise AssertionError("Invalid state in compact protocol: %d" %
+ self.state)
+
+ readByte = reader(__readByte)
+ __readI16 = __readZigZag
+ readI16 = reader(__readZigZag)
+ readI32 = reader(__readZigZag)
+ readI64 = reader(__readZigZag)
+
+ @reader
+ def readDouble(self):
+ buff = self.trans.readAll(8)
+ val, = unpack('<d', buff)
+ return val
+
+ def __readBinary(self):
+ size = self.__readSize()
+ self._check_string_length(size)
+ return self.trans.readAll(size)
+ readBinary = reader(__readBinary)
+
+ def __getTType(self, byte):
+ return TTYPES[byte & 0x0f]
class TCompactProtocolFactory(object):
- def __init__(self,
- string_length_limit=None,
- container_length_limit=None):
- self.string_length_limit = string_length_limit
- self.container_length_limit = container_length_limit
-
- def getProtocol(self, trans):
- return TCompactProtocol(trans,
- self.string_length_limit,
- self.container_length_limit)
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
+
+ def getProtocol(self, trans):
+ return TCompactProtocol(trans,
+ self.string_length_limit,
+ self.container_length_limit)
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index f9e65fbf2..db2099a34 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -17,7 +17,8 @@
# under the License.
#
-from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
+from .TProtocol import (TType, TProtocolBase, TProtocolException,
+ checkIntegerLimits)
import base64
import math
import sys
@@ -45,14 +46,14 @@ ZERO = b'0'
ESCSEQ0 = ord('\\')
ESCSEQ1 = ord('u')
ESCAPE_CHAR_VALS = {
- '"': '\\"',
- '\\': '\\\\',
- '\b': '\\b',
- '\f': '\\f',
- '\n': '\\n',
- '\r': '\\r',
- '\t': '\\t',
- # '/': '\\/',
+ '"': '\\"',
+ '\\': '\\\\',
+ '\b': '\\b',
+ '\f': '\\f',
+ '\n': '\\n',
+ '\r': '\\r',
+ '\t': '\\t',
+ # '/': '\\/',
}
ESCAPE_CHARS = {
b'"': '"',
@@ -66,519 +67,527 @@ ESCAPE_CHARS = {
}
NUMERIC_CHAR = b'+-.0123456789Ee'
-CTYPES = {TType.BOOL: 'tf',
- TType.BYTE: 'i8',
- TType.I16: 'i16',
- TType.I32: 'i32',
- TType.I64: 'i64',
- TType.DOUBLE: 'dbl',
- TType.STRING: 'str',
- TType.STRUCT: 'rec',
- TType.LIST: 'lst',
- TType.SET: 'set',
- TType.MAP: 'map'}
+CTYPES = {
+ TType.BOOL: 'tf',
+ TType.BYTE: 'i8',
+ TType.I16: 'i16',
+ TType.I32: 'i32',
+ TType.I64: 'i64',
+ TType.DOUBLE: 'dbl',
+ TType.STRING: 'str',
+ TType.STRUCT: 'rec',
+ TType.LIST: 'lst',
+ TType.SET: 'set',
+ TType.MAP: 'map',
+}
JTYPES = {}
for key in CTYPES.keys():
- JTYPES[CTYPES[key]] = key
+ JTYPES[CTYPES[key]] = key
class JSONBaseContext(object):
- def __init__(self, protocol):
- self.protocol = protocol
- self.first = True
+ def __init__(self, protocol):
+ self.protocol = protocol
+ self.first = True
- def doIO(self, function):
- pass
+ def doIO(self, function):
+ pass
- def write(self):
- pass
+ def write(self):
+ pass
- def read(self):
- pass
+ def read(self):
+ pass
- def escapeNum(self):
- return False
+ def escapeNum(self):
+ return False
- def __str__(self):
- return self.__class__.__name__
+ def __str__(self):
+ return self.__class__.__name__
class JSONListContext(JSONBaseContext):
- def doIO(self, function):
- if self.first is True:
- self.first = False
- else:
- function(COMMA)
+ def doIO(self, function):
+ if self.first is True:
+ self.first = False
+ else:
+ function(COMMA)
- def write(self):
- self.doIO(self.protocol.trans.write)
+ def write(self):
+ self.doIO(self.protocol.trans.write)
- def read(self):
- self.doIO(self.protocol.readJSONSyntaxChar)
+ def read(self):
+ self.doIO(self.protocol.readJSONSyntaxChar)
class JSONPairContext(JSONBaseContext):
- def __init__(self, protocol):
- super(JSONPairContext, self).__init__(protocol)
- self.colon = True
+ def __init__(self, protocol):
+ super(JSONPairContext, self).__init__(protocol)
+ self.colon = True
- def doIO(self, function):
- if self.first:
- self.first = False
- self.colon = True
- else:
- function(COLON if self.colon else COMMA)
- self.colon = not self.colon
+ def doIO(self, function):
+ if self.first:
+ self.first = False
+ self.colon = True
+ else:
+ function(COLON if self.colon else COMMA)
+ self.colon = not self.colon
- def write(self):
- self.doIO(self.protocol.trans.write)
+ def write(self):
+ self.doIO(self.protocol.trans.write)
- def read(self):
- self.doIO(self.protocol.readJSONSyntaxChar)
+ def read(self):
+ self.doIO(self.protocol.readJSONSyntaxChar)
- def escapeNum(self):
- return self.colon
+ def escapeNum(self):
+ return self.colon
- def __str__(self):
- return '%s, colon=%s' % (self.__class__.__name__, self.colon)
+ def __str__(self):
+ return '%s, colon=%s' % (self.__class__.__name__, self.colon)
class LookaheadReader():
- hasData = False
- data = ''
+ hasData = False
+ data = ''
- def __init__(self, protocol):
- self.protocol = protocol
+ def __init__(self, protocol):
+ self.protocol = protocol
- def read(self):
- if self.hasData is True:
- self.hasData = False
- else:
- self.data = self.protocol.trans.read(1)
- return self.data
+ def read(self):
+ if self.hasData is True:
+ self.hasData = False
+ else:
+ self.data = self.protocol.trans.read(1)
+ return self.data
- def peek(self):
- if self.hasData is False:
- self.data = self.protocol.trans.read(1)
- self.hasData = True
- return self.data
+ def peek(self):
+ if self.hasData is False:
+ self.data = self.protocol.trans.read(1)
+ self.hasData = True
+ return self.data
class TJSONProtocolBase(TProtocolBase):
- def __init__(self, trans):
- TProtocolBase.__init__(self, trans)
- self.resetWriteContext()
- self.resetReadContext()
-
- # We don't have length limit implementation for JSON protocols
- @property
- def string_length_limit(senf):
- return None
-
- @property
- def container_length_limit(senf):
- return None
-
- def resetWriteContext(self):
- self.context = JSONBaseContext(self)
- self.contextStack = [self.context]
-
- def resetReadContext(self):
- self.resetWriteContext()
- self.reader = LookaheadReader(self)
-
- def pushContext(self, ctx):
- self.contextStack.append(ctx)
- self.context = ctx
-
- def popContext(self):
- self.contextStack.pop()
- if self.contextStack:
- self.context = self.contextStack[-1]
- else:
- self.context = JSONBaseContext(self)
-
- def writeJSONString(self, string):
- self.context.write()
- json_str = ['"']
- for s in string:
- escaped = ESCAPE_CHAR_VALS.get(s, s)
- json_str.append(escaped)
- json_str.append('"')
- self.trans.write(str_to_binary(''.join(json_str)))
-
- def writeJSONNumber(self, number, formatter='{0}'):
- self.context.write()
- jsNumber = str(formatter.format(number)).encode('ascii')
- if self.context.escapeNum():
- self.trans.write(QUOTE)
- self.trans.write(jsNumber)
- self.trans.write(QUOTE)
- else:
- self.trans.write(jsNumber)
-
- def writeJSONBase64(self, binary):
- self.context.write()
- self.trans.write(QUOTE)
- self.trans.write(base64.b64encode(binary))
- self.trans.write(QUOTE)
-
- def writeJSONObjectStart(self):
- self.context.write()
- self.trans.write(LBRACE)
- self.pushContext(JSONPairContext(self))
-
- def writeJSONObjectEnd(self):
- self.popContext()
- self.trans.write(RBRACE)
-
- def writeJSONArrayStart(self):
- self.context.write()
- self.trans.write(LBRACKET)
- self.pushContext(JSONListContext(self))
-
- def writeJSONArrayEnd(self):
- self.popContext()
- self.trans.write(RBRACKET)
-
- def readJSONSyntaxChar(self, character):
- current = self.reader.read()
- if character != current:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Unexpected character: %s" % current)
-
- def _isHighSurrogate(self, codeunit):
- return codeunit >= 0xd800 and codeunit <= 0xdbff
-
- def _isLowSurrogate(self, codeunit):
- return codeunit >= 0xdc00 and codeunit <= 0xdfff
-
- def _toChar(self, high, low=None):
- if not low:
- if sys.version_info[0] == 2:
- return ("\\u%04x" % high).decode('unicode-escape').encode('utf-8')
- else:
- return chr(high)
- else:
- codepoint = (1 << 16) + ((high & 0x3ff) << 10)
- codepoint += low & 0x3ff
- if sys.version_info[0] == 2:
- s = "\\U%08x" % codepoint
- return s.decode('unicode-escape').encode('utf-8')
- else:
- return chr(codepoint)
-
- def readJSONString(self, skipContext):
- highSurrogate = None
- string = []
- if skipContext is False:
- self.context.read()
- self.readJSONSyntaxChar(QUOTE)
- while True:
- character = self.reader.read()
- if character == QUOTE:
- break
- if ord(character) == ESCSEQ0:
- character = self.reader.read()
- if ord(character) == ESCSEQ1:
- character = self.trans.read(4).decode('ascii')
- codeunit = int(character, 16)
- if self._isHighSurrogate(codeunit):
- if highSurrogate:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected low surrogate char")
- highSurrogate = codeunit
- continue
- elif self._isLowSurrogate(codeunit):
- if not highSurrogate:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected high surrogate char")
- character = self._toChar(highSurrogate, codeunit)
- highSurrogate = None
- else:
- character = self._toChar(codeunit)
+ def __init__(self, trans):
+ TProtocolBase.__init__(self, trans)
+ self.resetWriteContext()
+ self.resetReadContext()
+
+ # We don't have length limit implementation for JSON protocols
+ @property
+ def string_length_limit(senf):
+ return None
+
+ @property
+ def container_length_limit(senf):
+ return None
+
+ def resetWriteContext(self):
+ self.context = JSONBaseContext(self)
+ self.contextStack = [self.context]
+
+ def resetReadContext(self):
+ self.resetWriteContext()
+ self.reader = LookaheadReader(self)
+
+ def pushContext(self, ctx):
+ self.contextStack.append(ctx)
+ self.context = ctx
+
+ def popContext(self):
+ self.contextStack.pop()
+ if self.contextStack:
+ self.context = self.contextStack[-1]
+ else:
+ self.context = JSONBaseContext(self)
+
+ def writeJSONString(self, string):
+ self.context.write()
+ json_str = ['"']
+ for s in string:
+ escaped = ESCAPE_CHAR_VALS.get(s, s)
+ json_str.append(escaped)
+ json_str.append('"')
+ self.trans.write(str_to_binary(''.join(json_str)))
+
+ def writeJSONNumber(self, number, formatter='{0}'):
+ self.context.write()
+ jsNumber = str(formatter.format(number)).encode('ascii')
+ if self.context.escapeNum():
+ self.trans.write(QUOTE)
+ self.trans.write(jsNumber)
+ self.trans.write(QUOTE)
else:
- if character not in ESCAPE_CHARS:
+ self.trans.write(jsNumber)
+
+ def writeJSONBase64(self, binary):
+ self.context.write()
+ self.trans.write(QUOTE)
+ self.trans.write(base64.b64encode(binary))
+ self.trans.write(QUOTE)
+
+ def writeJSONObjectStart(self):
+ self.context.write()
+ self.trans.write(LBRACE)
+ self.pushContext(JSONPairContext(self))
+
+ def writeJSONObjectEnd(self):
+ self.popContext()
+ self.trans.write(RBRACE)
+
+ def writeJSONArrayStart(self):
+ self.context.write()
+ self.trans.write(LBRACKET)
+ self.pushContext(JSONListContext(self))
+
+ def writeJSONArrayEnd(self):
+ self.popContext()
+ self.trans.write(RBRACKET)
+
+ def readJSONSyntaxChar(self, character):
+ current = self.reader.read()
+ if character != current:
raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected control char")
- character = ESCAPE_CHARS[character]
- elif character in ESCAPE_CHAR_VALS:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Unescaped control char")
- elif sys.version_info[0] > 2:
- utf8_bytes = bytearray([ord(character)])
- while ord(self.reader.peek()) >= 0x80:
- utf8_bytes.append(ord(self.reader.read()))
- character = utf8_bytes.decode('utf8')
- string.append(character)
-
- if highSurrogate:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected low surrogate char")
- return ''.join(string)
-
- def isJSONNumeric(self, character):
- return (True if NUMERIC_CHAR.find(character) != - 1 else False)
-
- def readJSONQuotes(self):
- if (self.context.escapeNum()):
- self.readJSONSyntaxChar(QUOTE)
-
- def readJSONNumericChars(self):
- numeric = []
- while True:
- character = self.reader.peek()
- if self.isJSONNumeric(character) is False:
- break
- numeric.append(self.reader.read())
- return b''.join(numeric).decode('ascii')
-
- def readJSONInteger(self):
- self.context.read()
- self.readJSONQuotes()
- numeric = self.readJSONNumericChars()
- self.readJSONQuotes()
- try:
- return int(numeric)
- except ValueError:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Bad data encounted in numeric data")
-
- def readJSONDouble(self):
- self.context.read()
- if self.reader.peek() == QUOTE:
- string = self.readJSONString(True)
- try:
- double = float(string)
- if (self.context.escapeNum is False and
- not math.isinf(double) and
- not math.isnan(double)):
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Numeric data unexpectedly quoted")
- return double
- except ValueError:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Bad data encounted in numeric data")
- else:
- if self.context.escapeNum() is True:
+ "Unexpected character: %s" % current)
+
+ def _isHighSurrogate(self, codeunit):
+ return codeunit >= 0xd800 and codeunit <= 0xdbff
+
+ def _isLowSurrogate(self, codeunit):
+ return codeunit >= 0xdc00 and codeunit <= 0xdfff
+
+ def _toChar(self, high, low=None):
+ if not low:
+ if sys.version_info[0] == 2:
+ return ("\\u%04x" % high).decode('unicode-escape') \
+ .encode('utf-8')
+ else:
+ return chr(high)
+ else:
+ codepoint = (1 << 16) + ((high & 0x3ff) << 10)
+ codepoint += low & 0x3ff
+ if sys.version_info[0] == 2:
+ s = "\\U%08x" % codepoint
+ return s.decode('unicode-escape').encode('utf-8')
+ else:
+ return chr(codepoint)
+
+ def readJSONString(self, skipContext):
+ highSurrogate = None
+ string = []
+ if skipContext is False:
+ self.context.read()
self.readJSONSyntaxChar(QUOTE)
- try:
- return float(self.readJSONNumericChars())
- except ValueError:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Bad data encounted in numeric data")
-
- def readJSONBase64(self):
- string = self.readJSONString(False)
- size = len(string)
- m = size % 4
- # Force padding since b64encode method does not allow it
- if m != 0:
- for i in range(4 - m):
- string += '='
- return base64.b64decode(string)
-
- def readJSONObjectStart(self):
- self.context.read()
- self.readJSONSyntaxChar(LBRACE)
- self.pushContext(JSONPairContext(self))
-
- def readJSONObjectEnd(self):
- self.readJSONSyntaxChar(RBRACE)
- self.popContext()
-
- def readJSONArrayStart(self):
- self.context.read()
- self.readJSONSyntaxChar(LBRACKET)
- self.pushContext(JSONListContext(self))
-
- def readJSONArrayEnd(self):
- self.readJSONSyntaxChar(RBRACKET)
- self.popContext()
+ while True:
+ character = self.reader.read()
+ if character == QUOTE:
+ break
+ if ord(character) == ESCSEQ0:
+ character = self.reader.read()
+ if ord(character) == ESCSEQ1:
+ character = self.trans.read(4).decode('ascii')
+ codeunit = int(character, 16)
+ if self._isHighSurrogate(codeunit):
+ if highSurrogate:
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Expected low surrogate char")
+ highSurrogate = codeunit
+ continue
+ elif self._isLowSurrogate(codeunit):
+ if not highSurrogate:
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Expected high surrogate char")
+ character = self._toChar(highSurrogate, codeunit)
+ highSurrogate = None
+ else:
+ character = self._toChar(codeunit)
+ else:
+ if character not in ESCAPE_CHARS:
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Expected control char")
+ character = ESCAPE_CHARS[character]
+ elif character in ESCAPE_CHAR_VALS:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Unescaped control char")
+ elif sys.version_info[0] > 2:
+ utf8_bytes = bytearray([ord(character)])
+ while ord(self.reader.peek()) >= 0x80:
+ utf8_bytes.append(ord(self.reader.read()))
+ character = utf8_bytes.decode('utf8')
+ string.append(character)
+
+ if highSurrogate:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Expected low surrogate char")
+ return ''.join(string)
+
+ def isJSONNumeric(self, character):
+ return (True if NUMERIC_CHAR.find(character) != - 1 else False)
+
+ def readJSONQuotes(self):
+ if (self.context.escapeNum()):
+ self.readJSONSyntaxChar(QUOTE)
+
+ def readJSONNumericChars(self):
+ numeric = []
+ while True:
+ character = self.reader.peek()
+ if self.isJSONNumeric(character) is False:
+ break
+ numeric.append(self.reader.read())
+ return b''.join(numeric).decode('ascii')
+
+ def readJSONInteger(self):
+ self.context.read()
+ self.readJSONQuotes()
+ numeric = self.readJSONNumericChars()
+ self.readJSONQuotes()
+ try:
+ return int(numeric)
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+
+ def readJSONDouble(self):
+ self.context.read()
+ if self.reader.peek() == QUOTE:
+ string = self.readJSONString(True)
+ try:
+ double = float(string)
+ if (self.context.escapeNum is False and
+ not math.isinf(double) and
+ not math.isnan(double)):
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Numeric data unexpectedly quoted")
+ return double
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+ else:
+ if self.context.escapeNum() is True:
+ self.readJSONSyntaxChar(QUOTE)
+ try:
+ return float(self.readJSONNumericChars())
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+
+ def readJSONBase64(self):
+ string = self.readJSONString(False)
+ size = len(string)
+ m = size % 4
+ # Force padding since b64encode method does not allow it
+ if m != 0:
+ for i in range(4 - m):
+ string += '='
+ return base64.b64decode(string)
+
+ def readJSONObjectStart(self):
+ self.context.read()
+ self.readJSONSyntaxChar(LBRACE)
+ self.pushContext(JSONPairContext(self))
+
+ def readJSONObjectEnd(self):
+ self.readJSONSyntaxChar(RBRACE)
+ self.popContext()
+
+ def readJSONArrayStart(self):
+ self.context.read()
+ self.readJSONSyntaxChar(LBRACKET)
+ self.pushContext(JSONListContext(self))
+
+ def readJSONArrayEnd(self):
+ self.readJSONSyntaxChar(RBRACKET)
+ self.popContext()
class TJSONProtocol(TJSONProtocolBase):
- def readMessageBegin(self):
- self.resetReadContext()
- self.readJSONArrayStart()
- if self.readJSONInteger() != VERSION:
- raise TProtocolException(TProtocolException.BAD_VERSION,
- "Message contained bad version.")
- name = self.readJSONString(False)
- typen = self.readJSONInteger()
- seqid = self.readJSONInteger()
- return (name, typen, seqid)
-
- def readMessageEnd(self):
- self.readJSONArrayEnd()
-
- def readStructBegin(self):
- self.readJSONObjectStart()
-
- def readStructEnd(self):
- self.readJSONObjectEnd()
-
- def readFieldBegin(self):
- character = self.reader.peek()
- ttype = 0
- id = 0
- if character == RBRACE:
- ttype = TType.STOP
- else:
- id = self.readJSONInteger()
- self.readJSONObjectStart()
- ttype = JTYPES[self.readJSONString(False)]
- return (None, ttype, id)
-
- def readFieldEnd(self):
- self.readJSONObjectEnd()
-
- def readMapBegin(self):
- self.readJSONArrayStart()
- keyType = JTYPES[self.readJSONString(False)]
- valueType = JTYPES[self.readJSONString(False)]
- size = self.readJSONInteger()
- self.readJSONObjectStart()
- return (keyType, valueType, size)
-
- def readMapEnd(self):
- self.readJSONObjectEnd()
- self.readJSONArrayEnd()
-
- def readCollectionBegin(self):
- self.readJSONArrayStart()
- elemType = JTYPES[self.readJSONString(False)]
- size = self.readJSONInteger()
- return (elemType, size)
- readListBegin = readCollectionBegin
- readSetBegin = readCollectionBegin
-
- def readCollectionEnd(self):
- self.readJSONArrayEnd()
- readSetEnd = readCollectionEnd
- readListEnd = readCollectionEnd
-
- def readBool(self):
- return (False if self.readJSONInteger() == 0 else True)
-
- def readNumber(self):
- return self.readJSONInteger()
- readByte = readNumber
- readI16 = readNumber
- readI32 = readNumber
- readI64 = readNumber
-
- def readDouble(self):
- return self.readJSONDouble()
-
- def readString(self):
- return self.readJSONString(False)
-
- def readBinary(self):
- return self.readJSONBase64()
-
- def writeMessageBegin(self, name, request_type, seqid):
- self.resetWriteContext()
- self.writeJSONArrayStart()
- self.writeJSONNumber(VERSION)
- self.writeJSONString(name)
- self.writeJSONNumber(request_type)
- self.writeJSONNumber(seqid)
-
- def writeMessageEnd(self):
- self.writeJSONArrayEnd()
-
- def writeStructBegin(self, name):
- self.writeJSONObjectStart()
-
- def writeStructEnd(self):
- self.writeJSONObjectEnd()
-
- def writeFieldBegin(self, name, ttype, id):
- self.writeJSONNumber(id)
- self.writeJSONObjectStart()
- self.writeJSONString(CTYPES[ttype])
-
- def writeFieldEnd(self):
- self.writeJSONObjectEnd()
-
- def writeFieldStop(self):
- pass
-
- def writeMapBegin(self, ktype, vtype, size):
- self.writeJSONArrayStart()
- self.writeJSONString(CTYPES[ktype])
- self.writeJSONString(CTYPES[vtype])
- self.writeJSONNumber(size)
- self.writeJSONObjectStart()
-
- def writeMapEnd(self):
- self.writeJSONObjectEnd()
- self.writeJSONArrayEnd()
-
- def writeListBegin(self, etype, size):
- self.writeJSONArrayStart()
- self.writeJSONString(CTYPES[etype])
- self.writeJSONNumber(size)
-
- def writeListEnd(self):
- self.writeJSONArrayEnd()
-
- def writeSetBegin(self, etype, size):
- self.writeJSONArrayStart()
- self.writeJSONString(CTYPES[etype])
- self.writeJSONNumber(size)
+ def readMessageBegin(self):
+ self.resetReadContext()
+ self.readJSONArrayStart()
+ if self.readJSONInteger() != VERSION:
+ raise TProtocolException(TProtocolException.BAD_VERSION,
+ "Message contained bad version.")
+ name = self.readJSONString(False)
+ typen = self.readJSONInteger()
+ seqid = self.readJSONInteger()
+ return (name, typen, seqid)
+
+ def readMessageEnd(self):
+ self.readJSONArrayEnd()
+
+ def readStructBegin(self):
+ self.readJSONObjectStart()
+
+ def readStructEnd(self):
+ self.readJSONObjectEnd()
+
+ def readFieldBegin(self):
+ character = self.reader.peek()
+ ttype = 0
+ id = 0
+ if character == RBRACE:
+ ttype = TType.STOP
+ else:
+ id = self.readJSONInteger()
+ self.readJSONObjectStart()
+ ttype = JTYPES[self.readJSONString(False)]
+ return (None, ttype, id)
+
+ def readFieldEnd(self):
+ self.readJSONObjectEnd()
+
+ def readMapBegin(self):
+ self.readJSONArrayStart()
+ keyType = JTYPES[self.readJSONString(False)]
+ valueType = JTYPES[self.readJSONString(False)]
+ size = self.readJSONInteger()
+ self.readJSONObjectStart()
+ return (keyType, valueType, size)
+
+ def readMapEnd(self):
+ self.readJSONObjectEnd()
+ self.readJSONArrayEnd()
+
+ def readCollectionBegin(self):
+ self.readJSONArrayStart()
+ elemType = JTYPES[self.readJSONString(False)]
+ size = self.readJSONInteger()
+ return (elemType, size)
+ readListBegin = readCollectionBegin
+ readSetBegin = readCollectionBegin
+
+ def readCollectionEnd(self):
+ self.readJSONArrayEnd()
+ readSetEnd = readCollectionEnd
+ readListEnd = readCollectionEnd
+
+ def readBool(self):
+ return (False if self.readJSONInteger() == 0 else True)
+
+ def readNumber(self):
+ return self.readJSONInteger()
+ readByte = readNumber
+ readI16 = readNumber
+ readI32 = readNumber
+ readI64 = readNumber
+
+ def readDouble(self):
+ return self.readJSONDouble()
+
+ def readString(self):
+ return self.readJSONString(False)
+
+ def readBinary(self):
+ return self.readJSONBase64()
+
+ def writeMessageBegin(self, name, request_type, seqid):
+ self.resetWriteContext()
+ self.writeJSONArrayStart()
+ self.writeJSONNumber(VERSION)
+ self.writeJSONString(name)
+ self.writeJSONNumber(request_type)
+ self.writeJSONNumber(seqid)
+
+ def writeMessageEnd(self):
+ self.writeJSONArrayEnd()
+
+ def writeStructBegin(self, name):
+ self.writeJSONObjectStart()
+
+ def writeStructEnd(self):
+ self.writeJSONObjectEnd()
- def writeSetEnd(self):
- self.writeJSONArrayEnd()
+ def writeFieldBegin(self, name, ttype, id):
+ self.writeJSONNumber(id)
+ self.writeJSONObjectStart()
+ self.writeJSONString(CTYPES[ttype])
- def writeBool(self, boolean):
- self.writeJSONNumber(1 if boolean is True else 0)
-
- def writeByte(self, byte):
- checkIntegerLimits(byte, 8)
- self.writeJSONNumber(byte)
+ def writeFieldEnd(self):
+ self.writeJSONObjectEnd()
- def writeI16(self, i16):
- checkIntegerLimits(i16, 16)
- self.writeJSONNumber(i16)
+ def writeFieldStop(self):
+ pass
- def writeI32(self, i32):
- checkIntegerLimits(i32, 32)
- self.writeJSONNumber(i32)
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[ktype])
+ self.writeJSONString(CTYPES[vtype])
+ self.writeJSONNumber(size)
+ self.writeJSONObjectStart()
- def writeI64(self, i64):
- checkIntegerLimits(i64, 64)
- self.writeJSONNumber(i64)
+ def writeMapEnd(self):
+ self.writeJSONObjectEnd()
+ self.writeJSONArrayEnd()
- def writeDouble(self, dbl):
- # 17 significant digits should be just enough for any double precision value.
- self.writeJSONNumber(dbl, '{0:.17g}')
+ def writeListBegin(self, etype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[etype])
+ self.writeJSONNumber(size)
+
+ def writeListEnd(self):
+ self.writeJSONArrayEnd()
- def writeString(self, string):
- self.writeJSONString(string)
+ def writeSetBegin(self, etype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[etype])
+ self.writeJSONNumber(size)
- def writeBinary(self, binary):
- self.writeJSONBase64(binary)
+ def writeSetEnd(self):
+ self.writeJSONArrayEnd()
+
+ def writeBool(self, boolean):
+ self.writeJSONNumber(1 if boolean is True else 0)
+
+ def writeByte(self, byte):
+ checkIntegerLimits(byte, 8)
+ self.writeJSONNumber(byte)
+
+ def writeI16(self, i16):
+ checkIntegerLimits(i16, 16)
+ self.writeJSONNumber(i16)
+
+ def writeI32(self, i32):
+ checkIntegerLimits(i32, 32)
+ self.writeJSONNumber(i32)
+
+ def writeI64(self, i64):
+ checkIntegerLimits(i64, 64)
+ self.writeJSONNumber(i64)
+
+ def writeDouble(self, dbl):
+ # 17 significant digits should be just enough for any double precision
+ # value.
+ self.writeJSONNumber(dbl, '{0:.17g}')
+
+ def writeString(self, string):
+ self.writeJSONString(string)
+
+ def writeBinary(self, binary):
+ self.writeJSONBase64(binary)
class TJSONProtocolFactory(object):
- def getProtocol(self, trans):
- return TJSONProtocol(trans)
+ def getProtocol(self, trans):
+ return TJSONProtocol(trans)
- @property
- def string_length_limit(senf):
- return None
+ @property
+ def string_length_limit(senf):
+ return None
- @property
- def container_length_limit(senf):
- return None
+ @property
+ def container_length_limit(senf):
+ return None
class TSimpleJSONProtocol(TJSONProtocolBase):
diff --git a/lib/py/src/protocol/TMultiplexedProtocol.py b/lib/py/src/protocol/TMultiplexedProtocol.py
index d25f367b5..309f896d0 100644
--- a/lib/py/src/protocol/TMultiplexedProtocol.py
+++ b/lib/py/src/protocol/TMultiplexedProtocol.py
@@ -22,18 +22,19 @@ from thrift.protocol import TProtocolDecorator
SEPARATOR = ":"
+
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
- def __init__(self, protocol, serviceName):
- TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
- self.serviceName = serviceName
+ def __init__(self, protocol, serviceName):
+ TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
+ self.serviceName = serviceName
- def writeMessageBegin(self, name, type, seqid):
- if (type == TMessageType.CALL or
- type == TMessageType.ONEWAY):
- self.protocol.writeMessageBegin(
- self.serviceName + SEPARATOR + name,
- type,
- seqid
- )
- else:
- self.protocol.writeMessageBegin(name, type, seqid)
+ def writeMessageBegin(self, name, type, seqid):
+ if (type == TMessageType.CALL or
+ type == TMessageType.ONEWAY):
+ self.protocol.writeMessageBegin(
+ self.serviceName + SEPARATOR + name,
+ type,
+ seqid
+ )
+ else:
+ self.protocol.writeMessageBegin(name, type, seqid)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index d9aa2e82b..ed6938bb6 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -28,373 +28,373 @@ from six.moves import zip
class TProtocolException(TException):
- """Custom Protocol Exception class"""
+ """Custom Protocol Exception class"""
- UNKNOWN = 0
- INVALID_DATA = 1
- NEGATIVE_SIZE = 2
- SIZE_LIMIT = 3
- BAD_VERSION = 4
- NOT_IMPLEMENTED = 5
- DEPTH_LIMIT = 6
+ UNKNOWN = 0
+ INVALID_DATA = 1
+ NEGATIVE_SIZE = 2
+ SIZE_LIMIT = 3
+ BAD_VERSION = 4
+ NOT_IMPLEMENTED = 5
+ DEPTH_LIMIT = 6
- def __init__(self, type=UNKNOWN, message=None):
- TException.__init__(self, message)
- self.type = type
+ def __init__(self, type=UNKNOWN, message=None):
+ TException.__init__(self, message)
+ self.type = type
class TProtocolBase(object):
- """Base class for Thrift protocol driver."""
+ """Base class for Thrift protocol driver."""
- def __init__(self, trans):
- self.trans = trans
+ def __init__(self, trans):
+ self.trans = trans
- @staticmethod
- def _check_length(limit, length):
- if length < 0:
- raise TTransportException(TTransportException.NEGATIVE_SIZE,
- 'Negative length: %d' % length)
- if limit is not None and length > limit:
- raise TTransportException(TTransportException.SIZE_LIMIT,
- 'Length exceeded max allowed: %d' % limit)
+ @staticmethod
+ def _check_length(limit, length):
+ if length < 0:
+ raise TTransportException(TTransportException.NEGATIVE_SIZE,
+ 'Negative length: %d' % length)
+ if limit is not None and length > limit:
+ raise TTransportException(TTransportException.SIZE_LIMIT,
+ 'Length exceeded max allowed: %d' % limit)
- def writeMessageBegin(self, name, ttype, seqid):
- pass
+ def writeMessageBegin(self, name, ttype, seqid):
+ pass
- def writeMessageEnd(self):
- pass
+ def writeMessageEnd(self):
+ pass
- def writeStructBegin(self, name):
- pass
+ def writeStructBegin(self, name):
+ pass
- def writeStructEnd(self):
- pass
+ def writeStructEnd(self):
+ pass
- def writeFieldBegin(self, name, ttype, fid):
- pass
+ def writeFieldBegin(self, name, ttype, fid):
+ pass
- def writeFieldEnd(self):
- pass
+ def writeFieldEnd(self):
+ pass
- def writeFieldStop(self):
- pass
+ def writeFieldStop(self):
+ pass
- def writeMapBegin(self, ktype, vtype, size):
- pass
+ def writeMapBegin(self, ktype, vtype, size):
+ pass
- def writeMapEnd(self):
- pass
+ def writeMapEnd(self):
+ pass
- def writeListBegin(self, etype, size):
- pass
+ def writeListBegin(self, etype, size):
+ pass
- def writeListEnd(self):
- pass
+ def writeListEnd(self):
+ pass
- def writeSetBegin(self, etype, size):
- pass
+ def writeSetBegin(self, etype, size):
+ pass
- def writeSetEnd(self):
- pass
+ def writeSetEnd(self):
+ pass
- def writeBool(self, bool_val):
- pass
+ def writeBool(self, bool_val):
+ pass
- def writeByte(self, byte):
- pass
+ def writeByte(self, byte):
+ pass
- def writeI16(self, i16):
- pass
+ def writeI16(self, i16):
+ pass
- def writeI32(self, i32):
- pass
+ def writeI32(self, i32):
+ pass
- def writeI64(self, i64):
- pass
+ def writeI64(self, i64):
+ pass
- def writeDouble(self, dub):
- pass
+ def writeDouble(self, dub):
+ pass
- def writeString(self, str_val):
- self.writeBinary(str_to_binary(str_val))
+ def writeString(self, str_val):
+ self.writeBinary(str_to_binary(str_val))
- def writeBinary(self, str_val):
- pass
+ def writeBinary(self, str_val):
+ pass
- def writeUtf8(self, str_val):
- self.writeString(str_val.encode('utf8'))
+ def writeUtf8(self, str_val):
+ self.writeString(str_val.encode('utf8'))
- def readMessageBegin(self):
- pass
+ def readMessageBegin(self):
+ pass
- def readMessageEnd(self):
- pass
+ def readMessageEnd(self):
+ pass
- def readStructBegin(self):
- pass
+ def readStructBegin(self):
+ pass
- def readStructEnd(self):
- pass
+ def readStructEnd(self):
+ pass
- def readFieldBegin(self):
- pass
+ def readFieldBegin(self):
+ pass
- def readFieldEnd(self):
- pass
+ def readFieldEnd(self):
+ pass
- def readMapBegin(self):
- pass
+ def readMapBegin(self):
+ pass
- def readMapEnd(self):
- pass
+ def readMapEnd(self):
+ pass
- def readListBegin(self):
- pass
+ def readListBegin(self):
+ pass
- def readListEnd(self):
- pass
+ def readListEnd(self):
+ pass
- def readSetBegin(self):
- pass
+ def readSetBegin(self):
+ pass
- def readSetEnd(self):
- pass
+ def readSetEnd(self):
+ pass
- def readBool(self):
- pass
+ def readBool(self):
+ pass
- def readByte(self):
- pass
+ def readByte(self):
+ pass
- def readI16(self):
- pass
+ def readI16(self):
+ pass
- def readI32(self):
- pass
+ def readI32(self):
+ pass
- def readI64(self):
- pass
+ def readI64(self):
+ pass
- def readDouble(self):
- pass
+ def readDouble(self):
+ pass
- def readString(self):
- return binary_to_str(self.readBinary())
+ def readString(self):
+ return binary_to_str(self.readBinary())
- def readBinary(self):
- pass
+ def readBinary(self):
+ pass
- def readUtf8(self):
- return self.readString().decode('utf8')
+ def readUtf8(self):
+ return self.readString().decode('utf8')
- def skip(self, ttype):
- if ttype == TType.STOP:
- return
- elif ttype == TType.BOOL:
- self.readBool()
- elif ttype == TType.BYTE:
- self.readByte()
- elif ttype == TType.I16:
- self.readI16()
- elif ttype == TType.I32:
- self.readI32()
- elif ttype == TType.I64:
- self.readI64()
- elif ttype == TType.DOUBLE:
- self.readDouble()
- elif ttype == TType.STRING:
- self.readString()
- elif ttype == TType.STRUCT:
- name = self.readStructBegin()
- while True:
- (name, ttype, id) = self.readFieldBegin()
+ def skip(self, ttype):
if ttype == TType.STOP:
- break
- self.skip(ttype)
- self.readFieldEnd()
- self.readStructEnd()
- elif ttype == TType.MAP:
- (ktype, vtype, size) = self.readMapBegin()
- for i in range(size):
- self.skip(ktype)
- self.skip(vtype)
- self.readMapEnd()
- elif ttype == TType.SET:
- (etype, size) = self.readSetBegin()
- for i in range(size):
- self.skip(etype)
- self.readSetEnd()
- elif ttype == TType.LIST:
- (etype, size) = self.readListBegin()
- for i in range(size):
- self.skip(etype)
- self.readListEnd()
-
- # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
- _TTYPE_HANDLERS = (
- (None, None, False), # 0 TType.STOP
- (None, None, False), # 1 TType.VOID # TODO: handle void?
- ('readBool', 'writeBool', False), # 2 TType.BOOL
- ('readByte', 'writeByte', False), # 3 TType.BYTE and I08
- ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
- (None, None, False), # 5 undefined
- ('readI16', 'writeI16', False), # 6 TType.I16
- (None, None, False), # 7 undefined
- ('readI32', 'writeI32', False), # 8 TType.I32
- (None, None, False), # 9 undefined
- ('readI64', 'writeI64', False), # 10 TType.I64
- ('readString', 'writeString', False), # 11 TType.STRING and UTF7
- ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
- ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
- ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
- ('readContainerList', 'writeContainerList', True), # 15 TType.LIST
- (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
- (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
- )
-
- def _ttype_handlers(self, ttype, spec):
- if spec == 'BINARY':
- if ttype != TType.STRING:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid binary field type %d' % ttype)
- return ('readBinary', 'writeBinary', False)
- if sys.version_info[0] == 2 and spec == 'UTF8':
- if ttype != TType.STRING:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid string field type %d' % ttype)
- return ('readUtf8', 'writeUtf8', False)
- return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
-
- def _read_by_ttype(self, ttype, spec, espec):
- reader_name, _, is_container = self._ttype_handlers(ttype, spec)
- if reader_name is None:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid type %d' % (ttype))
- reader_func = getattr(self, reader_name)
- read = (lambda: reader_func(espec)) if is_container else reader_func
- while True:
- yield read()
-
- def readFieldByTType(self, ttype, spec):
- return self._read_by_ttype(ttype, spec, spec).next()
-
- def readContainerList(self, spec):
- ttype, tspec, is_immutable = spec
- (list_type, list_len) = self.readListBegin()
- # TODO: compare types we just decoded with thrift_spec
- elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
- results = (tuple if is_immutable else list)(elems)
- self.readListEnd()
- return results
-
- def readContainerSet(self, spec):
- ttype, tspec, is_immutable = spec
- (set_type, set_len) = self.readSetBegin()
- # TODO: compare types we just decoded with thrift_spec
- elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
- results = (frozenset if is_immutable else set)(elems)
- self.readSetEnd()
- return results
-
- def readContainerStruct(self, spec):
- (obj_class, obj_spec) = spec
- obj = obj_class()
- obj.read(self)
- return obj
-
- def readContainerMap(self, spec):
- ktype, kspec, vtype, vspec, is_immutable = spec
- (map_ktype, map_vtype, map_len) = self.readMapBegin()
- # TODO: compare types we just decoded with thrift_spec and
- # abort/skip if types disagree
- keys = self._read_by_ttype(ktype, spec, kspec)
- vals = self._read_by_ttype(vtype, spec, vspec)
- keyvals = islice(zip(keys, vals), map_len)
- results = (TFrozenDict if is_immutable else dict)(keyvals)
- self.readMapEnd()
- return results
-
- def readStruct(self, obj, thrift_spec, is_immutable=False):
- if is_immutable:
- fields = {}
- self.readStructBegin()
- while True:
- (fname, ftype, fid) = self.readFieldBegin()
- if ftype == TType.STOP:
- break
- try:
- field = thrift_spec[fid]
- except IndexError:
- self.skip(ftype)
- else:
- if field is not None and ftype == field[1]:
- fname = field[2]
- fspec = field[3]
- val = self.readFieldByTType(ftype, fspec)
- if is_immutable:
- fields[fname] = val
- else:
- setattr(obj, fname, val)
- else:
- self.skip(ftype)
- self.readFieldEnd()
- self.readStructEnd()
- if is_immutable:
- return obj(**fields)
-
- def writeContainerStruct(self, val, spec):
- val.write(self)
-
- def writeContainerList(self, val, spec):
- ttype, tspec, _ = spec
- self.writeListBegin(ttype, len(val))
- for _ in self._write_by_ttype(ttype, val, spec, tspec):
- pass
- self.writeListEnd()
-
- def writeContainerSet(self, val, spec):
- ttype, tspec, _ = spec
- self.writeSetBegin(ttype, len(val))
- for _ in self._write_by_ttype(ttype, val, spec, tspec):
- pass
- self.writeSetEnd()
-
- def writeContainerMap(self, val, spec):
- ktype, kspec, vtype, vspec, _ = spec
- self.writeMapBegin(ktype, vtype, len(val))
- for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
- self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
- pass
- self.writeMapEnd()
-
- def writeStruct(self, obj, thrift_spec):
- self.writeStructBegin(obj.__class__.__name__)
- for field in thrift_spec:
- if field is None:
- continue
- fname = field[2]
- val = getattr(obj, fname)
- if val is None:
- # skip writing out unset fields
- continue
- fid = field[0]
- ftype = field[1]
- fspec = field[3]
- self.writeFieldBegin(fname, ftype, fid)
- self.writeFieldByTType(ftype, val, fspec)
- self.writeFieldEnd()
- self.writeFieldStop()
- self.writeStructEnd()
-
- def _write_by_ttype(self, ttype, vals, spec, espec):
- _, writer_name, is_container = self._ttype_handlers(ttype, spec)
- writer_func = getattr(self, writer_name)
- write = (lambda v: writer_func(v, espec)) if is_container else writer_func
- for v in vals:
- yield write(v)
-
- def writeFieldByTType(self, ttype, val, spec):
- self._write_by_ttype(ttype, [val], spec, spec).next()
+ return
+ elif ttype == TType.BOOL:
+ self.readBool()
+ elif ttype == TType.BYTE:
+ self.readByte()
+ elif ttype == TType.I16:
+ self.readI16()
+ elif ttype == TType.I32:
+ self.readI32()
+ elif ttype == TType.I64:
+ self.readI64()
+ elif ttype == TType.DOUBLE:
+ self.readDouble()
+ elif ttype == TType.STRING:
+ self.readString()
+ elif ttype == TType.STRUCT:
+ name = self.readStructBegin()
+ while True:
+ (name, ttype, id) = self.readFieldBegin()
+ if ttype == TType.STOP:
+ break
+ self.skip(ttype)
+ self.readFieldEnd()
+ self.readStructEnd()
+ elif ttype == TType.MAP:
+ (ktype, vtype, size) = self.readMapBegin()
+ for i in range(size):
+ self.skip(ktype)
+ self.skip(vtype)
+ self.readMapEnd()
+ elif ttype == TType.SET:
+ (etype, size) = self.readSetBegin()
+ for i in range(size):
+ self.skip(etype)
+ self.readSetEnd()
+ elif ttype == TType.LIST:
+ (etype, size) = self.readListBegin()
+ for i in range(size):
+ self.skip(etype)
+ self.readListEnd()
+
+ # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
+ _TTYPE_HANDLERS = (
+ (None, None, False), # 0 TType.STOP
+ (None, None, False), # 1 TType.VOID # TODO: handle void?
+ ('readBool', 'writeBool', False), # 2 TType.BOOL
+ ('readByte', 'writeByte', False), # 3 TType.BYTE and I08
+ ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
+ (None, None, False), # 5 undefined
+ ('readI16', 'writeI16', False), # 6 TType.I16
+ (None, None, False), # 7 undefined
+ ('readI32', 'writeI32', False), # 8 TType.I32
+ (None, None, False), # 9 undefined
+ ('readI64', 'writeI64', False), # 10 TType.I64
+ ('readString', 'writeString', False), # 11 TType.STRING and UTF7
+ ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
+ ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
+ ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
+ ('readContainerList', 'writeContainerList', True), # 15 TType.LIST
+ (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
+ (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
+ )
+
+ def _ttype_handlers(self, ttype, spec):
+ if spec == 'BINARY':
+ if ttype != TType.STRING:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid binary field type %d' % ttype)
+ return ('readBinary', 'writeBinary', False)
+ if sys.version_info[0] == 2 and spec == 'UTF8':
+ if ttype != TType.STRING:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid string field type %d' % ttype)
+ return ('readUtf8', 'writeUtf8', False)
+ return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
+
+ def _read_by_ttype(self, ttype, spec, espec):
+ reader_name, _, is_container = self._ttype_handlers(ttype, spec)
+ if reader_name is None:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid type %d' % (ttype))
+ reader_func = getattr(self, reader_name)
+ read = (lambda: reader_func(espec)) if is_container else reader_func
+ while True:
+ yield read()
+
+ def readFieldByTType(self, ttype, spec):
+ return self._read_by_ttype(ttype, spec, spec).next()
+
+ def readContainerList(self, spec):
+ ttype, tspec, is_immutable = spec
+ (list_type, list_len) = self.readListBegin()
+ # TODO: compare types we just decoded with thrift_spec
+ elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
+ results = (tuple if is_immutable else list)(elems)
+ self.readListEnd()
+ return results
+
+ def readContainerSet(self, spec):
+ ttype, tspec, is_immutable = spec
+ (set_type, set_len) = self.readSetBegin()
+ # TODO: compare types we just decoded with thrift_spec
+ elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
+ results = (frozenset if is_immutable else set)(elems)
+ self.readSetEnd()
+ return results
+
+ def readContainerStruct(self, spec):
+ (obj_class, obj_spec) = spec
+ obj = obj_class()
+ obj.read(self)
+ return obj
+
+ def readContainerMap(self, spec):
+ ktype, kspec, vtype, vspec, is_immutable = spec
+ (map_ktype, map_vtype, map_len) = self.readMapBegin()
+ # TODO: compare types we just decoded with thrift_spec and
+ # abort/skip if types disagree
+ keys = self._read_by_ttype(ktype, spec, kspec)
+ vals = self._read_by_ttype(vtype, spec, vspec)
+ keyvals = islice(zip(keys, vals), map_len)
+ results = (TFrozenDict if is_immutable else dict)(keyvals)
+ self.readMapEnd()
+ return results
+
+ def readStruct(self, obj, thrift_spec, is_immutable=False):
+ if is_immutable:
+ fields = {}
+ self.readStructBegin()
+ while True:
+ (fname, ftype, fid) = self.readFieldBegin()
+ if ftype == TType.STOP:
+ break
+ try:
+ field = thrift_spec[fid]
+ except IndexError:
+ self.skip(ftype)
+ else:
+ if field is not None and ftype == field[1]:
+ fname = field[2]
+ fspec = field[3]
+ val = self.readFieldByTType(ftype, fspec)
+ if is_immutable:
+ fields[fname] = val
+ else:
+ setattr(obj, fname, val)
+ else:
+ self.skip(ftype)
+ self.readFieldEnd()
+ self.readStructEnd()
+ if is_immutable:
+ return obj(**fields)
+
+ def writeContainerStruct(self, val, spec):
+ val.write(self)
+
+ def writeContainerList(self, val, spec):
+ ttype, tspec, _ = spec
+ self.writeListBegin(ttype, len(val))
+ for _ in self._write_by_ttype(ttype, val, spec, tspec):
+ pass
+ self.writeListEnd()
+
+ def writeContainerSet(self, val, spec):
+ ttype, tspec, _ = spec
+ self.writeSetBegin(ttype, len(val))
+ for _ in self._write_by_ttype(ttype, val, spec, tspec):
+ pass
+ self.writeSetEnd()
+
+ def writeContainerMap(self, val, spec):
+ ktype, kspec, vtype, vspec, _ = spec
+ self.writeMapBegin(ktype, vtype, len(val))
+ for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
+ self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
+ pass
+ self.writeMapEnd()
+
+ def writeStruct(self, obj, thrift_spec):
+ self.writeStructBegin(obj.__class__.__name__)
+ for field in thrift_spec:
+ if field is None:
+ continue
+ fname = field[2]
+ val = getattr(obj, fname)
+ if val is None:
+ # skip writing out unset fields
+ continue
+ fid = field[0]
+ ftype = field[1]
+ fspec = field[3]
+ self.writeFieldBegin(fname, ftype, fid)
+ self.writeFieldByTType(ftype, val, fspec)
+ self.writeFieldEnd()
+ self.writeFieldStop()
+ self.writeStructEnd()
+
+ def _write_by_ttype(self, ttype, vals, spec, espec):
+ _, writer_name, is_container = self._ttype_handlers(ttype, spec)
+ writer_func = getattr(self, writer_name)
+ write = (lambda v: writer_func(v, espec)) if is_container else writer_func
+ for v in vals:
+ yield write(v)
+
+ def writeFieldByTType(self, ttype, val, spec):
+ self._write_by_ttype(ttype, [val], spec, spec).next()
def checkIntegerLimits(i, bits):
@@ -408,10 +408,10 @@ def checkIntegerLimits(i, bits):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i32 requires -2147483648 <= number <= 2147483647")
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "i64 requires -9223372036854775808 <= number <= 9223372036854775807")
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "i64 requires -9223372036854775808 <= number <= 9223372036854775807")
class TProtocolFactory(object):
- def getProtocol(self, trans):
- pass
+ def getProtocol(self, trans):
+ pass
diff --git a/lib/py/src/protocol/TProtocolDecorator.py b/lib/py/src/protocol/TProtocolDecorator.py
index bf50bfad8..8b270a466 100644
--- a/lib/py/src/protocol/TProtocolDecorator.py
+++ b/lib/py/src/protocol/TProtocolDecorator.py
@@ -17,26 +17,34 @@
# under the License.
#
+import types
+
from thrift.protocol.TProtocol import TProtocolBase
-from types import *
+
class TProtocolDecorator():
- def __init__(self, protocol):
- TProtocolBase(protocol)
- self.protocol = protocol
+ def __init__(self, protocol):
+ TProtocolBase(protocol)
+ self.protocol = protocol
- def __getattr__(self, name):
- if hasattr(self.protocol, name):
- member = getattr(self.protocol, name)
- if type(member) in [MethodType, FunctionType, LambdaType, BuiltinFunctionType, BuiltinMethodType]:
- return lambda *args, **kwargs: self._wrap(member, args, kwargs)
- else:
- return member
- raise AttributeError(name)
+ def __getattr__(self, name):
+ if hasattr(self.protocol, name):
+ member = getattr(self.protocol, name)
+ if type(member) in [
+ types.MethodType,
+ types.FunctionType,
+ types.LambdaType,
+ types.BuiltinFunctionType,
+ types.BuiltinMethodType,
+ ]:
+ return lambda *args, **kwargs: self._wrap(member, args, kwargs)
+ else:
+ return member
+ raise AttributeError(name)
- def _wrap(self, func, args, kwargs):
- if type(func) == MethodType:
- result = func(*args, **kwargs)
- else:
- result = func(self.protocol, *args, **kwargs)
- return result
+ def _wrap(self, func, args, kwargs):
+ if isinstance(func, types.MethodType):
+ result = func(*args, **kwargs)
+ else:
+ result = func(self.protocol, *args, **kwargs)
+ return result
diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py
index 7eefb458a..7148f66b3 100644
--- a/lib/py/src/protocol/__init__.py
+++ b/lib/py/src/protocol/__init__.py
@@ -17,4 +17,5 @@
# under the License.
#
-__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol']
+__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
+ 'TJSONProtocol', 'TProtocol']
diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py
index bf3b0e342..1b501a7aa 100644
--- a/lib/py/src/server/THttpServer.py
+++ b/lib/py/src/server/THttpServer.py
@@ -24,64 +24,64 @@ from thrift.transport import TTransport
class ResponseException(Exception):
- """Allows handlers to override the HTTP response
+ """Allows handlers to override the HTTP response
- Normally, THttpServer always sends a 200 response. If a handler wants
- to override this behavior (e.g., to simulate a misconfigured or
- overloaded web server during testing), it can raise a ResponseException.
- The function passed to the constructor will be called with the
- RequestHandler as its only argument.
- """
- def __init__(self, handler):
- self.handler = handler
+ Normally, THttpServer always sends a 200 response. If a handler wants
+ to override this behavior (e.g., to simulate a misconfigured or
+ overloaded web server during testing), it can raise a ResponseException.
+ The function passed to the constructor will be called with the
+ RequestHandler as its only argument.
+ """
+ def __init__(self, handler):
+ self.handler = handler
class THttpServer(TServer.TServer):
- """A simple HTTP-based Thrift server
-
- This class is not very performant, but it is useful (for example) for
- acting as a mock version of an Apache-based PHP Thrift endpoint.
- """
- def __init__(self,
- processor,
- server_address,
- inputProtocolFactory,
- outputProtocolFactory=None,
- server_class=BaseHTTPServer.HTTPServer):
- """Set up protocol factories and HTTP server.
+ """A simple HTTP-based Thrift server
- See BaseHTTPServer for server_address.
- See TServer for protocol factories.
+ This class is not very performant, but it is useful (for example) for
+ acting as a mock version of an Apache-based PHP Thrift endpoint.
"""
- if outputProtocolFactory is None:
- outputProtocolFactory = inputProtocolFactory
+ def __init__(self,
+ processor,
+ server_address,
+ inputProtocolFactory,
+ outputProtocolFactory=None,
+ server_class=BaseHTTPServer.HTTPServer):
+ """Set up protocol factories and HTTP server.
+
+ See BaseHTTPServer for server_address.
+ See TServer for protocol factories.
+ """
+ if outputProtocolFactory is None:
+ outputProtocolFactory = inputProtocolFactory
- TServer.TServer.__init__(self, processor, None, None, None,
- inputProtocolFactory, outputProtocolFactory)
+ TServer.TServer.__init__(self, processor, None, None, None,
+ inputProtocolFactory, outputProtocolFactory)
- thttpserver = self
+ thttpserver = self
- class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
- def do_POST(self):
- # Don't care about the request path.
- itrans = TTransport.TFileObjectTransport(self.rfile)
- otrans = TTransport.TFileObjectTransport(self.wfile)
- itrans = TTransport.TBufferedTransport(
- itrans, int(self.headers['Content-Length']))
- otrans = TTransport.TMemoryBuffer()
- iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
- oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
- try:
- thttpserver.processor.process(iprot, oprot)
- except ResponseException as exn:
- exn.handler(self)
- else:
- self.send_response(200)
- self.send_header("content-type", "application/x-thrift")
- self.end_headers()
- self.wfile.write(otrans.getvalue())
+ class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
+ def do_POST(self):
+ # Don't care about the request path.
+ itrans = TTransport.TFileObjectTransport(self.rfile)
+ otrans = TTransport.TFileObjectTransport(self.wfile)
+ itrans = TTransport.TBufferedTransport(
+ itrans, int(self.headers['Content-Length']))
+ otrans = TTransport.TMemoryBuffer()
+ iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
+ oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
+ try:
+ thttpserver.processor.process(iprot, oprot)
+ except ResponseException as exn:
+ exn.handler(self)
+ else:
+ self.send_response(200)
+ self.send_header("content-type", "application/x-thrift")
+ self.end_headers()
+ self.wfile.write(otrans.getvalue())
- self.httpd = server_class(server_address, RequestHander)
+ self.httpd = server_class(server_address, RequestHander)
- def serve(self):
- self.httpd.serve_forever()
+ def serve(self):
+ self.httpd.serve_forever()
diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py
index a930a8091..87031c137 100644
--- a/lib/py/src/server/TNonblockingServer.py
+++ b/lib/py/src/server/TNonblockingServer.py
@@ -24,13 +24,12 @@ only from the main thread.
The thread poool should be sized for concurrent tasks, not
maximum connections
"""
-import threading
-import socket
-import select
-import struct
import logging
-logger = logging.getLogger(__name__)
+import select
+import socket
+import struct
+import threading
from six.moves import queue
@@ -39,6 +38,8 @@ from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
__all__ = ['TNonblockingServer']
+logger = logging.getLogger(__name__)
+
class Worker(threading.Thread):
"""Worker is a small helper to process incoming connection."""
@@ -127,7 +128,7 @@ class Connection(object):
self.len, = struct.unpack('!i', self.message)
if self.len < 0:
logger.error("negative frame size, it seems client "
- "doesn't use FramedTransport")
+ "doesn't use FramedTransport")
self.close()
elif self.len == 0:
logger.error("empty frame, it's really strange")
@@ -149,7 +150,7 @@ class Connection(object):
read = self.socket.recv(self.len - len(self.message))
if len(read) == 0:
logger.error("can't read frame from socket (get %d of "
- "%d bytes)" % (len(self.message), self.len))
+ "%d bytes)" % (len(self.message), self.len))
self.close()
return
self.message += read
diff --git a/lib/py/src/server/TProcessPoolServer.py b/lib/py/src/server/TProcessPoolServer.py
index b2c2308a9..fe6dc8162 100644
--- a/lib/py/src/server/TProcessPoolServer.py
+++ b/lib/py/src/server/TProcessPoolServer.py
@@ -19,13 +19,14 @@
import logging
-logger = logging.getLogger(__name__)
-from multiprocessing import Process, Value, Condition, reduction
+from multiprocessing import Process, Value, Condition
from .TServer import TServer
from thrift.transport.TTransport import TTransportException
+logger = logging.getLogger(__name__)
+
class TProcessPoolServer(TServer):
"""Server with a fixed size pool of worker subprocesses to service requests
@@ -59,7 +60,7 @@ class TProcessPoolServer(TServer):
try:
client = self.serverTransport.accept()
if not client:
- continue
+ continue
self.serveClient(client)
except (KeyboardInterrupt, SystemExit):
return 0
@@ -76,7 +77,7 @@ class TProcessPoolServer(TServer):
try:
while True:
self.processor.process(iprot, oprot)
- except TTransportException as tx:
+ except TTransportException:
pass
except Exception as x:
logger.exception(x)
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index 30f063b43..d5d9c98a9 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -18,262 +18,259 @@
#
from six.moves import queue
+import logging
import os
-import sys
import threading
-import traceback
-
-import logging
-logger = logging.getLogger(__name__)
-from thrift.Thrift import TProcessor
from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport
+logger = logging.getLogger(__name__)
+
class TServer(object):
- """Base interface for a server, which must have a serve() method.
-
- Three constructors for all servers:
- 1) (processor, serverTransport)
- 2) (processor, serverTransport, transportFactory, protocolFactory)
- 3) (processor, serverTransport,
- inputTransportFactory, outputTransportFactory,
- inputProtocolFactory, outputProtocolFactory)
- """
- def __init__(self, *args):
- if (len(args) == 2):
- self.__initArgs__(args[0], args[1],
- TTransport.TTransportFactoryBase(),
- TTransport.TTransportFactoryBase(),
- TBinaryProtocol.TBinaryProtocolFactory(),
- TBinaryProtocol.TBinaryProtocolFactory())
- elif (len(args) == 4):
- self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
- elif (len(args) == 6):
- self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
-
- def __initArgs__(self, processor, serverTransport,
- inputTransportFactory, outputTransportFactory,
- inputProtocolFactory, outputProtocolFactory):
- self.processor = processor
- self.serverTransport = serverTransport
- self.inputTransportFactory = inputTransportFactory
- self.outputTransportFactory = outputTransportFactory
- self.inputProtocolFactory = inputProtocolFactory
- self.outputProtocolFactory = outputProtocolFactory
-
- def serve(self):
- pass
+ """Base interface for a server, which must have a serve() method.
+
+ Three constructors for all servers:
+ 1) (processor, serverTransport)
+ 2) (processor, serverTransport, transportFactory, protocolFactory)
+ 3) (processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory)
+ """
+ def __init__(self, *args):
+ if (len(args) == 2):
+ self.__initArgs__(args[0], args[1],
+ TTransport.TTransportFactoryBase(),
+ TTransport.TTransportFactoryBase(),
+ TBinaryProtocol.TBinaryProtocolFactory(),
+ TBinaryProtocol.TBinaryProtocolFactory())
+ elif (len(args) == 4):
+ self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
+ elif (len(args) == 6):
+ self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
+
+ def __initArgs__(self, processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory):
+ self.processor = processor
+ self.serverTransport = serverTransport
+ self.inputTransportFactory = inputTransportFactory
+ self.outputTransportFactory = outputTransportFactory
+ self.inputProtocolFactory = inputProtocolFactory
+ self.outputProtocolFactory = outputProtocolFactory
+
+ def serve(self):
+ pass
class TSimpleServer(TServer):
- """Simple single-threaded server that just pumps around one transport."""
-
- def __init__(self, *args):
- TServer.__init__(self, *args)
-
- def serve(self):
- self.serverTransport.listen()
- while True:
- client = self.serverTransport.accept()
- if not client:
- continue
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
- try:
+ """Simple single-threaded server that just pumps around one transport."""
+
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
+
+ def serve(self):
+ self.serverTransport.listen()
while True:
- self.processor.process(iprot, oprot)
- except TTransport.TTransportException as tx:
- pass
- except Exception as x:
- logger.exception(x)
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
- itrans.close()
- otrans.close()
+ itrans.close()
+ otrans.close()
class TThreadedServer(TServer):
- """Threaded server that spawns a new thread per each connection."""
-
- def __init__(self, *args, **kwargs):
- TServer.__init__(self, *args)
- self.daemon = kwargs.get("daemon", False)
-
- def serve(self):
- self.serverTransport.listen()
- while True:
- try:
- client = self.serverTransport.accept()
- if not client:
- continue
- t = threading.Thread(target=self.handle, args=(client,))
- t.setDaemon(self.daemon)
- t.start()
- except KeyboardInterrupt:
- raise
- except Exception as x:
- logger.exception(x)
-
- def handle(self, client):
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
- try:
- while True:
- self.processor.process(iprot, oprot)
- except TTransport.TTransportException as tx:
- pass
- except Exception as x:
- logger.exception(x)
-
- itrans.close()
- otrans.close()
+ """Threaded server that spawns a new thread per each connection."""
+
+ def __init__(self, *args, **kwargs):
+ TServer.__init__(self, *args)
+ self.daemon = kwargs.get("daemon", False)
+
+ def serve(self):
+ self.serverTransport.listen()
+ while True:
+ try:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ t = threading.Thread(target=self.handle, args=(client,))
+ t.setDaemon(self.daemon)
+ t.start()
+ except KeyboardInterrupt:
+ raise
+ except Exception as x:
+ logger.exception(x)
+
+ def handle(self, client):
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
+
+ itrans.close()
+ otrans.close()
class TThreadPoolServer(TServer):
- """Server with a fixed size pool of threads which service requests."""
-
- def __init__(self, *args, **kwargs):
- TServer.__init__(self, *args)
- self.clients = queue.Queue()
- self.threads = 10
- self.daemon = kwargs.get("daemon", False)
-
- def setNumThreads(self, num):
- """Set the number of worker threads that should be created"""
- self.threads = num
-
- def serveThread(self):
- """Loop around getting clients from the shared queue and process them."""
- while True:
- try:
- client = self.clients.get()
- self.serveClient(client)
- except Exception as x:
- logger.exception(x)
-
- def serveClient(self, client):
- """Process input/output from a client for as long as possible"""
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
- try:
- while True:
- self.processor.process(iprot, oprot)
- except TTransport.TTransportException as tx:
- pass
- except Exception as x:
- logger.exception(x)
-
- itrans.close()
- otrans.close()
-
- def serve(self):
- """Start a fixed number of worker threads and put client into a queue"""
- for i in range(self.threads):
- try:
- t = threading.Thread(target=self.serveThread)
- t.setDaemon(self.daemon)
- t.start()
- except Exception as x:
- logger.exception(x)
-
- # Pump the socket for clients
- self.serverTransport.listen()
- while True:
- try:
- client = self.serverTransport.accept()
- if not client:
- continue
- self.clients.put(client)
- except Exception as x:
- logger.exception(x)
+ """Server with a fixed size pool of threads which service requests."""
+ def __init__(self, *args, **kwargs):
+ TServer.__init__(self, *args)
+ self.clients = queue.Queue()
+ self.threads = 10
+ self.daemon = kwargs.get("daemon", False)
-class TForkingServer(TServer):
- """A Thrift server that forks a new process for each request
-
- This is more scalable than the threaded server as it does not cause
- GIL contention.
-
- Note that this has different semantics from the threading server.
- Specifically, updates to shared variables will no longer be shared.
- It will also not work on windows.
-
- This code is heavily inspired by SocketServer.ForkingMixIn in the
- Python stdlib.
- """
- def __init__(self, *args):
- TServer.__init__(self, *args)
- self.children = []
-
- def serve(self):
- def try_close(file):
- try:
- file.close()
- except IOError as e:
- logger.warning(e, exc_info=True)
-
- self.serverTransport.listen()
- while True:
- client = self.serverTransport.accept()
- if not client:
- continue
- try:
- pid = os.fork()
-
- if pid: # parent
- # add before collect, otherwise you race w/ waitpid
- self.children.append(pid)
- self.collect_children()
-
- # Parent must close socket or the connection may not get
- # closed promptly
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- try_close(itrans)
- try_close(otrans)
- else:
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
-
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
-
- ecode = 0
- try:
+ def setNumThreads(self, num):
+ """Set the number of worker threads that should be created"""
+ self.threads = num
+
+ def serveThread(self):
+ """Loop around getting clients from the shared queue and process them."""
+ while True:
try:
- while True:
+ client = self.clients.get()
+ self.serveClient(client)
+ except Exception as x:
+ logger.exception(x)
+
+ def serveClient(self, client):
+ """Process input/output from a client for as long as possible"""
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+ try:
+ while True:
self.processor.process(iprot, oprot)
- except TTransport.TTransportException:
- pass
- except Exception as e:
- logger.exception(e)
- ecode = 1
- finally:
- try_close(itrans)
- try_close(otrans)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
- os._exit(ecode)
+ itrans.close()
+ otrans.close()
- except TTransport.TTransportException:
- pass
- except Exception as x:
- logger.exception(x)
-
- def collect_children(self):
- while self.children:
- try:
- pid, status = os.waitpid(0, os.WNOHANG)
- except os.error:
- pid = None
-
- if pid:
- self.children.remove(pid)
- else:
- break
+ def serve(self):
+ """Start a fixed number of worker threads and put client into a queue"""
+ for i in range(self.threads):
+ try:
+ t = threading.Thread(target=self.serveThread)
+ t.setDaemon(self.daemon)
+ t.start()
+ except Exception as x:
+ logger.exception(x)
+
+ # Pump the socket for clients
+ self.serverTransport.listen()
+ while True:
+ try:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ self.clients.put(client)
+ except Exception as x:
+ logger.exception(x)
+
+
+class TForkingServer(TServer):
+ """A Thrift server that forks a new process for each request
+
+ This is more scalable than the threaded server as it does not cause
+ GIL contention.
+
+ Note that this has different semantics from the threading server.
+ Specifically, updates to shared variables will no longer be shared.
+ It will also not work on windows.
+
+ This code is heavily inspired by SocketServer.ForkingMixIn in the
+ Python stdlib.
+ """
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
+ self.children = []
+
+ def serve(self):
+ def try_close(file):
+ try:
+ file.close()
+ except IOError as e:
+ logger.warning(e, exc_info=True)
+
+ self.serverTransport.listen()
+ while True:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ try:
+ pid = os.fork()
+
+ if pid: # parent
+ # add before collect, otherwise you race w/ waitpid
+ self.children.append(pid)
+ self.collect_children()
+
+ # Parent must close socket or the connection may not get
+ # closed promptly
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ try_close(itrans)
+ try_close(otrans)
+ else:
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ ecode = 0
+ try:
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as e:
+ logger.exception(e)
+ ecode = 1
+ finally:
+ try_close(itrans)
+ try_close(otrans)
+
+ os._exit(ecode)
+
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
+
+ def collect_children(self):
+ while self.children:
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except os.error:
+ pid = None
+
+ if pid:
+ self.children.remove(pid)
+ else:
+ break
diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py
index 5abd41c70..95f118cb4 100644
--- a/lib/py/src/transport/THttpClient.py
+++ b/lib/py/src/transport/THttpClient.py
@@ -26,130 +26,130 @@ import warnings
from six.moves import urllib
from six.moves import http_client
-from .TTransport import *
+from .TTransport import TTransportBase
import six
class THttpClient(TTransportBase):
- """Http implementation of TTransport base."""
-
- def __init__(self, uri_or_host, port=None, path=None):
- """THttpClient supports two different types constructor parameters.
-
- THttpClient(host, port, path) - deprecated
- THttpClient(uri)
-
- Only the second supports https.
- """
- if port is not None:
- warnings.warn(
- "Please use the THttpClient('http://host:port/path') syntax",
- DeprecationWarning,
- stacklevel=2)
- self.host = uri_or_host
- self.port = port
- assert path
- self.path = path
- self.scheme = 'http'
- else:
- parsed = urllib.parse.urlparse(uri_or_host)
- self.scheme = parsed.scheme
- assert self.scheme in ('http', 'https')
- if self.scheme == 'http':
- self.port = parsed.port or http_client.HTTP_PORT
- elif self.scheme == 'https':
- self.port = parsed.port or http_client.HTTPS_PORT
- self.host = parsed.hostname
- self.path = parsed.path
- if parsed.query:
- self.path += '?%s' % parsed.query
- self.__wbuf = BytesIO()
- self.__http = None
- self.__http_response = None
- self.__timeout = None
- self.__custom_headers = None
-
- def open(self):
- if self.scheme == 'http':
- self.__http = http_client.HTTPConnection(self.host, self.port)
- else:
- self.__http = http_client.HTTPSConnection(self.host, self.port)
-
- def close(self):
- self.__http.close()
- self.__http = None
- self.__http_response = None
-
- def isOpen(self):
- return self.__http is not None
-
- def setTimeout(self, ms):
- if not hasattr(socket, 'getdefaulttimeout'):
- raise NotImplementedError
-
- if ms is None:
- self.__timeout = None
- else:
- self.__timeout = ms / 1000.0
-
- def setCustomHeaders(self, headers):
- self.__custom_headers = headers
-
- def read(self, sz):
- return self.__http_response.read(sz)
-
- def write(self, buf):
- self.__wbuf.write(buf)
-
- def __withTimeout(f):
- def _f(*args, **kwargs):
- orig_timeout = socket.getdefaulttimeout()
- socket.setdefaulttimeout(args[0].__timeout)
- try:
- result = f(*args, **kwargs)
- finally:
- socket.setdefaulttimeout(orig_timeout)
- return result
- return _f
-
- def flush(self):
- if self.isOpen():
- self.close()
- self.open()
-
- # Pull data out of buffer
- data = self.__wbuf.getvalue()
- self.__wbuf = BytesIO()
-
- # HTTP request
- self.__http.putrequest('POST', self.path)
-
- # Write headers
- self.__http.putheader('Content-Type', 'application/x-thrift')
- self.__http.putheader('Content-Length', str(len(data)))
-
- if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
- user_agent = 'Python/THttpClient'
- script = os.path.basename(sys.argv[0])
- if script:
- user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
- self.__http.putheader('User-Agent', user_agent)
-
- if self.__custom_headers:
- for key, val in six.iteritems(self.__custom_headers):
- self.__http.putheader(key, val)
-
- self.__http.endheaders()
-
- # Write payload
- self.__http.send(data)
-
- # Get reply to flush the request
- self.__http_response = self.__http.getresponse()
- self.code = self.__http_response.status
- self.message = self.__http_response.reason
- self.headers = self.__http_response.msg
-
- # Decorate if we know how to timeout
- if hasattr(socket, 'getdefaulttimeout'):
- flush = __withTimeout(flush)
+ """Http implementation of TTransport base."""
+
+ def __init__(self, uri_or_host, port=None, path=None):
+ """THttpClient supports two different types constructor parameters.
+
+ THttpClient(host, port, path) - deprecated
+ THttpClient(uri)
+
+ Only the second supports https.
+ """
+ if port is not None:
+ warnings.warn(
+ "Please use the THttpClient('http://host:port/path') syntax",
+ DeprecationWarning,
+ stacklevel=2)
+ self.host = uri_or_host
+ self.port = port
+ assert path
+ self.path = path
+ self.scheme = 'http'
+ else:
+ parsed = urllib.parse.urlparse(uri_or_host)
+ self.scheme = parsed.scheme
+ assert self.scheme in ('http', 'https')
+ if self.scheme == 'http':
+ self.port = parsed.port or http_client.HTTP_PORT
+ elif self.scheme == 'https':
+ self.port = parsed.port or http_client.HTTPS_PORT
+ self.host = parsed.hostname
+ self.path = parsed.path
+ if parsed.query:
+ self.path += '?%s' % parsed.query
+ self.__wbuf = BytesIO()
+ self.__http = None
+ self.__http_response = None
+ self.__timeout = None
+ self.__custom_headers = None
+
+ def open(self):
+ if self.scheme == 'http':
+ self.__http = http_client.HTTPConnection(self.host, self.port)
+ else:
+ self.__http = http_client.HTTPSConnection(self.host, self.port)
+
+ def close(self):
+ self.__http.close()
+ self.__http = None
+ self.__http_response = None
+
+ def isOpen(self):
+ return self.__http is not None
+
+ def setTimeout(self, ms):
+ if not hasattr(socket, 'getdefaulttimeout'):
+ raise NotImplementedError
+
+ if ms is None:
+ self.__timeout = None
+ else:
+ self.__timeout = ms / 1000.0
+
+ def setCustomHeaders(self, headers):
+ self.__custom_headers = headers
+
+ def read(self, sz):
+ return self.__http_response.read(sz)
+
+ def write(self, buf):
+ self.__wbuf.write(buf)
+
+ def __withTimeout(f):
+ def _f(*args, **kwargs):
+ orig_timeout = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(args[0].__timeout)
+ try:
+ result = f(*args, **kwargs)
+ finally:
+ socket.setdefaulttimeout(orig_timeout)
+ return result
+ return _f
+
+ def flush(self):
+ if self.isOpen():
+ self.close()
+ self.open()
+
+ # Pull data out of buffer
+ data = self.__wbuf.getvalue()
+ self.__wbuf = BytesIO()
+
+ # HTTP request
+ self.__http.putrequest('POST', self.path)
+
+ # Write headers
+ self.__http.putheader('Content-Type', 'application/x-thrift')
+ self.__http.putheader('Content-Length', str(len(data)))
+
+ if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
+ user_agent = 'Python/THttpClient'
+ script = os.path.basename(sys.argv[0])
+ if script:
+ user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
+ self.__http.putheader('User-Agent', user_agent)
+
+ if self.__custom_headers:
+ for key, val in six.iteritems(self.__custom_headers):
+ self.__http.putheader(key, val)
+
+ self.__http.endheaders()
+
+ # Write payload
+ self.__http.send(data)
+
+ # Get reply to flush the request
+ self.__http_response = self.__http.getresponse()
+ self.code = self.__http_response.status
+ self.message = self.__http_response.reason
+ self.headers = self.__http_response.msg
+
+ # Decorate if we know how to timeout
+ if hasattr(socket, 'getdefaulttimeout'):
+ flush = __withTimeout(flush)
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 9be0912f9..3f1a909df 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -32,345 +32,345 @@ warnings.filterwarnings('default', category=DeprecationWarning, module=__name__)
class TSSLBase(object):
- # SSLContext is not available for Python < 2.7.9
- _has_ssl_context = sys.hexversion >= 0x020709F0
-
- # ciphers argument is not available for Python < 2.7.0
- _has_ciphers = sys.hexversion >= 0x020700F0
-
- # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
- # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
- # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
- _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
-
- def _init_context(self, ssl_version):
- if self._has_ssl_context:
- self._context = ssl.SSLContext(ssl_version)
- if self._context.protocol == ssl.PROTOCOL_SSLv23:
- self._context.options |= ssl.OP_NO_SSLv2
- self._context.options |= ssl.OP_NO_SSLv3
- else:
- self._context = None
- self._ssl_version = ssl_version
-
- @property
- def ssl_version(self):
- if self._has_ssl_context:
- return self.ssl_context.protocol
- else:
- return self._ssl_version
-
- @property
- def ssl_context(self):
- return self._context
-
- SSL_VERSION = _default_protocol
- """
+ # SSLContext is not available for Python < 2.7.9
+ _has_ssl_context = sys.hexversion >= 0x020709F0
+
+ # ciphers argument is not available for Python < 2.7.0
+ _has_ciphers = sys.hexversion >= 0x020700F0
+
+ # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
+ # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
+ # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
+ _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
+
+ def _init_context(self, ssl_version):
+ if self._has_ssl_context:
+ self._context = ssl.SSLContext(ssl_version)
+ if self._context.protocol == ssl.PROTOCOL_SSLv23:
+ self._context.options |= ssl.OP_NO_SSLv2
+ self._context.options |= ssl.OP_NO_SSLv3
+ else:
+ self._context = None
+ self._ssl_version = ssl_version
+
+ @property
+ def ssl_version(self):
+ if self._has_ssl_context:
+ return self.ssl_context.protocol
+ else:
+ return self._ssl_version
+
+ @property
+ def ssl_context(self):
+ return self._context
+
+ SSL_VERSION = _default_protocol
+ """
Default SSL version.
For backword compatibility, it can be modified.
Use __init__ keywoard argument "ssl_version" instead.
"""
- def _deprecated_arg(self, args, kwargs, pos, key):
- if len(args) <= pos:
- return
- real_pos = pos + 3
- warnings.warn(
- '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
- DeprecationWarning)
- if key in kwargs:
- raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
- kwargs[key] = args[pos]
-
- def _unix_socket_arg(self, host, port, args, kwargs):
- key = 'unix_socket'
- if host is None and port is None and len(args) == 1 and key not in kwargs:
- kwargs[key] = args[0]
- return True
- return False
-
- def __getattr__(self, key):
- if key == 'SSL_VERSION':
- warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
- return self.ssl_version
-
- def __init__(self, server_side, host, ssl_opts):
- self._server_side = server_side
- if TSSLBase.SSL_VERSION != self._default_protocol:
- warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
- self._context = ssl_opts.pop('ssl_context', None)
- self._server_hostname = None
- if not self._server_side:
- self._server_hostname = ssl_opts.pop('server_hostname', host)
- if self._context:
- self._custom_context = True
- if ssl_opts:
- raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
- if not self._has_ssl_context:
- raise ValueError('ssl_context is not available for this version of Python')
- else:
- self._custom_context = False
- ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
- self._init_context(ssl_version)
- self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
- self.ca_certs = ssl_opts.pop('ca_certs', None)
- self.keyfile = ssl_opts.pop('keyfile', None)
- self.certfile = ssl_opts.pop('certfile', None)
- self.ciphers = ssl_opts.pop('ciphers', None)
-
- if ssl_opts:
- raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
-
- if self.cert_reqs != ssl.CERT_NONE:
- if not self.ca_certs:
- raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
- if not os.access(self.ca_certs, os.R_OK):
- raise IOError('Certificate Authority ca_certs file "%s" '
- 'is not readable, cannot validate SSL '
- 'certificates.' % (self.ca_certs))
-
- @property
- def certfile(self):
- return self._certfile
-
- @certfile.setter
- def certfile(self, certfile):
- if self._server_side and not certfile:
- raise ValueError('certfile is needed for server-side')
- if certfile and not os.access(certfile, os.R_OK):
- raise IOError('No such certfile found: %s' % (certfile))
- self._certfile = certfile
-
- def _wrap_socket(self, sock):
- if self._has_ssl_context:
- if not self._custom_context:
- self.ssl_context.verify_mode = self.cert_reqs
- if self.certfile:
- self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
- if self.ciphers:
- self.ssl_context.set_ciphers(self.ciphers)
- if self.ca_certs:
- self.ssl_context.load_verify_locations(self.ca_certs)
- return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
- server_hostname=self._server_hostname)
- else:
- ssl_opts = {
- 'ssl_version': self._ssl_version,
- 'server_side': self._server_side,
- 'ca_certs': self.ca_certs,
- 'keyfile': self.keyfile,
- 'certfile': self.certfile,
- 'cert_reqs': self.cert_reqs,
- }
- if self.ciphers:
- if self._has_ciphers:
- ssl_opts['ciphers'] = self.ciphers
+ def _deprecated_arg(self, args, kwargs, pos, key):
+ if len(args) <= pos:
+ return
+ real_pos = pos + 3
+ warnings.warn(
+ '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
+ DeprecationWarning)
+ if key in kwargs:
+ raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
+ kwargs[key] = args[pos]
+
+ def _unix_socket_arg(self, host, port, args, kwargs):
+ key = 'unix_socket'
+ if host is None and port is None and len(args) == 1 and key not in kwargs:
+ kwargs[key] = args[0]
+ return True
+ return False
+
+ def __getattr__(self, key):
+ if key == 'SSL_VERSION':
+ warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
+ return self.ssl_version
+
+ def __init__(self, server_side, host, ssl_opts):
+ self._server_side = server_side
+ if TSSLBase.SSL_VERSION != self._default_protocol:
+ warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
+ self._context = ssl_opts.pop('ssl_context', None)
+ self._server_hostname = None
+ if not self._server_side:
+ self._server_hostname = ssl_opts.pop('server_hostname', host)
+ if self._context:
+ self._custom_context = True
+ if ssl_opts:
+ raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
+ if not self._has_ssl_context:
+ raise ValueError('ssl_context is not available for this version of Python')
+ else:
+ self._custom_context = False
+ ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
+ self._init_context(ssl_version)
+ self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
+ self.ca_certs = ssl_opts.pop('ca_certs', None)
+ self.keyfile = ssl_opts.pop('keyfile', None)
+ self.certfile = ssl_opts.pop('certfile', None)
+ self.ciphers = ssl_opts.pop('ciphers', None)
+
+ if ssl_opts:
+ raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
+
+ if self.cert_reqs != ssl.CERT_NONE:
+ if not self.ca_certs:
+ raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
+ if not os.access(self.ca_certs, os.R_OK):
+ raise IOError('Certificate Authority ca_certs file "%s" '
+ 'is not readable, cannot validate SSL '
+ 'certificates.' % (self.ca_certs))
+
+ @property
+ def certfile(self):
+ return self._certfile
+
+ @certfile.setter
+ def certfile(self, certfile):
+ if self._server_side and not certfile:
+ raise ValueError('certfile is needed for server-side')
+ if certfile and not os.access(certfile, os.R_OK):
+ raise IOError('No such certfile found: %s' % (certfile))
+ self._certfile = certfile
+
+ def _wrap_socket(self, sock):
+ if self._has_ssl_context:
+ if not self._custom_context:
+ self.ssl_context.verify_mode = self.cert_reqs
+ if self.certfile:
+ self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
+ if self.ciphers:
+ self.ssl_context.set_ciphers(self.ciphers)
+ if self.ca_certs:
+ self.ssl_context.load_verify_locations(self.ca_certs)
+ return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
+ server_hostname=self._server_hostname)
else:
- logger.warning('ciphers is specified but ignored due to old Python version')
- return ssl.wrap_socket(sock, **ssl_opts)
+ ssl_opts = {
+ 'ssl_version': self._ssl_version,
+ 'server_side': self._server_side,
+ 'ca_certs': self.ca_certs,
+ 'keyfile': self.keyfile,
+ 'certfile': self.certfile,
+ 'cert_reqs': self.cert_reqs,
+ }
+ if self.ciphers:
+ if self._has_ciphers:
+ ssl_opts['ciphers'] = self.ciphers
+ else:
+ logger.warning('ciphers is specified but ignored due to old Python version')
+ return ssl.wrap_socket(sock, **ssl_opts)
class TSSLSocket(TSocket.TSocket, TSSLBase):
- """
- SSL implementation of TSocket
-
- This class creates outbound sockets wrapped using the
- python standard ssl module for encrypted connections.
- """
+ """
+ SSL implementation of TSocket
- # New signature
- # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
- # Deprecated signature
- # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
- def __init__(self, host='localhost', port=9090, *args, **kwargs):
- """Positional arguments: ``host``, ``port``, ``unix_socket``
-
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
- ``server_hostname`` (Python 2.7.9 or later)
- Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
-
- Alternative keywoard arguments: (Python 2.7.9 or later)
- ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
- ``server_hostname``: Passed to SSLContext.wrap_socket
+ This class creates outbound sockets wrapped using the
+ python standard ssl module for encrypted connections.
"""
- self.is_valid = False
- self.peercert = None
-
- if args:
- if len(args) > 6:
- raise TypeError('Too many positional argument')
- if not self._unix_socket_arg(host, port, args, kwargs):
- self._deprecated_arg(args, kwargs, 0, 'validate')
- self._deprecated_arg(args, kwargs, 1, 'ca_certs')
- self._deprecated_arg(args, kwargs, 2, 'keyfile')
- self._deprecated_arg(args, kwargs, 3, 'certfile')
- self._deprecated_arg(args, kwargs, 4, 'unix_socket')
- self._deprecated_arg(args, kwargs, 5, 'ciphers')
-
- validate = kwargs.pop('validate', None)
- if validate is not None:
- cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
- warnings.warn(
- 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
- DeprecationWarning)
- if 'cert_reqs' in kwargs:
- raise TypeError('Cannot specify both validate and cert_reqs')
- kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
-
- unix_socket = kwargs.pop('unix_socket', None)
- TSSLBase.__init__(self, False, host, kwargs)
- TSocket.TSocket.__init__(self, host, port, unix_socket)
-
- @property
- def validate(self):
- warnings.warn('Use cert_reqs instead', DeprecationWarning)
- return self.cert_reqs != ssl.CERT_NONE
-
- @validate.setter
- def validate(self, value):
- warnings.warn('Use cert_reqs instead', DeprecationWarning)
- self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
-
- def open(self):
- try:
- res0 = self._resolveAddr()
- for res in res0:
- sock_family, sock_type = res[0:2]
- ip_port = res[4]
- plain_sock = socket.socket(sock_family, sock_type)
- self.handle = self._wrap_socket(plain_sock)
- self.handle.settimeout(self._timeout)
+
+ # New signature
+ # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+ # Deprecated signature
+ # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
+ def __init__(self, host='localhost', port=9090, *args, **kwargs):
+ """Positional arguments: ``host``, ``port``, ``unix_socket``
+
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
+ ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
+ ``server_hostname`` (Python 2.7.9 or later)
+ Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
+
+ Alternative keywoard arguments: (Python 2.7.9 or later)
+ ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
+ ``server_hostname``: Passed to SSLContext.wrap_socket
+ """
+ self.is_valid = False
+ self.peercert = None
+
+ if args:
+ if len(args) > 6:
+ raise TypeError('Too many positional argument')
+ if not self._unix_socket_arg(host, port, args, kwargs):
+ self._deprecated_arg(args, kwargs, 0, 'validate')
+ self._deprecated_arg(args, kwargs, 1, 'ca_certs')
+ self._deprecated_arg(args, kwargs, 2, 'keyfile')
+ self._deprecated_arg(args, kwargs, 3, 'certfile')
+ self._deprecated_arg(args, kwargs, 4, 'unix_socket')
+ self._deprecated_arg(args, kwargs, 5, 'ciphers')
+
+ validate = kwargs.pop('validate', None)
+ if validate is not None:
+ cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
+ warnings.warn(
+ 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
+ DeprecationWarning)
+ if 'cert_reqs' in kwargs:
+ raise TypeError('Cannot specify both validate and cert_reqs')
+ kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
+
+ unix_socket = kwargs.pop('unix_socket', None)
+ TSSLBase.__init__(self, False, host, kwargs)
+ TSocket.TSocket.__init__(self, host, port, unix_socket)
+
+ @property
+ def validate(self):
+ warnings.warn('Use cert_reqs instead', DeprecationWarning)
+ return self.cert_reqs != ssl.CERT_NONE
+
+ @validate.setter
+ def validate(self, value):
+ warnings.warn('Use cert_reqs instead', DeprecationWarning)
+ self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
+
+ def open(self):
try:
- self.handle.connect(ip_port)
+ res0 = self._resolveAddr()
+ for res in res0:
+ sock_family, sock_type = res[0:2]
+ ip_port = res[4]
+ plain_sock = socket.socket(sock_family, sock_type)
+ self.handle = self._wrap_socket(plain_sock)
+ self.handle.settimeout(self._timeout)
+ try:
+ self.handle.connect(ip_port)
+ except socket.error as e:
+ if res is not res0[-1]:
+ logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
+ continue
+ else:
+ raise
+ break
except socket.error as e:
- if res is not res0[-1]:
- logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
- continue
- else:
- raise
- break
- except socket.error as e:
- if self._unix_socket:
- message = 'Could not connect to secure socket %s: %s' \
- % (self._unix_socket, e)
- else:
- message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
- logger.error('Error while connecting with %s.', ip_port, exc_info=True)
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message=message)
- if self.validate:
- self._validate_cert()
-
- def _validate_cert(self):
- """internal method to validate the peer's SSL certificate, and to check the
- commonName of the certificate to ensure it matches the hostname we
- used to make this connection. Does not support subjectAltName records
- in certificates.
-
- raises TTransportException if the certificate fails validation.
- """
- cert = self.handle.getpeercert()
- self.peercert = cert
- if 'subject' not in cert:
- raise TTransportException(
- type=TTransportException.NOT_OPEN,
- message='No SSL certificate found from %s:%s' % (self.host, self.port))
- fields = cert['subject']
- for field in fields:
- # ensure structure we get back is what we expect
- if not isinstance(field, tuple):
- continue
- cert_pair = field[0]
- if len(cert_pair) < 2:
- continue
- cert_key, cert_value = cert_pair[0:2]
- if cert_key != 'commonName':
- continue
- certhost = cert_value
- # this check should be performed by some sort of Access Manager
- if certhost == self.host:
- # success, cert commonName matches desired hostname
- self.is_valid = True
- return
- else:
+ if self._unix_socket:
+ message = 'Could not connect to secure socket %s: %s' \
+ % (self._unix_socket, e)
+ else:
+ message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
+ logger.error('Error while connecting with %s.', ip_port, exc_info=True)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
+ if self.validate:
+ self._validate_cert()
+
+ def _validate_cert(self):
+ """internal method to validate the peer's SSL certificate, and to check the
+ commonName of the certificate to ensure it matches the hostname we
+ used to make this connection. Does not support subjectAltName records
+ in certificates.
+
+ raises TTransportException if the certificate fails validation.
+ """
+ cert = self.handle.getpeercert()
+ self.peercert = cert
+ if 'subject' not in cert:
+ raise TTransportException(
+ type=TTransportException.NOT_OPEN,
+ message='No SSL certificate found from %s:%s' % (self.host, self.port))
+ fields = cert['subject']
+ for field in fields:
+ # ensure structure we get back is what we expect
+ if not isinstance(field, tuple):
+ continue
+ cert_pair = field[0]
+ if len(cert_pair) < 2:
+ continue
+ cert_key, cert_value = cert_pair[0:2]
+ if cert_key != 'commonName':
+ continue
+ certhost = cert_value
+ # this check should be performed by some sort of Access Manager
+ if certhost == self.host:
+ # success, cert commonName matches desired hostname
+ self.is_valid = True
+ return
+ else:
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message='Hostname we connected to "%s" doesn\'t match certificate '
+ 'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Hostname we connected to "%s" doesn\'t match certificate '
- 'provided commonName "%s"' % (self.host, certhost))
- raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Could not validate SSL certificate from '
- 'host "%s". Cert=%s' % (self.host, cert))
+ type=TTransportException.UNKNOWN,
+ message='Could not validate SSL certificate from '
+ 'host "%s". Cert=%s' % (self.host, cert))
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
- """SSL implementation of TServerSocket
-
- This uses the ssl module's wrap_socket() method to provide SSL
- negotiated encryption.
- """
+ """SSL implementation of TServerSocket
- # New signature
- # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
- # Deprecated signature
- # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
- def __init__(self, host=None, port=9090, *args, **kwargs):
- """Positional arguments: ``host``, ``port``, ``unix_socket``
-
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
- See ssl.wrap_socket documentation.
-
- Alternative keywoard arguments: (Python 2.7.9 or later)
- ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
- ``server_hostname``: Passed to SSLContext.wrap_socket
- """
- if args:
- if len(args) > 3:
- raise TypeError('Too many positional argument')
- if not self._unix_socket_arg(host, port, args, kwargs):
- self._deprecated_arg(args, kwargs, 0, 'certfile')
- self._deprecated_arg(args, kwargs, 1, 'unix_socket')
- self._deprecated_arg(args, kwargs, 2, 'ciphers')
-
- if 'ssl_context' not in kwargs:
- # Preserve existing behaviors for default values
- if 'cert_reqs' not in kwargs:
- kwargs['cert_reqs'] = ssl.CERT_NONE
- if'certfile' not in kwargs:
- kwargs['certfile'] = 'cert.pem'
-
- unix_socket = kwargs.pop('unix_socket', None)
- TSSLBase.__init__(self, True, None, kwargs)
- TSocket.TServerSocket.__init__(self, host, port, unix_socket)
-
- def setCertfile(self, certfile):
- """Set or change the server certificate file used to wrap new connections.
-
- @param certfile: The filename of the server certificate,
- i.e. '/etc/certs/server.pem'
- @type certfile: str
-
- Raises an IOError exception if the certfile is not present or unreadable.
+ This uses the ssl module's wrap_socket() method to provide SSL
+ negotiated encryption.
"""
- warnings.warn('Use certfile property instead.', DeprecationWarning)
- self.certfile = certfile
-
- def accept(self):
- plain_client, addr = self.handle.accept()
- try:
- client = self._wrap_socket(plain_client)
- except ssl.SSLError:
- logger.error('Error while accepting from %s', addr, exc_info=True)
- # failed handshake/ssl wrap, close socket to client
- plain_client.close()
- # raise
- # We can't raise the exception, because it kills most TServer derived
- # serve() methods.
- # Instead, return None, and let the TServer instance deal with it in
- # other exception handling. (but TSimpleServer dies anyway)
- return None
- result = TSocket.TSocket()
- result.setHandle(client)
- return result
+
+ # New signature
+ # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+ # Deprecated signature
+ # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
+ def __init__(self, host=None, port=9090, *args, **kwargs):
+ """Positional arguments: ``host``, ``port``, ``unix_socket``
+
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
+ ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
+ See ssl.wrap_socket documentation.
+
+ Alternative keywoard arguments: (Python 2.7.9 or later)
+ ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
+ ``server_hostname``: Passed to SSLContext.wrap_socket
+ """
+ if args:
+ if len(args) > 3:
+ raise TypeError('Too many positional argument')
+ if not self._unix_socket_arg(host, port, args, kwargs):
+ self._deprecated_arg(args, kwargs, 0, 'certfile')
+ self._deprecated_arg(args, kwargs, 1, 'unix_socket')
+ self._deprecated_arg(args, kwargs, 2, 'ciphers')
+
+ if 'ssl_context' not in kwargs:
+ # Preserve existing behaviors for default values
+ if 'cert_reqs' not in kwargs:
+ kwargs['cert_reqs'] = ssl.CERT_NONE
+ if'certfile' not in kwargs:
+ kwargs['certfile'] = 'cert.pem'
+
+ unix_socket = kwargs.pop('unix_socket', None)
+ TSSLBase.__init__(self, True, None, kwargs)
+ TSocket.TServerSocket.__init__(self, host, port, unix_socket)
+
+ def setCertfile(self, certfile):
+ """Set or change the server certificate file used to wrap new connections.
+
+ @param certfile: The filename of the server certificate,
+ i.e. '/etc/certs/server.pem'
+ @type certfile: str
+
+ Raises an IOError exception if the certfile is not present or unreadable.
+ """
+ warnings.warn('Use certfile property instead.', DeprecationWarning)
+ self.certfile = certfile
+
+ def accept(self):
+ plain_client, addr = self.handle.accept()
+ try:
+ client = self._wrap_socket(plain_client)
+ except ssl.SSLError:
+ logger.error('Error while accepting from %s', addr, exc_info=True)
+ # failed handshake/ssl wrap, close socket to client
+ plain_client.close()
+ # raise
+ # We can't raise the exception, because it kills most TServer derived
+ # serve() methods.
+ # Instead, return None, and let the TServer instance deal with it in
+ # other exception handling. (but TSimpleServer dies anyway)
+ return None
+ result = TSocket.TSocket()
+ result.setHandle(client)
+ return result
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index cb204a4a0..a8ed4b7dc 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -22,159 +22,159 @@ import os
import socket
import sys
-from .TTransport import *
+from .TTransport import TTransportBase, TTransportException, TServerTransportBase
class TSocketBase(TTransportBase):
- def _resolveAddr(self):
- if self._unix_socket is not None:
- return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
- self._unix_socket)]
- else:
- return socket.getaddrinfo(self.host,
- self.port,
- self._socket_family,
- socket.SOCK_STREAM,
- 0,
- socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
-
- def close(self):
- if self.handle:
- self.handle.close()
- self.handle = None
+ def _resolveAddr(self):
+ if self._unix_socket is not None:
+ return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
+ self._unix_socket)]
+ else:
+ return socket.getaddrinfo(self.host,
+ self.port,
+ self._socket_family,
+ socket.SOCK_STREAM,
+ 0,
+ socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
+
+ def close(self):
+ if self.handle:
+ self.handle.close()
+ self.handle = None
class TSocket(TSocketBase):
- """Socket implementation of TTransport base."""
-
- def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
- """Initialize a TSocket
-
- @param host(str) The host to connect to.
- @param port(int) The (TCP) port to connect to.
- @param unix_socket(str) The filename of a unix socket to connect to.
- (host and port will be ignored.)
- @param socket_family(int) The socket family to use with this socket.
- """
- self.host = host
- self.port = port
- self.handle = None
- self._unix_socket = unix_socket
- self._timeout = None
- self._socket_family = socket_family
-
- def setHandle(self, h):
- self.handle = h
-
- def isOpen(self):
- return self.handle is not None
-
- def setTimeout(self, ms):
- if ms is None:
- self._timeout = None
- else:
- self._timeout = ms / 1000.0
-
- if self.handle is not None:
- self.handle.settimeout(self._timeout)
-
- def open(self):
- try:
- res0 = self._resolveAddr()
- for res in res0:
- self.handle = socket.socket(res[0], res[1])
- self.handle.settimeout(self._timeout)
+ """Socket implementation of TTransport base."""
+
+ def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
+ """Initialize a TSocket
+
+ @param host(str) The host to connect to.
+ @param port(int) The (TCP) port to connect to.
+ @param unix_socket(str) The filename of a unix socket to connect to.
+ (host and port will be ignored.)
+ @param socket_family(int) The socket family to use with this socket.
+ """
+ self.host = host
+ self.port = port
+ self.handle = None
+ self._unix_socket = unix_socket
+ self._timeout = None
+ self._socket_family = socket_family
+
+ def setHandle(self, h):
+ self.handle = h
+
+ def isOpen(self):
+ return self.handle is not None
+
+ def setTimeout(self, ms):
+ if ms is None:
+ self._timeout = None
+ else:
+ self._timeout = ms / 1000.0
+
+ if self.handle is not None:
+ self.handle.settimeout(self._timeout)
+
+ def open(self):
+ try:
+ res0 = self._resolveAddr()
+ for res in res0:
+ self.handle = socket.socket(res[0], res[1])
+ self.handle.settimeout(self._timeout)
+ try:
+ self.handle.connect(res[4])
+ except socket.error as e:
+ if res is not res0[-1]:
+ continue
+ else:
+ raise e
+ break
+ except socket.error as e:
+ if self._unix_socket:
+ message = 'Could not connect to socket %s' % self._unix_socket
+ else:
+ message = 'Could not connect to %s:%d' % (self.host, self.port)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
+
+ def read(self, sz):
try:
- self.handle.connect(res[4])
+ buff = self.handle.recv(sz)
except socket.error as e:
- if res is not res0[-1]:
- continue
- else:
- raise e
- break
- except socket.error as e:
- if self._unix_socket:
- message = 'Could not connect to socket %s' % self._unix_socket
- else:
- message = 'Could not connect to %s:%d' % (self.host, self.port)
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message=message)
-
- def read(self, sz):
- try:
- buff = self.handle.recv(sz)
- except socket.error as e:
- if (e.args[0] == errno.ECONNRESET and
- (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
- # freebsd and Mach don't follow POSIX semantic of recv
- # and fail with ECONNRESET if peer performed shutdown.
- # See corresponding comment and code in TSocket::read()
- # in lib/cpp/src/transport/TSocket.cpp.
- self.close()
- # Trigger the check to raise the END_OF_FILE exception below.
- buff = ''
- else:
- raise
- if len(buff) == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE,
- message='TSocket read 0 bytes')
- return buff
-
- def write(self, buff):
- if not self.handle:
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message='Transport not open')
- sent = 0
- have = len(buff)
- while sent < have:
- plus = self.handle.send(buff)
- if plus == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE,
- message='TSocket sent 0 bytes')
- sent += plus
- buff = buff[plus:]
-
- def flush(self):
- pass
+ if (e.args[0] == errno.ECONNRESET and
+ (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
+ # freebsd and Mach don't follow POSIX semantic of recv
+ # and fail with ECONNRESET if peer performed shutdown.
+ # See corresponding comment and code in TSocket::read()
+ # in lib/cpp/src/transport/TSocket.cpp.
+ self.close()
+ # Trigger the check to raise the END_OF_FILE exception below.
+ buff = ''
+ else:
+ raise
+ if len(buff) == 0:
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket read 0 bytes')
+ return buff
+
+ def write(self, buff):
+ if not self.handle:
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message='Transport not open')
+ sent = 0
+ have = len(buff)
+ while sent < have:
+ plus = self.handle.send(buff)
+ if plus == 0:
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket sent 0 bytes')
+ sent += plus
+ buff = buff[plus:]
+
+ def flush(self):
+ pass
class TServerSocket(TSocketBase, TServerTransportBase):
- """Socket implementation of TServerTransport base."""
-
- def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
- self.host = host
- self.port = port
- self._unix_socket = unix_socket
- self._socket_family = socket_family
- self.handle = None
-
- def listen(self):
- res0 = self._resolveAddr()
- socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
- for res in res0:
- if res[0] is socket_family or res is res0[-1]:
- break
-
- # We need remove the old unix socket if the file exists and
- # nobody is listening on it.
- if self._unix_socket:
- tmp = socket.socket(res[0], res[1])
- try:
- tmp.connect(res[4])
- except socket.error as err:
- eno, message = err.args
- if eno == errno.ECONNREFUSED:
- os.unlink(res[4])
-
- self.handle = socket.socket(res[0], res[1])
- self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if hasattr(self.handle, 'settimeout'):
- self.handle.settimeout(None)
- self.handle.bind(res[4])
- self.handle.listen(128)
-
- def accept(self):
- client, addr = self.handle.accept()
- result = TSocket()
- result.setHandle(client)
- return result
+ """Socket implementation of TServerTransport base."""
+
+ def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
+ self.host = host
+ self.port = port
+ self._unix_socket = unix_socket
+ self._socket_family = socket_family
+ self.handle = None
+
+ def listen(self):
+ res0 = self._resolveAddr()
+ socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
+ for res in res0:
+ if res[0] is socket_family or res is res0[-1]:
+ break
+
+ # We need remove the old unix socket if the file exists and
+ # nobody is listening on it.
+ if self._unix_socket:
+ tmp = socket.socket(res[0], res[1])
+ try:
+ tmp.connect(res[4])
+ except socket.error as err:
+ eno, message = err.args
+ if eno == errno.ECONNREFUSED:
+ os.unlink(res[4])
+
+ self.handle = socket.socket(res[0], res[1])
+ self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if hasattr(self.handle, 'settimeout'):
+ self.handle.settimeout(None)
+ self.handle.bind(res[4])
+ self.handle.listen(128)
+
+ def accept(self):
+ client, addr = self.handle.accept()
+ result = TSocket()
+ result.setHandle(client)
+ return result
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index f99b3b9ba..6669891cd 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -23,427 +23,426 @@ from ..compat import BufferIO
class TTransportException(TException):
- """Custom Transport Exception class"""
+ """Custom Transport Exception class"""
- UNKNOWN = 0
- NOT_OPEN = 1
- ALREADY_OPEN = 2
- TIMED_OUT = 3
- END_OF_FILE = 4
- NEGATIVE_SIZE = 5
- SIZE_LIMIT = 6
+ UNKNOWN = 0
+ NOT_OPEN = 1
+ ALREADY_OPEN = 2
+ TIMED_OUT = 3
+ END_OF_FILE = 4
+ NEGATIVE_SIZE = 5
+ SIZE_LIMIT = 6
- def __init__(self, type=UNKNOWN, message=None):
- TException.__init__(self, message)
- self.type = type
+ def __init__(self, type=UNKNOWN, message=None):
+ TException.__init__(self, message)
+ self.type = type
class TTransportBase(object):
- """Base class for Thrift transport layer."""
+ """Base class for Thrift transport layer."""
- def isOpen(self):
- pass
+ def isOpen(self):
+ pass
- def open(self):
- pass
+ def open(self):
+ pass
- def close(self):
- pass
+ def close(self):
+ pass
- def read(self, sz):
- pass
+ def read(self, sz):
+ pass
- def readAll(self, sz):
- buff = b''
- have = 0
- while (have < sz):
- chunk = self.read(sz - have)
- have += len(chunk)
- buff += chunk
+ def readAll(self, sz):
+ buff = b''
+ have = 0
+ while (have < sz):
+ chunk = self.read(sz - have)
+ have += len(chunk)
+ buff += chunk
- if len(chunk) == 0:
- raise EOFError()
+ if len(chunk) == 0:
+ raise EOFError()
- return buff
+ return buff
- def write(self, buf):
- pass
+ def write(self, buf):
+ pass
- def flush(self):
- pass
+ def flush(self):
+ pass
# This class should be thought of as an interface.
class CReadableTransport(object):
- """base class for transports that are readable from C"""
+ """base class for transports that are readable from C"""
- # TODO(dreiss): Think about changing this interface to allow us to use
- # a (Python, not c) StringIO instead, because it allows
- # you to write after reading.
+ # TODO(dreiss): Think about changing this interface to allow us to use
+ # a (Python, not c) StringIO instead, because it allows
+ # you to write after reading.
- # NOTE: This is a classic class, so properties will NOT work
- # correctly for setting.
- @property
- def cstringio_buf(self):
- """A cStringIO buffer that contains the current chunk we are reading."""
- pass
+ # NOTE: This is a classic class, so properties will NOT work
+ # correctly for setting.
+ @property
+ def cstringio_buf(self):
+ """A cStringIO buffer that contains the current chunk we are reading."""
+ pass
- def cstringio_refill(self, partialread, reqlen):
- """Refills cstringio_buf.
+ def cstringio_refill(self, partialread, reqlen):
+ """Refills cstringio_buf.
- Returns the currently used buffer (which can but need not be the same as
- the old cstringio_buf). partialread is what the C code has read from the
- buffer, and should be inserted into the buffer before any more reads. The
- return value must be a new, not borrowed reference. Something along the
- lines of self._buf should be fine.
+ Returns the currently used buffer (which can but need not be the same as
+ the old cstringio_buf). partialread is what the C code has read from the
+ buffer, and should be inserted into the buffer before any more reads. The
+ return value must be a new, not borrowed reference. Something along the
+ lines of self._buf should be fine.
- If reqlen bytes can't be read, throw EOFError.
- """
- pass
+ If reqlen bytes can't be read, throw EOFError.
+ """
+ pass
class TServerTransportBase(object):
- """Base class for Thrift server transports."""
+ """Base class for Thrift server transports."""
- def listen(self):
- pass
+ def listen(self):
+ pass
- def accept(self):
- pass
+ def accept(self):
+ pass
- def close(self):
- pass
+ def close(self):
+ pass
class TTransportFactoryBase(object):
- """Base class for a Transport Factory"""
+ """Base class for a Transport Factory"""
- def getTransport(self, trans):
- return trans
+ def getTransport(self, trans):
+ return trans
class TBufferedTransportFactory(object):
- """Factory transport that builds buffered transports"""
+ """Factory transport that builds buffered transports"""
- def getTransport(self, trans):
- buffered = TBufferedTransport(trans)
- return buffered
+ def getTransport(self, trans):
+ buffered = TBufferedTransport(trans)
+ return buffered
class TBufferedTransport(TTransportBase, CReadableTransport):
- """Class that wraps another transport and buffers its I/O.
-
- The implementation uses a (configurable) fixed-size read buffer
- but buffers all writes until a flush is performed.
- """
- DEFAULT_BUFFER = 4096
-
- def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
- self.__trans = trans
- self.__wbuf = BufferIO()
- # Pass string argument to initialize read buffer as cStringIO.InputType
- self.__rbuf = BufferIO(b'')
- self.__rbuf_size = rbuf_size
-
- def isOpen(self):
- return self.__trans.isOpen()
-
- def open(self):
- return self.__trans.open()
-
- def close(self):
- return self.__trans.close()
-
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
- self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
- return self.__rbuf.read(sz)
-
- def write(self, buf):
- try:
- self.__wbuf.write(buf)
- except Exception as e:
- # on exception reset wbuf so it doesn't contain a partial function call
- self.__wbuf = BufferIO()
- raise e
- self.__wbuf.getvalue()
-
- def flush(self):
- out = self.__wbuf.getvalue()
- # reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = BufferIO()
- self.__trans.write(out)
- self.__trans.flush()
-
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self.__rbuf
-
- def cstringio_refill(self, partialread, reqlen):
- retstring = partialread
- if reqlen < self.__rbuf_size:
- # try to make a read of as much as we can.
- retstring += self.__trans.read(self.__rbuf_size)
-
- # but make sure we do read reqlen bytes.
- if len(retstring) < reqlen:
- retstring += self.__trans.readAll(reqlen - len(retstring))
-
- self.__rbuf = BufferIO(retstring)
- return self.__rbuf
+ """Class that wraps another transport and buffers its I/O.
+
+ The implementation uses a (configurable) fixed-size read buffer
+ but buffers all writes until a flush is performed.
+ """
+ DEFAULT_BUFFER = 4096
+
+ def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
+ self.__trans = trans
+ self.__wbuf = BufferIO()
+ # Pass string argument to initialize read buffer as cStringIO.InputType
+ self.__rbuf = BufferIO(b'')
+ self.__rbuf_size = rbuf_size
+
+ def isOpen(self):
+ return self.__trans.isOpen()
+
+ def open(self):
+ return self.__trans.open()
+
+ def close(self):
+ return self.__trans.close()
+
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+ self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
+ return self.__rbuf.read(sz)
+
+ def write(self, buf):
+ try:
+ self.__wbuf.write(buf)
+ except Exception as e:
+ # on exception reset wbuf so it doesn't contain a partial function call
+ self.__wbuf = BufferIO()
+ raise e
+ self.__wbuf.getvalue()
+
+ def flush(self):
+ out = self.__wbuf.getvalue()
+ # reset wbuf before write/flush to preserve state on underlying failure
+ self.__wbuf = BufferIO()
+ self.__trans.write(out)
+ self.__trans.flush()
+
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
+
+ def cstringio_refill(self, partialread, reqlen):
+ retstring = partialread
+ if reqlen < self.__rbuf_size:
+ # try to make a read of as much as we can.
+ retstring += self.__trans.read(self.__rbuf_size)
+
+ # but make sure we do read reqlen bytes.
+ if len(retstring) < reqlen:
+ retstring += self.__trans.readAll(reqlen - len(retstring))
+
+ self.__rbuf = BufferIO(retstring)
+ return self.__rbuf
class TMemoryBuffer(TTransportBase, CReadableTransport):
- """Wraps a cBytesIO object as a TTransport.
+ """Wraps a cBytesIO object as a TTransport.
- NOTE: Unlike the C++ version of this class, you cannot write to it
- then immediately read from it. If you want to read from a
- TMemoryBuffer, you must either pass a string to the constructor.
- TODO(dreiss): Make this work like the C++ version.
- """
+ NOTE: Unlike the C++ version of this class, you cannot write to it
+ then immediately read from it. If you want to read from a
+ TMemoryBuffer, you must either pass a string to the constructor.
+ TODO(dreiss): Make this work like the C++ version.
+ """
- def __init__(self, value=None):
- """value -- a value to read from for stringio
+ def __init__(self, value=None):
+ """value -- a value to read from for stringio
- If value is set, this will be a transport for reading,
- otherwise, it is for writing"""
- if value is not None:
- self._buffer = BufferIO(value)
- else:
- self._buffer = BufferIO()
+ If value is set, this will be a transport for reading,
+ otherwise, it is for writing"""
+ if value is not None:
+ self._buffer = BufferIO(value)
+ else:
+ self._buffer = BufferIO()
- def isOpen(self):
- return not self._buffer.closed
+ def isOpen(self):
+ return not self._buffer.closed
- def open(self):
- pass
+ def open(self):
+ pass
- def close(self):
- self._buffer.close()
+ def close(self):
+ self._buffer.close()
- def read(self, sz):
- return self._buffer.read(sz)
+ def read(self, sz):
+ return self._buffer.read(sz)
- def write(self, buf):
- self._buffer.write(buf)
+ def write(self, buf):
+ self._buffer.write(buf)
- def flush(self):
- pass
+ def flush(self):
+ pass
- def getvalue(self):
- return self._buffer.getvalue()
+ def getvalue(self):
+ return self._buffer.getvalue()
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self._buffer
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self._buffer
- def cstringio_refill(self, partialread, reqlen):
- # only one shot at reading...
- raise EOFError()
+ def cstringio_refill(self, partialread, reqlen):
+ # only one shot at reading...
+ raise EOFError()
class TFramedTransportFactory(object):
- """Factory transport that builds framed transports"""
+ """Factory transport that builds framed transports"""
- def getTransport(self, trans):
- framed = TFramedTransport(trans)
- return framed
+ def getTransport(self, trans):
+ framed = TFramedTransport(trans)
+ return framed
class TFramedTransport(TTransportBase, CReadableTransport):
- """Class that wraps another transport and frames its I/O when writing."""
-
- def __init__(self, trans,):
- self.__trans = trans
- self.__rbuf = BufferIO(b'')
- self.__wbuf = BufferIO()
-
- def isOpen(self):
- return self.__trans.isOpen()
-
- def open(self):
- return self.__trans.open()
-
- def close(self):
- return self.__trans.close()
-
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
-
- self.readFrame()
- return self.__rbuf.read(sz)
-
- def readFrame(self):
- buff = self.__trans.readAll(4)
- sz, = unpack('!i', buff)
- self.__rbuf = BufferIO(self.__trans.readAll(sz))
-
- def write(self, buf):
- self.__wbuf.write(buf)
-
- def flush(self):
- wout = self.__wbuf.getvalue()
- wsz = len(wout)
- # reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = BufferIO()
- # N.B.: Doing this string concatenation is WAY cheaper than making
- # two separate calls to the underlying socket object. Socket writes in
- # Python turn out to be REALLY expensive, but it seems to do a pretty
- # good job of managing string buffer operations without excessive copies
- buf = pack("!i", wsz) + wout
- self.__trans.write(buf)
- self.__trans.flush()
-
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self.__rbuf
-
- def cstringio_refill(self, prefix, reqlen):
- # self.__rbuf will already be empty here because fastbinary doesn't
- # ask for a refill until the previous buffer is empty. Therefore,
- # we can start reading new frames immediately.
- while len(prefix) < reqlen:
- self.readFrame()
- prefix += self.__rbuf.getvalue()
- self.__rbuf = BufferIO(prefix)
- return self.__rbuf
+ """Class that wraps another transport and frames its I/O when writing."""
+
+ def __init__(self, trans,):
+ self.__trans = trans
+ self.__rbuf = BufferIO(b'')
+ self.__wbuf = BufferIO()
+
+ def isOpen(self):
+ return self.__trans.isOpen()
+
+ def open(self):
+ return self.__trans.open()
+
+ def close(self):
+ return self.__trans.close()
+
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+
+ self.readFrame()
+ return self.__rbuf.read(sz)
+
+ def readFrame(self):
+ buff = self.__trans.readAll(4)
+ sz, = unpack('!i', buff)
+ self.__rbuf = BufferIO(self.__trans.readAll(sz))
+
+ def write(self, buf):
+ self.__wbuf.write(buf)
+
+ def flush(self):
+ wout = self.__wbuf.getvalue()
+ wsz = len(wout)
+ # reset wbuf before write/flush to preserve state on underlying failure
+ self.__wbuf = BufferIO()
+ # N.B.: Doing this string concatenation is WAY cheaper than making
+ # two separate calls to the underlying socket object. Socket writes in
+ # Python turn out to be REALLY expensive, but it seems to do a pretty
+ # good job of managing string buffer operations without excessive copies
+ buf = pack("!i", wsz) + wout
+ self.__trans.write(buf)
+ self.__trans.flush()
+
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
+
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self.readFrame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = BufferIO(prefix)
+ return self.__rbuf
class TFileObjectTransport(TTransportBase):
- """Wraps a file-like object to make it work as a Thrift transport."""
+ """Wraps a file-like object to make it work as a Thrift transport."""
- def __init__(self, fileobj):
- self.fileobj = fileobj
+ def __init__(self, fileobj):
+ self.fileobj = fileobj
- def isOpen(self):
- return True
+ def isOpen(self):
+ return True
- def close(self):
- self.fileobj.close()
+ def close(self):
+ self.fileobj.close()
- def read(self, sz):
- return self.fileobj.read(sz)
+ def read(self, sz):
+ return self.fileobj.read(sz)
- def write(self, buf):
- self.fileobj.write(buf)
+ def write(self, buf):
+ self.fileobj.write(buf)
- def flush(self):
- self.fileobj.flush()
+ def flush(self):
+ self.fileobj.flush()
class TSaslClientTransport(TTransportBase, CReadableTransport):
- """
- SASL transport
- """
-
- START = 1
- OK = 2
- BAD = 3
- ERROR = 4
- COMPLETE = 5
-
- def __init__(self, transport, host, service, mechanism='GSSAPI',
- **sasl_kwargs):
"""
- transport: an underlying transport to use, typically just a TSocket
- host: the name of the server, from a SASL perspective
- service: the name of the server's service, from a SASL perspective
- mechanism: the name of the preferred mechanism to use
-
- All other kwargs will be passed to the puresasl.client.SASLClient
- constructor.
+ SASL transport
"""
- from puresasl.client import SASLClient
-
- self.transport = transport
- self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
-
- self.__wbuf = BufferIO()
- self.__rbuf = BufferIO(b'')
-
- def open(self):
- if not self.transport.isOpen():
- self.transport.open()
-
- self.send_sasl_msg(self.START, self.sasl.mechanism)
- self.send_sasl_msg(self.OK, self.sasl.process())
-
- while True:
- status, challenge = self.recv_sasl_msg()
- if status == self.OK:
- self.send_sasl_msg(self.OK, self.sasl.process(challenge))
- elif status == self.COMPLETE:
- if not self.sasl.complete:
- raise TTransportException("The server erroneously indicated "
- "that SASL negotiation was complete")
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
+
+ def __init__(self, transport, host, service, mechanism='GSSAPI',
+ **sasl_kwargs):
+ """
+ transport: an underlying transport to use, typically just a TSocket
+ host: the name of the server, from a SASL perspective
+ service: the name of the server's service, from a SASL perspective
+ mechanism: the name of the preferred mechanism to use
+
+ All other kwargs will be passed to the puresasl.client.SASLClient
+ constructor.
+ """
+
+ from puresasl.client import SASLClient
+
+ self.transport = transport
+ self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
+
+ self.__wbuf = BufferIO()
+ self.__rbuf = BufferIO(b'')
+
+ def open(self):
+ if not self.transport.isOpen():
+ self.transport.open()
+
+ self.send_sasl_msg(self.START, self.sasl.mechanism)
+ self.send_sasl_msg(self.OK, self.sasl.process())
+
+ while True:
+ status, challenge = self.recv_sasl_msg()
+ if status == self.OK:
+ self.send_sasl_msg(self.OK, self.sasl.process(challenge))
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ raise TTransportException("The server erroneously indicated "
+ "that SASL negotiation was complete")
+ else:
+ break
+ else:
+ raise TTransportException("Bad SASL negotiation status: %d (%s)"
+ % (status, challenge))
+
+ def send_sasl_msg(self, status, body):
+ header = pack(">BI", status, len(body))
+ self.transport.write(header + body)
+ self.transport.flush()
+
+ def recv_sasl_msg(self):
+ header = self.transport.readAll(5)
+ status, length = unpack(">BI", header)
+ if length > 0:
+ payload = self.transport.readAll(length)
else:
- break
- else:
- raise TTransportException("Bad SASL negotiation status: %d (%s)"
- % (status, challenge))
-
- def send_sasl_msg(self, status, body):
- header = pack(">BI", status, len(body))
- self.transport.write(header + body)
- self.transport.flush()
-
- def recv_sasl_msg(self):
- header = self.transport.readAll(5)
- status, length = unpack(">BI", header)
- if length > 0:
- payload = self.transport.readAll(length)
- else:
- payload = ""
- return status, payload
-
- def write(self, data):
- self.__wbuf.write(data)
-
- def flush(self):
- data = self.__wbuf.getvalue()
- encoded = self.sasl.wrap(data)
- self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
- self.transport.flush()
- self.__wbuf = BufferIO()
-
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
-
- self._read_frame()
- return self.__rbuf.read(sz)
-
- def _read_frame(self):
- header = self.transport.readAll(4)
- length, = unpack('!i', header)
- encoded = self.transport.readAll(length)
- self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
-
- def close(self):
- self.sasl.dispose()
- self.transport.close()
-
- # based on TFramedTransport
- @property
- def cstringio_buf(self):
- return self.__rbuf
-
- def cstringio_refill(self, prefix, reqlen):
- # self.__rbuf will already be empty here because fastbinary doesn't
- # ask for a refill until the previous buffer is empty. Therefore,
- # we can start reading new frames immediately.
- while len(prefix) < reqlen:
- self._read_frame()
- prefix += self.__rbuf.getvalue()
- self.__rbuf = BufferIO(prefix)
- return self.__rbuf
-
+ payload = ""
+ return status, payload
+
+ def write(self, data):
+ self.__wbuf.write(data)
+
+ def flush(self):
+ data = self.__wbuf.getvalue()
+ encoded = self.sasl.wrap(data)
+ self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
+ self.transport.flush()
+ self.__wbuf = BufferIO()
+
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+
+ self._read_frame()
+ return self.__rbuf.read(sz)
+
+ def _read_frame(self):
+ header = self.transport.readAll(4)
+ length, = unpack('!i', header)
+ encoded = self.transport.readAll(length)
+ self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
+
+ def close(self):
+ self.sasl.dispose()
+ self.transport.close()
+
+ # based on TFramedTransport
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
+
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self._read_frame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = BufferIO(prefix)
+ return self.__rbuf
diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py
index 6149a6c8e..5710b573d 100644
--- a/lib/py/src/transport/TTwisted.py
+++ b/lib/py/src/transport/TTwisted.py
@@ -120,7 +120,7 @@ class ThriftSASLClientProtocol(ThriftClientProtocol):
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None,
- host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+ host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
"""
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
@@ -236,7 +236,7 @@ class ThriftServerProtocol(basic.Int32StringReceiver):
d = self.factory.processor.process(iprot, oprot)
d.addCallbacks(self.processOk, self.processError,
- callbackArgs=(tmo,))
+ callbackArgs=(tmo,))
class IThriftServerFactory(Interface):
@@ -288,7 +288,7 @@ class ThriftClientFactory(ClientFactory):
def buildProtocol(self, addr):
p = self.protocol(self.client_class, self.iprot_factory,
- self.oprot_factory)
+ self.oprot_factory)
p.factory = self
return p
@@ -298,7 +298,7 @@ class ThriftResource(resource.Resource):
allowedMethods = ('POST',)
def __init__(self, processor, inputProtocolFactory,
- outputProtocolFactory=None):
+ outputProtocolFactory=None):
resource.Resource.__init__(self)
self.inputProtocolFactory = inputProtocolFactory
if outputProtocolFactory is None:
diff --git a/lib/py/src/transport/TZlibTransport.py b/lib/py/src/transport/TZlibTransport.py
index 7fe5853ee..e84857924 100644
--- a/lib/py/src/transport/TZlibTransport.py
+++ b/lib/py/src/transport/TZlibTransport.py
@@ -29,220 +29,220 @@ from ..compat import BufferIO
class TZlibTransportFactory(object):
- """Factory transport that builds zlib compressed transports.
-
- This factory caches the last single client/transport that it was passed
- and returns the same TZlibTransport object that was created.
-
- This caching means the TServer class will get the _same_ transport
- object for both input and output transports from this factory.
- (For non-threaded scenarios only, since the cache only holds one object)
-
- The purpose of this caching is to allocate only one TZlibTransport where
- only one is really needed (since it must have separate read/write buffers),
- and makes the statistics from getCompSavings() and getCompRatio()
- easier to understand.
- """
- # class scoped cache of last transport given and zlibtransport returned
- _last_trans = None
- _last_z = None
-
- def getTransport(self, trans, compresslevel=9):
- """Wrap a transport, trans, with the TZlibTransport
- compressed transport class, returning a new
- transport to the caller.
-
- @param compresslevel: The zlib compression level, ranging
- from 0 (no compression) to 9 (best compression). Defaults to 9.
- @type compresslevel: int
-
- This method returns a TZlibTransport which wraps the
- passed C{trans} TTransport derived instance.
- """
- if trans == self._last_trans:
- return self._last_z
- ztrans = TZlibTransport(trans, compresslevel)
- self._last_trans = trans
- self._last_z = ztrans
- return ztrans
+ """Factory transport that builds zlib compressed transports.
+ This factory caches the last single client/transport that it was passed
+ and returns the same TZlibTransport object that was created.
-class TZlibTransport(TTransportBase, CReadableTransport):
- """Class that wraps a transport with zlib, compressing writes
- and decompresses reads, using the python standard
- library zlib module.
- """
- # Read buffer size for the python fastbinary C extension,
- # the TBinaryProtocolAccelerated class.
- DEFAULT_BUFFSIZE = 4096
-
- def __init__(self, trans, compresslevel=9):
- """Create a new TZlibTransport, wrapping C{trans}, another
- TTransport derived object.
-
- @param trans: A thrift transport object, i.e. a TSocket() object.
- @type trans: TTransport
- @param compresslevel: The zlib compression level, ranging
- from 0 (no compression) to 9 (best compression). Default is 9.
- @type compresslevel: int
- """
- self.__trans = trans
- self.compresslevel = compresslevel
- self.__rbuf = BufferIO()
- self.__wbuf = BufferIO()
- self._init_zlib()
- self._init_stats()
-
- def _reinit_buffers(self):
- """Internal method to initialize/reset the internal StringIO objects
- for read and write buffers.
- """
- self.__rbuf = BufferIO()
- self.__wbuf = BufferIO()
+ This caching means the TServer class will get the _same_ transport
+ object for both input and output transports from this factory.
+ (For non-threaded scenarios only, since the cache only holds one object)
- def _init_stats(self):
- """Internal method to reset the internal statistics counters
- for compression ratios and bandwidth savings.
+ The purpose of this caching is to allocate only one TZlibTransport where
+ only one is really needed (since it must have separate read/write buffers),
+ and makes the statistics from getCompSavings() and getCompRatio()
+ easier to understand.
"""
- self.bytes_in = 0
- self.bytes_out = 0
- self.bytes_in_comp = 0
- self.bytes_out_comp = 0
-
- def _init_zlib(self):
- """Internal method for setting up the zlib compression and
- decompression objects.
- """
- self._zcomp_read = zlib.decompressobj()
- self._zcomp_write = zlib.compressobj(self.compresslevel)
-
- def getCompRatio(self):
- """Get the current measured compression ratios (in,out) from
- this transport.
-
- Returns a tuple of:
- (inbound_compression_ratio, outbound_compression_ratio)
-
- The compression ratios are computed as:
- compressed / uncompressed
+ # class scoped cache of last transport given and zlibtransport returned
+ _last_trans = None
+ _last_z = None
+
+ def getTransport(self, trans, compresslevel=9):
+ """Wrap a transport, trans, with the TZlibTransport
+ compressed transport class, returning a new
+ transport to the caller.
+
+ @param compresslevel: The zlib compression level, ranging
+ from 0 (no compression) to 9 (best compression). Defaults to 9.
+ @type compresslevel: int
+
+ This method returns a TZlibTransport which wraps the
+ passed C{trans} TTransport derived instance.
+ """
+ if trans == self._last_trans:
+ return self._last_z
+ ztrans = TZlibTransport(trans, compresslevel)
+ self._last_trans = trans
+ self._last_z = ztrans
+ return ztrans
- E.g., data that compresses by 10x will have a ratio of: 0.10
- and data that compresses to half of ts original size will
- have a ratio of 0.5
- None is returned if no bytes have yet been processed in
- a particular direction.
- """
- r_percent, w_percent = (None, None)
- if self.bytes_in > 0:
- r_percent = self.bytes_in_comp / self.bytes_in
- if self.bytes_out > 0:
- w_percent = self.bytes_out_comp / self.bytes_out
- return (r_percent, w_percent)
-
- def getCompSavings(self):
- """Get the current count of saved bytes due to data
- compression.
-
- Returns a tuple of:
- (inbound_saved_bytes, outbound_saved_bytes)
-
- Note: if compression is actually expanding your
- data (only likely with very tiny thrift objects), then
- the values returned will be negative.
- """
- r_saved = self.bytes_in - self.bytes_in_comp
- w_saved = self.bytes_out - self.bytes_out_comp
- return (r_saved, w_saved)
-
- def isOpen(self):
- """Return the underlying transport's open status"""
- return self.__trans.isOpen()
-
- def open(self):
- """Open the underlying transport"""
- self._init_stats()
- return self.__trans.open()
-
- def listen(self):
- """Invoke the underlying transport's listen() method"""
- self.__trans.listen()
-
- def accept(self):
- """Accept connections on the underlying transport"""
- return self.__trans.accept()
-
- def close(self):
- """Close the underlying transport,"""
- self._reinit_buffers()
- self._init_zlib()
- return self.__trans.close()
-
- def read(self, sz):
- """Read up to sz bytes from the decompressed bytes buffer, and
- read from the underlying transport if the decompression
- buffer is empty.
- """
- ret = self.__rbuf.read(sz)
- if len(ret) > 0:
- return ret
- # keep reading from transport until something comes back
- while True:
- if self.readComp(sz):
- break
- ret = self.__rbuf.read(sz)
- return ret
-
- def readComp(self, sz):
- """Read compressed data from the underlying transport, then
- decompress it and append it to the internal StringIO read buffer
- """
- zbuf = self.__trans.read(sz)
- zbuf = self._zcomp_read.unconsumed_tail + zbuf
- buf = self._zcomp_read.decompress(zbuf)
- self.bytes_in += len(zbuf)
- self.bytes_in_comp += len(buf)
- old = self.__rbuf.read()
- self.__rbuf = BufferIO(old + buf)
- if len(old) + len(buf) == 0:
- return False
- return True
-
- def write(self, buf):
- """Write some bytes, putting them into the internal write
- buffer for eventual compression.
- """
- self.__wbuf.write(buf)
-
- def flush(self):
- """Flush any queued up data in the write buffer and ensure the
- compression buffer is flushed out to the underlying transport
+class TZlibTransport(TTransportBase, CReadableTransport):
+ """Class that wraps a transport with zlib, compressing writes
+ and decompresses reads, using the python standard
+ library zlib module.
"""
- wout = self.__wbuf.getvalue()
- if len(wout) > 0:
- zbuf = self._zcomp_write.compress(wout)
- self.bytes_out += len(wout)
- self.bytes_out_comp += len(zbuf)
- else:
- zbuf = ''
- ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
- self.bytes_out_comp += len(ztail)
- if (len(zbuf) + len(ztail)) > 0:
- self.__wbuf = BufferIO()
- self.__trans.write(zbuf + ztail)
- self.__trans.flush()
-
- @property
- def cstringio_buf(self):
- """Implement the CReadableTransport interface"""
- return self.__rbuf
-
- def cstringio_refill(self, partialread, reqlen):
- """Implement the CReadableTransport interface for refill"""
- retstring = partialread
- if reqlen < self.DEFAULT_BUFFSIZE:
- retstring += self.read(self.DEFAULT_BUFFSIZE)
- while len(retstring) < reqlen:
- retstring += self.read(reqlen - len(retstring))
- self.__rbuf = BufferIO(retstring)
- return self.__rbuf
+ # Read buffer size for the python fastbinary C extension,
+ # the TBinaryProtocolAccelerated class.
+ DEFAULT_BUFFSIZE = 4096
+
+ def __init__(self, trans, compresslevel=9):
+ """Create a new TZlibTransport, wrapping C{trans}, another
+ TTransport derived object.
+
+ @param trans: A thrift transport object, i.e. a TSocket() object.
+ @type trans: TTransport
+ @param compresslevel: The zlib compression level, ranging
+ from 0 (no compression) to 9 (best compression). Default is 9.
+ @type compresslevel: int
+ """
+ self.__trans = trans
+ self.compresslevel = compresslevel
+ self.__rbuf = BufferIO()
+ self.__wbuf = BufferIO()
+ self._init_zlib()
+ self._init_stats()
+
+ def _reinit_buffers(self):
+ """Internal method to initialize/reset the internal StringIO objects
+ for read and write buffers.
+ """
+ self.__rbuf = BufferIO()
+ self.__wbuf = BufferIO()
+
+ def _init_stats(self):
+ """Internal method to reset the internal statistics counters
+ for compression ratios and bandwidth savings.
+ """
+ self.bytes_in = 0
+ self.bytes_out = 0
+ self.bytes_in_comp = 0
+ self.bytes_out_comp = 0
+
+ def _init_zlib(self):
+ """Internal method for setting up the zlib compression and
+ decompression objects.
+ """
+ self._zcomp_read = zlib.decompressobj()
+ self._zcomp_write = zlib.compressobj(self.compresslevel)
+
+ def getCompRatio(self):
+ """Get the current measured compression ratios (in,out) from
+ this transport.
+
+ Returns a tuple of:
+ (inbound_compression_ratio, outbound_compression_ratio)
+
+ The compression ratios are computed as:
+ compressed / uncompressed
+
+ E.g., data that compresses by 10x will have a ratio of: 0.10
+ and data that compresses to half of ts original size will
+ have a ratio of 0.5
+
+ None is returned if no bytes have yet been processed in
+ a particular direction.
+ """
+ r_percent, w_percent = (None, None)
+ if self.bytes_in > 0:
+ r_percent = self.bytes_in_comp / self.bytes_in
+ if self.bytes_out > 0:
+ w_percent = self.bytes_out_comp / self.bytes_out
+ return (r_percent, w_percent)
+
+ def getCompSavings(self):
+ """Get the current count of saved bytes due to data
+ compression.
+
+ Returns a tuple of:
+ (inbound_saved_bytes, outbound_saved_bytes)
+
+ Note: if compression is actually expanding your
+ data (only likely with very tiny thrift objects), then
+ the values returned will be negative.
+ """
+ r_saved = self.bytes_in - self.bytes_in_comp
+ w_saved = self.bytes_out - self.bytes_out_comp
+ return (r_saved, w_saved)
+
+ def isOpen(self):
+ """Return the underlying transport's open status"""
+ return self.__trans.isOpen()
+
+ def open(self):
+ """Open the underlying transport"""
+ self._init_stats()
+ return self.__trans.open()
+
+ def listen(self):
+ """Invoke the underlying transport's listen() method"""
+ self.__trans.listen()
+
+ def accept(self):
+ """Accept connections on the underlying transport"""
+ return self.__trans.accept()
+
+ def close(self):
+ """Close the underlying transport,"""
+ self._reinit_buffers()
+ self._init_zlib()
+ return self.__trans.close()
+
+ def read(self, sz):
+ """Read up to sz bytes from the decompressed bytes buffer, and
+ read from the underlying transport if the decompression
+ buffer is empty.
+ """
+ ret = self.__rbuf.read(sz)
+ if len(ret) > 0:
+ return ret
+ # keep reading from transport until something comes back
+ while True:
+ if self.readComp(sz):
+ break
+ ret = self.__rbuf.read(sz)
+ return ret
+
+ def readComp(self, sz):
+ """Read compressed data from the underlying transport, then
+ decompress it and append it to the internal StringIO read buffer
+ """
+ zbuf = self.__trans.read(sz)
+ zbuf = self._zcomp_read.unconsumed_tail + zbuf
+ buf = self._zcomp_read.decompress(zbuf)
+ self.bytes_in += len(zbuf)
+ self.bytes_in_comp += len(buf)
+ old = self.__rbuf.read()
+ self.__rbuf = BufferIO(old + buf)
+ if len(old) + len(buf) == 0:
+ return False
+ return True
+
+ def write(self, buf):
+ """Write some bytes, putting them into the internal write
+ buffer for eventual compression.
+ """
+ self.__wbuf.write(buf)
+
+ def flush(self):
+ """Flush any queued up data in the write buffer and ensure the
+ compression buffer is flushed out to the underlying transport
+ """
+ wout = self.__wbuf.getvalue()
+ if len(wout) > 0:
+ zbuf = self._zcomp_write.compress(wout)
+ self.bytes_out += len(wout)
+ self.bytes_out_comp += len(zbuf)
+ else:
+ zbuf = ''
+ ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
+ self.bytes_out_comp += len(ztail)
+ if (len(zbuf) + len(ztail)) > 0:
+ self.__wbuf = BufferIO()
+ self.__trans.write(zbuf + ztail)
+ self.__trans.flush()
+
+ @property
+ def cstringio_buf(self):
+ """Implement the CReadableTransport interface"""
+ return self.__rbuf
+
+ def cstringio_refill(self, partialread, reqlen):
+ """Implement the CReadableTransport interface for refill"""
+ retstring = partialread
+ if reqlen < self.DEFAULT_BUFFSIZE:
+ retstring += self.read(self.DEFAULT_BUFFSIZE)
+ while len(retstring) < reqlen:
+ retstring += self.read(reqlen - len(retstring))
+ self.__rbuf = BufferIO(retstring)
+ return self.__rbuf
diff --git a/lib/py/test/_import_local_thrift.py b/lib/py/test/_import_local_thrift.py
index 30c1abcc0..174166969 100644
--- a/lib/py/test/_import_local_thrift.py
+++ b/lib/py/test/_import_local_thrift.py
@@ -6,8 +6,8 @@ SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
if sys.version_info[0] == 2:
- import glob
- libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0]
- sys.path.insert(0, libdir)
+ import glob
+ libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0]
+ sys.path.insert(0, libdir)
else:
- sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib'))
+ sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib'))
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index b7c3802fe..fa156a0fd 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -46,231 +46,231 @@ TEST_CIPHERS = 'DES-CBC3-SHA'
class ServerAcceptor(threading.Thread):
- def __init__(self, server):
- super(ServerAcceptor, self).__init__()
- self._server = server
- self.client = None
+ def __init__(self, server):
+ super(ServerAcceptor, self).__init__()
+ self._server = server
+ self.client = None
- def run(self):
- self._server.listen()
- self.client = self._server.accept()
+ def run(self):
+ self._server.listen()
+ self.client = self._server.accept()
# Python 2.6 compat
class AssertRaises(object):
- def __init__(self, expected):
- self._expected = expected
+ def __init__(self, expected):
+ self._expected = expected
- def __enter__(self):
- pass
+ def __enter__(self):
+ pass
- def __exit__(self, exc_type, exc_value, traceback):
- if not exc_type or not issubclass(exc_type, self._expected):
- raise Exception('fail')
- return True
+ def __exit__(self, exc_type, exc_value, traceback):
+ if not exc_type or not issubclass(exc_type, self._expected):
+ raise Exception('fail')
+ return True
class TSSLSocketTest(unittest.TestCase):
- def _assert_connection_failure(self, server, client):
- try:
- acc = ServerAcceptor(server)
- acc.start()
- time.sleep(CONNECT_DELAY)
- client.setTimeout(CONNECT_TIMEOUT)
- with self._assert_raises(Exception):
- client.open()
- select.select([], [client.handle], [], CONNECT_TIMEOUT)
- # self.assertIsNone(acc.client)
- self.assertTrue(acc.client is None)
- finally:
- server.close()
- client.close()
-
- def _assert_raises(self, exc):
- if sys.hexversion >= 0x020700F0:
- return self.assertRaises(exc)
- else:
- return AssertRaises(exc)
-
- def _assert_connection_success(self, server, client):
- try:
- acc = ServerAcceptor(server)
- acc.start()
- time.sleep(0.15)
- client.setTimeout(CONNECT_TIMEOUT)
- client.open()
- select.select([], [client.handle], [], CONNECT_TIMEOUT)
- # self.assertIsNotNone(acc.client)
- self.assertTrue(acc.client is not None)
- finally:
- server.close()
- client.close()
-
- # deprecated feature
- def test_deprecation(self):
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
- TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
- self.assertEqual(len(w), 1)
-
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
- # Deprecated signature
- # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
- client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
- self.assertEqual(len(w), 7)
-
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
- # Deprecated signature
- # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
- server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
- self.assertEqual(len(w), 3)
-
- self._assert_connection_success(server, client)
-
- # deprecated feature
- def test_set_cert_reqs_by_validate(self):
- c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
- self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
-
- c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
- self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
-
- # deprecated feature
- def test_set_validate_by_cert_reqs(self):
- c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
- self.assertFalse(c1.validate)
-
- c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
- self.assertTrue(c2.validate)
-
- c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
- self.assertTrue(c3.validate)
-
- def test_unix_domain_socket(self):
- if platform.system() == 'Windows':
- print('skipping test_unix_domain_socket')
- return
- server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
- self._assert_connection_success(server, client)
-
- def test_server_cert(self):
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
- self._assert_connection_success(server, client)
-
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- # server cert on in ca_certs
- client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
- self._assert_connection_failure(server, client)
-
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
- self._assert_connection_success(server, client)
-
- def test_set_server_cert(self):
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
- with self._assert_raises(Exception):
- server.certfile = 'foo'
- with self._assert_raises(Exception):
- server.certfile = None
- server.certfile = SERVER_CERT
- client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
- self._assert_connection_success(server, client)
-
- def test_client_cert(self):
- server = TSSLServerSocket(
- port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
- client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
- self._assert_connection_success(server, client)
-
- def test_ciphers(self):
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
- self._assert_connection_success(server, client)
-
- if not TSSLSocket._has_ciphers:
- # unittest.skip is not available for Python 2.6
- print('skipping test_ciphers')
- return
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
- self._assert_connection_failure(server, client)
-
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
- self._assert_connection_failure(server, client)
-
- def test_ssl2_and_ssl3_disabled(self):
- if not hasattr(ssl, 'PROTOCOL_SSLv3'):
- print('PROTOCOL_SSLv3 is not available')
- else:
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
- self._assert_connection_failure(server, client)
-
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
- self._assert_connection_failure(server, client)
-
- if not hasattr(ssl, 'PROTOCOL_SSLv2'):
- print('PROTOCOL_SSLv2 is not available')
- else:
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
- self._assert_connection_failure(server, client)
-
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
- self._assert_connection_failure(server, client)
-
- def test_newer_tls(self):
- if not TSSLSocket._has_ssl_context:
- # unittest.skip is not available for Python 2.6
- print('skipping test_newer_tls')
- return
- if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
- print('PROTOCOL_TLSv1_2 is not available')
- else:
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- self._assert_connection_success(server, client)
-
- if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
- print('PROTOCOL_TLSv1_1 is not available')
- else:
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
- self._assert_connection_success(server, client)
-
- if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
- print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
- else:
- server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
- self._assert_connection_failure(server, client)
-
- def test_ssl_context(self):
- if not TSSLSocket._has_ssl_context:
- # unittest.skip is not available for Python 2.6
- print('skipping test_ssl_context')
- return
- server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
- server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
- server_context.load_verify_locations(CLIENT_CERT)
-
- client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
- client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
- client_context.load_verify_locations(SERVER_CERT)
-
- server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
- client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
- self._assert_connection_success(server, client)
+ def _assert_connection_failure(self, server, client):
+ try:
+ acc = ServerAcceptor(server)
+ acc.start()
+ time.sleep(CONNECT_DELAY)
+ client.setTimeout(CONNECT_TIMEOUT)
+ with self._assert_raises(Exception):
+ client.open()
+ select.select([], [client.handle], [], CONNECT_TIMEOUT)
+ # self.assertIsNone(acc.client)
+ self.assertTrue(acc.client is None)
+ finally:
+ server.close()
+ client.close()
+
+ def _assert_raises(self, exc):
+ if sys.hexversion >= 0x020700F0:
+ return self.assertRaises(exc)
+ else:
+ return AssertRaises(exc)
+
+ def _assert_connection_success(self, server, client):
+ try:
+ acc = ServerAcceptor(server)
+ acc.start()
+ time.sleep(0.15)
+ client.setTimeout(CONNECT_TIMEOUT)
+ client.open()
+ select.select([], [client.handle], [], CONNECT_TIMEOUT)
+ # self.assertIsNotNone(acc.client)
+ self.assertTrue(acc.client is not None)
+ finally:
+ server.close()
+ client.close()
+
+ # deprecated feature
+ def test_deprecation(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
+ TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
+ self.assertEqual(len(w), 1)
+
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
+ # Deprecated signature
+ # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
+ client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
+ self.assertEqual(len(w), 7)
+
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
+ # Deprecated signature
+ # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
+ server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
+ self.assertEqual(len(w), 3)
+
+ self._assert_connection_success(server, client)
+
+ # deprecated feature
+ def test_set_cert_reqs_by_validate(self):
+ c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
+ self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
+
+ c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
+ self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
+
+ # deprecated feature
+ def test_set_validate_by_cert_reqs(self):
+ c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
+ self.assertFalse(c1.validate)
+
+ c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
+ self.assertTrue(c2.validate)
+
+ c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
+ self.assertTrue(c3.validate)
+
+ def test_unix_domain_socket(self):
+ if platform.system() == 'Windows':
+ print('skipping test_unix_domain_socket')
+ return
+ server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
+ self._assert_connection_success(server, client)
+
+ def test_server_cert(self):
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
+ self._assert_connection_success(server, client)
+
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ # server cert on in ca_certs
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
+ self._assert_connection_failure(server, client)
+
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
+ self._assert_connection_success(server, client)
+
+ def test_set_server_cert(self):
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
+ with self._assert_raises(Exception):
+ server.certfile = 'foo'
+ with self._assert_raises(Exception):
+ server.certfile = None
+ server.certfile = SERVER_CERT
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
+ self._assert_connection_success(server, client)
+
+ def test_client_cert(self):
+ server = TSSLServerSocket(
+ port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+ certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
+ self._assert_connection_success(server, client)
+
+ def test_ciphers(self):
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
+ self._assert_connection_success(server, client)
+
+ if not TSSLSocket._has_ciphers:
+ # unittest.skip is not available for Python 2.6
+ print('skipping test_ciphers')
+ return
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
+ self._assert_connection_failure(server, client)
+
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
+ self._assert_connection_failure(server, client)
+
+ def test_ssl2_and_ssl3_disabled(self):
+ if not hasattr(ssl, 'PROTOCOL_SSLv3'):
+ print('PROTOCOL_SSLv3 is not available')
+ else:
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
+ self._assert_connection_failure(server, client)
+
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
+ self._assert_connection_failure(server, client)
+
+ if not hasattr(ssl, 'PROTOCOL_SSLv2'):
+ print('PROTOCOL_SSLv2 is not available')
+ else:
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
+ self._assert_connection_failure(server, client)
+
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
+ self._assert_connection_failure(server, client)
+
+ def test_newer_tls(self):
+ if not TSSLSocket._has_ssl_context:
+ # unittest.skip is not available for Python 2.6
+ print('skipping test_newer_tls')
+ return
+ if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
+ print('PROTOCOL_TLSv1_2 is not available')
+ else:
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
+ self._assert_connection_success(server, client)
+
+ if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
+ print('PROTOCOL_TLSv1_1 is not available')
+ else:
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
+ self._assert_connection_success(server, client)
+
+ if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
+ print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
+ else:
+ server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
+ client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
+ self._assert_connection_failure(server, client)
+
+ def test_ssl_context(self):
+ if not TSSLSocket._has_ssl_context:
+ # unittest.skip is not available for Python 2.6
+ print('skipping test_ssl_context')
+ return
+ server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
+ server_context.load_verify_locations(CLIENT_CERT)
+
+ client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
+ client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
+ client_context.load_verify_locations(SERVER_CERT)
+
+ server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
+ client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
+ self._assert_connection_success(server, client)
if __name__ == '__main__':
- # import logging
- # logging.basicConfig(level=logging.DEBUG)
- unittest.main()
+ # import logging
+ # logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/lib/py/test/thrift_json.py b/lib/py/test/thrift_json.py
index e60aabacf..5ba7dd585 100644
--- a/lib/py/test/thrift_json.py
+++ b/lib/py/test/thrift_json.py
@@ -31,20 +31,20 @@ from thrift.transport import TTransport
# mklink /D thrift ..\src
#
+
class TestJSONString(unittest.TestCase):
- def test_escaped_unicode_string(self):
- unicode_json = b'"hello \\u0e01\\u0e02\\u0e03\\ud835\\udcab\\udb40\\udc70 unicode"'
- unicode_text = u'hello \u0e01\u0e02\u0e03\U0001D4AB\U000E0070 unicode'
+ def test_escaped_unicode_string(self):
+ unicode_json = b'"hello \\u0e01\\u0e02\\u0e03\\ud835\\udcab\\udb40\\udc70 unicode"'
+ unicode_text = u'hello \u0e01\u0e02\u0e03\U0001D4AB\U000E0070 unicode'
- buf = TTransport.TMemoryBuffer(unicode_json)
- transport = TTransport.TBufferedTransportFactory().getTransport(buf)
- protocol = TJSONProtocol(transport)
+ buf = TTransport.TMemoryBuffer(unicode_json)
+ transport = TTransport.TBufferedTransportFactory().getTransport(buf)
+ protocol = TJSONProtocol(transport)
- if sys.version_info[0] == 2:
- unicode_text = unicode_text.encode('utf8')
- self.assertEqual(protocol.readString(), unicode_text)
+ if sys.version_info[0] == 2:
+ unicode_text = unicode_text.encode('utf8')
+ self.assertEqual(protocol.readString(), unicode_text)
if __name__ == '__main__':
- unittest.main()
-
+ unittest.main()
diff --git a/test/crossrunner/collect.py b/test/crossrunner/collect.py
index f92b9e2d7..e91ac0b43 100644
--- a/test/crossrunner/collect.py
+++ b/test/crossrunner/collect.py
@@ -40,13 +40,13 @@ from .util import merge_dict
# (e.g. "binary" is equivalent to "binary:binary" in tests.json)
#
VALID_JSON_KEYS = [
- 'name', # name of the library, typically a language name
- 'workdir', # work directory where command is executed
- 'command', # test command
- 'extra_args', # args appended to command after other args are appended
- 'remote_args', # args added to the other side of the program
- 'join_args', # whether args should be passed as single concatenated string
- 'env', # additional environmental variable
+ 'name', # name of the library, typically a language name
+ 'workdir', # work directory where command is executed
+ 'command', # test command
+ 'extra_args', # args appended to command after other args are appended
+ 'remote_args', # args added to the other side of the program
+ 'join_args', # whether args should be passed as single concatenated string
+ 'env', # additional environmental variable
]
DEFAULT_DELAY = 1
@@ -54,102 +54,102 @@ DEFAULT_TIMEOUT = 5
def _collect_testlibs(config, server_match, client_match=[None]):
- """Collects server/client configurations from library configurations"""
- def expand_libs(config):
- for lib in config:
- sv = lib.pop('server', None)
- cl = lib.pop('client', None)
- yield lib, sv, cl
-
- def yield_testlibs(base_configs, configs, match):
- for base, conf in zip(base_configs, configs):
- if conf:
- if not match or base['name'] in match:
- platforms = conf.get('platforms') or base.get('platforms')
- if not platforms or platform.system() in platforms:
- yield merge_dict(base, conf)
-
- libs, svs, cls = zip(*expand_libs(config))
- servers = list(yield_testlibs(libs, svs, server_match))
- clients = list(yield_testlibs(libs, cls, client_match))
- return servers, clients
+ """Collects server/client configurations from library configurations"""
+ def expand_libs(config):
+ for lib in config:
+ sv = lib.pop('server', None)
+ cl = lib.pop('client', None)
+ yield lib, sv, cl
+
+ def yield_testlibs(base_configs, configs, match):
+ for base, conf in zip(base_configs, configs):
+ if conf:
+ if not match or base['name'] in match:
+ platforms = conf.get('platforms') or base.get('platforms')
+ if not platforms or platform.system() in platforms:
+ yield merge_dict(base, conf)
+
+ libs, svs, cls = zip(*expand_libs(config))
+ servers = list(yield_testlibs(libs, svs, server_match))
+ clients = list(yield_testlibs(libs, cls, client_match))
+ return servers, clients
def collect_features(config, match):
- res = list(map(re.compile, match))
- return list(filter(lambda c: any(map(lambda r: r.search(c['name']), res)), config))
+ res = list(map(re.compile, match))
+ return list(filter(lambda c: any(map(lambda r: r.search(c['name']), res)), config))
def _do_collect_tests(servers, clients):
- def intersection(key, o1, o2):
- """intersection of two collections.
- collections are replaced with sets the first time"""
- def cached_set(o, key):
- v = o[key]
- if not isinstance(v, set):
- v = set(v)
- o[key] = v
- return v
- return cached_set(o1, key) & cached_set(o2, key)
-
- def intersect_with_spec(key, o1, o2):
- # store as set of (spec, impl) tuple
- def cached_set(o):
- def to_spec_impl_tuples(values):
- for v in values:
- spec, _, impl = v.partition(':')
- yield spec, impl or spec
- v = o[key]
- if not isinstance(v, set):
- v = set(to_spec_impl_tuples(set(v)))
- o[key] = v
- return v
- for spec1, impl1 in cached_set(o1):
- for spec2, impl2 in cached_set(o2):
- if spec1 == spec2:
- name = impl1 if impl1 == impl2 else '%s-%s' % (impl1, impl2)
- yield name, impl1, impl2
-
- def maybe_max(key, o1, o2, default):
- """maximum of two if present, otherwise defult value"""
- v1 = o1.get(key)
- v2 = o2.get(key)
- return max(v1, v2) if v1 and v2 else v1 or v2 or default
-
- def filter_with_validkeys(o):
- ret = {}
- for key in VALID_JSON_KEYS:
- if key in o:
- ret[key] = o[key]
- return ret
-
- def merge_metadata(o, **ret):
- for key in VALID_JSON_KEYS:
- if key in o:
- ret[key] = o[key]
- return ret
-
- for sv, cl in product(servers, clients):
- for proto, proto1, proto2 in intersect_with_spec('protocols', sv, cl):
- for trans, trans1, trans2 in intersect_with_spec('transports', sv, cl):
- for sock in intersection('sockets', sv, cl):
- yield {
- 'server': merge_metadata(sv, **{'protocol': proto1, 'transport': trans1}),
- 'client': merge_metadata(cl, **{'protocol': proto2, 'transport': trans2}),
- 'delay': maybe_max('delay', sv, cl, DEFAULT_DELAY),
- 'timeout': maybe_max('timeout', sv, cl, DEFAULT_TIMEOUT),
- 'protocol': proto,
- 'transport': trans,
- 'socket': sock
- }
+ def intersection(key, o1, o2):
+ """intersection of two collections.
+ collections are replaced with sets the first time"""
+ def cached_set(o, key):
+ v = o[key]
+ if not isinstance(v, set):
+ v = set(v)
+ o[key] = v
+ return v
+ return cached_set(o1, key) & cached_set(o2, key)
+
+ def intersect_with_spec(key, o1, o2):
+ # store as set of (spec, impl) tuple
+ def cached_set(o):
+ def to_spec_impl_tuples(values):
+ for v in values:
+ spec, _, impl = v.partition(':')
+ yield spec, impl or spec
+ v = o[key]
+ if not isinstance(v, set):
+ v = set(to_spec_impl_tuples(set(v)))
+ o[key] = v
+ return v
+ for spec1, impl1 in cached_set(o1):
+ for spec2, impl2 in cached_set(o2):
+ if spec1 == spec2:
+ name = impl1 if impl1 == impl2 else '%s-%s' % (impl1, impl2)
+ yield name, impl1, impl2
+
+ def maybe_max(key, o1, o2, default):
+ """maximum of two if present, otherwise defult value"""
+ v1 = o1.get(key)
+ v2 = o2.get(key)
+ return max(v1, v2) if v1 and v2 else v1 or v2 or default
+
+ def filter_with_validkeys(o):
+ ret = {}
+ for key in VALID_JSON_KEYS:
+ if key in o:
+ ret[key] = o[key]
+ return ret
+
+ def merge_metadata(o, **ret):
+ for key in VALID_JSON_KEYS:
+ if key in o:
+ ret[key] = o[key]
+ return ret
+
+ for sv, cl in product(servers, clients):
+ for proto, proto1, proto2 in intersect_with_spec('protocols', sv, cl):
+ for trans, trans1, trans2 in intersect_with_spec('transports', sv, cl):
+ for sock in intersection('sockets', sv, cl):
+ yield {
+ 'server': merge_metadata(sv, **{'protocol': proto1, 'transport': trans1}),
+ 'client': merge_metadata(cl, **{'protocol': proto2, 'transport': trans2}),
+ 'delay': maybe_max('delay', sv, cl, DEFAULT_DELAY),
+ 'timeout': maybe_max('timeout', sv, cl, DEFAULT_TIMEOUT),
+ 'protocol': proto,
+ 'transport': trans,
+ 'socket': sock
+ }
def collect_cross_tests(tests_dict, server_match, client_match):
- sv, cl = _collect_testlibs(tests_dict, server_match, client_match)
- return list(_do_collect_tests(sv, cl))
+ sv, cl = _collect_testlibs(tests_dict, server_match, client_match)
+ return list(_do_collect_tests(sv, cl))
def collect_feature_tests(tests_dict, features_dict, server_match, feature_match):
- sv, _ = _collect_testlibs(tests_dict, server_match)
- ft = collect_features(features_dict, feature_match)
- return list(_do_collect_tests(sv, ft))
+ sv, _ = _collect_testlibs(tests_dict, server_match)
+ ft = collect_features(features_dict, feature_match)
+ return list(_do_collect_tests(sv, ft))
diff --git a/test/crossrunner/compat.py b/test/crossrunner/compat.py
index 6ab9d713b..f1ca91bb3 100644
--- a/test/crossrunner/compat.py
+++ b/test/crossrunner/compat.py
@@ -2,23 +2,23 @@ import os
import sys
if sys.version_info[0] == 2:
- _ENCODE = sys.getfilesystemencoding()
+ _ENCODE = sys.getfilesystemencoding()
- def path_join(*args):
- bin_args = map(lambda a: a.decode(_ENCODE), args)
- return os.path.join(*bin_args).encode(_ENCODE)
+ def path_join(*args):
+ bin_args = map(lambda a: a.decode(_ENCODE), args)
+ return os.path.join(*bin_args).encode(_ENCODE)
- def str_join(s, l):
- bin_args = map(lambda a: a.decode(_ENCODE), l)
- b = s.decode(_ENCODE)
- return b.join(bin_args).encode(_ENCODE)
+ def str_join(s, l):
+ bin_args = map(lambda a: a.decode(_ENCODE), l)
+ b = s.decode(_ENCODE)
+ return b.join(bin_args).encode(_ENCODE)
- logfile_open = open
+ logfile_open = open
else:
- path_join = os.path.join
- str_join = str.join
+ path_join = os.path.join
+ str_join = str.join
- def logfile_open(*args):
- return open(*args, errors='replace')
+ def logfile_open(*args):
+ return open(*args, errors='replace')
diff --git a/test/crossrunner/report.py b/test/crossrunner/report.py
index be7271cb1..cc5f26fe2 100644
--- a/test/crossrunner/report.py
+++ b/test/crossrunner/report.py
@@ -39,396 +39,396 @@ FAIL_JSON = 'known_failures_%s.json'
def generate_known_failures(testdir, overwrite, save, out):
- def collect_failures(results):
- success_index = 5
- for r in results:
- if not r[success_index]:
- yield TestEntry.get_name(*r)
- try:
- with logfile_open(path_join(testdir, RESULT_JSON), 'r') as fp:
- results = json.load(fp)
- except IOError:
- sys.stderr.write('Unable to load last result. Did you run tests ?\n')
- return False
- fails = collect_failures(results['results'])
- if not overwrite:
- known = load_known_failures(testdir)
- known.extend(fails)
- fails = known
- fails_json = json.dumps(sorted(set(fails)), indent=2, separators=(',', ': '))
- if save:
- with logfile_open(os.path.join(testdir, FAIL_JSON % platform.system()), 'w+') as fp:
- fp.write(fails_json)
- sys.stdout.write('Successfully updated known failures.\n')
- if out:
- sys.stdout.write(fails_json)
- sys.stdout.write('\n')
- return True
+ def collect_failures(results):
+ success_index = 5
+ for r in results:
+ if not r[success_index]:
+ yield TestEntry.get_name(*r)
+ try:
+ with logfile_open(path_join(testdir, RESULT_JSON), 'r') as fp:
+ results = json.load(fp)
+ except IOError:
+ sys.stderr.write('Unable to load last result. Did you run tests ?\n')
+ return False
+ fails = collect_failures(results['results'])
+ if not overwrite:
+ known = load_known_failures(testdir)
+ known.extend(fails)
+ fails = known
+ fails_json = json.dumps(sorted(set(fails)), indent=2, separators=(',', ': '))
+ if save:
+ with logfile_open(os.path.join(testdir, FAIL_JSON % platform.system()), 'w+') as fp:
+ fp.write(fails_json)
+ sys.stdout.write('Successfully updated known failures.\n')
+ if out:
+ sys.stdout.write(fails_json)
+ sys.stdout.write('\n')
+ return True
def load_known_failures(testdir):
- try:
- with logfile_open(path_join(testdir, FAIL_JSON % platform.system()), 'r') as fp:
- return json.load(fp)
- except IOError:
- return []
+ try:
+ with logfile_open(path_join(testdir, FAIL_JSON % platform.system()), 'r') as fp:
+ return json.load(fp)
+ except IOError:
+ return []
class TestReporter(object):
- # Unfortunately, standard library doesn't handle timezone well
- # DATETIME_FORMAT = '%a %b %d %H:%M:%S %Z %Y'
- DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
+ # Unfortunately, standard library doesn't handle timezone well
+ # DATETIME_FORMAT = '%a %b %d %H:%M:%S %Z %Y'
+ DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
- def __init__(self):
- self._log = multiprocessing.get_logger()
- self._lock = multiprocessing.Lock()
+ def __init__(self):
+ self._log = multiprocessing.get_logger()
+ self._lock = multiprocessing.Lock()
- @classmethod
- def test_logfile(cls, test_name, prog_kind, dir=None):
- relpath = path_join('log', '%s_%s.log' % (test_name, prog_kind))
- return relpath if not dir else os.path.realpath(path_join(dir, relpath))
+ @classmethod
+ def test_logfile(cls, test_name, prog_kind, dir=None):
+ relpath = path_join('log', '%s_%s.log' % (test_name, prog_kind))
+ return relpath if not dir else os.path.realpath(path_join(dir, relpath))
- def _start(self):
- self._start_time = time.time()
+ def _start(self):
+ self._start_time = time.time()
- @property
- def _elapsed(self):
- return time.time() - self._start_time
+ @property
+ def _elapsed(self):
+ return time.time() - self._start_time
- @classmethod
- def _format_date(cls):
- return '%s' % datetime.datetime.now().strftime(cls.DATETIME_FORMAT)
+ @classmethod
+ def _format_date(cls):
+ return '%s' % datetime.datetime.now().strftime(cls.DATETIME_FORMAT)
- def _print_date(self):
- print(self._format_date(), file=self.out)
+ def _print_date(self):
+ print(self._format_date(), file=self.out)
- def _print_bar(self, out=None):
- print(
- '==========================================================================',
- file=(out or self.out))
+ def _print_bar(self, out=None):
+ print(
+ '==========================================================================',
+ file=(out or self.out))
- def _print_exec_time(self):
- print('Test execution took {:.1f} seconds.'.format(self._elapsed), file=self.out)
+ def _print_exec_time(self):
+ print('Test execution took {:.1f} seconds.'.format(self._elapsed), file=self.out)
class ExecReporter(TestReporter):
- def __init__(self, testdir, test, prog):
- super(ExecReporter, self).__init__()
- self._test = test
- self._prog = prog
- self.logpath = self.test_logfile(test.name, prog.kind, testdir)
- self.out = None
-
- def begin(self):
- self._start()
- self._open()
- if self.out and not self.out.closed:
- self._print_header()
- else:
- self._log.debug('Output stream is not available.')
-
- def end(self, returncode):
- self._lock.acquire()
- try:
- if self.out and not self.out.closed:
- self._print_footer(returncode)
- self._close()
+ def __init__(self, testdir, test, prog):
+ super(ExecReporter, self).__init__()
+ self._test = test
+ self._prog = prog
+ self.logpath = self.test_logfile(test.name, prog.kind, testdir)
self.out = None
- else:
- self._log.debug('Output stream is not available.')
- finally:
- self._lock.release()
-
- def killed(self):
- print(file=self.out)
- print('Server process is successfully killed.', file=self.out)
- self.end(None)
-
- def died(self):
- print(file=self.out)
- print('*** Server process has died unexpectedly ***', file=self.out)
- self.end(None)
-
- _init_failure_exprs = {
- 'server': list(map(re.compile, [
- '[Aa]ddress already in use',
- 'Could not bind',
- 'EADDRINUSE',
- ])),
- 'client': list(map(re.compile, [
- '[Cc]onnection refused',
- 'Could not connect to localhost',
- 'ECONNREFUSED',
- 'No such file or directory', # domain socket
- ])),
- }
-
- def maybe_false_positive(self):
- """Searches through log file for socket bind error.
- Returns True if suspicious expression is found, otherwise False"""
- try:
- if self.out and not self.out.closed:
+
+ def begin(self):
+ self._start()
+ self._open()
+ if self.out and not self.out.closed:
+ self._print_header()
+ else:
+ self._log.debug('Output stream is not available.')
+
+ def end(self, returncode):
+ self._lock.acquire()
+ try:
+ if self.out and not self.out.closed:
+ self._print_footer(returncode)
+ self._close()
+ self.out = None
+ else:
+ self._log.debug('Output stream is not available.')
+ finally:
+ self._lock.release()
+
+ def killed(self):
+ print(file=self.out)
+ print('Server process is successfully killed.', file=self.out)
+ self.end(None)
+
+ def died(self):
+ print(file=self.out)
+ print('*** Server process has died unexpectedly ***', file=self.out)
+ self.end(None)
+
+ _init_failure_exprs = {
+ 'server': list(map(re.compile, [
+ '[Aa]ddress already in use',
+ 'Could not bind',
+ 'EADDRINUSE',
+ ])),
+ 'client': list(map(re.compile, [
+ '[Cc]onnection refused',
+ 'Could not connect to localhost',
+ 'ECONNREFUSED',
+ 'No such file or directory', # domain socket
+ ])),
+ }
+
+ def maybe_false_positive(self):
+ """Searches through log file for socket bind error.
+ Returns True if suspicious expression is found, otherwise False"""
+ try:
+ if self.out and not self.out.closed:
+ self.out.flush()
+ exprs = self._init_failure_exprs[self._prog.kind]
+
+ def match(line):
+ for expr in exprs:
+ if expr.search(line):
+ return True
+
+ with logfile_open(self.logpath, 'r') as fp:
+ if any(map(match, fp)):
+ return True
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception as ex:
+ self._log.warn('[%s]: Error while detecting false positive: %s' % (self._test.name, str(ex)))
+ self._log.info(traceback.print_exc())
+ return False
+
+ def _open(self):
+ self.out = logfile_open(self.logpath, 'w+')
+
+ def _close(self):
+ self.out.close()
+
+ def _print_header(self):
+ self._print_date()
+ print('Executing: %s' % str_join(' ', self._prog.command), file=self.out)
+ print('Directory: %s' % self._prog.workdir, file=self.out)
+ print('config:delay: %s' % self._test.delay, file=self.out)
+ print('config:timeout: %s' % self._test.timeout, file=self.out)
+ self._print_bar()
self.out.flush()
- exprs = self._init_failure_exprs[self._prog.kind]
-
- def match(line):
- for expr in exprs:
- if expr.search(line):
- return True
-
- with logfile_open(self.logpath, 'r') as fp:
- if any(map(match, fp)):
- return True
- except (KeyboardInterrupt, SystemExit):
- raise
- except Exception as ex:
- self._log.warn('[%s]: Error while detecting false positive: %s' % (self._test.name, str(ex)))
- self._log.info(traceback.print_exc())
- return False
-
- def _open(self):
- self.out = logfile_open(self.logpath, 'w+')
-
- def _close(self):
- self.out.close()
-
- def _print_header(self):
- self._print_date()
- print('Executing: %s' % str_join(' ', self._prog.command), file=self.out)
- print('Directory: %s' % self._prog.workdir, file=self.out)
- print('config:delay: %s' % self._test.delay, file=self.out)
- print('config:timeout: %s' % self._test.timeout, file=self.out)
- self._print_bar()
- self.out.flush()
-
- def _print_footer(self, returncode=None):
- self._print_bar()
- if returncode is not None:
- print('Return code: %d' % returncode, file=self.out)
- else:
- print('Process is killed.', file=self.out)
- self._print_exec_time()
- self._print_date()
+ def _print_footer(self, returncode=None):
+ self._print_bar()
+ if returncode is not None:
+ print('Return code: %d' % returncode, file=self.out)
+ else:
+ print('Process is killed.', file=self.out)
+ self._print_exec_time()
+ self._print_date()
-class SummaryReporter(TestReporter):
- def __init__(self, basedir, testdir_relative, concurrent=True):
- super(SummaryReporter, self).__init__()
- self._basedir = basedir
- self._testdir_rel = testdir_relative
- self.logdir = path_join(self.testdir, LOG_DIR)
- self.out_path = path_join(self.testdir, RESULT_JSON)
- self.concurrent = concurrent
- self.out = sys.stdout
- self._platform = platform.system()
- self._revision = self._get_revision()
- self._tests = []
- if not os.path.exists(self.logdir):
- os.mkdir(self.logdir)
- self._known_failures = load_known_failures(self.testdir)
- self._unexpected_success = []
- self._flaky_success = []
- self._unexpected_failure = []
- self._expected_failure = []
- self._print_header()
-
- @property
- def testdir(self):
- return path_join(self._basedir, self._testdir_rel)
-
- def _result_string(self, test):
- if test.success:
- if test.retry_count == 0:
- return 'success'
- elif test.retry_count == 1:
- return 'flaky(1 retry)'
- else:
- return 'flaky(%d retries)' % test.retry_count
- elif test.expired:
- return 'failure(timeout)'
- else:
- return 'failure(%d)' % test.returncode
-
- def _get_revision(self):
- p = subprocess.Popen(['git', 'rev-parse', '--short', 'HEAD'],
- cwd=self.testdir, stdout=subprocess.PIPE)
- out, _ = p.communicate()
- return out.strip()
-
- def _format_test(self, test, with_result=True):
- name = '%s-%s' % (test.server.name, test.client.name)
- trans = '%s-%s' % (test.transport, test.socket)
- if not with_result:
- return '{:24s}{:13s}{:25s}'.format(name[:23], test.protocol[:12], trans[:24])
- else:
- return '{:24s}{:13s}{:25s}{:s}\n'.format(name[:23], test.protocol[:12], trans[:24], self._result_string(test))
-
- def _print_test_header(self):
- self._print_bar()
- print(
- '{:24s}{:13s}{:25s}{:s}'.format('server-client:', 'protocol:', 'transport:', 'result:'),
- file=self.out)
-
- def _print_header(self):
- self._start()
- print('Apache Thrift - Integration Test Suite', file=self.out)
- self._print_date()
- self._print_test_header()
-
- def _print_unexpected_failure(self):
- if len(self._unexpected_failure) > 0:
- self.out.writelines([
- '*** Following %d failures were unexpected ***:\n' % len(self._unexpected_failure),
- 'If it is introduced by you, please fix it before submitting the code.\n',
- # 'If not, please report at https://issues.apache.org/jira/browse/THRIFT\n',
- ])
- self._print_test_header()
- for i in self._unexpected_failure:
- self.out.write(self._format_test(self._tests[i]))
- self._print_bar()
- else:
- print('No unexpected failures.', file=self.out)
-
- def _print_flaky_success(self):
- if len(self._flaky_success) > 0:
- print(
- 'Following %d tests were expected to cleanly succeed but needed retry:' % len(self._flaky_success),
- file=self.out)
- self._print_test_header()
- for i in self._flaky_success:
- self.out.write(self._format_test(self._tests[i]))
- self._print_bar()
-
- def _print_unexpected_success(self):
- if len(self._unexpected_success) > 0:
- print(
- 'Following %d tests were known to fail but succeeded (maybe flaky):' % len(self._unexpected_success),
- file=self.out)
- self._print_test_header()
- for i in self._unexpected_success:
- self.out.write(self._format_test(self._tests[i]))
- self._print_bar()
-
- def _http_server_command(self, port):
- if sys.version_info[0] < 3:
- return 'python -m SimpleHTTPServer %d' % port
- else:
- return 'python -m http.server %d' % port
-
- def _print_footer(self):
- fail_count = len(self._expected_failure) + len(self._unexpected_failure)
- self._print_bar()
- self._print_unexpected_success()
- self._print_flaky_success()
- self._print_unexpected_failure()
- self._write_html_data()
- self._assemble_log('unexpected failures', self._unexpected_failure)
- self._assemble_log('known failures', self._expected_failure)
- self.out.writelines([
- 'You can browse results at:\n',
- '\tfile://%s/%s\n' % (self.testdir, RESULT_HTML),
- '# If you use Chrome, run:\n',
- '# \tcd %s\n#\t%s\n' % (self._basedir, self._http_server_command(8001)),
- '# then browse:\n',
- '# \thttp://localhost:%d/%s/\n' % (8001, self._testdir_rel),
- 'Full log for each test is here:\n',
- '\ttest/log/client_server_protocol_transport_client.log\n',
- '\ttest/log/client_server_protocol_transport_server.log\n',
- '%d failed of %d tests in total.\n' % (fail_count, len(self._tests)),
- ])
- self._print_exec_time()
- self._print_date()
-
- def _render_result(self, test):
- return [
- test.server.name,
- test.client.name,
- test.protocol,
- test.transport,
- test.socket,
- test.success,
- test.as_expected,
- test.returncode,
- {
- 'server': self.test_logfile(test.name, test.server.kind),
- 'client': self.test_logfile(test.name, test.client.kind),
- },
- ]
-
- def _write_html_data(self):
- """Writes JSON data to be read by result html"""
- results = [self._render_result(r) for r in self._tests]
- with logfile_open(self.out_path, 'w+') as fp:
- fp.write(json.dumps({
- 'date': self._format_date(),
- 'revision': str(self._revision),
- 'platform': self._platform,
- 'duration': '{:.1f}'.format(self._elapsed),
- 'results': results,
- }, indent=2))
-
- def _assemble_log(self, title, indexes):
- if len(indexes) > 0:
- def add_prog_log(fp, test, prog_kind):
- print('*************************** %s message ***************************' % prog_kind,
- file=fp)
- path = self.test_logfile(test.name, prog_kind, self.testdir)
- if os.path.exists(path):
- with logfile_open(path, 'r') as prog_fp:
- print(prog_fp.read(), file=fp)
- filename = title.replace(' ', '_') + '.log'
- with logfile_open(os.path.join(self.logdir, filename), 'w+') as fp:
- for test in map(self._tests.__getitem__, indexes):
- fp.write('TEST: [%s]\n' % test.name)
- add_prog_log(fp, test, test.server.kind)
- add_prog_log(fp, test, test.client.kind)
- fp.write('**********************************************************************\n\n')
- print('%s are logged to %s/%s/%s' % (title.capitalize(), self._testdir_rel, LOG_DIR, filename))
-
- def end(self):
- self._print_footer()
- return len(self._unexpected_failure) == 0
-
- def add_test(self, test_dict):
- test = TestEntry(self.testdir, **test_dict)
- self._lock.acquire()
- try:
- if not self.concurrent:
- self.out.write(self._format_test(test, False))
- self.out.flush()
- self._tests.append(test)
- return len(self._tests) - 1
- finally:
- self._lock.release()
- def add_result(self, index, returncode, expired, retry_count):
- self._lock.acquire()
- try:
- failed = returncode is None or returncode != 0
- flaky = not failed and retry_count != 0
- test = self._tests[index]
- known = test.name in self._known_failures
- if failed:
- if known:
- self._log.debug('%s failed as expected' % test.name)
- self._expected_failure.append(index)
+class SummaryReporter(TestReporter):
+ def __init__(self, basedir, testdir_relative, concurrent=True):
+ super(SummaryReporter, self).__init__()
+ self._basedir = basedir
+ self._testdir_rel = testdir_relative
+ self.logdir = path_join(self.testdir, LOG_DIR)
+ self.out_path = path_join(self.testdir, RESULT_JSON)
+ self.concurrent = concurrent
+ self.out = sys.stdout
+ self._platform = platform.system()
+ self._revision = self._get_revision()
+ self._tests = []
+ if not os.path.exists(self.logdir):
+ os.mkdir(self.logdir)
+ self._known_failures = load_known_failures(self.testdir)
+ self._unexpected_success = []
+ self._flaky_success = []
+ self._unexpected_failure = []
+ self._expected_failure = []
+ self._print_header()
+
+ @property
+ def testdir(self):
+ return path_join(self._basedir, self._testdir_rel)
+
+ def _result_string(self, test):
+ if test.success:
+ if test.retry_count == 0:
+ return 'success'
+ elif test.retry_count == 1:
+ return 'flaky(1 retry)'
+ else:
+ return 'flaky(%d retries)' % test.retry_count
+ elif test.expired:
+ return 'failure(timeout)'
+ else:
+ return 'failure(%d)' % test.returncode
+
+ def _get_revision(self):
+ p = subprocess.Popen(['git', 'rev-parse', '--short', 'HEAD'],
+ cwd=self.testdir, stdout=subprocess.PIPE)
+ out, _ = p.communicate()
+ return out.strip()
+
+ def _format_test(self, test, with_result=True):
+ name = '%s-%s' % (test.server.name, test.client.name)
+ trans = '%s-%s' % (test.transport, test.socket)
+ if not with_result:
+ return '{:24s}{:13s}{:25s}'.format(name[:23], test.protocol[:12], trans[:24])
+ else:
+ return '{:24s}{:13s}{:25s}{:s}\n'.format(name[:23], test.protocol[:12], trans[:24], self._result_string(test))
+
+ def _print_test_header(self):
+ self._print_bar()
+ print(
+ '{:24s}{:13s}{:25s}{:s}'.format('server-client:', 'protocol:', 'transport:', 'result:'),
+ file=self.out)
+
+ def _print_header(self):
+ self._start()
+ print('Apache Thrift - Integration Test Suite', file=self.out)
+ self._print_date()
+ self._print_test_header()
+
+ def _print_unexpected_failure(self):
+ if len(self._unexpected_failure) > 0:
+ self.out.writelines([
+ '*** Following %d failures were unexpected ***:\n' % len(self._unexpected_failure),
+ 'If it is introduced by you, please fix it before submitting the code.\n',
+ # 'If not, please report at https://issues.apache.org/jira/browse/THRIFT\n',
+ ])
+ self._print_test_header()
+ for i in self._unexpected_failure:
+ self.out.write(self._format_test(self._tests[i]))
+ self._print_bar()
+ else:
+ print('No unexpected failures.', file=self.out)
+
+ def _print_flaky_success(self):
+ if len(self._flaky_success) > 0:
+ print(
+ 'Following %d tests were expected to cleanly succeed but needed retry:' % len(self._flaky_success),
+ file=self.out)
+ self._print_test_header()
+ for i in self._flaky_success:
+ self.out.write(self._format_test(self._tests[i]))
+ self._print_bar()
+
+ def _print_unexpected_success(self):
+ if len(self._unexpected_success) > 0:
+ print(
+ 'Following %d tests were known to fail but succeeded (maybe flaky):' % len(self._unexpected_success),
+ file=self.out)
+ self._print_test_header()
+ for i in self._unexpected_success:
+ self.out.write(self._format_test(self._tests[i]))
+ self._print_bar()
+
+ def _http_server_command(self, port):
+ if sys.version_info[0] < 3:
+ return 'python -m SimpleHTTPServer %d' % port
else:
- self._log.info('unexpected failure: %s' % test.name)
- self._unexpected_failure.append(index)
- elif flaky and not known:
- self._log.info('unexpected flaky success: %s' % test.name)
- self._flaky_success.append(index)
- elif not flaky and known:
- self._log.info('unexpected success: %s' % test.name)
- self._unexpected_success.append(index)
- test.success = not failed
- test.returncode = returncode
- test.retry_count = retry_count
- test.expired = expired
- test.as_expected = known == failed
- if not self.concurrent:
- self.out.write(self._result_string(test) + '\n')
- else:
- self.out.write(self._format_test(test))
- finally:
- self._lock.release()
+ return 'python -m http.server %d' % port
+
+ def _print_footer(self):
+ fail_count = len(self._expected_failure) + len(self._unexpected_failure)
+ self._print_bar()
+ self._print_unexpected_success()
+ self._print_flaky_success()
+ self._print_unexpected_failure()
+ self._write_html_data()
+ self._assemble_log('unexpected failures', self._unexpected_failure)
+ self._assemble_log('known failures', self._expected_failure)
+ self.out.writelines([
+ 'You can browse results at:\n',
+ '\tfile://%s/%s\n' % (self.testdir, RESULT_HTML),
+ '# If you use Chrome, run:\n',
+ '# \tcd %s\n#\t%s\n' % (self._basedir, self._http_server_command(8001)),
+ '# then browse:\n',
+ '# \thttp://localhost:%d/%s/\n' % (8001, self._testdir_rel),
+ 'Full log for each test is here:\n',
+ '\ttest/log/client_server_protocol_transport_client.log\n',
+ '\ttest/log/client_server_protocol_transport_server.log\n',
+ '%d failed of %d tests in total.\n' % (fail_count, len(self._tests)),
+ ])
+ self._print_exec_time()
+ self._print_date()
+
+ def _render_result(self, test):
+ return [
+ test.server.name,
+ test.client.name,
+ test.protocol,
+ test.transport,
+ test.socket,
+ test.success,
+ test.as_expected,
+ test.returncode,
+ {
+ 'server': self.test_logfile(test.name, test.server.kind),
+ 'client': self.test_logfile(test.name, test.client.kind),
+ },
+ ]
+
+ def _write_html_data(self):
+ """Writes JSON data to be read by result html"""
+ results = [self._render_result(r) for r in self._tests]
+ with logfile_open(self.out_path, 'w+') as fp:
+ fp.write(json.dumps({
+ 'date': self._format_date(),
+ 'revision': str(self._revision),
+ 'platform': self._platform,
+ 'duration': '{:.1f}'.format(self._elapsed),
+ 'results': results,
+ }, indent=2))
+
+ def _assemble_log(self, title, indexes):
+ if len(indexes) > 0:
+ def add_prog_log(fp, test, prog_kind):
+ print('*************************** %s message ***************************' % prog_kind,
+ file=fp)
+ path = self.test_logfile(test.name, prog_kind, self.testdir)
+ if os.path.exists(path):
+ with logfile_open(path, 'r') as prog_fp:
+ print(prog_fp.read(), file=fp)
+ filename = title.replace(' ', '_') + '.log'
+ with logfile_open(os.path.join(self.logdir, filename), 'w+') as fp:
+ for test in map(self._tests.__getitem__, indexes):
+ fp.write('TEST: [%s]\n' % test.name)
+ add_prog_log(fp, test, test.server.kind)
+ add_prog_log(fp, test, test.client.kind)
+ fp.write('**********************************************************************\n\n')
+ print('%s are logged to %s/%s/%s' % (title.capitalize(), self._testdir_rel, LOG_DIR, filename))
+
+ def end(self):
+ self._print_footer()
+ return len(self._unexpected_failure) == 0
+
+ def add_test(self, test_dict):
+ test = TestEntry(self.testdir, **test_dict)
+ self._lock.acquire()
+ try:
+ if not self.concurrent:
+ self.out.write(self._format_test(test, False))
+ self.out.flush()
+ self._tests.append(test)
+ return len(self._tests) - 1
+ finally:
+ self._lock.release()
+
+ def add_result(self, index, returncode, expired, retry_count):
+ self._lock.acquire()
+ try:
+ failed = returncode is None or returncode != 0
+ flaky = not failed and retry_count != 0
+ test = self._tests[index]
+ known = test.name in self._known_failures
+ if failed:
+ if known:
+ self._log.debug('%s failed as expected' % test.name)
+ self._expected_failure.append(index)
+ else:
+ self._log.info('unexpected failure: %s' % test.name)
+ self._unexpected_failure.append(index)
+ elif flaky and not known:
+ self._log.info('unexpected flaky success: %s' % test.name)
+ self._flaky_success.append(index)
+ elif not flaky and known:
+ self._log.info('unexpected success: %s' % test.name)
+ self._unexpected_success.append(index)
+ test.success = not failed
+ test.returncode = returncode
+ test.retry_count = retry_count
+ test.expired = expired
+ test.as_expected = known == failed
+ if not self.concurrent:
+ self.out.write(self._result_string(test) + '\n')
+ else:
+ self.out.write(self._format_test(test))
+ finally:
+ self._lock.release()
diff --git a/test/crossrunner/run.py b/test/crossrunner/run.py
index 68bd92869..18c162357 100644
--- a/test/crossrunner/run.py
+++ b/test/crossrunner/run.py
@@ -39,307 +39,307 @@ RESULT_ERROR = 64
class ExecutionContext(object):
- def __init__(self, cmd, cwd, env, report):
- self._log = multiprocessing.get_logger()
- self.report = report
- self.cmd = cmd
- self.cwd = cwd
- self.env = env
- self.timer = None
- self.expired = False
- self.killed = False
-
- def _expire(self):
- self._log.info('Timeout')
- self.expired = True
- self.kill()
-
- def kill(self):
- self._log.debug('Killing process : %d' % self.proc.pid)
- self.killed = True
- if platform.system() != 'Windows':
- try:
- os.killpg(self.proc.pid, signal.SIGKILL)
- except Exception:
- self._log.info('Failed to kill process group', exc_info=sys.exc_info())
- try:
- self.proc.kill()
- except Exception:
- self._log.info('Failed to kill process', exc_info=sys.exc_info())
-
- def _popen_args(self):
- args = {
- 'cwd': self.cwd,
- 'env': self.env,
- 'stdout': self.report.out,
- 'stderr': subprocess.STDOUT,
- }
- # make sure child processes doesn't remain after killing
- if platform.system() == 'Windows':
- DETACHED_PROCESS = 0x00000008
- args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP)
- else:
- args.update(preexec_fn=os.setsid)
- return args
-
- def start(self, timeout=0):
- joined = str_join(' ', self.cmd)
- self._log.debug('COMMAND: %s', joined)
- self._log.debug('WORKDIR: %s', self.cwd)
- self._log.debug('LOGFILE: %s', self.report.logpath)
- self.report.begin()
- self.proc = subprocess.Popen(self.cmd, **self._popen_args())
- if timeout > 0:
- self.timer = threading.Timer(timeout, self._expire)
- self.timer.start()
- return self._scoped()
-
- @contextlib.contextmanager
- def _scoped(self):
- yield self
- self._log.debug('Killing scoped process')
- if self.proc.poll() is None:
- self.kill()
- self.report.killed()
- else:
- self._log.debug('Process died unexpectedly')
- self.report.died()
-
- def wait(self):
- self.proc.communicate()
- if self.timer:
- self.timer.cancel()
- self.report.end(self.returncode)
-
- @property
- def returncode(self):
- return self.proc.returncode if self.proc else None
+ def __init__(self, cmd, cwd, env, report):
+ self._log = multiprocessing.get_logger()
+ self.report = report
+ self.cmd = cmd
+ self.cwd = cwd
+ self.env = env
+ self.timer = None
+ self.expired = False
+ self.killed = False
+
+ def _expire(self):
+ self._log.info('Timeout')
+ self.expired = True
+ self.kill()
+
+ def kill(self):
+ self._log.debug('Killing process : %d' % self.proc.pid)
+ self.killed = True
+ if platform.system() != 'Windows':
+ try:
+ os.killpg(self.proc.pid, signal.SIGKILL)
+ except Exception:
+ self._log.info('Failed to kill process group', exc_info=sys.exc_info())
+ try:
+ self.proc.kill()
+ except Exception:
+ self._log.info('Failed to kill process', exc_info=sys.exc_info())
+
+ def _popen_args(self):
+ args = {
+ 'cwd': self.cwd,
+ 'env': self.env,
+ 'stdout': self.report.out,
+ 'stderr': subprocess.STDOUT,
+ }
+ # make sure child processes doesn't remain after killing
+ if platform.system() == 'Windows':
+ DETACHED_PROCESS = 0x00000008
+ args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP)
+ else:
+ args.update(preexec_fn=os.setsid)
+ return args
+
+ def start(self, timeout=0):
+ joined = str_join(' ', self.cmd)
+ self._log.debug('COMMAND: %s', joined)
+ self._log.debug('WORKDIR: %s', self.cwd)
+ self._log.debug('LOGFILE: %s', self.report.logpath)
+ self.report.begin()
+ self.proc = subprocess.Popen(self.cmd, **self._popen_args())
+ if timeout > 0:
+ self.timer = threading.Timer(timeout, self._expire)
+ self.timer.start()
+ return self._scoped()
+
+ @contextlib.contextmanager
+ def _scoped(self):
+ yield self
+ self._log.debug('Killing scoped process')
+ if self.proc.poll() is None:
+ self.kill()
+ self.report.killed()
+ else:
+ self._log.debug('Process died unexpectedly')
+ self.report.died()
+
+ def wait(self):
+ self.proc.communicate()
+ if self.timer:
+ self.timer.cancel()
+ self.report.end(self.returncode)
+
+ @property
+ def returncode(self):
+ return self.proc.returncode if self.proc else None
def exec_context(port, logdir, test, prog):
- report = ExecReporter(logdir, test, prog)
- prog.build_command(port)
- return ExecutionContext(prog.command, prog.workdir, prog.env, report)
+ report = ExecReporter(logdir, test, prog)
+ prog.build_command(port)
+ return ExecutionContext(prog.command, prog.workdir, prog.env, report)
def run_test(testdir, logdir, test_dict, max_retry, async=True):
- try:
- logger = multiprocessing.get_logger()
- max_bind_retry = 3
- retry_count = 0
- bind_retry_count = 0
- test = TestEntry(testdir, **test_dict)
- while True:
- if stop.is_set():
- logger.debug('Skipping because shutting down')
- return (retry_count, None)
- logger.debug('Start')
- with PortAllocator.alloc_port_scoped(ports, test.socket) as port:
- logger.debug('Start with port %d' % port)
- sv = exec_context(port, logdir, test, test.server)
- cl = exec_context(port, logdir, test, test.client)
-
- logger.debug('Starting server')
- with sv.start():
- if test.delay > 0:
- logger.debug('Delaying client for %.2f seconds' % test.delay)
- time.sleep(test.delay)
- connect_retry_count = 0
- max_connect_retry = 10
- connect_retry_wait = 0.5
- while True:
- logger.debug('Starting client')
- cl.start(test.timeout)
- logger.debug('Waiting client')
- cl.wait()
- if not cl.report.maybe_false_positive() or connect_retry_count >= max_connect_retry:
- if connect_retry_count > 0 and connect_retry_count < max_connect_retry:
- logger.warn('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, connect_retry_count, connect_retry_wait))
- # Wait for 50ms to see if server does not die at the end.
- time.sleep(0.05)
- break
- logger.debug('Server may not be ready, waiting %.2f second...' % connect_retry_wait)
- time.sleep(connect_retry_wait)
- connect_retry_count += 1
-
- if sv.report.maybe_false_positive() and bind_retry_count < max_bind_retry:
- logger.warn('[%s]: Detected socket bind failure, retrying...', test.server.name)
- bind_retry_count += 1
- else:
- if cl.expired:
- result = RESULT_TIMEOUT
- elif not sv.killed and cl.proc.returncode == 0:
- # Server should be alive at the end.
- result = RESULT_ERROR
- else:
- result = cl.proc.returncode
-
- if result == 0 or retry_count >= max_retry:
- return (retry_count, result)
- else:
- logger.info('[%s-%s]: test failed, retrying...', test.server.name, test.client.name)
- retry_count += 1
- except (KeyboardInterrupt, SystemExit):
- logger.info('Interrupted execution')
- if not async:
- raise
- stop.set()
- return None
- except:
- if not async:
- raise
- logger.warn('Error executing [%s]', test.name, exc_info=sys.exc_info())
- return (retry_count, RESULT_ERROR)
+ try:
+ logger = multiprocessing.get_logger()
+ max_bind_retry = 3
+ retry_count = 0
+ bind_retry_count = 0
+ test = TestEntry(testdir, **test_dict)
+ while True:
+ if stop.is_set():
+ logger.debug('Skipping because shutting down')
+ return (retry_count, None)
+ logger.debug('Start')
+ with PortAllocator.alloc_port_scoped(ports, test.socket) as port:
+ logger.debug('Start with port %d' % port)
+ sv = exec_context(port, logdir, test, test.server)
+ cl = exec_context(port, logdir, test, test.client)
+
+ logger.debug('Starting server')
+ with sv.start():
+ if test.delay > 0:
+ logger.debug('Delaying client for %.2f seconds' % test.delay)
+ time.sleep(test.delay)
+ connect_retry_count = 0
+ max_connect_retry = 10
+ connect_retry_wait = 0.5
+ while True:
+ logger.debug('Starting client')
+ cl.start(test.timeout)
+ logger.debug('Waiting client')
+ cl.wait()
+ if not cl.report.maybe_false_positive() or connect_retry_count >= max_connect_retry:
+ if connect_retry_count > 0 and connect_retry_count < max_connect_retry:
+ logger.warn('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, connect_retry_count, connect_retry_wait))
+ # Wait for 50ms to see if server does not die at the end.
+ time.sleep(0.05)
+ break
+ logger.debug('Server may not be ready, waiting %.2f second...' % connect_retry_wait)
+ time.sleep(connect_retry_wait)
+ connect_retry_count += 1
+
+ if sv.report.maybe_false_positive() and bind_retry_count < max_bind_retry:
+ logger.warn('[%s]: Detected socket bind failure, retrying...', test.server.name)
+ bind_retry_count += 1
+ else:
+ if cl.expired:
+ result = RESULT_TIMEOUT
+ elif not sv.killed and cl.proc.returncode == 0:
+ # Server should be alive at the end.
+ result = RESULT_ERROR
+ else:
+ result = cl.proc.returncode
+
+ if result == 0 or retry_count >= max_retry:
+ return (retry_count, result)
+ else:
+ logger.info('[%s-%s]: test failed, retrying...', test.server.name, test.client.name)
+ retry_count += 1
+ except (KeyboardInterrupt, SystemExit):
+ logger.info('Interrupted execution')
+ if not async:
+ raise
+ stop.set()
+ return None
+ except:
+ if not async:
+ raise
+ logger.warn('Error executing [%s]', test.name, exc_info=sys.exc_info())
+ return (retry_count, RESULT_ERROR)
class PortAllocator(object):
- def __init__(self):
- self._log = multiprocessing.get_logger()
- self._lock = multiprocessing.Lock()
- self._ports = set()
- self._dom_ports = set()
- self._last_alloc = 0
-
- def _get_tcp_port(self):
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind(('127.0.0.1', 0))
- port = sock.getsockname()[1]
- self._lock.acquire()
- try:
- ok = port not in self._ports
- if ok:
- self._ports.add(port)
- self._last_alloc = time.time()
- finally:
- self._lock.release()
- sock.close()
- return port if ok else self._get_tcp_port()
-
- def _get_domain_port(self):
- port = random.randint(1024, 65536)
- self._lock.acquire()
- try:
- ok = port not in self._dom_ports
- if ok:
- self._dom_ports.add(port)
- finally:
- self._lock.release()
- return port if ok else self._get_domain_port()
-
- def alloc_port(self, socket_type):
- if socket_type in ('domain', 'abstract'):
- return self._get_domain_port()
- else:
- return self._get_tcp_port()
-
- # static method for inter-process invokation
- @staticmethod
- @contextlib.contextmanager
- def alloc_port_scoped(allocator, socket_type):
- port = allocator.alloc_port(socket_type)
- yield port
- allocator.free_port(socket_type, port)
-
- def free_port(self, socket_type, port):
- self._log.debug('free_port')
- self._lock.acquire()
- try:
- if socket_type == 'domain':
- self._dom_ports.remove(port)
- path = domain_socket_path(port)
- if os.path.exists(path):
- os.remove(path)
- elif socket_type == 'abstract':
- self._dom_ports.remove(port)
- else:
- self._ports.remove(port)
- except IOError:
- self._log.info('Error while freeing port', exc_info=sys.exc_info())
- finally:
- self._lock.release()
+ def __init__(self):
+ self._log = multiprocessing.get_logger()
+ self._lock = multiprocessing.Lock()
+ self._ports = set()
+ self._dom_ports = set()
+ self._last_alloc = 0
+
+ def _get_tcp_port(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ port = sock.getsockname()[1]
+ self._lock.acquire()
+ try:
+ ok = port not in self._ports
+ if ok:
+ self._ports.add(port)
+ self._last_alloc = time.time()
+ finally:
+ self._lock.release()
+ sock.close()
+ return port if ok else self._get_tcp_port()
+
+ def _get_domain_port(self):
+ port = random.randint(1024, 65536)
+ self._lock.acquire()
+ try:
+ ok = port not in self._dom_ports
+ if ok:
+ self._dom_ports.add(port)
+ finally:
+ self._lock.release()
+ return port if ok else self._get_domain_port()
+
+ def alloc_port(self, socket_type):
+ if socket_type in ('domain', 'abstract'):
+ return self._get_domain_port()
+ else:
+ return self._get_tcp_port()
+
+ # static method for inter-process invokation
+ @staticmethod
+ @contextlib.contextmanager
+ def alloc_port_scoped(allocator, socket_type):
+ port = allocator.alloc_port(socket_type)
+ yield port
+ allocator.free_port(socket_type, port)
+
+ def free_port(self, socket_type, port):
+ self._log.debug('free_port')
+ self._lock.acquire()
+ try:
+ if socket_type == 'domain':
+ self._dom_ports.remove(port)
+ path = domain_socket_path(port)
+ if os.path.exists(path):
+ os.remove(path)
+ elif socket_type == 'abstract':
+ self._dom_ports.remove(port)
+ else:
+ self._ports.remove(port)
+ except IOError:
+ self._log.info('Error while freeing port', exc_info=sys.exc_info())
+ finally:
+ self._lock.release()
class NonAsyncResult(object):
- def __init__(self, value):
- self._value = value
+ def __init__(self, value):
+ self._value = value
- def get(self, timeout=None):
- return self._value
+ def get(self, timeout=None):
+ return self._value
- def wait(self, timeout=None):
- pass
+ def wait(self, timeout=None):
+ pass
- def ready(self):
- return True
+ def ready(self):
+ return True
- def successful(self):
- return self._value == 0
+ def successful(self):
+ return self._value == 0
class TestDispatcher(object):
- def __init__(self, testdir, basedir, logdir_rel, concurrency):
- self._log = multiprocessing.get_logger()
- self.testdir = testdir
- self._report = SummaryReporter(basedir, logdir_rel, concurrency > 1)
- self.logdir = self._report.testdir
- # seems needed for python 2.x to handle keyboard interrupt
- self._stop = multiprocessing.Event()
- self._async = concurrency > 1
- if not self._async:
- self._pool = None
- global stop
- global ports
- stop = self._stop
- ports = PortAllocator()
- else:
- self._m = multiprocessing.managers.BaseManager()
- self._m.register('ports', PortAllocator)
- self._m.start()
- self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,))
- self._log.debug(
- 'TestDispatcher started with %d concurrent jobs' % concurrency)
-
- def _pool_init(self, address):
- global stop
- global m
- global ports
- stop = self._stop
- m = multiprocessing.managers.BaseManager(address)
- m.connect()
- ports = m.ports()
-
- def _dispatch_sync(self, test, cont, max_retry):
- r = run_test(self.testdir, self.logdir, test, max_retry, False)
- cont(r)
- return NonAsyncResult(r)
-
- def _dispatch_async(self, test, cont, max_retry):
- self._log.debug('_dispatch_async')
- return self._pool.apply_async(func=run_test, args=(self.testdir, self.logdir, test, max_retry), callback=cont)
-
- def dispatch(self, test, max_retry):
- index = self._report.add_test(test)
-
- def cont(result):
- if not self._stop.is_set():
- retry_count, returncode = result
- self._log.debug('freeing port')
- self._log.debug('adding result')
- self._report.add_result(index, returncode, returncode == RESULT_TIMEOUT, retry_count)
- self._log.debug('finish continuation')
- fn = self._dispatch_async if self._async else self._dispatch_sync
- return fn(test, cont, max_retry)
-
- def wait(self):
- if self._async:
- self._pool.close()
- self._pool.join()
- self._m.shutdown()
- return self._report.end()
-
- def terminate(self):
- self._stop.set()
- if self._async:
- self._pool.terminate()
- self._pool.join()
- self._m.shutdown()
+ def __init__(self, testdir, basedir, logdir_rel, concurrency):
+ self._log = multiprocessing.get_logger()
+ self.testdir = testdir
+ self._report = SummaryReporter(basedir, logdir_rel, concurrency > 1)
+ self.logdir = self._report.testdir
+ # seems needed for python 2.x to handle keyboard interrupt
+ self._stop = multiprocessing.Event()
+ self._async = concurrency > 1
+ if not self._async:
+ self._pool = None
+ global stop
+ global ports
+ stop = self._stop
+ ports = PortAllocator()
+ else:
+ self._m = multiprocessing.managers.BaseManager()
+ self._m.register('ports', PortAllocator)
+ self._m.start()
+ self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,))
+ self._log.debug(
+ 'TestDispatcher started with %d concurrent jobs' % concurrency)
+
+ def _pool_init(self, address):
+ global stop
+ global m
+ global ports
+ stop = self._stop
+ m = multiprocessing.managers.BaseManager(address)
+ m.connect()
+ ports = m.ports()
+
+ def _dispatch_sync(self, test, cont, max_retry):
+ r = run_test(self.testdir, self.logdir, test, max_retry, False)
+ cont(r)
+ return NonAsyncResult(r)
+
+ def _dispatch_async(self, test, cont, max_retry):
+ self._log.debug('_dispatch_async')
+ return self._pool.apply_async(func=run_test, args=(self.testdir, self.logdir, test, max_retry), callback=cont)
+
+ def dispatch(self, test, max_retry):
+ index = self._report.add_test(test)
+
+ def cont(result):
+ if not self._stop.is_set():
+ retry_count, returncode = result
+ self._log.debug('freeing port')
+ self._log.debug('adding result')
+ self._report.add_result(index, returncode, returncode == RESULT_TIMEOUT, retry_count)
+ self._log.debug('finish continuation')
+ fn = self._dispatch_async if self._async else self._dispatch_sync
+ return fn(test, cont, max_retry)
+
+ def wait(self):
+ if self._async:
+ self._pool.close()
+ self._pool.join()
+ self._m.shutdown()
+ return self._report.end()
+
+ def terminate(self):
+ self._stop.set()
+ if self._async:
+ self._pool.terminate()
+ self._pool.join()
+ self._m.shutdown()
diff --git a/test/crossrunner/test.py b/test/crossrunner/test.py
index fc90f7f30..dcc8a9416 100644
--- a/test/crossrunner/test.py
+++ b/test/crossrunner/test.py
@@ -26,118 +26,118 @@ from .util import merge_dict
def domain_socket_path(port):
- return '/tmp/ThriftTest.thrift.%d' % port
+ return '/tmp/ThriftTest.thrift.%d' % port
class TestProgram(object):
- def __init__(self, kind, name, protocol, transport, socket, workdir, command, env=None,
- extra_args=[], extra_args2=[], join_args=False, **kwargs):
- self.kind = kind
- self.name = name
- self.protocol = protocol
- self.transport = transport
- self.socket = socket
- self.workdir = workdir
- self.command = None
- self._base_command = self._fix_cmd_path(command)
- if env:
- self.env = copy.copy(os.environ)
- self.env.update(env)
- else:
- self.env = os.environ
- self._extra_args = extra_args
- self._extra_args2 = extra_args2
- self._join_args = join_args
-
- def _fix_cmd_path(self, cmd):
- # if the arg is a file in the current directory, make it path
- def abs_if_exists(arg):
- p = path_join(self.workdir, arg)
- return p if os.path.exists(p) else arg
-
- if cmd[0] == 'python':
- cmd[0] = sys.executable
- else:
- cmd[0] = abs_if_exists(cmd[0])
- return cmd
-
- def _socket_args(self, socket, port):
- return {
- 'ip-ssl': ['--ssl'],
- 'domain': ['--domain-socket=%s' % domain_socket_path(port)],
- 'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)],
- }.get(socket, None)
-
- def build_command(self, port):
- cmd = copy.copy(self._base_command)
- args = copy.copy(self._extra_args2)
- args.append('--protocol=' + self.protocol)
- args.append('--transport=' + self.transport)
- socket_args = self._socket_args(self.socket, port)
- if socket_args:
- args += socket_args
- args.append('--port=%d' % port)
- if self._join_args:
- cmd.append('%s' % " ".join(args))
- else:
- cmd.extend(args)
- if self._extra_args:
- cmd.extend(self._extra_args)
- self.command = cmd
- return self.command
+ def __init__(self, kind, name, protocol, transport, socket, workdir, command, env=None,
+ extra_args=[], extra_args2=[], join_args=False, **kwargs):
+ self.kind = kind
+ self.name = name
+ self.protocol = protocol
+ self.transport = transport
+ self.socket = socket
+ self.workdir = workdir
+ self.command = None
+ self._base_command = self._fix_cmd_path(command)
+ if env:
+ self.env = copy.copy(os.environ)
+ self.env.update(env)
+ else:
+ self.env = os.environ
+ self._extra_args = extra_args
+ self._extra_args2 = extra_args2
+ self._join_args = join_args
+
+ def _fix_cmd_path(self, cmd):
+ # if the arg is a file in the current directory, make it path
+ def abs_if_exists(arg):
+ p = path_join(self.workdir, arg)
+ return p if os.path.exists(p) else arg
+
+ if cmd[0] == 'python':
+ cmd[0] = sys.executable
+ else:
+ cmd[0] = abs_if_exists(cmd[0])
+ return cmd
+
+ def _socket_args(self, socket, port):
+ return {
+ 'ip-ssl': ['--ssl'],
+ 'domain': ['--domain-socket=%s' % domain_socket_path(port)],
+ 'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)],
+ }.get(socket, None)
+
+ def build_command(self, port):
+ cmd = copy.copy(self._base_command)
+ args = copy.copy(self._extra_args2)
+ args.append('--protocol=' + self.protocol)
+ args.append('--transport=' + self.transport)
+ socket_args = self._socket_args(self.socket, port)
+ if socket_args:
+ args += socket_args
+ args.append('--port=%d' % port)
+ if self._join_args:
+ cmd.append('%s' % " ".join(args))
+ else:
+ cmd.extend(args)
+ if self._extra_args:
+ cmd.extend(self._extra_args)
+ self.command = cmd
+ return self.command
class TestEntry(object):
- def __init__(self, testdir, server, client, delay, timeout, **kwargs):
- self.testdir = testdir
- self._log = multiprocessing.get_logger()
- self._config = kwargs
- self.protocol = kwargs['protocol']
- self.transport = kwargs['transport']
- self.socket = kwargs['socket']
- srv_dict = self._fix_workdir(merge_dict(self._config, server))
- cli_dict = self._fix_workdir(merge_dict(self._config, client))
- cli_dict['extra_args2'] = srv_dict.pop('remote_args', [])
- srv_dict['extra_args2'] = cli_dict.pop('remote_args', [])
- self.server = TestProgram('server', **srv_dict)
- self.client = TestProgram('client', **cli_dict)
- self.delay = delay
- self.timeout = timeout
- self._name = None
- # results
- self.success = None
- self.as_expected = None
- self.returncode = None
- self.expired = False
- self.retry_count = 0
-
- def _fix_workdir(self, config):
- key = 'workdir'
- path = config.get(key, None)
- if not path:
- path = self.testdir
- if os.path.isabs(path):
- path = os.path.realpath(path)
- else:
- path = os.path.realpath(path_join(self.testdir, path))
- config.update({key: path})
- return config
-
- @classmethod
- def get_name(cls, server, client, proto, trans, sock, *args):
- return '%s-%s_%s_%s-%s' % (server, client, proto, trans, sock)
-
- @property
- def name(self):
- if not self._name:
- self._name = self.get_name(
- self.server.name, self.client.name, self.protocol, self.transport, self.socket)
- return self._name
-
- @property
- def transport_name(self):
- return '%s-%s' % (self.transport, self.socket)
+ def __init__(self, testdir, server, client, delay, timeout, **kwargs):
+ self.testdir = testdir
+ self._log = multiprocessing.get_logger()
+ self._config = kwargs
+ self.protocol = kwargs['protocol']
+ self.transport = kwargs['transport']
+ self.socket = kwargs['socket']
+ srv_dict = self._fix_workdir(merge_dict(self._config, server))
+ cli_dict = self._fix_workdir(merge_dict(self._config, client))
+ cli_dict['extra_args2'] = srv_dict.pop('remote_args', [])
+ srv_dict['extra_args2'] = cli_dict.pop('remote_args', [])
+ self.server = TestProgram('server', **srv_dict)
+ self.client = TestProgram('client', **cli_dict)
+ self.delay = delay
+ self.timeout = timeout
+ self._name = None
+ # results
+ self.success = None
+ self.as_expected = None
+ self.returncode = None
+ self.expired = False
+ self.retry_count = 0
+
+ def _fix_workdir(self, config):
+ key = 'workdir'
+ path = config.get(key, None)
+ if not path:
+ path = self.testdir
+ if os.path.isabs(path):
+ path = os.path.realpath(path)
+ else:
+ path = os.path.realpath(path_join(self.testdir, path))
+ config.update({key: path})
+ return config
+
+ @classmethod
+ def get_name(cls, server, client, proto, trans, sock, *args):
+ return '%s-%s_%s_%s-%s' % (server, client, proto, trans, sock)
+
+ @property
+ def name(self):
+ if not self._name:
+ self._name = self.get_name(
+ self.server.name, self.client.name, self.protocol, self.transport, self.socket)
+ return self._name
+
+ @property
+ def transport_name(self):
+ return '%s-%s' % (self.transport, self.socket)
def test_name(server, client, protocol, transport, socket, **kwargs):
- return TestEntry.get_name(server['name'], client['name'], protocol, transport, socket)
+ return TestEntry.get_name(server['name'], client['name'], protocol, transport, socket)
diff --git a/test/crossrunner/util.py b/test/crossrunner/util.py
index 750ed475e..e2d195a22 100644
--- a/test/crossrunner/util.py
+++ b/test/crossrunner/util.py
@@ -21,11 +21,11 @@ import copy
def merge_dict(base, update):
- """Update dict concatenating list values"""
- res = copy.deepcopy(base)
- for k, v in list(update.items()):
- if k in list(res.keys()) and isinstance(v, list):
- res[k].extend(v)
- else:
- res[k] = v
- return res
+ """Update dict concatenating list values"""
+ res = copy.deepcopy(base)
+ for k, v in list(update.items()):
+ if k in list(res.keys()) and isinstance(v, list):
+ res[k].extend(v)
+ else:
+ res[k] = v
+ return res
diff --git a/test/features/container_limit.py b/test/features/container_limit.py
index 4a7da6065..beed0c5ec 100644
--- a/test/features/container_limit.py
+++ b/test/features/container_limit.py
@@ -10,63 +10,63 @@ from thrift.Thrift import TMessageType, TType
# TODO: generate from ThriftTest.thrift
def test_list(proto, value):
- method_name = 'testList'
- ttype = TType.LIST
- etype = TType.I32
- proto.writeMessageBegin(method_name, TMessageType.CALL, 3)
- proto.writeStructBegin(method_name + '_args')
- proto.writeFieldBegin('thing', ttype, 1)
- proto.writeListBegin(etype, len(value))
- for e in value:
- proto.writeI32(e)
- proto.writeListEnd()
- proto.writeFieldEnd()
- proto.writeFieldStop()
- proto.writeStructEnd()
- proto.writeMessageEnd()
- proto.trans.flush()
+ method_name = 'testList'
+ ttype = TType.LIST
+ etype = TType.I32
+ proto.writeMessageBegin(method_name, TMessageType.CALL, 3)
+ proto.writeStructBegin(method_name + '_args')
+ proto.writeFieldBegin('thing', ttype, 1)
+ proto.writeListBegin(etype, len(value))
+ for e in value:
+ proto.writeI32(e)
+ proto.writeListEnd()
+ proto.writeFieldEnd()
+ proto.writeFieldStop()
+ proto.writeStructEnd()
+ proto.writeMessageEnd()
+ proto.trans.flush()
- _, mtype, _ = proto.readMessageBegin()
- assert mtype == TMessageType.REPLY
- proto.readStructBegin()
- _, ftype, fid = proto.readFieldBegin()
- assert fid == 0
- assert ftype == ttype
- etype2, len2 = proto.readListBegin()
- assert etype == etype2
- assert len2 == len(value)
- for i in range(len2):
- v = proto.readI32()
- assert v == value[i]
- proto.readListEnd()
- proto.readFieldEnd()
- _, ftype, _ = proto.readFieldBegin()
- assert ftype == TType.STOP
- proto.readStructEnd()
- proto.readMessageEnd()
+ _, mtype, _ = proto.readMessageBegin()
+ assert mtype == TMessageType.REPLY
+ proto.readStructBegin()
+ _, ftype, fid = proto.readFieldBegin()
+ assert fid == 0
+ assert ftype == ttype
+ etype2, len2 = proto.readListBegin()
+ assert etype == etype2
+ assert len2 == len(value)
+ for i in range(len2):
+ v = proto.readI32()
+ assert v == value[i]
+ proto.readListEnd()
+ proto.readFieldEnd()
+ _, ftype, _ = proto.readFieldBegin()
+ assert ftype == TType.STOP
+ proto.readStructEnd()
+ proto.readMessageEnd()
def main(argv):
- p = argparse.ArgumentParser()
- add_common_args(p)
- p.add_argument('--limit', type=int)
- args = p.parse_args()
- proto = init_protocol(args)
- # TODO: test set and map
- test_list(proto, list(range(args.limit - 1)))
- test_list(proto, list(range(args.limit - 1)))
- print('[OK]: limit - 1')
- test_list(proto, list(range(args.limit)))
- test_list(proto, list(range(args.limit)))
- print('[OK]: just limit')
- try:
- test_list(proto, list(range(args.limit + 1)))
- except:
- print('[OK]: limit + 1')
- else:
- print('[ERROR]: limit + 1')
- assert False
+ p = argparse.ArgumentParser()
+ add_common_args(p)
+ p.add_argument('--limit', type=int)
+ args = p.parse_args()
+ proto = init_protocol(args)
+ # TODO: test set and map
+ test_list(proto, list(range(args.limit - 1)))
+ test_list(proto, list(range(args.limit - 1)))
+ print('[OK]: limit - 1')
+ test_list(proto, list(range(args.limit)))
+ test_list(proto, list(range(args.limit)))
+ print('[OK]: just limit')
+ try:
+ test_list(proto, list(range(args.limit + 1)))
+ except:
+ print('[OK]: limit + 1')
+ else:
+ print('[ERROR]: limit + 1')
+ assert False
if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
+ sys.exit(main(sys.argv[1:]))
diff --git a/test/features/local_thrift/__init__.py b/test/features/local_thrift/__init__.py
index 383ee5f40..0a0bb0b66 100644
--- a/test/features/local_thrift/__init__.py
+++ b/test/features/local_thrift/__init__.py
@@ -5,10 +5,10 @@ SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
if sys.version_info[0] == 2:
- import glob
- libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0]
- sys.path.insert(0, libdir)
- thrift = __import__('thrift')
+ import glob
+ libdir = glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*'))[0]
+ sys.path.insert(0, libdir)
+ thrift = __import__('thrift')
else:
- sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib'))
- thrift = __import__('thrift')
+ sys.path.insert(0, os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib'))
+ thrift = __import__('thrift')
diff --git a/test/features/string_limit.py b/test/features/string_limit.py
index b4d48acdb..3c68b3ea3 100644
--- a/test/features/string_limit.py
+++ b/test/features/string_limit.py
@@ -10,52 +10,52 @@ from thrift.Thrift import TMessageType, TType
# TODO: generate from ThriftTest.thrift
def test_string(proto, value):
- method_name = 'testString'
- ttype = TType.STRING
- proto.writeMessageBegin(method_name, TMessageType.CALL, 3)
- proto.writeStructBegin(method_name + '_args')
- proto.writeFieldBegin('thing', ttype, 1)
- proto.writeString(value)
- proto.writeFieldEnd()
- proto.writeFieldStop()
- proto.writeStructEnd()
- proto.writeMessageEnd()
- proto.trans.flush()
-
- _, mtype, _ = proto.readMessageBegin()
- assert mtype == TMessageType.REPLY
- proto.readStructBegin()
- _, ftype, fid = proto.readFieldBegin()
- assert fid == 0
- assert ftype == ttype
- result = proto.readString()
- proto.readFieldEnd()
- _, ftype, _ = proto.readFieldBegin()
- assert ftype == TType.STOP
- proto.readStructEnd()
- proto.readMessageEnd()
- assert value == result
+ method_name = 'testString'
+ ttype = TType.STRING
+ proto.writeMessageBegin(method_name, TMessageType.CALL, 3)
+ proto.writeStructBegin(method_name + '_args')
+ proto.writeFieldBegin('thing', ttype, 1)
+ proto.writeString(value)
+ proto.writeFieldEnd()
+ proto.writeFieldStop()
+ proto.writeStructEnd()
+ proto.writeMessageEnd()
+ proto.trans.flush()
+
+ _, mtype, _ = proto.readMessageBegin()
+ assert mtype == TMessageType.REPLY
+ proto.readStructBegin()
+ _, ftype, fid = proto.readFieldBegin()
+ assert fid == 0
+ assert ftype == ttype
+ result = proto.readString()
+ proto.readFieldEnd()
+ _, ftype, _ = proto.readFieldBegin()
+ assert ftype == TType.STOP
+ proto.readStructEnd()
+ proto.readMessageEnd()
+ assert value == result
def main(argv):
- p = argparse.ArgumentParser()
- add_common_args(p)
- p.add_argument('--limit', type=int)
- args = p.parse_args()
- proto = init_protocol(args)
- test_string(proto, 'a' * (args.limit - 1))
- test_string(proto, 'a' * (args.limit - 1))
- print('[OK]: limit - 1')
- test_string(proto, 'a' * args.limit)
- test_string(proto, 'a' * args.limit)
- print('[OK]: just limit')
- try:
- test_string(proto, 'a' * (args.limit + 1))
- except:
- print('[OK]: limit + 1')
- else:
- print('[ERROR]: limit + 1')
- assert False
+ p = argparse.ArgumentParser()
+ add_common_args(p)
+ p.add_argument('--limit', type=int)
+ args = p.parse_args()
+ proto = init_protocol(args)
+ test_string(proto, 'a' * (args.limit - 1))
+ test_string(proto, 'a' * (args.limit - 1))
+ print('[OK]: limit - 1')
+ test_string(proto, 'a' * args.limit)
+ test_string(proto, 'a' * args.limit)
+ print('[OK]: just limit')
+ try:
+ test_string(proto, 'a' * (args.limit + 1))
+ except:
+ print('[OK]: limit + 1')
+ else:
+ print('[ERROR]: limit + 1')
+ assert False
if __name__ == '__main__':
- main(sys.argv[1:])
+ main(sys.argv[1:])
diff --git a/test/features/theader_binary.py b/test/features/theader_binary.py
index 62a26715d..02e010b8b 100644
--- a/test/features/theader_binary.py
+++ b/test/features/theader_binary.py
@@ -14,57 +14,57 @@ from thrift.protocol.TCompactProtocol import TCompactProtocol
def test_void(proto):
- proto.writeMessageBegin('testVoid', TMessageType.CALL, 3)
- proto.writeStructBegin('testVoid_args')
- proto.writeFieldStop()
- proto.writeStructEnd()
- proto.writeMessageEnd()
- proto.trans.flush()
+ proto.writeMessageBegin('testVoid', TMessageType.CALL, 3)
+ proto.writeStructBegin('testVoid_args')
+ proto.writeFieldStop()
+ proto.writeStructEnd()
+ proto.writeMessageEnd()
+ proto.trans.flush()
- _, mtype, _ = proto.readMessageBegin()
- assert mtype == TMessageType.REPLY
- proto.readStructBegin()
- _, ftype, _ = proto.readFieldBegin()
- assert ftype == TType.STOP
- proto.readStructEnd()
- proto.readMessageEnd()
+ _, mtype, _ = proto.readMessageBegin()
+ assert mtype == TMessageType.REPLY
+ proto.readStructBegin()
+ _, ftype, _ = proto.readFieldBegin()
+ assert ftype == TType.STOP
+ proto.readStructEnd()
+ proto.readMessageEnd()
# THeader stack should accept binary protocol with optionally framed transport
def main(argv):
- p = argparse.ArgumentParser()
- add_common_args(p)
- # Since THeaderTransport acts as framed transport when detected frame, we
- # cannot use --transport=framed as it would result in 2 layered frames.
- p.add_argument('--override-transport')
- p.add_argument('--override-protocol')
- args = p.parse_args()
- assert args.protocol == 'header'
- assert args.transport == 'buffered'
- assert not args.ssl
+ p = argparse.ArgumentParser()
+ add_common_args(p)
+ # Since THeaderTransport acts as framed transport when detected frame, we
+ # cannot use --transport=framed as it would result in 2 layered frames.
+ p.add_argument('--override-transport')
+ p.add_argument('--override-protocol')
+ args = p.parse_args()
+ assert args.protocol == 'header'
+ assert args.transport == 'buffered'
+ assert not args.ssl
- sock = TSocket(args.host, args.port, socket_family=socket.AF_INET)
- if not args.override_transport or args.override_transport == 'buffered':
- trans = TBufferedTransport(sock)
- elif args.override_transport == 'framed':
- print('TFRAMED')
- trans = TFramedTransport(sock)
- else:
- raise ValueError('invalid transport')
- trans.open()
+ sock = TSocket(args.host, args.port, socket_family=socket.AF_INET)
+ if not args.override_transport or args.override_transport == 'buffered':
+ trans = TBufferedTransport(sock)
+ elif args.override_transport == 'framed':
+ print('TFRAMED')
+ trans = TFramedTransport(sock)
+ else:
+ raise ValueError('invalid transport')
+ trans.open()
- if not args.override_protocol or args.override_protocol == 'binary':
- proto = TBinaryProtocol(trans)
- elif args.override_protocol == 'compact':
- proto = TCompactProtocol(trans)
- else:
- raise ValueError('invalid transport')
+ if not args.override_protocol or args.override_protocol == 'binary':
+ proto = TBinaryProtocol(trans)
+ elif args.override_protocol == 'compact':
+ proto = TCompactProtocol(trans)
+ else:
+ raise ValueError('invalid transport')
- test_void(proto)
- test_void(proto)
+ test_void(proto)
+ test_void(proto)
- trans.close()
+ trans.close()
if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
+ sys.exit(main(sys.argv[1:]))
diff --git a/test/features/util.py b/test/features/util.py
index e36413629..e4997d0b7 100644
--- a/test/features/util.py
+++ b/test/features/util.py
@@ -11,30 +11,30 @@ from thrift.protocol.TJSONProtocol import TJSONProtocol
def add_common_args(p):
- p.add_argument('--host', default='localhost')
- p.add_argument('--port', type=int, default=9090)
- p.add_argument('--protocol', default='binary')
- p.add_argument('--transport', default='buffered')
- p.add_argument('--ssl', action='store_true')
+ p.add_argument('--host', default='localhost')
+ p.add_argument('--port', type=int, default=9090)
+ p.add_argument('--protocol', default='binary')
+ p.add_argument('--transport', default='buffered')
+ p.add_argument('--ssl', action='store_true')
def parse_common_args(argv):
- p = argparse.ArgumentParser()
- add_common_args(p)
- return p.parse_args(argv)
+ p = argparse.ArgumentParser()
+ add_common_args(p)
+ return p.parse_args(argv)
def init_protocol(args):
- sock = TSocket(args.host, args.port, socket_family=socket.AF_INET)
- sock.setTimeout(500)
- trans = {
- 'buffered': TBufferedTransport,
- 'framed': TFramedTransport,
- 'http': THttpClient,
- }[args.transport](sock)
- trans.open()
- return {
- 'binary': TBinaryProtocol,
- 'compact': TCompactProtocol,
- 'json': TJSONProtocol,
- }[args.protocol](trans)
+ sock = TSocket(args.host, args.port, socket_family=socket.AF_INET)
+ sock.setTimeout(500)
+ trans = {
+ 'buffered': TBufferedTransport,
+ 'framed': TFramedTransport,
+ 'http': THttpClient,
+ }[args.transport](sock)
+ trans.open()
+ return {
+ 'binary': TBinaryProtocol,
+ 'compact': TCompactProtocol,
+ 'json': TJSONProtocol,
+ }[args.protocol](trans)
diff --git a/test/py.tornado/test_suite.py b/test/py.tornado/test_suite.py
index e0bf91356..b9ce78181 100755
--- a/test/py.tornado/test_suite.py
+++ b/test/py.tornado/test_suite.py
@@ -27,7 +27,7 @@ import time
import unittest
basepath = os.path.abspath(os.path.dirname(__file__))
-sys.path.insert(0, basepath+'/gen-py.tornado')
+sys.path.insert(0, basepath + '/gen-py.tornado')
sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib*'))[0])
try:
diff --git a/test/py.twisted/test_suite.py b/test/py.twisted/test_suite.py
index 2c07baaf8..3a59bb1f1 100755
--- a/test/py.twisted/test_suite.py
+++ b/test/py.twisted/test_suite.py
@@ -19,7 +19,10 @@
# under the License.
#
-import sys, os, glob, time
+import sys
+import os
+import glob
+import time
basepath = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, os.path.join(basepath, 'gen-py.twisted'))
sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib.*'))[0])
@@ -35,6 +38,7 @@ from twisted.internet.protocol import ClientCreator
from zope.interface import implements
+
class TestHandler:
implements(ThriftTest.Iface)
@@ -100,6 +104,7 @@ class TestHandler:
def testTypedef(self, thing):
return thing
+
class ThriftTestCase(unittest.TestCase):
@defer.inlineCallbacks
@@ -109,15 +114,15 @@ class ThriftTestCase(unittest.TestCase):
self.pfactory = TBinaryProtocol.TBinaryProtocolFactory()
self.server = reactor.listenTCP(0,
- TTwisted.ThriftServerFactory(self.processor,
- self.pfactory), interface="127.0.0.1")
+ TTwisted.ThriftServerFactory(self.processor,
+ self.pfactory), interface="127.0.0.1")
self.portNo = self.server.getHost().port
self.txclient = yield ClientCreator(reactor,
- TTwisted.ThriftClientProtocol,
- ThriftTest.Client,
- self.pfactory).connectTCP("127.0.0.1", self.portNo)
+ TTwisted.ThriftClientProtocol,
+ ThriftTest.Client,
+ self.pfactory).connectTCP("127.0.0.1", self.portNo)
self.client = self.txclient.client
@defer.inlineCallbacks
@@ -179,7 +184,7 @@ class ThriftTestCase(unittest.TestCase):
try:
yield self.client.testException("throw_undeclared")
self.fail("should have thrown exception")
- except Exception: # type is undefined
+ except Exception: # type is undefined
pass
@defer.inlineCallbacks
diff --git a/test/py/FastbinaryTest.py b/test/py/FastbinaryTest.py
index 9d258fdbf..a8718dce1 100755
--- a/test/py/FastbinaryTest.py
+++ b/test/py/FastbinaryTest.py
@@ -41,11 +41,11 @@ from DebugProtoTest.ttypes import Backwards, Bonk, Empty, HolyMoley, OneOfEach,
class TDevNullTransport(TTransport.TTransportBase):
- def __init__(self):
- pass
+ def __init__(self):
+ pass
- def isOpen(self):
- return True
+ def isOpen(self):
+ return True
ooe1 = OneOfEach()
ooe1.im_true = True
@@ -71,8 +71,8 @@ ooe2.zomg_unicode = u"\xd3\x80\xe2\x85\xae\xce\x9d\x20"\
u"\xc7\x83\xe2\x80\xbc"
if sys.version_info[0] == 2 and os.environ.get('THRIFT_TEST_PY_NO_UTF8STRINGS'):
- ooe1.zomg_unicode = ooe1.zomg_unicode.encode('utf8')
- ooe2.zomg_unicode = ooe2.zomg_unicode.encode('utf8')
+ ooe1.zomg_unicode = ooe1.zomg_unicode.encode('utf8')
+ ooe2.zomg_unicode = ooe2.zomg_unicode.encode('utf8')
hm = HolyMoley(**{"big": [], "contain": set(), "bonks": {}})
hm.big.append(ooe1)
@@ -86,13 +86,13 @@ hm.contain.add(())
hm.bonks["nothing"] = []
hm.bonks["something"] = [
- Bonk(**{"type": 1, "message": "Wait."}),
- Bonk(**{"type": 2, "message": "What?"}),
+ Bonk(**{"type": 1, "message": "Wait."}),
+ Bonk(**{"type": 2, "message": "What?"}),
]
hm.bonks["poe"] = [
- Bonk(**{"type": 3, "message": "quoth"}),
- Bonk(**{"type": 4, "message": "the raven"}),
- Bonk(**{"type": 5, "message": "nevermore"}),
+ Bonk(**{"type": 3, "message": "quoth"}),
+ Bonk(**{"type": 4, "message": "the raven"}),
+ Bonk(**{"type": 5, "message": "nevermore"}),
]
rs = RandomStuff()
@@ -112,110 +112,110 @@ my_zero = Srv.Janky_result(**{"success": 5})
def check_write(o):
- trans_fast = TTransport.TMemoryBuffer()
- trans_slow = TTransport.TMemoryBuffer()
- prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast)
- prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow)
+ trans_fast = TTransport.TMemoryBuffer()
+ trans_slow = TTransport.TMemoryBuffer()
+ prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast)
+ prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow)
- o.write(prot_fast)
- o.write(prot_slow)
- ORIG = trans_slow.getvalue()
- MINE = trans_fast.getvalue()
- if ORIG != MINE:
- print("mine: %s\norig: %s" % (repr(MINE), repr(ORIG)))
+ o.write(prot_fast)
+ o.write(prot_slow)
+ ORIG = trans_slow.getvalue()
+ MINE = trans_fast.getvalue()
+ if ORIG != MINE:
+ print("mine: %s\norig: %s" % (repr(MINE), repr(ORIG)))
def check_read(o):
- prot = TBinaryProtocol.TBinaryProtocol(TTransport.TMemoryBuffer())
- o.write(prot)
-
- slow_version_binary = prot.trans.getvalue()
-
- prot = TBinaryProtocol.TBinaryProtocolAccelerated(
- TTransport.TMemoryBuffer(slow_version_binary))
- c = o.__class__()
- c.read(prot)
- if c != o:
- print("copy: ")
- pprint(eval(repr(c)))
- print("orig: ")
- pprint(eval(repr(o)))
-
- prot = TBinaryProtocol.TBinaryProtocolAccelerated(
- TTransport.TBufferedTransport(
- TTransport.TMemoryBuffer(slow_version_binary)))
- c = o.__class__()
- c.read(prot)
- if c != o:
- print("copy: ")
- pprint(eval(repr(c)))
- print("orig: ")
- pprint(eval(repr(o)))
+ prot = TBinaryProtocol.TBinaryProtocol(TTransport.TMemoryBuffer())
+ o.write(prot)
+
+ slow_version_binary = prot.trans.getvalue()
+
+ prot = TBinaryProtocol.TBinaryProtocolAccelerated(
+ TTransport.TMemoryBuffer(slow_version_binary))
+ c = o.__class__()
+ c.read(prot)
+ if c != o:
+ print("copy: ")
+ pprint(eval(repr(c)))
+ print("orig: ")
+ pprint(eval(repr(o)))
+
+ prot = TBinaryProtocol.TBinaryProtocolAccelerated(
+ TTransport.TBufferedTransport(
+ TTransport.TMemoryBuffer(slow_version_binary)))
+ c = o.__class__()
+ c.read(prot)
+ if c != o:
+ print("copy: ")
+ pprint(eval(repr(c)))
+ print("orig: ")
+ pprint(eval(repr(o)))
def do_test():
- check_write(hm)
- check_read(HolyMoley())
- no_set = deepcopy(hm)
- no_set.contain = set()
- check_read(no_set)
- check_write(rs)
- check_read(rs)
- check_write(rshuge)
- check_read(rshuge)
- check_write(my_zero)
- check_read(my_zero)
- check_read(Backwards(**{"first_tag2": 4, "second_tag1": 2}))
-
- # One case where the serialized form changes, but only superficially.
- o = Backwards(**{"first_tag2": 4, "second_tag1": 2})
- trans_fast = TTransport.TMemoryBuffer()
- trans_slow = TTransport.TMemoryBuffer()
- prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast)
- prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow)
-
- o.write(prot_fast)
- o.write(prot_slow)
- ORIG = trans_slow.getvalue()
- MINE = trans_fast.getvalue()
- assert id(ORIG) != id(MINE)
-
- prot = TBinaryProtocol.TBinaryProtocolAccelerated(TTransport.TMemoryBuffer())
- o.write(prot)
- prot = TBinaryProtocol.TBinaryProtocol(
- TTransport.TMemoryBuffer(prot.trans.getvalue()))
- c = o.__class__()
- c.read(prot)
- if c != o:
- print("copy: ")
- pprint(eval(repr(c)))
- print("orig: ")
- pprint(eval(repr(o)))
+ check_write(hm)
+ check_read(HolyMoley())
+ no_set = deepcopy(hm)
+ no_set.contain = set()
+ check_read(no_set)
+ check_write(rs)
+ check_read(rs)
+ check_write(rshuge)
+ check_read(rshuge)
+ check_write(my_zero)
+ check_read(my_zero)
+ check_read(Backwards(**{"first_tag2": 4, "second_tag1": 2}))
+
+ # One case where the serialized form changes, but only superficially.
+ o = Backwards(**{"first_tag2": 4, "second_tag1": 2})
+ trans_fast = TTransport.TMemoryBuffer()
+ trans_slow = TTransport.TMemoryBuffer()
+ prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast)
+ prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow)
+
+ o.write(prot_fast)
+ o.write(prot_slow)
+ ORIG = trans_slow.getvalue()
+ MINE = trans_fast.getvalue()
+ assert id(ORIG) != id(MINE)
+
+ prot = TBinaryProtocol.TBinaryProtocolAccelerated(TTransport.TMemoryBuffer())
+ o.write(prot)
+ prot = TBinaryProtocol.TBinaryProtocol(
+ TTransport.TMemoryBuffer(prot.trans.getvalue()))
+ c = o.__class__()
+ c.read(prot)
+ if c != o:
+ print("copy: ")
+ pprint(eval(repr(c)))
+ print("orig: ")
+ pprint(eval(repr(o)))
def do_benchmark(iters=5000):
- setup = """
+ setup = """
from __main__ import hm, rs, TDevNullTransport
from thrift.protocol import TBinaryProtocol
trans = TDevNullTransport()
prot = TBinaryProtocol.TBinaryProtocol%s(trans)
"""
- setup_fast = setup % "Accelerated"
- setup_slow = setup % ""
+ setup_fast = setup % "Accelerated"
+ setup_slow = setup % ""
- print("Starting Benchmarks")
+ print("Starting Benchmarks")
- print("HolyMoley Standard = %f" %
- timeit.Timer('hm.write(prot)', setup_slow).timeit(number=iters))
- print("HolyMoley Acceler. = %f" %
- timeit.Timer('hm.write(prot)', setup_fast).timeit(number=iters))
+ print("HolyMoley Standard = %f" %
+ timeit.Timer('hm.write(prot)', setup_slow).timeit(number=iters))
+ print("HolyMoley Acceler. = %f" %
+ timeit.Timer('hm.write(prot)', setup_fast).timeit(number=iters))
- print("FastStruct Standard = %f" %
- timeit.Timer('rs.write(prot)', setup_slow).timeit(number=iters))
- print("FastStruct Acceler. = %f" %
- timeit.Timer('rs.write(prot)', setup_fast).timeit(number=iters))
+ print("FastStruct Standard = %f" %
+ timeit.Timer('rs.write(prot)', setup_slow).timeit(number=iters))
+ print("FastStruct Acceler. = %f" %
+ timeit.Timer('rs.write(prot)', setup_fast).timeit(number=iters))
if __name__ == '__main__':
- do_test()
- do_benchmark()
+ do_test()
+ do_benchmark()
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index d5ebd6a6d..98ead431d 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -37,13 +37,13 @@ DEFAULT_LIBDIR_GLOB = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')
DEFAULT_LIBDIR_PY3 = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib')
SCRIPTS = [
- 'FastbinaryTest.py',
- 'TestFrozen.py',
- 'TSimpleJSONProtocolTest.py',
- 'SerializationTest.py',
- 'TestEof.py',
- 'TestSyntax.py',
- 'TestSocket.py',
+ 'FastbinaryTest.py',
+ 'TestFrozen.py',
+ 'TSimpleJSONProtocolTest.py',
+ 'SerializationTest.py',
+ 'TestEof.py',
+ 'TestSyntax.py',
+ 'TestSocket.py',
]
FRAMED = ["TNonblockingServer"]
SKIP_ZLIB = ['TNonblockingServer', 'THttpServer']
@@ -51,20 +51,20 @@ SKIP_SSL = ['TNonblockingServer', 'THttpServer']
EXTRA_DELAY = dict(TProcessPoolServer=5.5)
PROTOS = [
- 'accel',
- 'binary',
- 'compact',
- 'json',
+ 'accel',
+ 'binary',
+ 'compact',
+ 'json',
]
SERVERS = [
- "TSimpleServer",
- "TThreadedServer",
- "TThreadPoolServer",
- "TProcessPoolServer",
- "TForkingServer",
- "TNonblockingServer",
- "THttpServer",
+ "TSimpleServer",
+ "TThreadedServer",
+ "TThreadPoolServer",
+ "TProcessPoolServer",
+ "TForkingServer",
+ "TNonblockingServer",
+ "THttpServer",
]
@@ -73,246 +73,246 @@ def relfile(fname):
def setup_pypath(libdir, gendir):
- dirs = [libdir, gendir]
- env = copy.deepcopy(os.environ)
- pypath = env.get('PYTHONPATH', None)
- if pypath:
- dirs.append(pypath)
- env['PYTHONPATH'] = ':'.join(dirs)
- if gendir.endswith('gen-py-no_utf8strings'):
- env['THRIFT_TEST_PY_NO_UTF8STRINGS'] = '1'
- return env
+ dirs = [libdir, gendir]
+ env = copy.deepcopy(os.environ)
+ pypath = env.get('PYTHONPATH', None)
+ if pypath:
+ dirs.append(pypath)
+ env['PYTHONPATH'] = ':'.join(dirs)
+ if gendir.endswith('gen-py-no_utf8strings'):
+ env['THRIFT_TEST_PY_NO_UTF8STRINGS'] = '1'
+ return env
def runScriptTest(libdir, genbase, genpydir, script):
- env = setup_pypath(libdir, os.path.join(genbase, genpydir))
- script_args = [sys.executable, relfile(script)]
- print('\nTesting script: %s\n----' % (' '.join(script_args)))
- ret = subprocess.call(script_args, env=env)
- if ret != 0:
- print('*** FAILED ***', file=sys.stderr)
- print('LIBDIR: %s' % libdir, file=sys.stderr)
- print('PY_GEN: %s' % genpydir, file=sys.stderr)
- print('SCRIPT: %s' % script, file=sys.stderr)
- raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args)))
+ env = setup_pypath(libdir, os.path.join(genbase, genpydir))
+ script_args = [sys.executable, relfile(script)]
+ print('\nTesting script: %s\n----' % (' '.join(script_args)))
+ ret = subprocess.call(script_args, env=env)
+ if ret != 0:
+ print('*** FAILED ***', file=sys.stderr)
+ print('LIBDIR: %s' % libdir, file=sys.stderr)
+ print('PY_GEN: %s' % genpydir, file=sys.stderr)
+ print('SCRIPT: %s' % script, file=sys.stderr)
+ raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args)))
def runServiceTest(libdir, genbase, genpydir, server_class, proto, port, use_zlib, use_ssl, verbose):
- env = setup_pypath(libdir, os.path.join(genbase, genpydir))
- # Build command line arguments
- server_args = [sys.executable, relfile('TestServer.py')]
- cli_args = [sys.executable, relfile('TestClient.py')]
- for which in (server_args, cli_args):
- which.append('--protocol=%s' % proto) # accel, binary, compact or json
- which.append('--port=%d' % port) # default to 9090
- if use_zlib:
- which.append('--zlib')
- if use_ssl:
- which.append('--ssl')
- if verbose == 0:
- which.append('-q')
- if verbose == 2:
- which.append('-v')
- # server-specific option to select server class
- server_args.append(server_class)
- # client-specific cmdline options
- if server_class in FRAMED:
- cli_args.append('--transport=framed')
- else:
- cli_args.append('--transport=buffered')
- if server_class == 'THttpServer':
- cli_args.append('--http=/')
- if verbose > 0:
- print('Testing server %s: %s' % (server_class, ' '.join(server_args)))
- serverproc = subprocess.Popen(server_args, env=env)
-
- def ensureServerAlive():
- if serverproc.poll() is not None:
- print(('FAIL: Server process (%s) failed with retcode %d')
- % (' '.join(server_args), serverproc.returncode))
- raise Exception('Server subprocess %s died, args: %s'
- % (server_class, ' '.join(server_args)))
-
- # Wait for the server to start accepting connections on the given port.
- sock = socket.socket()
- sleep_time = 0.1 # Seconds
- max_attempts = 100
- try:
- attempt = 0
- while sock.connect_ex(('127.0.0.1', port)) != 0:
- attempt += 1
- if attempt >= max_attempts:
- raise Exception("TestServer not ready on port %d after %.2f seconds"
- % (port, sleep_time * attempt))
- ensureServerAlive()
- time.sleep(sleep_time)
- finally:
- sock.close()
-
- try:
+ env = setup_pypath(libdir, os.path.join(genbase, genpydir))
+ # Build command line arguments
+ server_args = [sys.executable, relfile('TestServer.py')]
+ cli_args = [sys.executable, relfile('TestClient.py')]
+ for which in (server_args, cli_args):
+ which.append('--protocol=%s' % proto) # accel, binary, compact or json
+ which.append('--port=%d' % port) # default to 9090
+ if use_zlib:
+ which.append('--zlib')
+ if use_ssl:
+ which.append('--ssl')
+ if verbose == 0:
+ which.append('-q')
+ if verbose == 2:
+ which.append('-v')
+ # server-specific option to select server class
+ server_args.append(server_class)
+ # client-specific cmdline options
+ if server_class in FRAMED:
+ cli_args.append('--transport=framed')
+ else:
+ cli_args.append('--transport=buffered')
+ if server_class == 'THttpServer':
+ cli_args.append('--http=/')
if verbose > 0:
- print('Testing client: %s' % (' '.join(cli_args)))
- ret = subprocess.call(cli_args, env=env)
- if ret != 0:
- print('*** FAILED ***', file=sys.stderr)
- print('LIBDIR: %s' % libdir, file=sys.stderr)
- print('PY_GEN: %s' % genpydir, file=sys.stderr)
- raise Exception("Client subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(cli_args)))
- finally:
- # check that server didn't die
- ensureServerAlive()
- extra_sleep = EXTRA_DELAY.get(server_class, 0)
- if extra_sleep > 0 and verbose > 0:
- print('Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child'
- 'processes to terminate via alarm'
- % (server_class, proto, use_zlib, use_ssl, extra_sleep))
- time.sleep(extra_sleep)
- os.kill(serverproc.pid, signal.SIGKILL)
- serverproc.wait()
+ print('Testing server %s: %s' % (server_class, ' '.join(server_args)))
+ serverproc = subprocess.Popen(server_args, env=env)
+
+ def ensureServerAlive():
+ if serverproc.poll() is not None:
+ print(('FAIL: Server process (%s) failed with retcode %d')
+ % (' '.join(server_args), serverproc.returncode))
+ raise Exception('Server subprocess %s died, args: %s'
+ % (server_class, ' '.join(server_args)))
+
+ # Wait for the server to start accepting connections on the given port.
+ sock = socket.socket()
+ sleep_time = 0.1 # Seconds
+ max_attempts = 100
+ try:
+ attempt = 0
+ while sock.connect_ex(('127.0.0.1', port)) != 0:
+ attempt += 1
+ if attempt >= max_attempts:
+ raise Exception("TestServer not ready on port %d after %.2f seconds"
+ % (port, sleep_time * attempt))
+ ensureServerAlive()
+ time.sleep(sleep_time)
+ finally:
+ sock.close()
+
+ try:
+ if verbose > 0:
+ print('Testing client: %s' % (' '.join(cli_args)))
+ ret = subprocess.call(cli_args, env=env)
+ if ret != 0:
+ print('*** FAILED ***', file=sys.stderr)
+ print('LIBDIR: %s' % libdir, file=sys.stderr)
+ print('PY_GEN: %s' % genpydir, file=sys.stderr)
+ raise Exception("Client subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(cli_args)))
+ finally:
+ # check that server didn't die
+ ensureServerAlive()
+ extra_sleep = EXTRA_DELAY.get(server_class, 0)
+ if extra_sleep > 0 and verbose > 0:
+ print('Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child'
+ 'processes to terminate via alarm'
+ % (server_class, proto, use_zlib, use_ssl, extra_sleep))
+ time.sleep(extra_sleep)
+ os.kill(serverproc.pid, signal.SIGKILL)
+ serverproc.wait()
class TestCases(object):
- def __init__(self, genbase, libdir, port, gendirs, servers, verbose):
- self.genbase = genbase
- self.libdir = libdir
- self.port = port
- self.verbose = verbose
- self.gendirs = gendirs
- self.servers = servers
-
- def default_conf(self):
- return {
- 'gendir': self.gendirs[0],
- 'server': self.servers[0],
- 'proto': PROTOS[0],
- 'zlib': False,
- 'ssl': False,
- }
-
- def run(self, conf, test_count):
- with_zlib = conf['zlib']
- with_ssl = conf['ssl']
- try_server = conf['server']
- try_proto = conf['proto']
- genpydir = conf['gendir']
- # skip any servers that don't work with the Zlib transport
- if with_zlib and try_server in SKIP_ZLIB:
- return False
- # skip any servers that don't work with SSL
- if with_ssl and try_server in SKIP_SSL:
- return False
- if self.verbose > 0:
- print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s'
- % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl))
- runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl, self.verbose)
- if self.verbose > 0:
- print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.'
- % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count))
- return True
-
- def test_feature(self, name, values):
- test_count = 0
- conf = self.default_conf()
- for try_server in values:
- conf[name] = try_server
- if self.run(conf, test_count):
- test_count += 1
- return test_count
-
- def run_all_tests(self):
- test_count = 0
- for try_server in self.servers:
- for genpydir in self.gendirs:
- for try_proto in PROTOS:
- for with_zlib in (False, True):
- # skip any servers that don't work with the Zlib transport
- if with_zlib and try_server in SKIP_ZLIB:
- continue
- for with_ssl in (False, True):
- # skip any servers that don't work with SSL
- if with_ssl and try_server in SKIP_SSL:
- continue
- test_count += 1
- if self.verbose > 0:
- print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s'
- % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl))
- runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl)
- if self.verbose > 0:
- print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.'
- % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count))
- return test_count
+ def __init__(self, genbase, libdir, port, gendirs, servers, verbose):
+ self.genbase = genbase
+ self.libdir = libdir
+ self.port = port
+ self.verbose = verbose
+ self.gendirs = gendirs
+ self.servers = servers
+
+ def default_conf(self):
+ return {
+ 'gendir': self.gendirs[0],
+ 'server': self.servers[0],
+ 'proto': PROTOS[0],
+ 'zlib': False,
+ 'ssl': False,
+ }
+
+ def run(self, conf, test_count):
+ with_zlib = conf['zlib']
+ with_ssl = conf['ssl']
+ try_server = conf['server']
+ try_proto = conf['proto']
+ genpydir = conf['gendir']
+ # skip any servers that don't work with the Zlib transport
+ if with_zlib and try_server in SKIP_ZLIB:
+ return False
+ # skip any servers that don't work with SSL
+ if with_ssl and try_server in SKIP_SSL:
+ return False
+ if self.verbose > 0:
+ print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s'
+ % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl))
+ runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl, self.verbose)
+ if self.verbose > 0:
+ print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.'
+ % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count))
+ return True
+
+ def test_feature(self, name, values):
+ test_count = 0
+ conf = self.default_conf()
+ for try_server in values:
+ conf[name] = try_server
+ if self.run(conf, test_count):
+ test_count += 1
+ return test_count
+
+ def run_all_tests(self):
+ test_count = 0
+ for try_server in self.servers:
+ for genpydir in self.gendirs:
+ for try_proto in PROTOS:
+ for with_zlib in (False, True):
+ # skip any servers that don't work with the Zlib transport
+ if with_zlib and try_server in SKIP_ZLIB:
+ continue
+ for with_ssl in (False, True):
+ # skip any servers that don't work with SSL
+ if with_ssl and try_server in SKIP_SSL:
+ continue
+ test_count += 1
+ if self.verbose > 0:
+ print('\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s'
+ % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl))
+ runServiceTest(self.libdir, self.genbase, genpydir, try_server, try_proto, self.port, with_zlib, with_ssl)
+ if self.verbose > 0:
+ print('OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.'
+ % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count))
+ return test_count
def default_libdir():
- if sys.version_info[0] == 2:
- return glob.glob(DEFAULT_LIBDIR_GLOB)[0]
- else:
- return DEFAULT_LIBDIR_PY3
+ if sys.version_info[0] == 2:
+ return glob.glob(DEFAULT_LIBDIR_GLOB)[0]
+ else:
+ return DEFAULT_LIBDIR_PY3
def main():
- parser = OptionParser()
- parser.add_option('--all', action="store_true", dest='all')
- parser.add_option('--genpydirs', type='string', dest='genpydirs',
- default='default,slots,oldstyle,no_utf8strings,dynamic,dynamicslots',
- help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests')
- parser.add_option("--port", type="int", dest="port", default=9090,
- help="port number for server to listen on")
- parser.add_option('-v', '--verbose', action="store_const",
- dest="verbose", const=2,
- help="verbose output")
- parser.add_option('-q', '--quiet', action="store_const",
- dest="verbose", const=0,
- help="minimal output")
- parser.add_option('-L', '--libdir', dest="libdir", default=default_libdir(),
- help="directory path that contains Thrift Python library")
- parser.add_option('--gen-base', dest="gen_base", default=SCRIPT_DIR,
- help="directory path that contains Thrift Python library")
- parser.set_defaults(verbose=1)
- options, args = parser.parse_args()
-
- generated_dirs = []
- for gp_dir in options.genpydirs.split(','):
- generated_dirs.append('gen-py-%s' % (gp_dir))
-
- # commandline permits a single class name to be specified to override SERVERS=[...]
- servers = SERVERS
- if len(args) == 1:
- if args[0] in SERVERS:
- servers = args
+ parser = OptionParser()
+ parser.add_option('--all', action="store_true", dest='all')
+ parser.add_option('--genpydirs', type='string', dest='genpydirs',
+ default='default,slots,oldstyle,no_utf8strings,dynamic,dynamicslots',
+ help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests')
+ parser.add_option("--port", type="int", dest="port", default=9090,
+ help="port number for server to listen on")
+ parser.add_option('-v', '--verbose', action="store_const",
+ dest="verbose", const=2,
+ help="verbose output")
+ parser.add_option('-q', '--quiet', action="store_const",
+ dest="verbose", const=0,
+ help="minimal output")
+ parser.add_option('-L', '--libdir', dest="libdir", default=default_libdir(),
+ help="directory path that contains Thrift Python library")
+ parser.add_option('--gen-base', dest="gen_base", default=SCRIPT_DIR,
+ help="directory path that contains Thrift Python library")
+ parser.set_defaults(verbose=1)
+ options, args = parser.parse_args()
+
+ generated_dirs = []
+ for gp_dir in options.genpydirs.split(','):
+ generated_dirs.append('gen-py-%s' % (gp_dir))
+
+ # commandline permits a single class name to be specified to override SERVERS=[...]
+ servers = SERVERS
+ if len(args) == 1:
+ if args[0] in SERVERS:
+ servers = args
+ else:
+ print('Unavailable server type "%s", please choose one of: %s' % (args[0], servers))
+ sys.exit(0)
+
+ tests = TestCases(options.gen_base, options.libdir, options.port, generated_dirs, servers, options.verbose)
+
+ # run tests without a client/server first
+ print('----------------')
+ print(' Executing individual test scripts with various generated code directories')
+ print(' Directories to be tested: ' + ', '.join(generated_dirs))
+ print(' Scripts to be tested: ' + ', '.join(SCRIPTS))
+ print('----------------')
+ for genpydir in generated_dirs:
+ for script in SCRIPTS:
+ runScriptTest(options.libdir, options.gen_base, genpydir, script)
+
+ print('----------------')
+ print(' Executing Client/Server tests with various generated code directories')
+ print(' Servers to be tested: ' + ', '.join(servers))
+ print(' Directories to be tested: ' + ', '.join(generated_dirs))
+ print(' Protocols to be tested: ' + ', '.join(PROTOS))
+ print(' Options to be tested: ZLIB(yes/no), SSL(yes/no)')
+ print('----------------')
+
+ if options.all:
+ tests.run_all_tests()
else:
- print('Unavailable server type "%s", please choose one of: %s' % (args[0], servers))
- sys.exit(0)
-
- tests = TestCases(options.gen_base, options.libdir, options.port, generated_dirs, servers, options.verbose)
-
- # run tests without a client/server first
- print('----------------')
- print(' Executing individual test scripts with various generated code directories')
- print(' Directories to be tested: ' + ', '.join(generated_dirs))
- print(' Scripts to be tested: ' + ', '.join(SCRIPTS))
- print('----------------')
- for genpydir in generated_dirs:
- for script in SCRIPTS:
- runScriptTest(options.libdir, options.gen_base, genpydir, script)
-
- print('----------------')
- print(' Executing Client/Server tests with various generated code directories')
- print(' Servers to be tested: ' + ', '.join(servers))
- print(' Directories to be tested: ' + ', '.join(generated_dirs))
- print(' Protocols to be tested: ' + ', '.join(PROTOS))
- print(' Options to be tested: ZLIB(yes/no), SSL(yes/no)')
- print('----------------')
-
- if options.all:
- tests.run_all_tests()
- else:
- tests.test_feature('gendir', generated_dirs)
- tests.test_feature('server', servers)
- tests.test_feature('proto', PROTOS)
- tests.test_feature('zlib', [False, True])
- tests.test_feature('ssl', [False, True])
+ tests.test_feature('gendir', generated_dirs)
+ tests.test_feature('server', servers)
+ tests.test_feature('proto', PROTOS)
+ tests.test_feature('zlib', [False, True])
+ tests.test_feature('ssl', [False, True])
if __name__ == '__main__':
- sys.exit(main())
+ sys.exit(main())
diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py
index d4755cf2a..65a149599 100755
--- a/test/py/SerializationTest.py
+++ b/test/py/SerializationTest.py
@@ -30,341 +30,342 @@ import unittest
class AbstractTest(unittest.TestCase):
- def setUp(self):
- self.v1obj = VersioningTestV1(
- begin_in_both=12345,
- old_string='aaa',
- end_in_both=54321,
- )
-
- self.v2obj = VersioningTestV2(
- begin_in_both=12345,
- newint=1,
- newbyte=2,
- newshort=3,
- newlong=4,
- newdouble=5.0,
- newstruct=Bonk(message="Hello!", type=123),
- newlist=[7,8,9],
- newset=set([42,1,8]),
- newmap={1:2,2:3},
- newstring="Hola!",
- end_in_both=54321,
- )
-
- self.bools = Bools(im_true=True, im_false=False)
- self.bools_flipped = Bools(im_true=False, im_false=True)
-
- self.large_deltas = LargeDeltas (
- b1=self.bools,
- b10=self.bools_flipped,
- b100=self.bools,
- check_true=True,
- b1000=self.bools_flipped,
- check_false=False,
- vertwo2000=VersioningTestV2(newstruct=Bonk(message='World!', type=314)),
- a_set2500=set(['lazy', 'brown', 'cow']),
- vertwo3000=VersioningTestV2(newset=set([2, 3, 5, 7, 11])),
- big_numbers=[2**8, 2**16, 2**31-1, -(2**31-1)]
- )
-
- self.compact_struct = CompactProtoTestStruct(
- a_byte = 127,
- a_i16=32000,
- a_i32=1000000000,
- a_i64=0xffffffffff,
- a_double=5.6789,
- a_string="my string",
- true_field=True,
- false_field=False,
- empty_struct_field=Empty(),
- byte_list=[-127, -1, 0, 1, 127],
- i16_list=[-1, 0, 1, 0x7fff],
- i32_list= [-1, 0, 0xff, 0xffff, 0xffffff, 0x7fffffff],
- i64_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff],
- double_list=[0.1, 0.2, 0.3],
- string_list=["first", "second", "third"],
- boolean_list=[True, True, True, False, False, False],
- struct_list=[Empty(), Empty()],
- byte_set=set([-127, -1, 0, 1, 127]),
- i16_set=set([-1, 0, 1, 0x7fff]),
- i32_set=set([1, 2, 3]),
- i64_set=set([-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff]),
- double_set=set([0.1, 0.2, 0.3]),
- string_set=set(["first", "second", "third"]),
- boolean_set=set([True, False]),
- #struct_set=set([Empty()]), # unhashable instance
- byte_byte_map={1 : 2},
- i16_byte_map={1 : 1, -1 : 1, 0x7fff : 1},
- i32_byte_map={1 : 1, -1 : 1, 0x7fffffff : 1},
- i64_byte_map={0 : 1, 1 : 1, -1 : 1, 0x7fffffffffffffff : 1},
- double_byte_map={-1.1 : 1, 1.1 : 1},
- string_byte_map={"first" : 1, "second" : 2, "third" : 3, "" : 0},
- boolean_byte_map={True : 1, False: 0},
- byte_i16_map={1 : 1, 2 : -1, 3 : 0x7fff},
- byte_i32_map={1 : 1, 2 : -1, 3 : 0x7fffffff},
- byte_i64_map={1 : 1, 2 : -1, 3 : 0x7fffffffffffffff},
- byte_double_map={1 : 0.1, 2 : -0.1, 3 : 1000000.1},
- byte_string_map={1 : "", 2 : "blah", 3 : "loooooooooooooong string"},
- byte_boolean_map={1 : True, 2 : False},
- #list_byte_map # unhashable
- #set_byte_map={set([1, 2, 3]) : 1, set([0, 1]) : 2, set([]) : 0}, # unhashable
- #map_byte_map # unhashable
- byte_map_map={0 : {}, 1 : {1 : 1}, 2 : {1 : 1, 2 : 2}},
- byte_set_map={0 : set([]), 1 : set([1]), 2 : set([1, 2])},
- byte_list_map={0 : [], 1 : [1], 2 : [1, 2]},
- )
-
- self.nested_lists_i32x2 = NestedListsI32x2(
- [
- [ 1, 1, 2 ],
- [ 2, 7, 9 ],
- [ 3, 5, 8 ]
- ]
- )
-
- self.nested_lists_i32x3 = NestedListsI32x3(
- [
- [
- [ 2, 7, 9 ],
- [ 3, 5, 8 ]
- ],
- [
- [ 1, 1, 2 ],
- [ 1, 4, 9 ]
- ]
- ]
- )
-
- self.nested_mixedx2 = NestedMixedx2( int_set_list=[
- set([1,2,3]),
- set([1,4,9]),
- set([1,2,3,5,8,13,21]),
- set([-1, 0, 1])
- ],
- # note, the sets below are sets of chars, since the strings are iterated
- map_int_strset={ 10:set('abc'), 20:set('def'), 30:set('GHI') },
- map_int_strset_list=[
- { 10:set('abc'), 20:set('def'), 30:set('GHI') },
- { 100:set('lmn'), 200:set('opq'), 300:set('RST') },
- { 1000:set('uvw'), 2000:set('wxy'), 3000:set('XYZ') }
- ]
- )
-
- self.nested_lists_bonk = NestedListsBonk(
- [
- [
- [
- Bonk(message='inner A first', type=1),
- Bonk(message='inner A second', type=1)
- ],
- [
- Bonk(message='inner B first', type=2),
- Bonk(message='inner B second', type=2)
- ]
- ]
- ]
- )
-
- self.list_bonks = ListBonks(
- [
- Bonk(message='inner A', type=1),
- Bonk(message='inner B', type=2),
- Bonk(message='inner C', type=0)
- ]
- )
-
- def _serialize(self, obj):
- trans = TTransport.TMemoryBuffer()
- prot = self.protocol_factory.getProtocol(trans)
- obj.write(prot)
- return trans.getvalue()
-
- def _deserialize(self, objtype, data):
- prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
- ret = objtype()
- ret.read(prot)
- return ret
-
- def testForwards(self):
- obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
- self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
- self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
-
- def testBackwards(self):
- obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
- self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
- self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
-
- def testSerializeV1(self):
- obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj))
- self.assertEquals(obj, self.v1obj)
-
- def testSerializeV2(self):
- obj = self._deserialize(VersioningTestV2, self._serialize(self.v2obj))
- self.assertEquals(obj, self.v2obj)
-
- def testBools(self):
- self.assertNotEquals(self.bools, self.bools_flipped)
- self.assertNotEquals(self.bools, self.v1obj)
- obj = self._deserialize(Bools, self._serialize(self.bools))
- self.assertEquals(obj, self.bools)
- obj = self._deserialize(Bools, self._serialize(self.bools_flipped))
- self.assertEquals(obj, self.bools_flipped)
- rep = repr(self.bools)
- self.assertTrue(len(rep) > 0)
-
- def testLargeDeltas(self):
- # test large field deltas (meaningful in CompactProto only)
- obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas))
- self.assertEquals(obj, self.large_deltas)
- rep = repr(self.large_deltas)
- self.assertTrue(len(rep) > 0)
-
- def testNestedListsI32x2(self):
- obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2))
- self.assertEquals(obj, self.nested_lists_i32x2)
- rep = repr(self.nested_lists_i32x2)
- self.assertTrue(len(rep) > 0)
-
- def testNestedListsI32x3(self):
- obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3))
- self.assertEquals(obj, self.nested_lists_i32x3)
- rep = repr(self.nested_lists_i32x3)
- self.assertTrue(len(rep) > 0)
-
- def testNestedMixedx2(self):
- obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2))
- self.assertEquals(obj, self.nested_mixedx2)
- rep = repr(self.nested_mixedx2)
- self.assertTrue(len(rep) > 0)
-
- def testNestedListsBonk(self):
- obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk))
- self.assertEquals(obj, self.nested_lists_bonk)
- rep = repr(self.nested_lists_bonk)
- self.assertTrue(len(rep) > 0)
-
- def testListBonks(self):
- obj = self._deserialize(ListBonks, self._serialize(self.list_bonks))
- self.assertEquals(obj, self.list_bonks)
- rep = repr(self.list_bonks)
- self.assertTrue(len(rep) > 0)
-
- def testCompactStruct(self):
- # test large field deltas (meaningful in CompactProto only)
- obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct))
- self.assertEquals(obj, self.compact_struct)
- rep = repr(self.compact_struct)
- self.assertTrue(len(rep) > 0)
-
- def testIntegerLimits(self):
- if (sys.version_info[0] == 2 and sys.version_info[1] <= 6):
- print('Skipping testIntegerLimits for Python 2.6')
- return
- bad_values = [CompactProtoTestStruct(a_byte=128), CompactProtoTestStruct(a_byte=-129),
- CompactProtoTestStruct(a_i16=32768), CompactProtoTestStruct(a_i16=-32769),
- CompactProtoTestStruct(a_i32=2147483648), CompactProtoTestStruct(a_i32=-2147483649),
- CompactProtoTestStruct(a_i64=9223372036854775808), CompactProtoTestStruct(a_i64=-9223372036854775809)
+ def setUp(self):
+ self.v1obj = VersioningTestV1(
+ begin_in_both=12345,
+ old_string='aaa',
+ end_in_both=54321,
+ )
+
+ self.v2obj = VersioningTestV2(
+ begin_in_both=12345,
+ newint=1,
+ newbyte=2,
+ newshort=3,
+ newlong=4,
+ newdouble=5.0,
+ newstruct=Bonk(message="Hello!", type=123),
+ newlist=[7, 8, 9],
+ newset=set([42, 1, 8]),
+ newmap={1: 2, 2: 3},
+ newstring="Hola!",
+ end_in_both=54321,
+ )
+
+ self.bools = Bools(im_true=True, im_false=False)
+ self.bools_flipped = Bools(im_true=False, im_false=True)
+
+ self.large_deltas = LargeDeltas(
+ b1=self.bools,
+ b10=self.bools_flipped,
+ b100=self.bools,
+ check_true=True,
+ b1000=self.bools_flipped,
+ check_false=False,
+ vertwo2000=VersioningTestV2(newstruct=Bonk(message='World!', type=314)),
+ a_set2500=set(['lazy', 'brown', 'cow']),
+ vertwo3000=VersioningTestV2(newset=set([2, 3, 5, 7, 11])),
+ big_numbers=[2**8, 2**16, 2**31 - 1, -(2**31 - 1)]
+ )
+
+ self.compact_struct = CompactProtoTestStruct(
+ a_byte=127,
+ a_i16=32000,
+ a_i32=1000000000,
+ a_i64=0xffffffffff,
+ a_double=5.6789,
+ a_string="my string",
+ true_field=True,
+ false_field=False,
+ empty_struct_field=Empty(),
+ byte_list=[-127, -1, 0, 1, 127],
+ i16_list=[-1, 0, 1, 0x7fff],
+ i32_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0x7fffffff],
+ i64_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff],
+ double_list=[0.1, 0.2, 0.3],
+ string_list=["first", "second", "third"],
+ boolean_list=[True, True, True, False, False, False],
+ struct_list=[Empty(), Empty()],
+ byte_set=set([-127, -1, 0, 1, 127]),
+ i16_set=set([-1, 0, 1, 0x7fff]),
+ i32_set=set([1, 2, 3]),
+ i64_set=set([-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff]),
+ double_set=set([0.1, 0.2, 0.3]),
+ string_set=set(["first", "second", "third"]),
+ boolean_set=set([True, False]),
+ # struct_set=set([Empty()]), # unhashable instance
+ byte_byte_map={1: 2},
+ i16_byte_map={1: 1, -1: 1, 0x7fff: 1},
+ i32_byte_map={1: 1, -1: 1, 0x7fffffff: 1},
+ i64_byte_map={0: 1, 1: 1, -1: 1, 0x7fffffffffffffff: 1},
+ double_byte_map={-1.1: 1, 1.1: 1},
+ string_byte_map={"first": 1, "second": 2, "third": 3, "": 0},
+ boolean_byte_map={True: 1, False: 0},
+ byte_i16_map={1: 1, 2: -1, 3: 0x7fff},
+ byte_i32_map={1: 1, 2: -1, 3: 0x7fffffff},
+ byte_i64_map={1: 1, 2: -1, 3: 0x7fffffffffffffff},
+ byte_double_map={1: 0.1, 2: -0.1, 3: 1000000.1},
+ byte_string_map={1: "", 2: "blah", 3: "loooooooooooooong string"},
+ byte_boolean_map={1: True, 2: False},
+ # list_byte_map # unhashable
+ # set_byte_map={set([1, 2, 3]) : 1, set([0, 1]) : 2, set([]) : 0}, # unhashable
+ # map_byte_map # unhashable
+ byte_map_map={0: {}, 1: {1: 1}, 2: {1: 1, 2: 2}},
+ byte_set_map={0: set([]), 1: set([1]), 2: set([1, 2])},
+ byte_list_map={0: [], 1: [1], 2: [1, 2]},
+ )
+
+ self.nested_lists_i32x2 = NestedListsI32x2(
+ [
+ [1, 1, 2],
+ [2, 7, 9],
+ [3, 5, 8]
+ ]
+ )
+
+ self.nested_lists_i32x3 = NestedListsI32x3(
+ [
+ [
+ [2, 7, 9],
+ [3, 5, 8]
+ ],
+ [
+ [1, 1, 2],
+ [1, 4, 9]
]
-
- for value in bad_values:
- self.assertRaises(Exception, self._serialize, value)
+ ]
+ )
+
+ self.nested_mixedx2 = NestedMixedx2(int_set_list=[
+ set([1, 2, 3]),
+ set([1, 4, 9]),
+ set([1, 2, 3, 5, 8, 13, 21]),
+ set([-1, 0, 1])
+ ],
+ # note, the sets below are sets of chars, since the strings are iterated
+ map_int_strset={10: set('abc'), 20: set('def'), 30: set('GHI')},
+ map_int_strset_list=[
+ {10: set('abc'), 20: set('def'), 30: set('GHI')},
+ {100: set('lmn'), 200: set('opq'), 300: set('RST')},
+ {1000: set('uvw'), 2000: set('wxy'), 3000: set('XYZ')}
+ ]
+ )
+
+ self.nested_lists_bonk = NestedListsBonk(
+ [
+ [
+ [
+ Bonk(message='inner A first', type=1),
+ Bonk(message='inner A second', type=1)
+ ],
+ [
+ Bonk(message='inner B first', type=2),
+ Bonk(message='inner B second', type=2)
+ ]
+ ]
+ ]
+ )
+
+ self.list_bonks = ListBonks(
+ [
+ Bonk(message='inner A', type=1),
+ Bonk(message='inner B', type=2),
+ Bonk(message='inner C', type=0)
+ ]
+ )
+
+ def _serialize(self, obj):
+ trans = TTransport.TMemoryBuffer()
+ prot = self.protocol_factory.getProtocol(trans)
+ obj.write(prot)
+ return trans.getvalue()
+
+ def _deserialize(self, objtype, data):
+ prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
+ ret = objtype()
+ ret.read(prot)
+ return ret
+
+ def testForwards(self):
+ obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
+ self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
+ self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
+
+ def testBackwards(self):
+ obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
+ self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
+ self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
+
+ def testSerializeV1(self):
+ obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj))
+ self.assertEquals(obj, self.v1obj)
+
+ def testSerializeV2(self):
+ obj = self._deserialize(VersioningTestV2, self._serialize(self.v2obj))
+ self.assertEquals(obj, self.v2obj)
+
+ def testBools(self):
+ self.assertNotEquals(self.bools, self.bools_flipped)
+ self.assertNotEquals(self.bools, self.v1obj)
+ obj = self._deserialize(Bools, self._serialize(self.bools))
+ self.assertEquals(obj, self.bools)
+ obj = self._deserialize(Bools, self._serialize(self.bools_flipped))
+ self.assertEquals(obj, self.bools_flipped)
+ rep = repr(self.bools)
+ self.assertTrue(len(rep) > 0)
+
+ def testLargeDeltas(self):
+ # test large field deltas (meaningful in CompactProto only)
+ obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas))
+ self.assertEquals(obj, self.large_deltas)
+ rep = repr(self.large_deltas)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedListsI32x2(self):
+ obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2))
+ self.assertEquals(obj, self.nested_lists_i32x2)
+ rep = repr(self.nested_lists_i32x2)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedListsI32x3(self):
+ obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3))
+ self.assertEquals(obj, self.nested_lists_i32x3)
+ rep = repr(self.nested_lists_i32x3)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedMixedx2(self):
+ obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2))
+ self.assertEquals(obj, self.nested_mixedx2)
+ rep = repr(self.nested_mixedx2)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedListsBonk(self):
+ obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk))
+ self.assertEquals(obj, self.nested_lists_bonk)
+ rep = repr(self.nested_lists_bonk)
+ self.assertTrue(len(rep) > 0)
+
+ def testListBonks(self):
+ obj = self._deserialize(ListBonks, self._serialize(self.list_bonks))
+ self.assertEquals(obj, self.list_bonks)
+ rep = repr(self.list_bonks)
+ self.assertTrue(len(rep) > 0)
+
+ def testCompactStruct(self):
+ # test large field deltas (meaningful in CompactProto only)
+ obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct))
+ self.assertEquals(obj, self.compact_struct)
+ rep = repr(self.compact_struct)
+ self.assertTrue(len(rep) > 0)
+
+ def testIntegerLimits(self):
+ if (sys.version_info[0] == 2 and sys.version_info[1] <= 6):
+ print('Skipping testIntegerLimits for Python 2.6')
+ return
+ bad_values = [CompactProtoTestStruct(a_byte=128), CompactProtoTestStruct(a_byte=-129),
+ CompactProtoTestStruct(a_i16=32768), CompactProtoTestStruct(a_i16=-32769),
+ CompactProtoTestStruct(a_i32=2147483648), CompactProtoTestStruct(a_i32=-2147483649),
+ CompactProtoTestStruct(a_i64=9223372036854775808), CompactProtoTestStruct(a_i64=-9223372036854775809)
+ ]
+
+ for value in bad_values:
+ self.assertRaises(Exception, self._serialize, value)
class NormalBinaryTest(AbstractTest):
- protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
+ protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
class AcceleratedBinaryTest(AbstractTest):
- protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
+ protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
class CompactProtocolTest(AbstractTest):
- protocol_factory = TCompactProtocol.TCompactProtocolFactory()
+ protocol_factory = TCompactProtocol.TCompactProtocolFactory()
class JSONProtocolTest(AbstractTest):
- protocol_factory = TJSONProtocol.TJSONProtocolFactory()
+ protocol_factory = TJSONProtocol.TJSONProtocolFactory()
class AcceleratedFramedTest(unittest.TestCase):
- def testSplit(self):
- """Test FramedTransport and BinaryProtocolAccelerated
+ def testSplit(self):
+ """Test FramedTransport and BinaryProtocolAccelerated
+
+ Tests that TBinaryProtocolAccelerated and TFramedTransport
+ play nicely together when a read spans a frame"""
+
+ protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
+ bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z") + 1))
+
+ databuf = TTransport.TMemoryBuffer()
+ prot = protocol_factory.getProtocol(databuf)
+ prot.writeI32(42)
+ prot.writeString(bigstring)
+ prot.writeI16(24)
+ data = databuf.getvalue()
+ cutpoint = len(data) // 2
+ parts = [data[:cutpoint], data[cutpoint:]]
+
+ framed_buffer = TTransport.TMemoryBuffer()
+ framed_writer = TTransport.TFramedTransport(framed_buffer)
+ for part in parts:
+ framed_writer.write(part)
+ framed_writer.flush()
+ self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8)
+
+ # Recreate framed_buffer so we can read from it.
+ framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue())
+ framed_reader = TTransport.TFramedTransport(framed_buffer)
+ prot = protocol_factory.getProtocol(framed_reader)
+ self.assertEqual(prot.readI32(), 42)
+ self.assertEqual(prot.readString(), bigstring)
+ self.assertEqual(prot.readI16(), 24)
- Tests that TBinaryProtocolAccelerated and TFramedTransport
- play nicely together when a read spans a frame"""
-
- protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
- bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z")+1))
-
- databuf = TTransport.TMemoryBuffer()
- prot = protocol_factory.getProtocol(databuf)
- prot.writeI32(42)
- prot.writeString(bigstring)
- prot.writeI16(24)
- data = databuf.getvalue()
- cutpoint = len(data) // 2
- parts = [ data[:cutpoint], data[cutpoint:] ]
-
- framed_buffer = TTransport.TMemoryBuffer()
- framed_writer = TTransport.TFramedTransport(framed_buffer)
- for part in parts:
- framed_writer.write(part)
- framed_writer.flush()
- self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8)
-
- # Recreate framed_buffer so we can read from it.
- framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue())
- framed_reader = TTransport.TFramedTransport(framed_buffer)
- prot = protocol_factory.getProtocol(framed_reader)
- self.assertEqual(prot.readI32(), 42)
- self.assertEqual(prot.readString(), bigstring)
- self.assertEqual(prot.readI16(), 24)
class SerializersTest(unittest.TestCase):
- def testSerializeThenDeserialize(self):
- obj = Xtruct2(i32_thing=1,
- struct_thing=Xtruct(string_thing="foo"))
+ def testSerializeThenDeserialize(self):
+ obj = Xtruct2(i32_thing=1,
+ struct_thing=Xtruct(string_thing="foo"))
- s1 = serialize(obj)
- for i in range(10):
- self.assertEquals(s1, serialize(obj))
- objcopy = Xtruct2()
- deserialize(objcopy, serialize(obj))
- self.assertEquals(obj, objcopy)
+ s1 = serialize(obj)
+ for i in range(10):
+ self.assertEquals(s1, serialize(obj))
+ objcopy = Xtruct2()
+ deserialize(objcopy, serialize(obj))
+ self.assertEquals(obj, objcopy)
- obj = Xtruct(string_thing="bar")
- objcopy = Xtruct()
- deserialize(objcopy, serialize(obj))
- self.assertEquals(obj, objcopy)
+ obj = Xtruct(string_thing="bar")
+ objcopy = Xtruct()
+ deserialize(objcopy, serialize(obj))
+ self.assertEquals(obj, objcopy)
- # test booleans
- obj = Bools(im_true=True, im_false=False)
- objcopy = Bools()
- deserialize(objcopy, serialize(obj))
- self.assertEquals(obj, objcopy)
+ # test booleans
+ obj = Bools(im_true=True, im_false=False)
+ objcopy = Bools()
+ deserialize(objcopy, serialize(obj))
+ self.assertEquals(obj, objcopy)
- # test enums
- for num, name in Numberz._VALUES_TO_NAMES.items():
- obj = Bonk(message='enum Numberz value %d is string %s' % (num, name), type=num)
- objcopy = Bonk()
- deserialize(objcopy, serialize(obj))
- self.assertEquals(obj, objcopy)
+ # test enums
+ for num, name in Numberz._VALUES_TO_NAMES.items():
+ obj = Bonk(message='enum Numberz value %d is string %s' % (num, name), type=num)
+ objcopy = Bonk()
+ deserialize(objcopy, serialize(obj))
+ self.assertEquals(obj, objcopy)
def suite():
- suite = unittest.TestSuite()
- loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
- suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
- suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
- suite.addTest(loader.loadTestsFromTestCase(CompactProtocolTest))
- suite.addTest(loader.loadTestsFromTestCase(JSONProtocolTest))
- suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest))
- suite.addTest(loader.loadTestsFromTestCase(SerializersTest))
- return suite
+ suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
+ suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
+ suite.addTest(loader.loadTestsFromTestCase(CompactProtocolTest))
+ suite.addTest(loader.loadTestsFromTestCase(JSONProtocolTest))
+ suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest))
+ suite.addTest(loader.loadTestsFromTestCase(SerializersTest))
+ return suite
if __name__ == "__main__":
- unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
+ unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
diff --git a/test/py/TSimpleJSONProtocolTest.py b/test/py/TSimpleJSONProtocolTest.py
index 1ed8c1574..72987602b 100644
--- a/test/py/TSimpleJSONProtocolTest.py
+++ b/test/py/TSimpleJSONProtocolTest.py
@@ -28,81 +28,81 @@ import unittest
class SimpleJSONProtocolTest(unittest.TestCase):
- protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
-
- def _assertDictEqual(self, a, b, msg=None):
- if hasattr(self, 'assertDictEqual'):
- # assertDictEqual only in Python 2.7. Depends on your machine.
- self.assertDictEqual(a, b, msg)
- return
-
- # Substitute implementation not as good as unittest library's
- self.assertEquals(len(a), len(b), msg)
- for k, v in a.iteritems():
- self.assertTrue(k in b, msg)
- self.assertEquals(b.get(k), v, msg)
-
- def _serialize(self, obj):
- trans = TTransport.TMemoryBuffer()
- prot = self.protocol_factory.getProtocol(trans)
- obj.write(prot)
- return trans.getvalue()
-
- def _deserialize(self, objtype, data):
- prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
- ret = objtype()
- ret.read(prot)
- return ret
-
- def testWriteOnly(self):
- self.assertRaises(NotImplementedError,
- self._deserialize, VersioningTestV1, b'{}')
-
- def testSimpleMessage(self):
- v1obj = VersioningTestV1(
- begin_in_both=12345,
- old_string='aaa',
- end_in_both=54321)
- expected = dict(begin_in_both=v1obj.begin_in_both,
- old_string=v1obj.old_string,
- end_in_both=v1obj.end_in_both)
- actual = json.loads(self._serialize(v1obj).decode('ascii'))
-
- self._assertDictEqual(expected, actual)
-
- def testComplicated(self):
- v2obj = VersioningTestV2(
- begin_in_both=12345,
- newint=1,
- newbyte=2,
- newshort=3,
- newlong=4,
- newdouble=5.0,
- newstruct=Bonk(message="Hello!", type=123),
- newlist=[7, 8, 9],
- newset=set([42, 1, 8]),
- newmap={1: 2, 2: 3},
- newstring="Hola!",
- end_in_both=54321)
- expected = dict(begin_in_both=v2obj.begin_in_both,
- newint=v2obj.newint,
- newbyte=v2obj.newbyte,
- newshort=v2obj.newshort,
- newlong=v2obj.newlong,
- newdouble=v2obj.newdouble,
- newstruct=dict(message=v2obj.newstruct.message,
- type=v2obj.newstruct.type),
- newlist=v2obj.newlist,
- newset=list(v2obj.newset),
- newmap=v2obj.newmap,
- newstring=v2obj.newstring,
- end_in_both=v2obj.end_in_both)
-
- # Need to load/dump because map keys get escaped.
- expected = json.loads(json.dumps(expected))
- actual = json.loads(self._serialize(v2obj).decode('ascii'))
- self._assertDictEqual(expected, actual)
+ protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
+
+ def _assertDictEqual(self, a, b, msg=None):
+ if hasattr(self, 'assertDictEqual'):
+ # assertDictEqual only in Python 2.7. Depends on your machine.
+ self.assertDictEqual(a, b, msg)
+ return
+
+ # Substitute implementation not as good as unittest library's
+ self.assertEquals(len(a), len(b), msg)
+ for k, v in a.iteritems():
+ self.assertTrue(k in b, msg)
+ self.assertEquals(b.get(k), v, msg)
+
+ def _serialize(self, obj):
+ trans = TTransport.TMemoryBuffer()
+ prot = self.protocol_factory.getProtocol(trans)
+ obj.write(prot)
+ return trans.getvalue()
+
+ def _deserialize(self, objtype, data):
+ prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
+ ret = objtype()
+ ret.read(prot)
+ return ret
+
+ def testWriteOnly(self):
+ self.assertRaises(NotImplementedError,
+ self._deserialize, VersioningTestV1, b'{}')
+
+ def testSimpleMessage(self):
+ v1obj = VersioningTestV1(
+ begin_in_both=12345,
+ old_string='aaa',
+ end_in_both=54321)
+ expected = dict(begin_in_both=v1obj.begin_in_both,
+ old_string=v1obj.old_string,
+ end_in_both=v1obj.end_in_both)
+ actual = json.loads(self._serialize(v1obj).decode('ascii'))
+
+ self._assertDictEqual(expected, actual)
+
+ def testComplicated(self):
+ v2obj = VersioningTestV2(
+ begin_in_both=12345,
+ newint=1,
+ newbyte=2,
+ newshort=3,
+ newlong=4,
+ newdouble=5.0,
+ newstruct=Bonk(message="Hello!", type=123),
+ newlist=[7, 8, 9],
+ newset=set([42, 1, 8]),
+ newmap={1: 2, 2: 3},
+ newstring="Hola!",
+ end_in_both=54321)
+ expected = dict(begin_in_both=v2obj.begin_in_both,
+ newint=v2obj.newint,
+ newbyte=v2obj.newbyte,
+ newshort=v2obj.newshort,
+ newlong=v2obj.newlong,
+ newdouble=v2obj.newdouble,
+ newstruct=dict(message=v2obj.newstruct.message,
+ type=v2obj.newstruct.type),
+ newlist=v2obj.newlist,
+ newset=list(v2obj.newset),
+ newmap=v2obj.newmap,
+ newstring=v2obj.newstring,
+ end_in_both=v2obj.end_in_both)
+
+ # Need to load/dump because map keys get escaped.
+ expected = json.loads(json.dumps(expected))
+ actual = json.loads(self._serialize(v2obj).decode('ascii'))
+ self._assertDictEqual(expected, actual)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 347329e08..bc7650dcc 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -32,42 +32,42 @@ DEFAULT_LIBDIR_GLOB = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')
class AbstractTest(unittest.TestCase):
- def setUp(self):
- if options.http_path:
- self.transport = THttpClient.THttpClient(options.host, port=options.port, path=options.http_path)
- else:
- if options.ssl:
- from thrift.transport import TSSLSocket
- socket = TSSLSocket.TSSLSocket(options.host, options.port, validate=False)
- else:
- socket = TSocket.TSocket(options.host, options.port)
- # frame or buffer depending upon args
- self.transport = TTransport.TBufferedTransport(socket)
- if options.trans == 'framed':
- self.transport = TTransport.TFramedTransport(socket)
- elif options.trans == 'buffered':
- self.transport = TTransport.TBufferedTransport(socket)
- elif options.trans == '':
- raise AssertionError('Unknown --transport option: %s' % options.trans)
- if options.zlib:
- self.transport = TZlibTransport.TZlibTransport(self.transport, 9)
- self.transport.open()
- protocol = self.get_protocol(self.transport)
- self.client = ThriftTest.Client(protocol)
-
- def tearDown(self):
- self.transport.close()
-
- def testVoid(self):
- print('testVoid')
- self.client.testVoid()
-
- def testString(self):
- print('testString')
- self.assertEqual(self.client.testString('Python' * 20), 'Python' * 20)
- self.assertEqual(self.client.testString(''), '')
- s1 = u'\b\t\n/\\\\\r{}:パイソン"'
- s2 = u"""Afrikaans, Alemannisch, Aragonés, العربية, مصرى,
+ def setUp(self):
+ if options.http_path:
+ self.transport = THttpClient.THttpClient(options.host, port=options.port, path=options.http_path)
+ else:
+ if options.ssl:
+ from thrift.transport import TSSLSocket
+ socket = TSSLSocket.TSSLSocket(options.host, options.port, validate=False)
+ else:
+ socket = TSocket.TSocket(options.host, options.port)
+ # frame or buffer depending upon args
+ self.transport = TTransport.TBufferedTransport(socket)
+ if options.trans == 'framed':
+ self.transport = TTransport.TFramedTransport(socket)
+ elif options.trans == 'buffered':
+ self.transport = TTransport.TBufferedTransport(socket)
+ elif options.trans == '':
+ raise AssertionError('Unknown --transport option: %s' % options.trans)
+ if options.zlib:
+ self.transport = TZlibTransport.TZlibTransport(self.transport, 9)
+ self.transport.open()
+ protocol = self.get_protocol(self.transport)
+ self.client = ThriftTest.Client(protocol)
+
+ def tearDown(self):
+ self.transport.close()
+
+ def testVoid(self):
+ print('testVoid')
+ self.client.testVoid()
+
+ def testString(self):
+ print('testString')
+ self.assertEqual(self.client.testString('Python' * 20), 'Python' * 20)
+ self.assertEqual(self.client.testString(''), '')
+ s1 = u'\b\t\n/\\\\\r{}:パイソン"'
+ s2 = u"""Afrikaans, Alemannisch, Aragonés, العربية, مصرى,
Asturianu, Aymar aru, Azərbaycan, Башҡорт, Boarisch, Žemaitėška,
Беларуская, Беларуская (тарашкевіца), Български, Bamanankan,
বাংলা, Brezhoneg, Bosanski, Català, Mìng-dĕ̤ng-ngṳ̄, Нохчийн,
@@ -92,199 +92,199 @@ class AbstractTest(unittest.TestCase):
Türkçe, Татарча/Tatarça, Українська, اردو, Tiếng Việt, Volapük,
Walon, Winaray, 吴语, isiXhosa, ייִדיש, Yorùbá, Zeêuws, 中文,
Bân-lâm-gú, 粵語"""
- if sys.version_info[0] == 2 and os.environ.get('THRIFT_TEST_PY_NO_UTF8STRINGS'):
- s1 = s1.encode('utf8')
- s2 = s2.encode('utf8')
- self.assertEqual(self.client.testString(s1), s1)
- self.assertEqual(self.client.testString(s2), s2)
-
- def testBool(self):
- print('testBool')
- self.assertEqual(self.client.testBool(True), True)
- self.assertEqual(self.client.testBool(False), False)
-
- def testByte(self):
- print('testByte')
- self.assertEqual(self.client.testByte(63), 63)
- self.assertEqual(self.client.testByte(-127), -127)
-
- def testI32(self):
- print('testI32')
- self.assertEqual(self.client.testI32(-1), -1)
- self.assertEqual(self.client.testI32(0), 0)
-
- def testI64(self):
- print('testI64')
- self.assertEqual(self.client.testI64(1), 1)
- self.assertEqual(self.client.testI64(-34359738368), -34359738368)
-
- def testDouble(self):
- print('testDouble')
- self.assertEqual(self.client.testDouble(-5.235098235), -5.235098235)
- self.assertEqual(self.client.testDouble(0), 0)
- self.assertEqual(self.client.testDouble(-1), -1)
- self.assertEqual(self.client.testDouble(-0.000341012439638598279), -0.000341012439638598279)
-
- def testBinary(self):
- print('testBinary')
- val = bytearray([i for i in range(0, 256)])
- self.assertEqual(bytearray(self.client.testBinary(bytes(val))), val)
-
- def testStruct(self):
- print('testStruct')
- x = Xtruct()
- x.string_thing = "Zero"
- x.byte_thing = 1
- x.i32_thing = -3
- x.i64_thing = -5
- y = self.client.testStruct(x)
- self.assertEqual(y, x)
-
- def testNest(self):
- print('testNest')
- inner = Xtruct(string_thing="Zero", byte_thing=1, i32_thing=-3, i64_thing=-5)
- x = Xtruct2(struct_thing=inner, byte_thing=0, i32_thing=0)
- y = self.client.testNest(x)
- self.assertEqual(y, x)
-
- def testMap(self):
- print('testMap')
- x = {0: 1, 1: 2, 2: 3, 3: 4, -1: -2}
- y = self.client.testMap(x)
- self.assertEqual(y, x)
-
- def testSet(self):
- print('testSet')
- x = set([8, 1, 42])
- y = self.client.testSet(x)
- self.assertEqual(y, x)
-
- def testList(self):
- print('testList')
- x = [1, 4, 9, -42]
- y = self.client.testList(x)
- self.assertEqual(y, x)
-
- def testEnum(self):
- print('testEnum')
- x = Numberz.FIVE
- y = self.client.testEnum(x)
- self.assertEqual(y, x)
-
- def testTypedef(self):
- print('testTypedef')
- x = 0xffffffffffffff # 7 bytes of 0xff
- y = self.client.testTypedef(x)
- self.assertEqual(y, x)
-
- def testMapMap(self):
- print('testMapMap')
- x = {
- -4: {-4: -4, -3: -3, -2: -2, -1: -1},
- 4: {4: 4, 3: 3, 2: 2, 1: 1},
- }
- y = self.client.testMapMap(42)
- self.assertEqual(y, x)
-
- def testMulti(self):
- print('testMulti')
- xpected = Xtruct(string_thing='Hello2', byte_thing=74, i32_thing=0xff00ff, i64_thing=0xffffffffd0d0)
- y = self.client.testMulti(xpected.byte_thing,
- xpected.i32_thing,
- xpected.i64_thing,
- {0: 'abc'},
- Numberz.FIVE,
- 0xf0f0f0)
- self.assertEqual(y, xpected)
-
- def testException(self):
- print('testException')
- self.client.testException('Safe')
- try:
- self.client.testException('Xception')
- self.fail("should have gotten exception")
- except Xception as x:
- self.assertEqual(x.errorCode, 1001)
- self.assertEqual(x.message, 'Xception')
- # TODO ensure same behavior for repr within generated python variants
- # ensure exception's repr method works
- # x_repr = repr(x)
- # self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')')
-
- try:
- self.client.testException('TException')
- self.fail("should have gotten exception")
- except TException as x:
- pass
-
- # Should not throw
- self.client.testException('success')
-
- def testMultiException(self):
- print('testMultiException')
- try:
- self.client.testMultiException('Xception', 'ignore')
- except Xception as ex:
- self.assertEqual(ex.errorCode, 1001)
- self.assertEqual(ex.message, 'This is an Xception')
-
- try:
- self.client.testMultiException('Xception2', 'ignore')
- except Xception2 as ex:
- self.assertEqual(ex.errorCode, 2002)
- self.assertEqual(ex.struct_thing.string_thing, 'This is an Xception2')
-
- y = self.client.testMultiException('success', 'foobar')
- self.assertEqual(y.string_thing, 'foobar')
-
- def testOneway(self):
- print('testOneway')
- start = time.time()
- self.client.testOneway(1) # type is int, not float
- end = time.time()
- self.assertTrue(end - start < 3,
- "oneway sleep took %f sec" % (end - start))
-
- def testOnewayThenNormal(self):
- print('testOnewayThenNormal')
- self.client.testOneway(1) # type is int, not float
- self.assertEqual(self.client.testString('Python'), 'Python')
+ if sys.version_info[0] == 2 and os.environ.get('THRIFT_TEST_PY_NO_UTF8STRINGS'):
+ s1 = s1.encode('utf8')
+ s2 = s2.encode('utf8')
+ self.assertEqual(self.client.testString(s1), s1)
+ self.assertEqual(self.client.testString(s2), s2)
+
+ def testBool(self):
+ print('testBool')
+ self.assertEqual(self.client.testBool(True), True)
+ self.assertEqual(self.client.testBool(False), False)
+
+ def testByte(self):
+ print('testByte')
+ self.assertEqual(self.client.testByte(63), 63)
+ self.assertEqual(self.client.testByte(-127), -127)
+
+ def testI32(self):
+ print('testI32')
+ self.assertEqual(self.client.testI32(-1), -1)
+ self.assertEqual(self.client.testI32(0), 0)
+
+ def testI64(self):
+ print('testI64')
+ self.assertEqual(self.client.testI64(1), 1)
+ self.assertEqual(self.client.testI64(-34359738368), -34359738368)
+
+ def testDouble(self):
+ print('testDouble')
+ self.assertEqual(self.client.testDouble(-5.235098235), -5.235098235)
+ self.assertEqual(self.client.testDouble(0), 0)
+ self.assertEqual(self.client.testDouble(-1), -1)
+ self.assertEqual(self.client.testDouble(-0.000341012439638598279), -0.000341012439638598279)
+
+ def testBinary(self):
+ print('testBinary')
+ val = bytearray([i for i in range(0, 256)])
+ self.assertEqual(bytearray(self.client.testBinary(bytes(val))), val)
+
+ def testStruct(self):
+ print('testStruct')
+ x = Xtruct()
+ x.string_thing = "Zero"
+ x.byte_thing = 1
+ x.i32_thing = -3
+ x.i64_thing = -5
+ y = self.client.testStruct(x)
+ self.assertEqual(y, x)
+
+ def testNest(self):
+ print('testNest')
+ inner = Xtruct(string_thing="Zero", byte_thing=1, i32_thing=-3, i64_thing=-5)
+ x = Xtruct2(struct_thing=inner, byte_thing=0, i32_thing=0)
+ y = self.client.testNest(x)
+ self.assertEqual(y, x)
+
+ def testMap(self):
+ print('testMap')
+ x = {0: 1, 1: 2, 2: 3, 3: 4, -1: -2}
+ y = self.client.testMap(x)
+ self.assertEqual(y, x)
+
+ def testSet(self):
+ print('testSet')
+ x = set([8, 1, 42])
+ y = self.client.testSet(x)
+ self.assertEqual(y, x)
+
+ def testList(self):
+ print('testList')
+ x = [1, 4, 9, -42]
+ y = self.client.testList(x)
+ self.assertEqual(y, x)
+
+ def testEnum(self):
+ print('testEnum')
+ x = Numberz.FIVE
+ y = self.client.testEnum(x)
+ self.assertEqual(y, x)
+
+ def testTypedef(self):
+ print('testTypedef')
+ x = 0xffffffffffffff # 7 bytes of 0xff
+ y = self.client.testTypedef(x)
+ self.assertEqual(y, x)
+
+ def testMapMap(self):
+ print('testMapMap')
+ x = {
+ -4: {-4: -4, -3: -3, -2: -2, -1: -1},
+ 4: {4: 4, 3: 3, 2: 2, 1: 1},
+ }
+ y = self.client.testMapMap(42)
+ self.assertEqual(y, x)
+
+ def testMulti(self):
+ print('testMulti')
+ xpected = Xtruct(string_thing='Hello2', byte_thing=74, i32_thing=0xff00ff, i64_thing=0xffffffffd0d0)
+ y = self.client.testMulti(xpected.byte_thing,
+ xpected.i32_thing,
+ xpected.i64_thing,
+ {0: 'abc'},
+ Numberz.FIVE,
+ 0xf0f0f0)
+ self.assertEqual(y, xpected)
+
+ def testException(self):
+ print('testException')
+ self.client.testException('Safe')
+ try:
+ self.client.testException('Xception')
+ self.fail("should have gotten exception")
+ except Xception as x:
+ self.assertEqual(x.errorCode, 1001)
+ self.assertEqual(x.message, 'Xception')
+ # TODO ensure same behavior for repr within generated python variants
+ # ensure exception's repr method works
+ # x_repr = repr(x)
+ # self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')')
+
+ try:
+ self.client.testException('TException')
+ self.fail("should have gotten exception")
+ except TException as x:
+ pass
+
+ # Should not throw
+ self.client.testException('success')
+
+ def testMultiException(self):
+ print('testMultiException')
+ try:
+ self.client.testMultiException('Xception', 'ignore')
+ except Xception as ex:
+ self.assertEqual(ex.errorCode, 1001)
+ self.assertEqual(ex.message, 'This is an Xception')
+
+ try:
+ self.client.testMultiException('Xception2', 'ignore')
+ except Xception2 as ex:
+ self.assertEqual(ex.errorCode, 2002)
+ self.assertEqual(ex.struct_thing.string_thing, 'This is an Xception2')
+
+ y = self.client.testMultiException('success', 'foobar')
+ self.assertEqual(y.string_thing, 'foobar')
+
+ def testOneway(self):
+ print('testOneway')
+ start = time.time()
+ self.client.testOneway(1) # type is int, not float
+ end = time.time()
+ self.assertTrue(end - start < 3,
+ "oneway sleep took %f sec" % (end - start))
+
+ def testOnewayThenNormal(self):
+ print('testOnewayThenNormal')
+ self.client.testOneway(1) # type is int, not float
+ self.assertEqual(self.client.testString('Python'), 'Python')
class NormalBinaryTest(AbstractTest):
- def get_protocol(self, transport):
- return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
+ def get_protocol(self, transport):
+ return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
class CompactTest(AbstractTest):
- def get_protocol(self, transport):
- return TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
+ def get_protocol(self, transport):
+ return TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
class JSONTest(AbstractTest):
- def get_protocol(self, transport):
- return TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
+ def get_protocol(self, transport):
+ return TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
class AcceleratedBinaryTest(AbstractTest):
- def get_protocol(self, transport):
- return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(transport)
+ def get_protocol(self, transport):
+ return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(transport)
def suite():
- suite = unittest.TestSuite()
- loader = unittest.TestLoader()
- if options.proto == 'binary': # look for --proto on cmdline
- suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
- elif options.proto == 'accel':
- suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
- elif options.proto == 'compact':
- suite.addTest(loader.loadTestsFromTestCase(CompactTest))
- elif options.proto == 'json':
- suite.addTest(loader.loadTestsFromTestCase(JSONTest))
- else:
- raise AssertionError('Unknown protocol given with --protocol: %s' % options.proto)
- return suite
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ if options.proto == 'binary': # look for --proto on cmdline
+ suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
+ elif options.proto == 'accel':
+ suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
+ elif options.proto == 'compact':
+ suite.addTest(loader.loadTestsFromTestCase(CompactTest))
+ elif options.proto == 'json':
+ suite.addTest(loader.loadTestsFromTestCase(JSONTest))
+ else:
+ raise AssertionError('Unknown protocol given with --protocol: %s' % options.proto)
+ return suite
class OwnArgsTestProgram(unittest.TestProgram):
@@ -296,50 +296,50 @@ class OwnArgsTestProgram(unittest.TestProgram):
self.createTests()
if __name__ == "__main__":
- parser = OptionParser()
- parser.add_option('--libpydir', type='string', dest='libpydir',
- help='include this directory in sys.path for locating library code')
- parser.add_option('--genpydir', type='string', dest='genpydir',
- help='include this directory in sys.path for locating generated code')
- parser.add_option("--port", type="int", dest="port",
- help="connect to server at port")
- parser.add_option("--host", type="string", dest="host",
- help="connect to server")
- parser.add_option("--zlib", action="store_true", dest="zlib",
- help="use zlib wrapper for compressed transport")
- parser.add_option("--ssl", action="store_true", dest="ssl",
- help="use SSL for encrypted transport")
- parser.add_option("--http", dest="http_path",
- help="Use the HTTP transport with the specified path")
- parser.add_option('-v', '--verbose', action="store_const",
- dest="verbose", const=2,
- help="verbose output")
- parser.add_option('-q', '--quiet', action="store_const",
- dest="verbose", const=0,
- help="minimal output")
- parser.add_option('--protocol', dest="proto", type="string",
- help="protocol to use, one of: accel, binary, compact, json")
- parser.add_option('--transport', dest="trans", type="string",
- help="transport to use, one of: buffered, framed")
- parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
- options, args = parser.parse_args()
-
- if options.genpydir:
- sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
- if options.libpydir:
- sys.path.insert(0, glob.glob(options.libpydir)[0])
- else:
- sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0])
-
- from ThriftTest import ThriftTest
- from ThriftTest.ttypes import Xtruct, Xtruct2, Numberz, Xception, Xception2
- from thrift.Thrift import TException
- from thrift.transport import TTransport
- from thrift.transport import TSocket
- from thrift.transport import THttpClient
- from thrift.transport import TZlibTransport
- from thrift.protocol import TBinaryProtocol
- from thrift.protocol import TCompactProtocol
- from thrift.protocol import TJSONProtocol
-
- OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1))
+ parser = OptionParser()
+ parser.add_option('--libpydir', type='string', dest='libpydir',
+ help='include this directory in sys.path for locating library code')
+ parser.add_option('--genpydir', type='string', dest='genpydir',
+ help='include this directory in sys.path for locating generated code')
+ parser.add_option("--port", type="int", dest="port",
+ help="connect to server at port")
+ parser.add_option("--host", type="string", dest="host",
+ help="connect to server")
+ parser.add_option("--zlib", action="store_true", dest="zlib",
+ help="use zlib wrapper for compressed transport")
+ parser.add_option("--ssl", action="store_true", dest="ssl",
+ help="use SSL for encrypted transport")
+ parser.add_option("--http", dest="http_path",
+ help="Use the HTTP transport with the specified path")
+ parser.add_option('-v', '--verbose', action="store_const",
+ dest="verbose", const=2,
+ help="verbose output")
+ parser.add_option('-q', '--quiet', action="store_const",
+ dest="verbose", const=0,
+ help="minimal output")
+ parser.add_option('--protocol', dest="proto", type="string",
+ help="protocol to use, one of: accel, binary, compact, json")
+ parser.add_option('--transport', dest="trans", type="string",
+ help="transport to use, one of: buffered, framed")
+ parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
+ options, args = parser.parse_args()
+
+ if options.genpydir:
+ sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
+ if options.libpydir:
+ sys.path.insert(0, glob.glob(options.libpydir)[0])
+ else:
+ sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0])
+
+ from ThriftTest import ThriftTest
+ from ThriftTest.ttypes import Xtruct, Xtruct2, Numberz, Xception, Xception2
+ from thrift.Thrift import TException
+ from thrift.transport import TTransport
+ from thrift.transport import TSocket
+ from thrift.transport import THttpClient
+ from thrift.transport import TZlibTransport
+ from thrift.protocol import TBinaryProtocol
+ from thrift.protocol import TCompactProtocol
+ from thrift.protocol import TJSONProtocol
+
+ OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1))
diff --git a/test/py/TestEof.py b/test/py/TestEof.py
index 661463822..0239fc621 100755
--- a/test/py/TestEof.py
+++ b/test/py/TestEof.py
@@ -28,99 +28,99 @@ import unittest
class TestEof(unittest.TestCase):
- def make_data(self, pfactory=None):
- trans = TTransport.TMemoryBuffer()
- if pfactory:
- prot = pfactory.getProtocol(trans)
- else:
- prot = TBinaryProtocol.TBinaryProtocol(trans)
-
- x = Xtruct()
- x.string_thing = "Zero"
- x.byte_thing = 0
-
- x.write(prot)
-
- x = Xtruct()
- x.string_thing = "One"
- x.byte_thing = 1
-
- x.write(prot)
-
- return trans.getvalue()
-
- def testTransportReadAll(self):
- """Test that readAll on any type of transport throws an EOFError"""
- trans = TTransport.TMemoryBuffer(self.make_data())
- trans.readAll(1)
-
- try:
- trans.readAll(10000)
- except EOFError:
- return
-
- self.fail("Should have gotten EOFError")
-
- def eofTestHelper(self, pfactory):
- trans = TTransport.TMemoryBuffer(self.make_data(pfactory))
- prot = pfactory.getProtocol(trans)
-
- x = Xtruct()
- x.read(prot)
- self.assertEqual(x.string_thing, "Zero")
- self.assertEqual(x.byte_thing, 0)
-
- x = Xtruct()
- x.read(prot)
- self.assertEqual(x.string_thing, "One")
- self.assertEqual(x.byte_thing, 1)
-
- try:
- x = Xtruct()
- x.read(prot)
- except EOFError:
- return
-
- self.fail("Should have gotten EOFError")
-
- def eofTestHelperStress(self, pfactory):
- """Teest the ability of TBinaryProtocol to deal with the removal of every byte in the file"""
- # TODO: we should make sure this covers more of the code paths
-
- data = self.make_data(pfactory)
- for i in range(0, len(data) + 1):
- trans = TTransport.TMemoryBuffer(data[0:i])
- prot = pfactory.getProtocol(trans)
- try:
+ def make_data(self, pfactory=None):
+ trans = TTransport.TMemoryBuffer()
+ if pfactory:
+ prot = pfactory.getProtocol(trans)
+ else:
+ prot = TBinaryProtocol.TBinaryProtocol(trans)
+
x = Xtruct()
- x.read(prot)
- x.read(prot)
- x.read(prot)
- except EOFError:
- continue
- self.fail("Should have gotten an EOFError")
+ x.string_thing = "Zero"
+ x.byte_thing = 0
+
+ x.write(prot)
+
+ x = Xtruct()
+ x.string_thing = "One"
+ x.byte_thing = 1
+
+ x.write(prot)
- def testBinaryProtocolEof(self):
- """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream"""
- self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory())
- self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory())
+ return trans.getvalue()
- def testBinaryProtocolAcceleratedEof(self):
- """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream"""
- self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
- self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
+ def testTransportReadAll(self):
+ """Test that readAll on any type of transport throws an EOFError"""
+ trans = TTransport.TMemoryBuffer(self.make_data())
+ trans.readAll(1)
- def testCompactProtocolEof(self):
- """Test that TCompactProtocol throws an EOFError when it reaches the end of the stream"""
- self.eofTestHelper(TCompactProtocol.TCompactProtocolFactory())
- self.eofTestHelperStress(TCompactProtocol.TCompactProtocolFactory())
+ try:
+ trans.readAll(10000)
+ except EOFError:
+ return
+
+ self.fail("Should have gotten EOFError")
+
+ def eofTestHelper(self, pfactory):
+ trans = TTransport.TMemoryBuffer(self.make_data(pfactory))
+ prot = pfactory.getProtocol(trans)
+
+ x = Xtruct()
+ x.read(prot)
+ self.assertEqual(x.string_thing, "Zero")
+ self.assertEqual(x.byte_thing, 0)
+
+ x = Xtruct()
+ x.read(prot)
+ self.assertEqual(x.string_thing, "One")
+ self.assertEqual(x.byte_thing, 1)
+
+ try:
+ x = Xtruct()
+ x.read(prot)
+ except EOFError:
+ return
+
+ self.fail("Should have gotten EOFError")
+
+ def eofTestHelperStress(self, pfactory):
+ """Teest the ability of TBinaryProtocol to deal with the removal of every byte in the file"""
+ # TODO: we should make sure this covers more of the code paths
+
+ data = self.make_data(pfactory)
+ for i in range(0, len(data) + 1):
+ trans = TTransport.TMemoryBuffer(data[0:i])
+ prot = pfactory.getProtocol(trans)
+ try:
+ x = Xtruct()
+ x.read(prot)
+ x.read(prot)
+ x.read(prot)
+ except EOFError:
+ continue
+ self.fail("Should have gotten an EOFError")
+
+ def testBinaryProtocolEof(self):
+ """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream"""
+ self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory())
+ self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory())
+
+ def testBinaryProtocolAcceleratedEof(self):
+ """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream"""
+ self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
+ self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
+
+ def testCompactProtocolEof(self):
+ """Test that TCompactProtocol throws an EOFError when it reaches the end of the stream"""
+ self.eofTestHelper(TCompactProtocol.TCompactProtocolFactory())
+ self.eofTestHelperStress(TCompactProtocol.TCompactProtocolFactory())
def suite():
- suite = unittest.TestSuite()
- loader = unittest.TestLoader()
- suite.addTest(loader.loadTestsFromTestCase(TestEof))
- return suite
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ suite.addTest(loader.loadTestsFromTestCase(TestEof))
+ return suite
if __name__ == "__main__":
- unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
+ unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
diff --git a/test/py/TestFrozen.py b/test/py/TestFrozen.py
index 76750ad88..30a6a557f 100755
--- a/test/py/TestFrozen.py
+++ b/test/py/TestFrozen.py
@@ -28,89 +28,89 @@ import unittest
class TestFrozenBase(unittest.TestCase):
- def _roundtrip(self, src, dst):
- otrans = TTransport.TMemoryBuffer()
- optoro = self.protocol(otrans)
- src.write(optoro)
- itrans = TTransport.TMemoryBuffer(otrans.getvalue())
- iproto = self.protocol(itrans)
- return dst.read(iproto) or dst
-
- def test_dict_is_hashable_only_after_frozen(self):
- d0 = {}
- self.assertFalse(isinstance(d0, collections.Hashable))
- d1 = TFrozenDict(d0)
- self.assertTrue(isinstance(d1, collections.Hashable))
-
- def test_struct_with_collection_fields(self):
- pass
-
- def test_set(self):
- """Test that annotated set field can be serialized and deserialized"""
- x = CompactProtoTestStruct(set_byte_map={
- frozenset([42, 100, -100]): 99,
- frozenset([0]): 100,
- frozenset([]): 0,
- })
- x2 = self._roundtrip(x, CompactProtoTestStruct())
- self.assertEqual(x2.set_byte_map[frozenset([42, 100, -100])], 99)
- self.assertEqual(x2.set_byte_map[frozenset([0])], 100)
- self.assertEqual(x2.set_byte_map[frozenset([])], 0)
-
- def test_map(self):
- """Test that annotated map field can be serialized and deserialized"""
- x = CompactProtoTestStruct(map_byte_map={
- TFrozenDict({42: 42, 100: -100}): 99,
- TFrozenDict({0: 0}): 100,
- TFrozenDict({}): 0,
- })
- x2 = self._roundtrip(x, CompactProtoTestStruct())
- self.assertEqual(x2.map_byte_map[TFrozenDict({42: 42, 100: -100})], 99)
- self.assertEqual(x2.map_byte_map[TFrozenDict({0: 0})], 100)
- self.assertEqual(x2.map_byte_map[TFrozenDict({})], 0)
-
- def test_list(self):
- """Test that annotated list field can be serialized and deserialized"""
- x = CompactProtoTestStruct(list_byte_map={
- (42, 100, -100): 99,
- (0,): 100,
- (): 0,
- })
- x2 = self._roundtrip(x, CompactProtoTestStruct())
- self.assertEqual(x2.list_byte_map[(42, 100, -100)], 99)
- self.assertEqual(x2.list_byte_map[(0,)], 100)
- self.assertEqual(x2.list_byte_map[()], 0)
-
- def test_empty_struct(self):
- """Test that annotated empty struct can be serialized and deserialized"""
- x = CompactProtoTestStruct(empty_struct_field=Empty())
- x2 = self._roundtrip(x, CompactProtoTestStruct())
- self.assertEqual(x2.empty_struct_field, Empty())
-
- def test_struct(self):
- """Test that annotated struct can be serialized and deserialized"""
- x = Wrapper(foo=Empty())
- self.assertEqual(x.foo, Empty())
- x2 = self._roundtrip(x, Wrapper)
- self.assertEqual(x2.foo, Empty())
+ def _roundtrip(self, src, dst):
+ otrans = TTransport.TMemoryBuffer()
+ optoro = self.protocol(otrans)
+ src.write(optoro)
+ itrans = TTransport.TMemoryBuffer(otrans.getvalue())
+ iproto = self.protocol(itrans)
+ return dst.read(iproto) or dst
+
+ def test_dict_is_hashable_only_after_frozen(self):
+ d0 = {}
+ self.assertFalse(isinstance(d0, collections.Hashable))
+ d1 = TFrozenDict(d0)
+ self.assertTrue(isinstance(d1, collections.Hashable))
+
+ def test_struct_with_collection_fields(self):
+ pass
+
+ def test_set(self):
+ """Test that annotated set field can be serialized and deserialized"""
+ x = CompactProtoTestStruct(set_byte_map={
+ frozenset([42, 100, -100]): 99,
+ frozenset([0]): 100,
+ frozenset([]): 0,
+ })
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.set_byte_map[frozenset([42, 100, -100])], 99)
+ self.assertEqual(x2.set_byte_map[frozenset([0])], 100)
+ self.assertEqual(x2.set_byte_map[frozenset([])], 0)
+
+ def test_map(self):
+ """Test that annotated map field can be serialized and deserialized"""
+ x = CompactProtoTestStruct(map_byte_map={
+ TFrozenDict({42: 42, 100: -100}): 99,
+ TFrozenDict({0: 0}): 100,
+ TFrozenDict({}): 0,
+ })
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.map_byte_map[TFrozenDict({42: 42, 100: -100})], 99)
+ self.assertEqual(x2.map_byte_map[TFrozenDict({0: 0})], 100)
+ self.assertEqual(x2.map_byte_map[TFrozenDict({})], 0)
+
+ def test_list(self):
+ """Test that annotated list field can be serialized and deserialized"""
+ x = CompactProtoTestStruct(list_byte_map={
+ (42, 100, -100): 99,
+ (0,): 100,
+ (): 0,
+ })
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.list_byte_map[(42, 100, -100)], 99)
+ self.assertEqual(x2.list_byte_map[(0,)], 100)
+ self.assertEqual(x2.list_byte_map[()], 0)
+
+ def test_empty_struct(self):
+ """Test that annotated empty struct can be serialized and deserialized"""
+ x = CompactProtoTestStruct(empty_struct_field=Empty())
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.empty_struct_field, Empty())
+
+ def test_struct(self):
+ """Test that annotated struct can be serialized and deserialized"""
+ x = Wrapper(foo=Empty())
+ self.assertEqual(x.foo, Empty())
+ x2 = self._roundtrip(x, Wrapper)
+ self.assertEqual(x2.foo, Empty())
class TestFrozen(TestFrozenBase):
- def protocol(self, trans):
- return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(trans)
+ def protocol(self, trans):
+ return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(trans)
class TestFrozenAccelerated(TestFrozenBase):
- def protocol(self, trans):
- return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(trans)
+ def protocol(self, trans):
+ return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(trans)
def suite():
- suite = unittest.TestSuite()
- loader = unittest.TestLoader()
- suite.addTest(loader.loadTestsFromTestCase(TestFrozen))
- suite.addTest(loader.loadTestsFromTestCase(TestFrozenAccelerated))
- return suite
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ suite.addTest(loader.loadTestsFromTestCase(TestFrozen))
+ suite.addTest(loader.loadTestsFromTestCase(TestFrozenAccelerated))
+ return suite
if __name__ == "__main__":
- unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
+ unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index f12a9fe76..ef93509b2 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -32,287 +32,287 @@ DEFAULT_LIBDIR_GLOB = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')
class TestHandler(object):
- def testVoid(self):
- if options.verbose > 1:
- logging.info('testVoid()')
-
- def testString(self, str):
- if options.verbose > 1:
- logging.info('testString(%s)' % str)
- return str
-
- def testBool(self, boolean):
- if options.verbose > 1:
- logging.info('testBool(%s)' % str(boolean).lower())
- return boolean
-
- def testByte(self, byte):
- if options.verbose > 1:
- logging.info('testByte(%d)' % byte)
- return byte
-
- def testI16(self, i16):
- if options.verbose > 1:
- logging.info('testI16(%d)' % i16)
- return i16
-
- def testI32(self, i32):
- if options.verbose > 1:
- logging.info('testI32(%d)' % i32)
- return i32
-
- def testI64(self, i64):
- if options.verbose > 1:
- logging.info('testI64(%d)' % i64)
- return i64
-
- def testDouble(self, dub):
- if options.verbose > 1:
- logging.info('testDouble(%f)' % dub)
- return dub
-
- def testBinary(self, thing):
- if options.verbose > 1:
- logging.info('testBinary()') # TODO: hex output
- return thing
-
- def testStruct(self, thing):
- if options.verbose > 1:
- logging.info('testStruct({%s, %s, %s, %s})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing))
- return thing
-
- def testException(self, arg):
- # if options.verbose > 1:
- logging.info('testException(%s)' % arg)
- if arg == 'Xception':
- raise Xception(errorCode=1001, message=arg)
- elif arg == 'TException':
- raise TException(message='This is a TException')
-
- def testMultiException(self, arg0, arg1):
- if options.verbose > 1:
- logging.info('testMultiException(%s, %s)' % (arg0, arg1))
- if arg0 == 'Xception':
- raise Xception(errorCode=1001, message='This is an Xception')
- elif arg0 == 'Xception2':
- raise Xception2(
- errorCode=2002,
- struct_thing=Xtruct(string_thing='This is an Xception2'))
- return Xtruct(string_thing=arg1)
-
- def testOneway(self, seconds):
- if options.verbose > 1:
- logging.info('testOneway(%d) => sleeping...' % seconds)
- time.sleep(seconds / 3) # be quick
- if options.verbose > 1:
- logging.info('done sleeping')
-
- def testNest(self, thing):
- if options.verbose > 1:
- logging.info('testNest(%s)' % thing)
- return thing
-
- def testMap(self, thing):
- if options.verbose > 1:
- logging.info('testMap(%s)' % thing)
- return thing
-
- def testStringMap(self, thing):
- if options.verbose > 1:
- logging.info('testStringMap(%s)' % thing)
- return thing
-
- def testSet(self, thing):
- if options.verbose > 1:
- logging.info('testSet(%s)' % thing)
- return thing
-
- def testList(self, thing):
- if options.verbose > 1:
- logging.info('testList(%s)' % thing)
- return thing
-
- def testEnum(self, thing):
- if options.verbose > 1:
- logging.info('testEnum(%s)' % thing)
- return thing
-
- def testTypedef(self, thing):
- if options.verbose > 1:
- logging.info('testTypedef(%s)' % thing)
- return thing
-
- def testMapMap(self, thing):
- if options.verbose > 1:
- logging.info('testMapMap(%s)' % thing)
- return {
- -4: {
- -4: -4,
- -3: -3,
- -2: -2,
- -1: -1,
- },
- 4: {
- 4: 4,
- 3: 3,
- 2: 2,
- 1: 1,
- },
- }
-
- def testInsanity(self, argument):
- if options.verbose > 1:
- logging.info('testInsanity(%s)' % argument)
- return {
- 1: {
- 2: argument,
- 3: argument,
- },
- 2: {6: Insanity()},
- }
-
- def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5):
- if options.verbose > 1:
- logging.info('testMulti(%s)' % [arg0, arg1, arg2, arg3, arg4, arg5])
- return Xtruct(string_thing='Hello2',
- byte_thing=arg0, i32_thing=arg1, i64_thing=arg2)
+ def testVoid(self):
+ if options.verbose > 1:
+ logging.info('testVoid()')
+
+ def testString(self, str):
+ if options.verbose > 1:
+ logging.info('testString(%s)' % str)
+ return str
+
+ def testBool(self, boolean):
+ if options.verbose > 1:
+ logging.info('testBool(%s)' % str(boolean).lower())
+ return boolean
+
+ def testByte(self, byte):
+ if options.verbose > 1:
+ logging.info('testByte(%d)' % byte)
+ return byte
+
+ def testI16(self, i16):
+ if options.verbose > 1:
+ logging.info('testI16(%d)' % i16)
+ return i16
+
+ def testI32(self, i32):
+ if options.verbose > 1:
+ logging.info('testI32(%d)' % i32)
+ return i32
+
+ def testI64(self, i64):
+ if options.verbose > 1:
+ logging.info('testI64(%d)' % i64)
+ return i64
+
+ def testDouble(self, dub):
+ if options.verbose > 1:
+ logging.info('testDouble(%f)' % dub)
+ return dub
+
+ def testBinary(self, thing):
+ if options.verbose > 1:
+ logging.info('testBinary()') # TODO: hex output
+ return thing
+
+ def testStruct(self, thing):
+ if options.verbose > 1:
+ logging.info('testStruct({%s, %s, %s, %s})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing))
+ return thing
+
+ def testException(self, arg):
+ # if options.verbose > 1:
+ logging.info('testException(%s)' % arg)
+ if arg == 'Xception':
+ raise Xception(errorCode=1001, message=arg)
+ elif arg == 'TException':
+ raise TException(message='This is a TException')
+
+ def testMultiException(self, arg0, arg1):
+ if options.verbose > 1:
+ logging.info('testMultiException(%s, %s)' % (arg0, arg1))
+ if arg0 == 'Xception':
+ raise Xception(errorCode=1001, message='This is an Xception')
+ elif arg0 == 'Xception2':
+ raise Xception2(
+ errorCode=2002,
+ struct_thing=Xtruct(string_thing='This is an Xception2'))
+ return Xtruct(string_thing=arg1)
+
+ def testOneway(self, seconds):
+ if options.verbose > 1:
+ logging.info('testOneway(%d) => sleeping...' % seconds)
+ time.sleep(seconds / 3) # be quick
+ if options.verbose > 1:
+ logging.info('done sleeping')
+
+ def testNest(self, thing):
+ if options.verbose > 1:
+ logging.info('testNest(%s)' % thing)
+ return thing
+
+ def testMap(self, thing):
+ if options.verbose > 1:
+ logging.info('testMap(%s)' % thing)
+ return thing
+
+ def testStringMap(self, thing):
+ if options.verbose > 1:
+ logging.info('testStringMap(%s)' % thing)
+ return thing
+
+ def testSet(self, thing):
+ if options.verbose > 1:
+ logging.info('testSet(%s)' % thing)
+ return thing
+
+ def testList(self, thing):
+ if options.verbose > 1:
+ logging.info('testList(%s)' % thing)
+ return thing
+
+ def testEnum(self, thing):
+ if options.verbose > 1:
+ logging.info('testEnum(%s)' % thing)
+ return thing
+
+ def testTypedef(self, thing):
+ if options.verbose > 1:
+ logging.info('testTypedef(%s)' % thing)
+ return thing
+
+ def testMapMap(self, thing):
+ if options.verbose > 1:
+ logging.info('testMapMap(%s)' % thing)
+ return {
+ -4: {
+ -4: -4,
+ -3: -3,
+ -2: -2,
+ -1: -1,
+ },
+ 4: {
+ 4: 4,
+ 3: 3,
+ 2: 2,
+ 1: 1,
+ },
+ }
+
+ def testInsanity(self, argument):
+ if options.verbose > 1:
+ logging.info('testInsanity(%s)' % argument)
+ return {
+ 1: {
+ 2: argument,
+ 3: argument,
+ },
+ 2: {6: Insanity()},
+ }
+
+ def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5):
+ if options.verbose > 1:
+ logging.info('testMulti(%s)' % [arg0, arg1, arg2, arg3, arg4, arg5])
+ return Xtruct(string_thing='Hello2',
+ byte_thing=arg0, i32_thing=arg1, i64_thing=arg2)
def main(options):
- # set up the protocol factory form the --protocol option
- prot_factories = {
- 'binary': TBinaryProtocol.TBinaryProtocolFactory,
- 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
- 'compact': TCompactProtocol.TCompactProtocolFactory,
- 'json': TJSONProtocol.TJSONProtocolFactory,
- }
- pfactory_cls = prot_factories.get(options.proto, None)
- if pfactory_cls is None:
- raise AssertionError('Unknown --protocol option: %s' % options.proto)
- pfactory = pfactory_cls()
- try:
- pfactory.string_length_limit = options.string_limit
- pfactory.container_length_limit = options.container_limit
- except:
- # Ignore errors for those protocols that does not support length limit
- pass
-
- # get the server type (TSimpleServer, TNonblockingServer, etc...)
- if len(args) > 1:
- raise AssertionError('Only one server type may be specified, not multiple types.')
- server_type = args[0]
-
- # Set up the handler and processor objects
- handler = TestHandler()
- processor = ThriftTest.Processor(handler)
-
- # Handle THttpServer as a special case
- if server_type == 'THttpServer':
- server = THttpServer.THttpServer(processor, ('', options.port), pfactory)
- server.serve()
- sys.exit(0)
-
- # set up server transport and transport factory
-
- abs_key_path = os.path.join(os.path.dirname(SCRIPT_DIR), 'keys', 'server.pem')
-
- host = None
- if options.ssl:
- from thrift.transport import TSSLSocket
- transport = TSSLSocket.TSSLServerSocket(host, options.port, certfile=abs_key_path)
- else:
- transport = TSocket.TServerSocket(host, options.port)
- tfactory = TTransport.TBufferedTransportFactory()
- if options.trans == 'buffered':
- tfactory = TTransport.TBufferedTransportFactory()
- elif options.trans == 'framed':
- tfactory = TTransport.TFramedTransportFactory()
- elif options.trans == '':
- raise AssertionError('Unknown --transport option: %s' % options.trans)
- else:
+ # set up the protocol factory form the --protocol option
+ prot_factories = {
+ 'binary': TBinaryProtocol.TBinaryProtocolFactory,
+ 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
+ 'compact': TCompactProtocol.TCompactProtocolFactory,
+ 'json': TJSONProtocol.TJSONProtocolFactory,
+ }
+ pfactory_cls = prot_factories.get(options.proto, None)
+ if pfactory_cls is None:
+ raise AssertionError('Unknown --protocol option: %s' % options.proto)
+ pfactory = pfactory_cls()
+ try:
+ pfactory.string_length_limit = options.string_limit
+ pfactory.container_length_limit = options.container_limit
+ except:
+ # Ignore errors for those protocols that does not support length limit
+ pass
+
+ # get the server type (TSimpleServer, TNonblockingServer, etc...)
+ if len(args) > 1:
+ raise AssertionError('Only one server type may be specified, not multiple types.')
+ server_type = args[0]
+
+ # Set up the handler and processor objects
+ handler = TestHandler()
+ processor = ThriftTest.Processor(handler)
+
+ # Handle THttpServer as a special case
+ if server_type == 'THttpServer':
+ server = THttpServer.THttpServer(processor, ('', options.port), pfactory)
+ server.serve()
+ sys.exit(0)
+
+ # set up server transport and transport factory
+
+ abs_key_path = os.path.join(os.path.dirname(SCRIPT_DIR), 'keys', 'server.pem')
+
+ host = None
+ if options.ssl:
+ from thrift.transport import TSSLSocket
+ transport = TSSLSocket.TSSLServerSocket(host, options.port, certfile=abs_key_path)
+ else:
+ transport = TSocket.TServerSocket(host, options.port)
tfactory = TTransport.TBufferedTransportFactory()
- # if --zlib, then wrap server transport, and use a different transport factory
- if options.zlib:
- transport = TZlibTransport.TZlibTransport(transport) # wrap with zlib
- tfactory = TZlibTransport.TZlibTransportFactory()
-
- # do server-specific setup here:
- if server_type == "TNonblockingServer":
- server = TNonblockingServer.TNonblockingServer(processor, transport, inputProtocolFactory=pfactory)
- elif server_type == "TProcessPoolServer":
- import signal
- from thrift.server import TProcessPoolServer
- server = TProcessPoolServer.TProcessPoolServer(processor, transport, tfactory, pfactory)
- server.setNumWorkers(5)
-
- def set_alarm():
- def clean_shutdown(signum, frame):
- for worker in server.workers:
- if options.verbose > 0:
- logging.info('Terminating worker: %s' % worker)
- worker.terminate()
- if options.verbose > 0:
- logging.info('Requesting server to stop()')
- try:
- server.stop()
- except:
- pass
- signal.signal(signal.SIGALRM, clean_shutdown)
- signal.alarm(4)
- set_alarm()
- else:
- # look up server class dynamically to instantiate server
- ServerClass = getattr(TServer, server_type)
- server = ServerClass(processor, transport, tfactory, pfactory)
- # enter server main loop
- server.serve()
+ if options.trans == 'buffered':
+ tfactory = TTransport.TBufferedTransportFactory()
+ elif options.trans == 'framed':
+ tfactory = TTransport.TFramedTransportFactory()
+ elif options.trans == '':
+ raise AssertionError('Unknown --transport option: %s' % options.trans)
+ else:
+ tfactory = TTransport.TBufferedTransportFactory()
+ # if --zlib, then wrap server transport, and use a different transport factory
+ if options.zlib:
+ transport = TZlibTransport.TZlibTransport(transport) # wrap with zlib
+ tfactory = TZlibTransport.TZlibTransportFactory()
+
+ # do server-specific setup here:
+ if server_type == "TNonblockingServer":
+ server = TNonblockingServer.TNonblockingServer(processor, transport, inputProtocolFactory=pfactory)
+ elif server_type == "TProcessPoolServer":
+ import signal
+ from thrift.server import TProcessPoolServer
+ server = TProcessPoolServer.TProcessPoolServer(processor, transport, tfactory, pfactory)
+ server.setNumWorkers(5)
+
+ def set_alarm():
+ def clean_shutdown(signum, frame):
+ for worker in server.workers:
+ if options.verbose > 0:
+ logging.info('Terminating worker: %s' % worker)
+ worker.terminate()
+ if options.verbose > 0:
+ logging.info('Requesting server to stop()')
+ try:
+ server.stop()
+ except:
+ pass
+ signal.signal(signal.SIGALRM, clean_shutdown)
+ signal.alarm(4)
+ set_alarm()
+ else:
+ # look up server class dynamically to instantiate server
+ ServerClass = getattr(TServer, server_type)
+ server = ServerClass(processor, transport, tfactory, pfactory)
+ # enter server main loop
+ server.serve()
if __name__ == '__main__':
- parser = OptionParser()
- parser.add_option('--libpydir', type='string', dest='libpydir',
- help='include this directory to sys.path for locating library code')
- parser.add_option('--genpydir', type='string', dest='genpydir',
- default='gen-py',
- help='include this directory to sys.path for locating generated code')
- parser.add_option("--port", type="int", dest="port",
- help="port number for server to listen on")
- parser.add_option("--zlib", action="store_true", dest="zlib",
- help="use zlib wrapper for compressed transport")
- parser.add_option("--ssl", action="store_true", dest="ssl",
- help="use SSL for encrypted transport")
- parser.add_option('-v', '--verbose', action="store_const",
- dest="verbose", const=2,
- help="verbose output")
- parser.add_option('-q', '--quiet', action="store_const",
- dest="verbose", const=0,
- help="minimal output")
- parser.add_option('--protocol', dest="proto", type="string",
- help="protocol to use, one of: accel, binary, compact, json")
- parser.add_option('--transport', dest="trans", type="string",
- help="transport to use, one of: buffered, framed")
- parser.add_option('--container-limit', dest='container_limit', type='int', default=None)
- parser.add_option('--string-limit', dest='string_limit', type='int', default=None)
- parser.set_defaults(port=9090, verbose=1, proto='binary')
- options, args = parser.parse_args()
-
- # Print TServer log to stdout so that the test-runner can redirect it to log files
- logging.basicConfig(level=options.verbose)
-
- sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
- if options.libpydir:
- sys.path.insert(0, glob.glob(options.libpydir)[0])
- else:
- sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0])
-
- from ThriftTest import ThriftTest
- from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity
- from thrift.Thrift import TException
- from thrift.transport import TTransport
- from thrift.transport import TSocket
- from thrift.transport import TZlibTransport
- from thrift.protocol import TBinaryProtocol
- from thrift.protocol import TCompactProtocol
- from thrift.protocol import TJSONProtocol
- from thrift.server import TServer, TNonblockingServer, THttpServer
-
- sys.exit(main(options))
+ parser = OptionParser()
+ parser.add_option('--libpydir', type='string', dest='libpydir',
+ help='include this directory to sys.path for locating library code')
+ parser.add_option('--genpydir', type='string', dest='genpydir',
+ default='gen-py',
+ help='include this directory to sys.path for locating generated code')
+ parser.add_option("--port", type="int", dest="port",
+ help="port number for server to listen on")
+ parser.add_option("--zlib", action="store_true", dest="zlib",
+ help="use zlib wrapper for compressed transport")
+ parser.add_option("--ssl", action="store_true", dest="ssl",
+ help="use SSL for encrypted transport")
+ parser.add_option('-v', '--verbose', action="store_const",
+ dest="verbose", const=2,
+ help="verbose output")
+ parser.add_option('-q', '--quiet', action="store_const",
+ dest="verbose", const=0,
+ help="minimal output")
+ parser.add_option('--protocol', dest="proto", type="string",
+ help="protocol to use, one of: accel, binary, compact, json")
+ parser.add_option('--transport', dest="trans", type="string",
+ help="transport to use, one of: buffered, framed")
+ parser.add_option('--container-limit', dest='container_limit', type='int', default=None)
+ parser.add_option('--string-limit', dest='string_limit', type='int', default=None)
+ parser.set_defaults(port=9090, verbose=1, proto='binary')
+ options, args = parser.parse_args()
+
+ # Print TServer log to stdout so that the test-runner can redirect it to log files
+ logging.basicConfig(level=options.verbose)
+
+ sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
+ if options.libpydir:
+ sys.path.insert(0, glob.glob(options.libpydir)[0])
+ else:
+ sys.path.insert(0, glob.glob(DEFAULT_LIBDIR_GLOB)[0])
+
+ from ThriftTest import ThriftTest
+ from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity
+ from thrift.Thrift import TException
+ from thrift.transport import TTransport
+ from thrift.transport import TSocket
+ from thrift.transport import TZlibTransport
+ from thrift.protocol import TBinaryProtocol
+ from thrift.protocol import TCompactProtocol
+ from thrift.protocol import TJSONProtocol
+ from thrift.server import TServer, TNonblockingServer, THttpServer
+
+ sys.exit(main(options))
diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py
index a01be85ac..9b578cca4 100755
--- a/test/py/TestSocket.py
+++ b/test/py/TestSocket.py
@@ -68,10 +68,10 @@ class TimeoutTest(unittest.TestCase):
self.assert_(time.time() - starttime < 5.0)
if __name__ == '__main__':
- suite = unittest.TestSuite()
- loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
- suite.addTest(loader.loadTestsFromTestCase(TimeoutTest))
+ suite.addTest(loader.loadTestsFromTestCase(TimeoutTest))
- testRunner = unittest.TextTestRunner(verbosity=2)
- testRunner.run(suite)
+ testRunner = unittest.TextTestRunner(verbosity=2)
+ testRunner.run(suite)
diff --git a/test/test.py b/test/test.py
index a5bcd9bb2..42babebb3 100755
--- a/test/test.py
+++ b/test/test.py
@@ -46,124 +46,124 @@ CONFIG_FILE = 'tests.json'
def run_cross_tests(server_match, client_match, jobs, skip_known_failures, retry_count):
- logger = multiprocessing.get_logger()
- logger.debug('Collecting tests')
- with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp:
- j = json.load(fp)
- tests = crossrunner.collect_cross_tests(j, server_match, client_match)
- if not tests:
- print('No test found that matches the criteria', file=sys.stderr)
- print(' servers: %s' % server_match, file=sys.stderr)
- print(' clients: %s' % client_match, file=sys.stderr)
- return False
- if skip_known_failures:
- logger.debug('Skipping known failures')
- known = crossrunner.load_known_failures(TEST_DIR)
- tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests))
-
- dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, TEST_DIR_RELATIVE, jobs)
- logger.debug('Executing %d tests' % len(tests))
- try:
- for r in [dispatcher.dispatch(test, retry_count) for test in tests]:
- r.wait()
- logger.debug('Waiting for completion')
- return dispatcher.wait()
- except (KeyboardInterrupt, SystemExit):
- logger.debug('Interrupted, shutting down')
- dispatcher.terminate()
- return False
+ logger = multiprocessing.get_logger()
+ logger.debug('Collecting tests')
+ with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp:
+ j = json.load(fp)
+ tests = crossrunner.collect_cross_tests(j, server_match, client_match)
+ if not tests:
+ print('No test found that matches the criteria', file=sys.stderr)
+ print(' servers: %s' % server_match, file=sys.stderr)
+ print(' clients: %s' % client_match, file=sys.stderr)
+ return False
+ if skip_known_failures:
+ logger.debug('Skipping known failures')
+ known = crossrunner.load_known_failures(TEST_DIR)
+ tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests))
+
+ dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, TEST_DIR_RELATIVE, jobs)
+ logger.debug('Executing %d tests' % len(tests))
+ try:
+ for r in [dispatcher.dispatch(test, retry_count) for test in tests]:
+ r.wait()
+ logger.debug('Waiting for completion')
+ return dispatcher.wait()
+ except (KeyboardInterrupt, SystemExit):
+ logger.debug('Interrupted, shutting down')
+ dispatcher.terminate()
+ return False
def run_feature_tests(server_match, feature_match, jobs, skip_known_failures, retry_count):
- basedir = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE)
- logger = multiprocessing.get_logger()
- logger.debug('Collecting tests')
- with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp:
- j = json.load(fp)
- with open(path_join(basedir, CONFIG_FILE), 'r') as fp:
- j2 = json.load(fp)
- tests = crossrunner.collect_feature_tests(j, j2, server_match, feature_match)
- if not tests:
- print('No test found that matches the criteria', file=sys.stderr)
- print(' servers: %s' % server_match, file=sys.stderr)
- print(' features: %s' % feature_match, file=sys.stderr)
- return False
- if skip_known_failures:
- logger.debug('Skipping known failures')
- known = crossrunner.load_known_failures(basedir)
- tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests))
-
- dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, FEATURE_DIR_RELATIVE, jobs)
- logger.debug('Executing %d tests' % len(tests))
- try:
- for r in [dispatcher.dispatch(test, retry_count) for test in tests]:
- r.wait()
- logger.debug('Waiting for completion')
- return dispatcher.wait()
- except (KeyboardInterrupt, SystemExit):
- logger.debug('Interrupted, shutting down')
- dispatcher.terminate()
- return False
+ basedir = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE)
+ logger = multiprocessing.get_logger()
+ logger.debug('Collecting tests')
+ with open(path_join(TEST_DIR, CONFIG_FILE), 'r') as fp:
+ j = json.load(fp)
+ with open(path_join(basedir, CONFIG_FILE), 'r') as fp:
+ j2 = json.load(fp)
+ tests = crossrunner.collect_feature_tests(j, j2, server_match, feature_match)
+ if not tests:
+ print('No test found that matches the criteria', file=sys.stderr)
+ print(' servers: %s' % server_match, file=sys.stderr)
+ print(' features: %s' % feature_match, file=sys.stderr)
+ return False
+ if skip_known_failures:
+ logger.debug('Skipping known failures')
+ known = crossrunner.load_known_failures(basedir)
+ tests = list(filter(lambda t: crossrunner.test_name(**t) not in known, tests))
+
+ dispatcher = crossrunner.TestDispatcher(TEST_DIR, ROOT_DIR, FEATURE_DIR_RELATIVE, jobs)
+ logger.debug('Executing %d tests' % len(tests))
+ try:
+ for r in [dispatcher.dispatch(test, retry_count) for test in tests]:
+ r.wait()
+ logger.debug('Waiting for completion')
+ return dispatcher.wait()
+ except (KeyboardInterrupt, SystemExit):
+ logger.debug('Interrupted, shutting down')
+ dispatcher.terminate()
+ return False
def default_concurrenty():
- try:
- return int(os.environ.get('THRIFT_CROSSTEST_CONCURRENCY'))
- except (TypeError, ValueError):
- # Since much time is spent sleeping, use many threads
- return int(multiprocessing.cpu_count() * 1.25) + 1
+ try:
+ return int(os.environ.get('THRIFT_CROSSTEST_CONCURRENCY'))
+ except (TypeError, ValueError):
+ # Since much time is spent sleeping, use many threads
+ return int(multiprocessing.cpu_count() * 1.25) + 1
def main(argv):
- parser = argparse.ArgumentParser()
- parser.add_argument('--server', default='', nargs='*',
- help='list of servers to test')
- parser.add_argument('--client', default='', nargs='*',
- help='list of clients to test')
- parser.add_argument('-F', '--features', nargs='*', default=None,
- help='run server feature tests instead of cross language tests')
- parser.add_argument('-s', '--skip-known-failures', action='store_true', dest='skip_known_failures',
- help='do not execute tests that are known to fail')
- parser.add_argument('-r', '--retry-count', type=int,
- default=0, help='maximum retry on failure')
- parser.add_argument('-j', '--jobs', type=int,
- default=default_concurrenty(),
- help='number of concurrent test executions')
-
- g = parser.add_argument_group(title='Advanced')
- g.add_argument('-v', '--verbose', action='store_const',
- dest='log_level', const=logging.DEBUG, default=logging.WARNING,
- help='show debug output for test runner')
- g.add_argument('-P', '--print-expected-failures', choices=['merge', 'overwrite'],
- dest='print_failures',
- help="generate expected failures based on last result and print to stdout")
- g.add_argument('-U', '--update-expected-failures', choices=['merge', 'overwrite'],
- dest='update_failures',
- help="generate expected failures based on last result and save to default file location")
- options = parser.parse_args(argv)
-
- logger = multiprocessing.log_to_stderr()
- logger.setLevel(options.log_level)
-
- if options.features is not None and options.client:
- print('Cannot specify both --features and --client ', file=sys.stderr)
- return 1
-
- # Allow multiple args separated with ',' for backward compatibility
- server_match = list(chain(*[x.split(',') for x in options.server]))
- client_match = list(chain(*[x.split(',') for x in options.client]))
-
- if options.update_failures or options.print_failures:
- dire = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE) if options.features is not None else TEST_DIR
- res = crossrunner.generate_known_failures(
- dire, options.update_failures == 'overwrite',
- options.update_failures, options.print_failures)
- elif options.features is not None:
- features = options.features or ['.*']
- res = run_feature_tests(server_match, features, options.jobs, options.skip_known_failures, options.retry_count)
- else:
- res = run_cross_tests(server_match, client_match, options.jobs, options.skip_known_failures, options.retry_count)
- return 0 if res else 1
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--server', default='', nargs='*',
+ help='list of servers to test')
+ parser.add_argument('--client', default='', nargs='*',
+ help='list of clients to test')
+ parser.add_argument('-F', '--features', nargs='*', default=None,
+ help='run server feature tests instead of cross language tests')
+ parser.add_argument('-s', '--skip-known-failures', action='store_true', dest='skip_known_failures',
+ help='do not execute tests that are known to fail')
+ parser.add_argument('-r', '--retry-count', type=int,
+ default=0, help='maximum retry on failure')
+ parser.add_argument('-j', '--jobs', type=int,
+ default=default_concurrenty(),
+ help='number of concurrent test executions')
+
+ g = parser.add_argument_group(title='Advanced')
+ g.add_argument('-v', '--verbose', action='store_const',
+ dest='log_level', const=logging.DEBUG, default=logging.WARNING,
+ help='show debug output for test runner')
+ g.add_argument('-P', '--print-expected-failures', choices=['merge', 'overwrite'],
+ dest='print_failures',
+ help="generate expected failures based on last result and print to stdout")
+ g.add_argument('-U', '--update-expected-failures', choices=['merge', 'overwrite'],
+ dest='update_failures',
+ help="generate expected failures based on last result and save to default file location")
+ options = parser.parse_args(argv)
+
+ logger = multiprocessing.log_to_stderr()
+ logger.setLevel(options.log_level)
+
+ if options.features is not None and options.client:
+ print('Cannot specify both --features and --client ', file=sys.stderr)
+ return 1
+
+ # Allow multiple args separated with ',' for backward compatibility
+ server_match = list(chain(*[x.split(',') for x in options.server]))
+ client_match = list(chain(*[x.split(',') for x in options.client]))
+
+ if options.update_failures or options.print_failures:
+ dire = path_join(ROOT_DIR, FEATURE_DIR_RELATIVE) if options.features is not None else TEST_DIR
+ res = crossrunner.generate_known_failures(
+ dire, options.update_failures == 'overwrite',
+ options.update_failures, options.print_failures)
+ elif options.features is not None:
+ features = options.features or ['.*']
+ res = run_feature_tests(server_match, features, options.jobs, options.skip_known_failures, options.retry_count)
+ else:
+ res = run_cross_tests(server_match, client_match, options.jobs, options.skip_known_failures, options.retry_count)
+ return 0 if res else 1
if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
+ sys.exit(main(sys.argv[1:]))
diff --git a/tutorial/php/runserver.py b/tutorial/php/runserver.py
index ae29fed9c..077daa102 100755
--- a/tutorial/php/runserver.py
+++ b/tutorial/php/runserver.py
@@ -26,7 +26,8 @@ import CGIHTTPServer
# chdir(2) into the tutorial directory.
os.chdir(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
+
class Handler(CGIHTTPServer.CGIHTTPRequestHandler):
- cgi_directories = ['/php']
+ cgi_directories = ['/php']
BaseHTTPServer.HTTPServer(('', 8080), Handler).serve_forever()