summaryrefslogtreecommitdiff
path: root/lib/py/src
diff options
context:
space:
mode:
Diffstat (limited to 'lib/py/src')
-rw-r--r--lib/py/src/server/THttpServer.py30
1 files changed, 23 insertions, 7 deletions
diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py
index 9bf80cbb8..3047d9c00 100644
--- a/lib/py/src/server/THttpServer.py
+++ b/lib/py/src/server/THttpServer.py
@@ -22,6 +22,19 @@ import BaseHTTPServer
from thrift.server import TServer
from thrift.transport import TTransport
+class ResponseException(Exception):
+ """Allows handlers to override the HTTP response
+
+ Normally, THttpServer always sends a 200 response. If a handler wants
+ to override this behavior (e.g., to simulate a misconfigured or
+ overloaded web server during testing), it can raise a ResponseException.
+ The function passed to the constructor will be called with the
+ RequestHandler as its only argument.
+ """
+ def __init__(self, handler):
+ self.handler = handler
+
+
class THttpServer(TServer.TServer):
"""A simple HTTP-based Thrift server
@@ -47,18 +60,21 @@ class THttpServer(TServer.TServer):
class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
def do_POST(self):
# Don't care about the request path.
- self.send_response(200)
- self.send_header("content-type", "application/x-thrift")
- self.end_headers()
-
itrans = TTransport.TFileObjectTransport(self.rfile)
otrans = TTransport.TFileObjectTransport(self.wfile)
itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length']))
- otrans = TTransport.TBufferedTransport(otrans)
+ otrans = TTransport.TMemoryBuffer()
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
- thttpserver.processor.process(iprot, oprot)
- otrans.flush()
+ try:
+ thttpserver.processor.process(iprot, oprot)
+ except ResponseException, exn:
+ exn.handler(self)
+ else:
+ self.send_response(200)
+ self.send_header("content-type", "application/x-thrift")
+ self.end_headers()
+ self.wfile.write(otrans.getvalue())
self.httpd = server_class(server_address, RequestHander)