diff options
author | Allan Saddi <allan@saddi.com> | 2005-04-15 01:33:01 +0000 |
---|---|---|
committer | Allan Saddi <allan@saddi.com> | 2005-04-15 01:33:01 +0000 |
commit | e8f091226f39a888c019b1637fb8a47927b8a4ab (patch) | |
tree | 8ee47923b0f8ba625f705fecc29bed128d0e2219 | |
parent | edfd863888c497d6da96a0f8b462a0ebb35dfe34 (diff) | |
download | flup-e8f091226f39a888c019b1637fb8a47927b8a4ab.tar.gz |
flup package, first cut.
-rw-r--r-- | flup/__init__.py | 1 | ||||
-rw-r--r-- | flup/middleware/error.py | 352 | ||||
-rw-r--r-- | flup/middleware/gzip.py | 247 | ||||
-rw-r--r-- | flup/middleware/session.py | 742 | ||||
-rw-r--r-- | flup/publisher/__init__.py | 1 | ||||
-rw-r--r-- | flup/resolver/__init__.py | 1 | ||||
-rw-r--r-- | flup/server/__init__.py | 1 | ||||
-rw-r--r-- | flup/server/ajp.py | 1195 | ||||
-rw-r--r-- | flup/server/ajp_fork.py | 1025 | ||||
-rw-r--r-- | flup/server/fcgi.py | 1306 | ||||
-rw-r--r-- | flup/server/fcgi_fork.py | 1169 | ||||
-rw-r--r-- | flup/server/prefork.py | 364 | ||||
-rw-r--r-- | flup/server/scgi.py | 699 | ||||
-rw-r--r-- | flup/server/scgi_fork.py | 528 | ||||
-rw-r--r-- | flup/server/threadpool.py | 113 |
15 files changed, 7744 insertions, 0 deletions
diff --git a/flup/__init__.py b/flup/__init__.py new file mode 100644 index 0000000..792d600 --- /dev/null +++ b/flup/__init__.py @@ -0,0 +1 @@ +# diff --git a/flup/middleware/error.py b/flup/middleware/error.py new file mode 100644 index 0000000..637e842 --- /dev/null +++ b/flup/middleware/error.py @@ -0,0 +1,352 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import os +import traceback +import time +from email.Message import Message +from email.MIMEMultipart import MIMEMultipart +from email.MIMEText import MIMEText +import smtplib + +try: + import thread +except ImportError: + import dummy_thread as thread + +__all__ = ['ErrorMiddleware'] + +def _wrapIterator(appIter, errorMiddleware, environ, start_response): + """ + Wrapper around the application's iterator which catches any unhandled + exceptions. Forwards close() and __len__ to the application iterator, + if necessary. + """ + class metaIterWrapper(type): + def __init__(cls, name, bases, clsdict): + super(metaIterWrapper, cls).__init__(name, bases, clsdict) + if hasattr(appIter, '__len__'): + cls.__len__ = appIter.__len__ + + class iterWrapper(object): + __metaclass__ = metaIterWrapper + def __init__(self): + self._next = iter(appIter).next + if hasattr(appIter, 'close'): + self.close = appIter.close + + def __iter__(self): + return self + + def next(self): + try: + return self._next() + except StopIteration: + raise + except: + errorMiddleware.exceptionHandler(environ) + + # I'm not sure I like this next part. + try: + errorIter = errorMiddleware.displayErrorPage(environ, + start_response) + except: + # Headers already sent, what can be done? + raise + else: + # The exception occurred early enough for start_response() + # to succeed. Swap iterators! + self._next = iter(errorIter).next + return self._next() + + return iterWrapper() + +class ErrorMiddleware(object): + """ + Middleware that catches any unhandled exceptions from the application. + Displays a (static) error page to the user while emailing details about + the exception to an administrator. + """ + def __init__(self, application, adminAddress, + fromAddress='wsgiapp', + smtpHost='localhost', + + applicationName=None, + + errorPageMimeType='text/html', + errorPage=None, + errorPageFile='error.html', + + emailInterval=15, + intervalCheckFile='errorEmailCheck', + + debug=False): + """ + Explanation of parameters: + + application - WSGI application. + + adminAddress - Email address of administrator. + fromAddress - Email address that the error email should appear to + originate from. By default 'wsgiapp@hostname.of.server'. + smtpHost - SMTP email server, through which to send the email. + + applicationName - Name of your WSGI application, to help differentiate + it from other applications in email. By default, this is the Python + name of the application object. (You should explicitly set this + if you use other middleware components, otherwise the name + deduced will probably be that of a middleware component.) + + errorPageMimeType - MIME type of the static error page. 'text/html' + by default. + errorPage - String representing the body of the static error page. + If None (the default), errorPageFile must point to an existing file. + errorPageFile - File from which to take the static error page (may + be relative to current directory or an absolute filename). + + emailInterval - Minimum number of minutes between error mailings, + to prevent the administrator's mailbox from filling up. + intervalCheckFile - When running in one-shot mode (as determined by + the 'wsgi.run_once' environment variable), this file is used to + keep track of the last time an email was sent. May be relative + (to the current directory) or an absolute filename. + + debug - If True, will attempt to display the traceback as a webpage. + No email is sent. If False (the default), the static error page is + displayed and the error email is sent, if necessary. + """ + self._application = application + + self._adminAddress = adminAddress + self._fromAddress = fromAddress + self._smtpHost = smtpHost + + # Set up a generic application name if not specified. + if applicationName is None: + applicationName = [] + if application.__module__ != '__main__': + applicationName.append('%s.' % application.__module__) + applicationName.append(application.__name__) + applicationName = ''.join(applicationName) + self._applicationName = applicationName + + self._errorPageMimeType = errorPageMimeType + # If errorPage was unspecified, set it from the static file + # specified by errorPageFile. + if errorPage is None: + f = open(errorPageFile) + errorPage = f.read() + f.close + self._errorPage = errorPage + + self._emailInterval = emailInterval * 60 + self._lastEmailTime = 0 + self._intervalCheckFile = intervalCheckFile + + # Set up displayErrorPage appropriately. + self._debug = debug + if debug: + self.displayErrorPage = self._displayDebugPage + else: + self.displayErrorPage = self._displayErrorPage + + # Lock for _lastEmailTime + self._lock = thread.allocate_lock() + + def _displayErrorPage(self, environ, start_response): + """ + Displays the static error page. May be overridden. (Maybe you'd + rather redirect or something?) This is basically a mini-WSGI + application, except that start_response() is called with the third + argument. + + Really, there's nothing keeping you from overriding this method + and displaying a dynamic error page. But I thought it might be safer + to display a static page. :) + """ + start_response('200 OK', [('Content-Type', self._errorPageMimeType), + ('Content-Length', + str(len(self._errorPage)))], + sys.exc_info()) + return [self._errorPage] + + def _displayDebugPage(self, environ, start_response): + """ + When debugging, display an informative traceback of the exception. + """ + import cgitb + result = [cgitb.html(sys.exc_info())] + start_response('200 OK', [('Content-Type', 'text/html'), + ('Content-Length', str(len(result[0])))], + sys.exc_info()) + return result + + def _generateHTMLErrorEmail(self): + """ + Generates the HTML version of the error email. Must return a string. + """ + import cgitb + return cgitb.html(sys.exc_info()) + + def _generatePlainErrorEmail(self): + """ + Generates the plain-text version of the error email. Must return a + string. + """ + import cgitb + return cgitb.text(sys.exc_info()) + + def _generateErrorEmail(self): + """ + Generates the error email. Must return an instance of email.Message + or subclass. + + This implementation generates a MIME multipart/alternative email with + an HTML description of the error and a simpler plain-text alternative + of the traceback. + """ + msg = MIMEMultipart('alternative') + msg.attach(MIMEText(self._generatePlainErrorEmail())) + msg.attach(MIMEText(self._generateHTMLErrorEmail(), 'html')) + return msg + + def _sendErrorEmail(self, environ): + """ + Sends the error email as generated by _generateErrorEmail(). If + anything goes wrong sending the email, the exception is caught + and reported to wsgi.errors. I don't think there's really much else + that can be done in that case. + """ + msg = self._generateErrorEmail() + + msg['From'] = self._fromAddress + msg['To'] = self._adminAddress + msg['Subject'] = '%s: unhandled exception' % self._applicationName + + try: + server = smtplib.SMTP(self._smtpHost) + server.sendmail(self._fromAddress, self._adminAddress, + msg.as_string()) + server.quit() + except Exception, e: + stderr = environ['wsgi.errors'] + stderr.write('%s: Failed to send error email: %r %s\n' % + (self.__class__.__name__, e, e)) + stderr.flush() + + def _shouldSendEmail(self, environ): + """ + Returns True if an email should be sent. The last time an email was + sent is tracked by either an instance variable (if oneShot is False), + or the mtime of a file on the filesystem (if oneShot is True). + """ + if self._debug or self._adminAddress is None: + # Never send email when debugging or when there's no admin + # address. + return False + + now = time.time() + if not environ['wsgi.run_once']: + self._lock.acquire() + ret = (self._lastEmailTime + self._emailInterval) < now + if ret: + self._lastEmailTime = now + self._lock.release() + else: + # The following should be protected, but do I *really* want + # to get into the mess of using filesystem and file-based locks? + # At worse, multiple emails get sent. + ret = True + + try: + mtime = os.path.getmtime(self._intervalCheckFile) + except: + # Assume file doesn't exist, which is OK. Send email + # unconditionally. + pass + else: + if (mtime + self._emailInterval) >= now: + ret = False + + if ret: + # NB: If _intervalCheckFile cannot be created or written to + # for whatever reason, you will *always* get an error email. + try: + open(self._intervalCheckFile, 'w').close() + except: + # Probably a good idea to report failure. + stderr = environ['wsgi.errors'] + stderr.write('%s: Error writing intervalCheckFile %r\n' + % (self.__class__.__name__, + self._intervalCheckFile)) + stderr.flush() + return ret + + def exceptionHandler(self, environ): + """ + Common handling of exceptions. + """ + # Unconditionally report to wsgi.errors. + stderr = environ['wsgi.errors'] + traceback.print_exc(file=stderr) + stderr.flush() + + # Send error email, if needed. + if self._shouldSendEmail(environ): + self._sendErrorEmail(environ) + + def __call__(self, environ, start_response): + """ + WSGI application interface. Simply wraps the call to the application + with a try ... except. All the fancy stuff happens in the except + clause. + """ + try: + return _wrapIterator(self._application(environ, start_response), + self, environ, start_response) + except: + # Report the exception. + self.exceptionHandler(environ) + + # Display static error page. + return self.displayErrorPage(environ, start_response) + +if __name__ == '__main__': + def myapp(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/plain')]) + raise RuntimeError, "I'm broken!" + return ['Hello World!\n'] + + # Note - email address is taken from sys.argv[1]. I'm not leaving + # my email address here. ;) + app = ErrorMiddleware(myapp, sys.argv[1]) + + from ajp import WSGIServer + WSGIServer(app).run() diff --git a/flup/middleware/gzip.py b/flup/middleware/gzip.py new file mode 100644 index 0000000..4f4b7b1 --- /dev/null +++ b/flup/middleware/gzip.py @@ -0,0 +1,247 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import struct +import time +import zlib + +__all__ = ['GzipMiddleware'] + +# This gzip middleware component differentiates itself from others in that +# it (hopefully) follows the spec more closely. Namely with regard to the +# application iterator and buffering. (It doesn't buffer.) +# See <http://www.python.org/peps/pep-0333.html#middleware-handling-of-block-boundaries> +# +# Of course this all comes with a price... just LOOK at this mess! :) +# +# The inner workings of gzip and the gzip file format were gleaned from gzip.py + +def _gzipHeader(): + """Returns a gzip header (with no filename).""" + # See GzipFile._write_gzip_header in gzip.py + return '\037\213' \ + '\010' \ + '\0' + \ + struct.pack('<L', long(time.time())) + \ + '\002' \ + '\377' + +class _iterWrapper(object): + """ + gzip iterator wrapper. It ensures that: the application iterator's close() + method (if any) is called by the parent server; and at least one value + is yielded each time the application's iterator yields a value. + + If the application's iterator yields N values, this iterator will yield + N+1 values. This is to account for the gzip trailer. + """ + def __init__(self, appIter, gzipMiddleware): + self._g = gzipMiddleware + self._next = iter(appIter).next + + self._last = False # True if appIter has yielded last value. + self._trailerSent = False + + if hasattr(appIter, 'close'): + self.close = appIter.close + + def __iter__(self): + return self + + # This would've been a lot easier had I used a generator. But then I'd have + # to wrap the generator anyway to ensure that any existing close() method + # was called. (Calling it within the generator is not the same thing, + # namely it does not ensure that it will be called no matter what!) + def next(self): + if not self._last: + # Need to catch StopIteration here so we can append trailer. + try: + data = self._next() + except StopIteration: + self._last = True + + if not self._last: + if self._g.gzipOk: + return self._g.gzipData(data) + else: + return data + else: + # See if trailer needs to be sent. + if self._g.headerSent and not self._trailerSent: + self._trailerSent = True + return self._g.gzipTrailer() + # Otherwise, that's the end of this iterator. + raise StopIteration + +class _gzipMiddleware(object): + """ + The actual gzip middleware component. Holds compression state as well + implementations of start_response and write. Instantiated before each + call to the underlying application. + + This class is private. See GzipMiddleware for the public interface. + """ + def __init__(self, start_response, mimeTypes, compresslevel): + self._start_response = start_response + self._mimeTypes = mimeTypes + + self.gzipOk = False + self.headerSent = False + + # See GzipFile.__init__ and GzipFile._init_write in gzip.py + self._crc = zlib.crc32('') + self._size = 0 + self._compress = zlib.compressobj(compresslevel, + zlib.DEFLATED, + -zlib.MAX_WBITS, + zlib.DEF_MEM_LEVEL, + 0) + + def gzipData(self, data): + """ + Compresses the given data, prepending the gzip header if necessary. + Returns the result as a string. + """ + if not self.headerSent: + self.headerSent = True + out = _gzipHeader() + else: + out = '' + + # See GzipFile.write in gzip.py + length = len(data) + if length > 0: + self._size += length + self._crc = zlib.crc32(data, self._crc) + out += self._compress.compress(data) + return out + + def gzipTrailer(self): + # See GzipFile.close in gzip.py + return self._compress.flush() + \ + struct.pack('<l', self._crc) + \ + struct.pack('<L', self._size & 0xffffffffL) + + def start_response(self, status, headers, exc_info=None): + self.gzipOk = False + + # Scan the headers. Only allow gzip compression if the Content-Type + # is one that we're flagged to compress AND the headers do not + # already contain Content-Encoding. + for name,value in headers: + name = name.lower() + if name == 'content-type' and value in self._mimeTypes: + self.gzipOk = True + elif name == 'content-encoding': + self.gzipOk = False + break + + if self.gzipOk: + # Remove Content-Length, if present, because compression will + # most surely change it. (And unfortunately, we can't predict + # the final size...) + headers = [(name,value) for name,value in headers + if name.lower() != 'content-length'] + headers.append(('Content-Encoding', 'gzip')) + + _write = self._start_response(status, headers, exc_info) + + if self.gzipOk: + def write_gzip(data): + _write(self.gzipData(data)) + return write_gzip + else: + return _write + +class GzipMiddleware(object): + """ + WSGI middleware component that gzip compresses the application's output + (if the client supports gzip compression - gleaned from the + Accept-Encoding request header). + + mimeTypes should be a list of Content-Types that are OK to compress. + + compresslevel is the gzip compression level, an integer from 1 to 9; 1 + is the fastest and produces the least compression, and 9 is the slowest, + producing the most compression. + """ + def __init__(self, application, mimeTypes=None, compresslevel=9): + if mimeTypes is None: + mimeTypes = ['text/html'] + + self._application = application + self._mimeTypes = mimeTypes + self._compresslevel = compresslevel + + def __call__(self, environ, start_response): + """WSGI application interface.""" + # If the client doesn't support gzip encoding, just pass through + # directly to the application. + if 'gzip' not in environ.get('HTTP_ACCEPT_ENCODING', ''): + return self._application(environ, start_response) + + # All of the work is done in _gzipMiddleware and _iterWrapper. + g = _gzipMiddleware(start_response, self._mimeTypes, + self._compresslevel) + + result = self._application(environ, g.start_response) + + # See if it's a length 1 iterable... + try: + shortcut = len(result) == 1 + except: + shortcut = False + + if shortcut: + # Special handling if application returns a length 1 iterable: + # also return a length 1 iterable! + try: + i = iter(result) + # Hmmm, if we get a StopIteration here, the application's + # broken (__len__ lied!) + data = i.next() + if g.gzipOk: + return [g.gzipData(data) + g.gzipTrailer()] + else: + return [data] + finally: + if hasattr(result, 'close'): + result.close() + + return _iterWrapper(result, g) + +if __name__ == '__main__': + def myapp(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return ['Hello World!\n'] + app = GzipMiddleware(myapp) + + from ajp import WSGIServer + import logging + WSGIServer(app, loggingLevel=logging.DEBUG).run() diff --git a/flup/middleware/session.py b/flup/middleware/session.py new file mode 100644 index 0000000..81818d6 --- /dev/null +++ b/flup/middleware/session.py @@ -0,0 +1,742 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import os +import errno +import string +import time +import weakref +import atexit +import shelve +import cPickle as pickle + +try: + import thread +except ImportError: + import dummy_thread as thread + +__all__ = ['Session', + 'SessionStore', + 'MemorySessionStore', + 'ShelveSessionStore', + 'DiskSessionStore', + 'SessionMiddleware'] + +class Session(dict): + """ + Session objects, basically dictionaries. + """ + identifierLength = 32 + # Would be nice if len(identifierChars) were some power of 2. + identifierChars = string.digits + string.letters + '-_' + + def __init__(self, identifier): + super(Session, self).__init__() + + assert self.isIdentifierValid(identifier) + self._identifier = identifier + + self._creationTime = self._lastAccessTime = time.time() + self._isValid = True + + def _get_identifier(self): + return self._identifier + identifier = property(_get_identifier, None, None, + 'Unique identifier for Session within its Store') + + def _get_creationTime(self): + return self._creationTime + creationTime = property(_get_creationTime, None, None, + 'Time when Session was created') + + def _get_lastAccessTime(self): + return self._lastAccessTime + lastAccessTime = property(_get_lastAccessTime, None, None, + 'Time when Session was last accessed') + + def _get_isValid(self): + return self._isValid + isValid = property(_get_isValid, None, None, + 'Whether or not this Session is valid') + + def touch(self): + """Update Session's access time.""" + self._lastAccessTime = time.time() + + def invalidate(self): + """Invalidate this Session.""" + self.clear() + self._creationTime = self._lastAccessTime = 0 + self._isValid = False + + @classmethod + def isIdentifierValid(cls, ident): + """ + Returns whether or not the given string *could be* a valid session + identifier. + """ + if type(ident) is str and len(ident) == cls.identifierLength: + for c in ident: + if c not in cls.identifierChars: + return False + return True + return False + + @classmethod + def generateIdentifier(cls): + """ + Generate a random session identifier. + """ + raw = os.urandom(cls.identifierLength) + + sessId = '' + for c in raw: + # So we lose 2 bits per random byte... + sessId += cls.identifierChars[ord(c) % len(cls.identifierChars)] + return sessId + +def _shutdown(ref): + store = ref() + if store is not None: + store.shutdown() + +class SessionStore(object): + """ + Abstract base class for session stores. You first acquire a session by + calling createSession() or checkOutSession(). After using the session, + you must call checkInSession(). You must not keep references to sessions + outside of a check in/check out block. Always obtain a fresh reference. + + Some external mechanism must be set up to call periodic() periodically + (perhaps every 5 minutes). + + After timeout minutes of inactivity, sessions are deleted. + """ + _sessionClass = Session + + def __init__(self, timeout=60, sessionClass=None): + self._lock = thread.allocate_lock() + + # Timeout in minutes + self._sessionTimeout = timeout + + if sessionClass is not None: + self._sessionClass = sessionClass + + self._checkOutList = {} + self._shutdownRan = False + + # Ensure shutdown is called. + atexit.register(_shutdown, weakref.ref(self)) + + # Public interface. + + def createSession(self): + """ + Create a new session with a unique identifier. Should never fail. + (Will raise a RuntimeError in the rare event that it does.) + + The newly-created session should eventually be released by + a call to checkInSession(). + """ + assert not self._shutdownRan + self._lock.acquire() + try: + attempts = 0 + while attempts < 10000: + sessId = self._sessionClass.generateIdentifier() + sess = self._createSession(sessId) + if sess is not None: break + attempts += 1 + + if attempts >= 10000: + raise RuntimeError, self.__class__.__name__ + \ + '.createSession() failed' + + assert sess.identifier not in self._checkOutList + self._checkOutList[sess.identifier] = sess + return sess + finally: + self._lock.release() + + def checkOutSession(self, identifier): + """ + Checks out a session for use. Returns the session if it exists, + otherwise returns None. If this call succeeds, the session + will be touch()'ed and locked from use by other processes. + Therefore, it should eventually be released by a call to + checkInSession(). + """ + assert not self._shutdownRan + + if not self._sessionClass.isIdentifierValid(identifier): + return None + + self._lock.acquire() + try: + sess = self._loadSession(identifier) + if sess is not None: + assert sess.identifier not in self._checkOutList + self._checkOutList[sess.identifier] = sess + sess.touch() + return sess + finally: + self._lock.release() + + def checkInSession(self, session): + """ + Returns the session for use by other threads/processes. Safe to + pass None. + """ + assert not self._shutdownRan + + if session is None: + return + + self._lock.acquire() + try: + assert session.identifier in self._checkOutList + if session.isValid: + self._saveSession(session) + else: + self._deleteSession(session.identifier) + del self._checkOutList[session.identifier] + finally: + self._lock.release() + + def shutdown(self): + """Clean up outstanding sessions.""" + self._lock.acquire() + try: + if not self._shutdownRan: + # Save or delete any sessions that are still out there. + for key,sess in self._checkOutList.items(): + if sess.isValid: + self._saveSession(sess) + else: + self._deleteSession(sess.identifier) + self._checkOutList.clear() + self._shutdown() + self._shutdownRan = True + finally: + self._lock.release() + + def __del__(self): + self.shutdown() + + def periodic(self): + """Timeout old sessions. Should be called periodically.""" + self._lock.acquire() + try: + if not self._shutdownRan: + self._periodic() + finally: + self._lock.release() + + # To be implemented by subclasses. self._lock will be held whenever + # these are called and for methods that take an identifier, + # the identifier will be guaranteed to be valid (but it will not + # necessarily exist). + + def _createSession(self, identifier): + """ + Attempt to create the session with the given identifier. If + successful, return the newly-created session, which must + also be implicitly locked from use by other processes. (The + session returned should be an instance of self._sessionClass.) + If unsuccessful, return None. + """ + raise NotImplementedError, self.__class__.__name__ + '._createSession' + + def _loadSession(self, identifier): + """ + Load the session with the identifier from secondary storage returning + None if it does not exist. If the load is successful, the session + must be locked from use by other processes. + """ + raise NotImplementedError, self.__class__.__name__ + '._loadSession' + + def _saveSession(self, session): + """ + Store the session into secondary storage. Also implicitly releases + the session for use by other processes. + """ + raise NotImplementedError, self.__class__.__name__ + '._saveSession' + + def _deleteSession(self, identifier): + """ + Deletes the session from secondary storage. Must be OK to pass + in an invalid (non-existant) identifier. If the session did exist, + it must be released for use by other processes. + """ + raise NotImplementedError, self.__class__.__name__ + '._deleteSession' + + def _periodic(self): + """Remove timedout sessions from secondary storage.""" + raise NotImplementedError, self.__class__.__name__ + '._periodic' + + def _shutdown(self): + """Performs necessary shutdown actions for secondary store.""" + raise NotImplementedError, self.__class__.__name__ + '._shutdown' + + # Utilities + + def _isSessionTimedout(self, session, now=time.time()): + return (session.lastAccessTime + self._sessionTimeout * 60) < now + +class MemorySessionStore(SessionStore): + """ + Memory-based session store. Great for persistent applications, terrible + for one-shot ones. :) + """ + def __init__(self, *a, **kw): + super(MemorySessionStore, self).__init__(*a, **kw) + + # Our "secondary store". + self._secondaryStore = {} + + def _createSession(self, identifier): + if self._secondaryStore.has_key(identifier): + return None + sess = self._sessionClass(identifier) + self._secondaryStore[sess.identifier] = sess + return sess + + def _loadSession(self, identifier): + return self._secondaryStore.get(identifier, None) + + def _saveSession(self, session): + self._secondaryStore[session.identifier] = session + + def _deleteSession(self, identifier): + if self._secondaryStore.has_key(identifier): + del self._secondaryStore[identifier] + + def _periodic(self): + now = time.time() + for key,sess in self._secondaryStore.items(): + if self._isSessionTimedout(sess, now): + del self._secondaryStore[key] + + def _shutdown(self): + pass + +class ShelveSessionStore(SessionStore): + """ + Session store based on Python "shelves." Only use if you can guarantee + that storeFile will NOT be accessed concurrently by other instances. + (In other processes, threads, anywhere!) + """ + def __init__(self, storeFile='sessions', *a, **kw): + super(ShelveSessionStore, self).__init__(*a, **kw) + + self._secondaryStore = shelve.open(storeFile, + protocol=pickle.HIGHEST_PROTOCOL) + + def _createSession(self, identifier): + if self._secondaryStore.has_key(identifier): + return None + sess = self._sessionClass(identifier) + self._secondaryStore[sess.identifier] = sess + return sess + + def _loadSession(self, identifier): + return self._secondaryStore.get(identifier, None) + + def _saveSession(self, session): + self._secondaryStore[session.identifier] = session + + def _deleteSession(self, identifier): + if self._secondaryStore.has_key(identifier): + del self._secondaryStore[identifier] + + def _periodic(self): + now = time.time() + for key,sess in self._secondaryStore.items(): + if self._isSessionTimedout(sess, now): + del self._secondaryStore[key] + + def _shutdown(self): + self._secondaryStore.close() + +class DiskSessionStore(SessionStore): + """ + Disk-based session store that stores each session as its own file + within a specified directory. Should be safe for concurrent use. + (As long as the underlying OS/filesystem respects create()'s O_EXCL.) + """ + def __init__(self, storeDir='sessions', *a, **kw): + super(DiskSessionStore, self).__init__(*a, **kw) + + self._sessionDir = storeDir + if not os.access(self._sessionDir, os.F_OK): + # Doesn't exist, try to create it. + os.mkdir(self._sessionDir) + + def _filenameForSession(self, identifier): + return os.path.join(self._sessionDir, identifier + '.sess') + + def _lockSession(self, identifier, block=True): + fn = self._filenameForSession(identifier) + '.lock' + while True: + try: + fd = os.open(fn, os.O_WRONLY|os.O_CREAT|os.O_EXCL) + except OSError, e: + if e.errno != errno.EEXIST: + raise + else: + os.close(fd) + break + + if not block: + return False + + # See if the lock is stale. If so, remove it. + try: + now = time.time() + mtime = os.path.getmtime(fn) + if (mtime + 60) < now: + os.unlink(fn) + except OSError, e: + if e.errno != errno.ENOENT: + raise + + time.sleep(0.1) + + return True + + def _unlockSession(self, identifier): + fn = self._filenameForSession(identifier) + '.lock' + os.unlink(fn) # Need to catch errors? + + def _createSession(self, identifier): + fn = self._filenameForSession(identifier) + lfn = fn + '.lock' + # Attempt to create the file's *lock* first. + lfd = fd = -1 + try: + lfd = os.open(lfn, os.O_WRONLY|os.O_CREAT|os.O_EXCL) + fd = os.open(fn, os.O_WRONLY|os.O_CREAT|os.O_EXCL) + except OSError, e: + if e.errno == errno.EEXIST: + if lfd >= 0: + # Remove lockfile. + os.close(lfd) + os.unlink(lfn) + return None + raise + else: + # Success. + os.close(fd) + os.close(lfd) + return self._sessionClass(identifier) + + def _loadSession(self, identifier, block=True): + if not self._lockSession(identifier, block): + return None + try: + return pickle.load(open(self._filenameForSession(identifier))) + except: + self._unlockSession(identifier) + return None + + def _saveSession(self, session): + f = open(self._filenameForSession(session.identifier), 'w+') + pickle.dump(session, f, protocol=pickle.HIGHEST_PROTOCOL) + f.close() + self._unlockSession(session.identifier) + + def _deleteSession(self, identifier): + try: + os.unlink(self._filenameForSession(identifier)) + except: + pass + self._unlockSession(identifier) + + def _periodic(self): + now = time.time() + sessions = os.listdir(self._sessionDir) + for name in sessions: + if not name.endswith('.sess'): + continue + identifier = name[:-5] + if not self._sessionClass.isIdentifierValid(identifier): + continue + # Not very efficient. + sess = self._loadSession(identifier, block=False) + if sess is None: + continue + if self._isSessionTimedout(sess, now): + self._deleteSession(identifier) + else: + self._unlockSession(identifier) + + def _shutdown(self): + pass + +# SessionMiddleware stuff. + +from Cookie import SimpleCookie +import cgi +import urlparse + +class SessionService(object): + """ + WSGI extension API passed to applications as + environ['com.saddi.service.session']. + + Public API: (assume service = environ['com.saddi.service.session']) + service.session - Returns the Session associated with the client. + service.hasSession - True if the client is currently associated with + a Session. + service.isSessionNew - True if the Session was created in this + transaction. + service.hasSessionExpired - True if the client is associated with a + non-existent Session. + service.encodesSessionInURL - True if the Session ID should be encoded in + the URL. (read/write) + service.encodeURL(url) - Returns url encoded with Session ID (if + necessary). + """ + _expiredSessionIdentifier = 'expired session' + + def __init__(self, store, environ, + cookieName='_SID_', + fieldName='_SID_'): + self._store = store + self._cookieName = cookieName + self._fieldName = fieldName + + self._session = None + self._newSession = False + self._expired = False + self.encodesSessionInURL = False + + if __debug__: self._closed = False + + self._loadExistingSession(environ) + + def _loadSessionFromCookie(self, environ): + """ + Attempt to load the associated session using the identifier from + the cookie. + """ + C = SimpleCookie(environ.get('HTTP_COOKIE')) + morsel = C.get(self._cookieName, None) + if morsel is not None: + self._session = self._store.checkOutSession(morsel.value) + self._expired = self._session is None + + def _loadSessionFromQueryString(self, environ): + """ + Attempt to load the associated session using the identifier from + the query string. + """ + qs = cgi.parse_qsl(environ.get('QUERY_STRING', '')) + for name,value in qs: + if name == self._fieldName: + self._session = self._store.checkOutSession(value) + self._expired = self._session is None + self.encodesSessionInURL = True + break + + def _loadExistingSession(self, environ): + """Attempt to associate with an existing Session.""" + # Try cookie first. + self._loadSessionFromCookie(environ) + + # Next, try query string. + if self._session is None: + self._loadSessionFromQueryString(environ) + + def _sessionIdentifier(self): + """Returns the identifier of the current session.""" + assert self._session is not None + return self._session.identifier + + def _shouldAddCookie(self): + """ + Returns True if the session cookie should be added to the header + (if not encoding the session ID in the URL). The cookie is added if + one of these three conditions are true: a) the session was just + created, b) the session is no longer valid, or c) the client is + associated with a non-existent session. + """ + return self._newSession or \ + (self._session is not None and not self._session.isValid) or \ + (self._session is None and self._expired) + + def addCookie(self, headers): + """Adds Set-Cookie header if needed.""" + if not self.encodesSessionInURL and self._shouldAddCookie(): + if self._session is not None: + sessId = self._sessionIdentifier() + expireCookie = not self._session.isValid + else: + sessId = self._expiredSessionIdentifier + expireCookie = True + + C = SimpleCookie() + name = self._cookieName + C[name] = sessId + C[name]['path'] = '/' + if expireCookie: + # Expire cookie + C[name]['expires'] = -365*24*60*60 + C[name]['max-age'] = 0 + headers.append(('Set-Cookie', C[name].OutputString())) + + def close(self): + """Checks session back into session store.""" + if self._session is None: + return + # Check the session back in and get rid of our reference. + self._store.checkInSession(self._session) + self._session = None + if __debug__: self._closed = True + + # Public API + + def _get_session(self): + assert not self._closed + if self._session is None: + self._session = self._store.createSession() + self._newSession = True + + assert self._session is not None + return self._session + session = property(_get_session, None, None, + 'Returns the Session object associated with this ' + 'client') + + def _get_hasSession(self): + assert not self._closed + return self._session is not None + hasSession = property(_get_hasSession, None, None, + 'True if a Session currently exists for this client') + + def _get_isSessionNew(self): + assert not self._closed + return self._newSession + isSessionNew = property(_get_isSessionNew, None, None, + 'True if the Session was created in this ' + 'transaction') + + def _get_hasSessionExpired(self): + assert not self._closed + return self._expired + hasSessionExpired = property(_get_hasSessionExpired, None, None, + 'True if the client was associated with a ' + 'non-existent Session') + + # Utilities + + def encodeURL(self, url): + """Encodes session ID in URL, if necessary.""" + assert not self._closed + if not self.encodesSessionInURL or self._session is None: + return url + u = list(urlparse.urlsplit(url)) + q = '%s=%s' % (self._fieldName, self._sessionIdentifier()) + if u[3]: + u[3] = q + '&' + u[3] + else: + u[3] = q + return urlparse.urlunsplit(u) + +def _addClose(appIter, closeFunc): + """ + Wraps an iterator so that its close() method calls closeFunc. Respects + the existence of __len__ and the iterator's own close() method. + + Need to use metaclass magic because __len__ and next are not + recognized unless they're part of the class. (Can't assign at + __init__ time.) + """ + class metaIterWrapper(type): + def __init__(cls, name, bases, clsdict): + super(metaIterWrapper, cls).__init__(name, bases, clsdict) + if hasattr(appIter, '__len__'): + cls.__len__ = appIter.__len__ + cls.next = iter(appIter).next + if hasattr(appIter, 'close'): + def _close(self): + appIter.close() + closeFunc() + cls.close = _close + else: + cls.close = closeFunc + + class iterWrapper(object): + __metaclass__ = metaIterWrapper + def __iter__(self): + return self + + return iterWrapper() + +class SessionMiddleware(object): + """ + WSGI middleware that adds a session service. A SessionService instance + is passed to the application in environ['com.saddi.service.session']. + A references to this instance should not be saved. (A new instance is + instantiated with every call to the application.) + """ + _serviceClass = SessionService + + def __init__(self, store, application, serviceClass=None, **kw): + self._store = store + self._application = application + if serviceClass is not None: + self._serviceClass = serviceClass + self._serviceKW = kw + + def __call__(self, environ, start_response): + service = self._serviceClass(self._store, environ, **self._serviceKW) + environ['com.saddi.service.session'] = service + + def my_start_response(status, headers, exc_info=None): + service.addCookie(headers) + return start_response(status, headers, exc_info) + + try: + result = self._application(environ, my_start_response) + except: + # If anything goes wrong, ensure the session is checked back in. + service.close() + raise + + # The iterator must be unconditionally wrapped, just in case it + # is a generator. (In which case, we may not know that a Session + # has been checked out until completion of the first iteration.) + return _addClose(result, service.close) + +if __name__ == '__main__': + mss = MemorySessionStore(timeout=5) +# sss = ShelveSessionStore(timeout=5) + dss = DiskSessionStore(timeout=5) diff --git a/flup/publisher/__init__.py b/flup/publisher/__init__.py new file mode 100644 index 0000000..792d600 --- /dev/null +++ b/flup/publisher/__init__.py @@ -0,0 +1 @@ +# diff --git a/flup/resolver/__init__.py b/flup/resolver/__init__.py new file mode 100644 index 0000000..792d600 --- /dev/null +++ b/flup/resolver/__init__.py @@ -0,0 +1 @@ +# diff --git a/flup/server/__init__.py b/flup/server/__init__.py new file mode 100644 index 0000000..792d600 --- /dev/null +++ b/flup/server/__init__.py @@ -0,0 +1 @@ +# diff --git a/flup/server/ajp.py b/flup/server/ajp.py new file mode 100644 index 0000000..baa2deb --- /dev/null +++ b/flup/server/ajp.py @@ -0,0 +1,1195 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +""" +ajp - an AJP 1.3/WSGI gateway. + +For more information about AJP and AJP connectors for your web server, see +<http://jakarta.apache.org/tomcat/connectors-doc/>. + +For more information about the Web Server Gateway Interface, see +<http://www.python.org/peps/pep-0333.html>. + +Example usage: + + #!/usr/bin/env python + import sys + from myapplication import app # Assume app is your WSGI application object + from ajp import WSGIServer + ret = WSGIServer(app).run() + sys.exit(ret and 42 or 0) + +See the documentation for WSGIServer for more information. + +About the bit of logic at the end: +Upon receiving SIGHUP, the python script will exit with status code 42. This +can be used by a wrapper script to determine if the python script should be +re-run. When a SIGINT or SIGTERM is received, the script exits with status +code 0, possibly indicating a normal exit. + +Example wrapper script: + + #!/bin/sh + STATUS=42 + while test $STATUS -eq 42; do + python "$@" that_script_above.py + STATUS=$? + done + +Example workers.properties (for mod_jk): + + worker.list=foo + worker.foo.port=8009 + worker.foo.host=localhost + worker.foo.type=ajp13 + +Example httpd.conf (for mod_jk): + + JkWorkersFile /path/to/workers.properties + JkMount /* foo + +Note that if you mount your ajp application anywhere but the root ("/"), you +SHOULD specifiy scriptName to the WSGIServer constructor. This will ensure +that SCRIPT_NAME/PATH_INFO are correctly deduced. +""" + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import socket +import select +import struct +import signal +import logging +import errno +import datetime +import time + +# Unfortunately, for now, threads are required. +import thread +import threading + +__all__ = ['WSGIServer'] + +# Packet header prefixes. +SERVER_PREFIX = '\x12\x34' +CONTAINER_PREFIX = 'AB' + +# Server packet types. +PKTTYPE_FWD_REQ = '\x02' +PKTTYPE_SHUTDOWN = '\x07' +PKTTYPE_PING = '\x08' +PKTTYPE_CPING = '\x0a' + +# Container packet types. +PKTTYPE_SEND_BODY = '\x03' +PKTTYPE_SEND_HEADERS = '\x04' +PKTTYPE_END_RESPONSE = '\x05' +PKTTYPE_GET_BODY = '\x06' +PKTTYPE_CPONG = '\x09' + +# Code tables for methods/headers/attributes. +methodTable = [ + None, + 'OPTIONS', + 'GET', + 'HEAD', + 'POST', + 'PUT', + 'DELETE', + 'TRACE', + 'PROPFIND', + 'PROPPATCH', + 'MKCOL', + 'COPY', + 'MOVE', + 'LOCK', + 'UNLOCK', + 'ACL', + 'REPORT', + 'VERSION-CONTROL', + 'CHECKIN', + 'CHECKOUT', + 'UNCHECKOUT', + 'SEARCH', + 'MKWORKSPACE', + 'UPDATE', + 'LABEL', + 'MERGE', + 'BASELINE_CONTROL', + 'MKACTIVITY' + ] + +requestHeaderTable = [ + None, + 'Accept', + 'Accept-Charset', + 'Accept-Encoding', + 'Accept-Language', + 'Authorization', + 'Connection', + 'Content-Type', + 'Content-Length', + 'Cookie', + 'Cookie2', + 'Host', + 'Pragma', + 'Referer', + 'User-Agent' + ] + +attributeTable = [ + None, + 'CONTEXT', + 'SERVLET_PATH', + 'REMOTE_USER', + 'AUTH_TYPE', + 'QUERY_STRING', + 'JVM_ROUTE', + 'SSL_CERT', + 'SSL_CIPHER', + 'SSL_SESSION', + None, # name follows + 'SSL_KEY_SIZE' + ] + +responseHeaderTable = [ + None, + 'content-type', + 'content-language', + 'content-length', + 'date', + 'last-modified', + 'location', + 'set-cookie', + 'set-cookie2', + 'servlet-engine', + 'status', + 'www-authenticate' + ] + +# The main classes use this name for logging. +LoggerName = 'ajp-wsgi' + +# Set up module-level logger. +console = logging.StreamHandler() +console.setLevel(logging.DEBUG) +console.setFormatter(logging.Formatter('%(asctime)s : %(message)s', + '%Y-%m-%d %H:%M:%S')) +logging.getLogger(LoggerName).addHandler(console) +del console + +class ProtocolError(Exception): + """ + Exception raised when the server does something unexpected or + sends garbled data. Usually leads to a Connection closing. + """ + pass + +def decodeString(data, pos=0): + """Decode a string.""" + try: + length = struct.unpack('>H', data[pos:pos+2])[0] + pos += 2 + if length == 0xffff: # This was undocumented! + return '', pos + s = data[pos:pos+length] + return s, pos+length+1 # Don't forget NUL + except Exception, e: + raise ProtocolError, 'decodeString: '+str(e) + +def decodeRequestHeader(data, pos=0): + """Decode a request header/value pair.""" + try: + if data[pos] == '\xa0': + # Use table + i = ord(data[pos+1]) + name = requestHeaderTable[i] + if name is None: + raise ValueError, 'bad request header code' + pos += 2 + else: + name, pos = decodeString(data, pos) + value, pos = decodeString(data, pos) + return name, value, pos + except Exception, e: + raise ProtocolError, 'decodeRequestHeader: '+str(e) + +def decodeAttribute(data, pos=0): + """Decode a request attribute.""" + try: + i = ord(data[pos]) + pos += 1 + if i == 0xff: + # end + return None, None, pos + elif i == 0x0a: + # name follows + name, pos = decodeString(data, pos) + elif i == 0x0b: + # Special handling of SSL_KEY_SIZE. + name = attributeTable[i] + # Value is an int, not a string. + value = struct.unpack('>H', data[pos:pos+2])[0] + return name, str(value), pos+2 + else: + name = attributeTable[i] + if name is None: + raise ValueError, 'bad attribute code' + value, pos = decodeString(data, pos) + return name, value, pos + except Exception, e: + raise ProtocolError, 'decodeAttribute: '+str(e) + +def encodeString(s): + """Encode a string.""" + return struct.pack('>H', len(s)) + s + '\x00' + +def encodeResponseHeader(name, value): + """Encode a response header/value pair.""" + lname = name.lower() + if lname in responseHeaderTable: + # Use table + i = responseHeaderTable.index(lname) + out = '\xa0' + chr(i) + else: + out = encodeString(name) + out += encodeString(value) + return out + +class Packet(object): + """An AJP message packet.""" + def __init__(self): + self.data = '' + # Don't set this on write, it will be calculated automatically. + self.length = 0 + + def _recvall(sock, length): + """ + Attempts to receive length bytes from a socket, blocking if necessary. + (Socket may be blocking or non-blocking.) + """ + dataList = [] + recvLen = 0 + while length: + try: + data = sock.recv(length) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if not data: # EOF + break + dataList.append(data) + dataLen = len(data) + recvLen += dataLen + length -= dataLen + return ''.join(dataList), recvLen + _recvall = staticmethod(_recvall) + + def read(self, sock): + """Attempt to read a packet from the server.""" + try: + header, length = self._recvall(sock, 4) + except socket.error: + # Treat any sort of socket errors as EOF (close Connection). + raise EOFError + + if length < 4: + raise EOFError + + if header[:2] != SERVER_PREFIX: + raise ProtocolError, 'invalid header' + + self.length = struct.unpack('>H', header[2:4])[0] + if self.length: + try: + self.data, length = self._recvall(sock, self.length) + except socket.error: + raise EOFError + + if length < self.length: + raise EOFError + + def _sendall(sock, data): + """ + Writes data to a socket and does not return until all the data is sent. + """ + length = len(data) + while length: + try: + sent = sock.send(data) + except socket.error, e: + if e[0] == errno.EPIPE: + return # Don't bother raising an exception. Just ignore. + elif e[0] == errno.EAGAIN: + select.select([], [sock], []) + continue + else: + raise + data = data[sent:] + length -= sent + _sendall = staticmethod(_sendall) + + def write(self, sock): + """Send a packet to the server.""" + self.length = len(self.data) + self._sendall(sock, CONTAINER_PREFIX + struct.pack('>H', self.length)) + if self.length: + self._sendall(sock, self.data) + +class InputStream(object): + """ + File-like object that represents the request body (if any). Supports + the bare mininum methods required by the WSGI spec. Thanks to + StringIO for ideas. + """ + def __init__(self, conn): + self._conn = conn + + # See WSGIServer. + self._shrinkThreshold = conn.server.inputStreamShrinkThreshold + + self._buf = '' + self._bufList = [] + self._pos = 0 # Current read position. + self._avail = 0 # Number of bytes currently available. + self._length = 0 # Set to Content-Length in request. + + self.logger = logging.getLogger(LoggerName) + + def bytesAvailForAdd(self): + return self._length - self._avail + + def _shrinkBuffer(self): + """Gets rid of already read data (since we can't rewind).""" + if self._pos >= self._shrinkThreshold: + self._buf = self._buf[self._pos:] + self._avail -= self._pos + self._length -= self._pos + self._pos = 0 + + assert self._avail >= 0 and self._length >= 0 + + def _waitForData(self): + toAdd = min(self.bytesAvailForAdd(), 0xffff) + assert toAdd > 0 + pkt = Packet() + pkt.data = PKTTYPE_GET_BODY + \ + struct.pack('>H', toAdd) + self._conn.writePacket(pkt) + self._conn.processInput() + + def read(self, n=-1): + if self._pos == self._length: + return '' + while True: + if n < 0 or (self._avail - self._pos) < n: + # Not enough data available. + if not self.bytesAvailForAdd(): + # And there's no more coming. + newPos = self._avail + break + else: + # Ask for more data and wait. + self._waitForData() + continue + else: + newPos = self._pos + n + break + # Merge buffer list, if necessary. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readline(self, length=None): + if self._pos == self._length: + return '' + while True: + # Unfortunately, we need to merge the buffer list early. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + # Find newline. + i = self._buf.find('\n', self._pos) + if i < 0: + # Not found? + if not self.bytesAvailForAdd(): + # No more data coming. + newPos = self._avail + break + else: + # Wait for more to come. + self._waitForData() + continue + else: + newPos = i + 1 + break + if length is not None: + if self._pos + length < newPos: + newPos = self._pos + length + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readlines(self, sizehint=0): + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def __iter__(self): + return self + + def next(self): + r = self.readline() + if not r: + raise StopIteration + return r + + def setDataLength(self, length): + """ + Once Content-Length is known, Request calls this method to set it. + """ + self._length = length + + def addData(self, data): + """ + Adds data from the server to this InputStream. Note that we never ask + the server for data beyond the Content-Length, so the server should + never send us an EOF (empty string argument). + """ + if not data: + raise ProtocolError, 'short data' + self._bufList.append(data) + length = len(data) + self._avail += length + if self._avail > self._length: + raise ProtocolError, 'too much data' + +class Request(object): + """ + A Request object. A more fitting name would probably be Transaction, but + it's named Request to mirror my FastCGI driver. :) This object + encapsulates all the data about the HTTP request and allows the handler + to send a response. + + The only attributes/methods that the handler should concern itself + with are: environ, input, startResponse(), and write(). + """ + # Do not ever change the following value. + _maxWrite = 8192 - 4 - 3 # 8k - pkt header - send body header + + def __init__(self, conn): + self._conn = conn + + self.environ = { + 'SCRIPT_NAME': conn.server.scriptName + } + self.input = InputStream(conn) + + self._headersSent = False + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.info('%s %s', + self.environ['REQUEST_METHOD'], + self.environ['REQUEST_URI']) + + start = datetime.datetime.now() + + try: + self._conn.server.handler(self) + except: + self.logger.exception('Exception caught from handler') + if not self._headersSent: + self._conn.server.error(self) + + end = datetime.datetime.now() + + # Notify server of end of response (reuse flag is set to true). + pkt = Packet() + pkt.data = PKTTYPE_END_RESPONSE + '\x01' + self._conn.writePacket(pkt) + + handlerTime = end - start + self.logger.debug('%s %s done (%.3f secs)', + self.environ['REQUEST_METHOD'], + self.environ['REQUEST_URI'], + handlerTime.seconds + + handlerTime.microseconds / 1000000.0) + + # The following methods are called from the Connection to set up this + # Request. + + def setMethod(self, value): + self.environ['REQUEST_METHOD'] = value + + def setProtocol(self, value): + self.environ['SERVER_PROTOCOL'] = value + + def setRequestURI(self, value): + self.environ['REQUEST_URI'] = value + + scriptName = self._conn.server.scriptName + if not value.startswith(scriptName): + self.logger.warning('scriptName does not match request URI') + + self.environ['PATH_INFO'] = value[len(scriptName):] + + def setRemoteAddr(self, value): + self.environ['REMOTE_ADDR'] = value + + def setRemoteHost(self, value): + self.environ['REMOTE_HOST'] = value + + def setServerName(self, value): + self.environ['SERVER_NAME'] = value + + def setServerPort(self, value): + self.environ['SERVER_PORT'] = str(value) + + def setIsSSL(self, value): + if value: + self.environ['HTTPS'] = 'on' + + def addHeader(self, name, value): + name = name.replace('-', '_').upper() + if name in ('CONTENT_TYPE', 'CONTENT_LENGTH'): + self.environ[name] = value + if name == 'CONTENT_LENGTH': + length = int(value) + self.input.setDataLength(length) + else: + self.environ['HTTP_'+name] = value + + def addAttribute(self, name, value): + self.environ[name] = value + + # The only two methods that should be called from the handler. + + def startResponse(self, statusCode, statusMsg, headers): + """ + Begin the HTTP response. This must only be called once and it + must be called before any calls to write(). + + statusCode is the integer status code (e.g. 200). statusMsg + is the associated reason message (e.g.'OK'). headers is a list + of 2-tuples - header name/value pairs. (Both header name and value + must be strings.) + """ + assert not self._headersSent, 'Headers already sent!' + + pkt = Packet() + pkt.data = PKTTYPE_SEND_HEADERS + \ + struct.pack('>H', statusCode) + \ + encodeString(statusMsg) + \ + struct.pack('>H', len(headers)) + \ + ''.join([encodeResponseHeader(name, value) + for name,value in headers]) + + self._conn.writePacket(pkt) + + self._headersSent = True + + def write(self, data): + """ + Write data (which comprises the response body). Note that due to + restrictions on AJP packet size, we limit our writes to 8185 bytes + each packet. + """ + assert self._headersSent, 'Headers must be sent first!' + + bytesLeft = len(data) + while bytesLeft: + toWrite = min(bytesLeft, self._maxWrite) + + pkt = Packet() + pkt.data = PKTTYPE_SEND_BODY + \ + struct.pack('>H', toWrite) + \ + data[:toWrite] + self._conn.writePacket(pkt) + + data = data[toWrite:] + bytesLeft -= toWrite + +class Connection(object): + """ + A single Connection with the server. Requests are not multiplexed over the + same connection, so at any given time, the Connection is either + waiting for a request, or processing a single request. + """ + def __init__(self, sock, addr, server): + self.server = server + self._sock = sock + self._addr = addr + + self._request = None + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.debug('Connection starting up (%s:%d)', + self._addr[0], self._addr[1]) + + # Main loop. Errors will cause the loop to be exited and + # the socket to be closed. + while True: + try: + self.processInput() + except ProtocolError, e: + self.logger.error("Protocol error '%s'", str(e)) + break + except EOFError: + break + except: + self.logger.exception('Exception caught in Connection') + break + + self.logger.debug('Connection shutting down (%s:%d)', + self._addr[0], self._addr[1]) + + self._sock.close() + + def processInput(self): + """Wait for and process a single packet.""" + pkt = Packet() + select.select([self._sock], [], []) + pkt.read(self._sock) + + # Body chunks have no packet type code. + if self._request is not None: + self._processBody(pkt) + return + + if not pkt.length: + raise ProtocolError, 'unexpected empty packet' + + pkttype = pkt.data[0] + if pkttype == PKTTYPE_FWD_REQ: + self._forwardRequest(pkt) + elif pkttype == PKTTYPE_SHUTDOWN: + self._shutdown(pkt) + elif pkttype == PKTTYPE_PING: + self._ping(pkt) + elif pkttype == PKTTYPE_CPING: + self._cping(pkt) + else: + raise ProtocolError, 'unknown packet type' + + def _forwardRequest(self, pkt): + """ + Creates a Request object, fills it in from the packet, then runs it. + """ + assert self._request is None + + req = self.server.requestClass(self) + i = ord(pkt.data[1]) + method = methodTable[i] + if method is None: + raise ValueError, 'bad method field' + req.setMethod(method) + value, pos = decodeString(pkt.data, 2) + req.setProtocol(value) + value, pos = decodeString(pkt.data, pos) + req.setRequestURI(value) + value, pos = decodeString(pkt.data, pos) + req.setRemoteAddr(value) + value, pos = decodeString(pkt.data, pos) + req.setRemoteHost(value) + value, pos = decodeString(pkt.data, pos) + req.setServerName(value) + value = struct.unpack('>H', pkt.data[pos:pos+2])[0] + req.setServerPort(value) + i = ord(pkt.data[pos+2]) + req.setIsSSL(i != 0) + + # Request headers. + numHeaders = struct.unpack('>H', pkt.data[pos+3:pos+5])[0] + pos += 5 + for i in range(numHeaders): + name, value, pos = decodeRequestHeader(pkt.data, pos) + req.addHeader(name, value) + + # Attributes. + while True: + name, value, pos = decodeAttribute(pkt.data, pos) + if name is None: + break + req.addAttribute(name, value) + + self._request = req + + # Read first body chunk, if needed. + if req.input.bytesAvailForAdd(): + self.processInput() + + # Run Request. + req.run() + + self._request = None + + def _shutdown(self, pkt): + """Not sure what to do with this yet.""" + self.logger.info('Received shutdown request from server') + + def _ping(self, pkt): + """I have no idea what this packet means.""" + self.logger.debug('Received ping') + + def _cping(self, pkt): + """Respond to a PING (CPING) packet.""" + self.logger.debug('Received PING, sending PONG') + pkt = Packet() + pkt.data = PKTTYPE_CPONG + self.writePacket(pkt) + + def _processBody(self, pkt): + """ + Handles a body chunk from the server by appending it to the + InputStream. + """ + if pkt.length: + length = struct.unpack('>H', pkt.data[:2])[0] + self._request.input.addData(pkt.data[2:2+length]) + else: + # Shouldn't really ever get here. + self._request.input.addData('') + + def writePacket(self, pkt): + """Sends a Packet to the server.""" + pkt.write(self._sock) + +class ThreadPool(object): + """ + Thread pool that maintains the number of idle threads between + minSpare and maxSpare inclusive. By default, there is no limit on + the number of threads that can be started, but this can be controlled + by maxThreads. + """ + def __init__(self, minSpare=1, maxSpare=5, maxThreads=sys.maxint): + self._minSpare = minSpare + self._maxSpare = maxSpare + self._maxThreads = max(minSpare, maxThreads) + + self._lock = threading.Condition() + self._workQueue = [] + self._idleCount = self._workerCount = maxSpare + + # Start the minimum number of worker threads. + for i in range(maxSpare): + thread.start_new_thread(self._worker, ()) + + def addJob(self, job, allowQueuing=True): + """ + Adds a job to the work queue. The job object should have a run() + method. If allowQueuing is True (the default), the job will be + added to the work queue regardless if there are any idle threads + ready. (The only way for there to be no idle threads is if maxThreads + is some reasonable, finite limit.) + + Otherwise, if allowQueuing is False, and there are no more idle + threads, the job will not be queued. + + Returns True if the job was queued, False otherwise. + """ + self._lock.acquire() + try: + # Maintain minimum number of spares. + while self._idleCount < self._minSpare and \ + self._workerCount < self._maxThreads: + self._workerCount += 1 + self._idleCount += 1 + thread.start_new_thread(self._worker, ()) + + # Hand off the job. + if self._idleCount or allowQueuing: + self._workQueue.append(job) + self._lock.notify() + return True + else: + return False + finally: + self._lock.release() + + def _worker(self): + """ + Worker thread routine. Waits for a job, executes it, repeat. + """ + self._lock.acquire() + while True: + while not self._workQueue: + self._lock.wait() + + # We have a job to do... + job = self._workQueue.pop(0) + + assert self._idleCount > 0 + self._idleCount -= 1 + + self._lock.release() + + job.run() + + self._lock.acquire() + + if self._idleCount == self._maxSpare: + break # NB: lock still held + self._idleCount += 1 + assert self._idleCount <= self._maxSpare + + # Die off... + assert self._workerCount > self._maxSpare + self._workerCount -= 1 + + self._lock.release() + +class WSGIServer(object): + """ + AJP1.3/WSGI server. Runs your WSGI application as a persistant program + that understands AJP1.3. Opens up a TCP socket, binds it, and then + waits for forwarded requests from your webserver. + + Why AJP? Two good reasons are that AJP provides load-balancing and + fail-over support. Personally, I just wanted something new to + implement. :) + + Of course you will need an AJP1.3 connector for your webserver (e.g. + mod_jk) - see <http://jakarta.apache.org/tomcat/connectors-doc/>. + """ + # What Request class to use. + requestClass = Request + + # Limits the size of the InputStream's string buffer to this size + 8k. + # Since the InputStream is not seekable, we throw away already-read + # data once this certain amount has been read. (The 8k is there because + # it is the maximum size of new data added per chunk.) + inputStreamShrinkThreshold = 102400 - 8192 + + def __init__(self, application, scriptName='', environ=None, + multithreaded=True, + bindAddress=('localhost', 8009), allowedServers=None, + loggingLevel=logging.INFO, **kw): + """ + scriptName is the initial portion of the URL path that "belongs" + to your application. It is used to determine PATH_INFO (which doesn't + seem to be passed in). An empty scriptName means your application + is mounted at the root of your virtual host. + + environ, which must be a dictionary, can contain any additional + environment variables you want to pass to your application. + + Set multithreaded to False if your application is not thread-safe. + + bindAddress is the address to bind to, which must be a tuple of + length 2. The first element is a string, which is the host name + or IPv4 address of a local interface. The 2nd element is the port + number. + + allowedServers must be None or a list of strings representing the + IPv4 addresses of servers allowed to connect. None means accept + connections from anywhere. + + loggingLevel sets the logging level of the module-level logger. + + Any additional keyword arguments are passed to the underlying + ThreadPool. + """ + if environ is None: + environ = {} + + self.application = application + self.scriptName = scriptName + self.environ = environ + self.multithreaded = multithreaded + self._bindAddress = bindAddress + self._allowedServers = allowedServers + + # Used to force single-threadedness. + self._appLock = thread.allocate_lock() + + self._threadPool = ThreadPool(**kw) + + self.logger = logging.getLogger(LoggerName) + self.logger.setLevel(loggingLevel) + + def _setupSocket(self): + """Creates and binds the socket for communication with the server.""" + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(self._bindAddress) + sock.listen(socket.SOMAXCONN) + return sock + + def _cleanupSocket(self, sock): + """Closes the main socket.""" + sock.close() + + def _isServerAllowed(self, addr): + return self._allowedServers is None or \ + addr[0] in self._allowedServers + + def _installSignalHandlers(self): + self._oldSIGs = [(x,signal.getsignal(x)) for x in + (signal.SIGHUP, signal.SIGINT, signal.SIGTERM)] + signal.signal(signal.SIGHUP, self._hupHandler) + signal.signal(signal.SIGINT, self._intHandler) + signal.signal(signal.SIGTERM, self._intHandler) + + def _restoreSignalHandlers(self): + for signum,handler in self._oldSIGs: + signal.signal(signum, handler) + + def _hupHandler(self, signum, frame): + self._hupReceived = True + self._keepGoing = False + + def _intHandler(self, signum, frame): + self._keepGoing = False + + def run(self, timeout=1.0): + """ + Main loop. Call this after instantiating WSGIServer. SIGHUP, SIGINT, + SIGTERM cause it to cleanup and return. (If a SIGHUP is caught, this + method returns True. Returns False otherwise.) + """ + self.logger.info('%s starting up', self.__class__.__name__) + + try: + sock = self._setupSocket() + except socket.error, e: + self.logger.error('Failed to bind socket (%s), exiting', e[1]) + return False + + self._keepGoing = True + self._hupReceived = False + + # Install signal handlers. + self._installSignalHandlers() + + while self._keepGoing: + try: + r, w, e = select.select([sock], [], [], timeout) + except select.error, e: + if e[0] == errno.EINTR: + continue + raise + + if r: + try: + clientSock, addr = sock.accept() + except socket.error, e: + if e[0] in (errno.EINTR, errno.EAGAIN): + continue + raise + + if not self._isServerAllowed(addr): + self.logger.warning('Server connection from %s disallowed', + addr[0]) + clientSock.close() + continue + + # Hand off to Connection. + conn = Connection(clientSock, addr, self) + if not self._threadPool.addJob(conn, allowQueuing=False): + # No thread left, immediately close the socket to hopefully + # indicate to the web server that we're at our limit... + # and to prevent having too many opened (and useless) + # files. + clientSock.close() + + self._mainloopPeriodic() + + # Restore old signal handlers. + self._restoreSignalHandlers() + + self._cleanupSocket(sock) + + self.logger.info('%s shutting down%s', self.__class__.__name__, + self._hupReceived and ' (reload requested)' or '') + + return self._hupReceived + + def _mainloopPeriodic(self): + """ + Called with just about each iteration of the main loop. Meant to + be overridden. + """ + pass + + def _exit(self, reload=False): + """ + Protected convenience method for subclasses to force an exit. Not + really thread-safe, which is why it isn't public. + """ + if self._keepGoing: + self._keepGoing = False + self._hupReceived = reload + + def handler(self, request): + """ + WSGI handler. Sets up WSGI environment, calls the application, + and sends the application's response. + """ + environ = request.environ + environ.update(self.environ) + + environ['wsgi.version'] = (1,0) + environ['wsgi.input'] = request.input + environ['wsgi.errors'] = sys.stderr + environ['wsgi.multithread'] = self.multithreaded + environ['wsgi.multiprocess'] = True + environ['wsgi.run_once'] = False + + if environ.get('HTTPS', 'off') in ('on', '1'): + environ['wsgi.url_scheme'] = 'https' + else: + environ['wsgi.url_scheme'] = 'http' + + headers_set = [] + headers_sent = [] + result = None + + def write(data): + assert type(data) is str, 'write() argument must be string' + assert headers_set, 'write() before start_response()' + + if not headers_sent: + status, responseHeaders = headers_sent[:] = headers_set + statusCode = int(status[:3]) + statusMsg = status[4:] + found = False + for header,value in responseHeaders: + if header.lower() == 'content-length': + found = True + break + if not found and result is not None: + try: + if len(result) == 1: + responseHeaders.append(('Content-Length', + str(len(data)))) + except: + pass + request.startResponse(statusCode, statusMsg, responseHeaders) + + request.write(data) + + def start_response(status, response_headers, exc_info=None): + if exc_info: + try: + if headers_sent: + # Re-raise if too late + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None # avoid dangling circular ref + else: + assert not headers_set, 'Headers already set!' + + assert type(status) is str, 'Status must be a string' + assert len(status) >= 4, 'Status must be at least 4 characters' + assert int(status[:3]), 'Status must begin with 3-digit code' + assert status[3] == ' ', 'Status must have a space after code' + assert type(response_headers) is list, 'Headers must be a list' + if __debug__: + for name,val in response_headers: + assert type(name) is str, 'Header names must be strings' + assert type(val) is str, 'Header values must be strings' + + headers_set[:] = [status, response_headers] + return write + + if not self.multithreaded: + self._appLock.acquire() + try: + result = self.application(environ, start_response) + try: + for data in result: + if data: + write(data) + if not headers_sent: + write('') # in case body was empty + finally: + if hasattr(result, 'close'): + result.close() + finally: + if not self.multithreaded: + self._appLock.release() + + def error(self, request): + """ + Override to provide custom error handling. Ideally, however, + all errors should be caught at the application level. + """ + request.startResponse(200, 'OK', [('Content-Type', 'text/html')]) + import cgitb + request.write(cgitb.html(sys.exc_info())) + +if __name__ == '__main__': + def test_app(environ, start_response): + """Probably not the most efficient example.""" + import cgi + start_response('200 OK', [('Content-Type', 'text/html')]) + yield '<html><head><title>Hello World!</title></head>\n' \ + '<body>\n' \ + '<p>Hello World!</p>\n' \ + '<table border="1">' + names = environ.keys() + names.sort() + for name in names: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + name, cgi.escape(`environ[name]`)) + + form = cgi.FieldStorage(fp=environ['wsgi.input'], environ=environ, + keep_blank_values=1) + if form.list: + yield '<tr><th colspan="2">Form data</th></tr>' + + for field in form.list: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + field.name, field.value) + + yield '</table>\n' \ + '</body></html>\n' + + # Explicitly set bindAddress to *:8009 for testing. + WSGIServer(test_app, + bindAddress=('', 8009), + loggingLevel=logging.DEBUG).run() diff --git a/flup/server/ajp_fork.py b/flup/server/ajp_fork.py new file mode 100644 index 0000000..4df9704 --- /dev/null +++ b/flup/server/ajp_fork.py @@ -0,0 +1,1025 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +""" +ajp - an AJP 1.3/WSGI gateway. + +For more information about AJP and AJP connectors for your web server, see +<http://jakarta.apache.org/tomcat/connectors-doc/>. + +For more information about the Web Server Gateway Interface, see +<http://www.python.org/peps/pep-0333.html>. + +Example usage: + + #!/usr/bin/env python + import sys + from myapplication import app # Assume app is your WSGI application object + from ajp import WSGIServer + ret = WSGIServer(app).run() + sys.exit(ret and 42 or 0) + +See the documentation for WSGIServer for more information. + +About the bit of logic at the end: +Upon receiving SIGHUP, the python script will exit with status code 42. This +can be used by a wrapper script to determine if the python script should be +re-run. When a SIGINT or SIGTERM is received, the script exits with status +code 0, possibly indicating a normal exit. + +Example wrapper script: + + #!/bin/sh + STATUS=42 + while test $STATUS -eq 42; do + python "$@" that_script_above.py + STATUS=$? + done + +Example workers.properties (for mod_jk): + + worker.list=foo + worker.foo.port=8009 + worker.foo.host=localhost + worker.foo.type=ajp13 + +Example httpd.conf (for mod_jk): + + JkWorkersFile /path/to/workers.properties + JkMount /* foo + +Note that if you mount your ajp application anywhere but the root ("/"), you +SHOULD specifiy scriptName to the WSGIServer constructor. This will ensure +that SCRIPT_NAME/PATH_INFO are correctly deduced. +""" + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import socket +import select +import struct +import logging +import errno +import datetime +import prefork + +__all__ = ['WSGIServer'] + +# Packet header prefixes. +SERVER_PREFIX = '\x12\x34' +CONTAINER_PREFIX = 'AB' + +# Server packet types. +PKTTYPE_FWD_REQ = '\x02' +PKTTYPE_SHUTDOWN = '\x07' +PKTTYPE_PING = '\x08' +PKTTYPE_CPING = '\x0a' + +# Container packet types. +PKTTYPE_SEND_BODY = '\x03' +PKTTYPE_SEND_HEADERS = '\x04' +PKTTYPE_END_RESPONSE = '\x05' +PKTTYPE_GET_BODY = '\x06' +PKTTYPE_CPONG = '\x09' + +# Code tables for methods/headers/attributes. +methodTable = [ + None, + 'OPTIONS', + 'GET', + 'HEAD', + 'POST', + 'PUT', + 'DELETE', + 'TRACE', + 'PROPFIND', + 'PROPPATCH', + 'MKCOL', + 'COPY', + 'MOVE', + 'LOCK', + 'UNLOCK', + 'ACL', + 'REPORT', + 'VERSION-CONTROL', + 'CHECKIN', + 'CHECKOUT', + 'UNCHECKOUT', + 'SEARCH', + 'MKWORKSPACE', + 'UPDATE', + 'LABEL', + 'MERGE', + 'BASELINE_CONTROL', + 'MKACTIVITY' + ] + +requestHeaderTable = [ + None, + 'Accept', + 'Accept-Charset', + 'Accept-Encoding', + 'Accept-Language', + 'Authorization', + 'Connection', + 'Content-Type', + 'Content-Length', + 'Cookie', + 'Cookie2', + 'Host', + 'Pragma', + 'Referer', + 'User-Agent' + ] + +attributeTable = [ + None, + 'CONTEXT', + 'SERVLET_PATH', + 'REMOTE_USER', + 'AUTH_TYPE', + 'QUERY_STRING', + 'JVM_ROUTE', + 'SSL_CERT', + 'SSL_CIPHER', + 'SSL_SESSION', + None, # name follows + 'SSL_KEY_SIZE' + ] + +responseHeaderTable = [ + None, + 'content-type', + 'content-language', + 'content-length', + 'date', + 'last-modified', + 'location', + 'set-cookie', + 'set-cookie2', + 'servlet-engine', + 'status', + 'www-authenticate' + ] + +# The main classes use this name for logging. +LoggerName = 'ajp-wsgi' + +# Set up module-level logger. +console = logging.StreamHandler() +console.setLevel(logging.DEBUG) +console.setFormatter(logging.Formatter('%(asctime)s : %(message)s', + '%Y-%m-%d %H:%M:%S')) +logging.getLogger(LoggerName).addHandler(console) +del console + +class ProtocolError(Exception): + """ + Exception raised when the server does something unexpected or + sends garbled data. Usually leads to a Connection closing. + """ + pass + +def decodeString(data, pos=0): + """Decode a string.""" + try: + length = struct.unpack('>H', data[pos:pos+2])[0] + pos += 2 + if length == 0xffff: # This was undocumented! + return '', pos + s = data[pos:pos+length] + return s, pos+length+1 # Don't forget NUL + except Exception, e: + raise ProtocolError, 'decodeString: '+str(e) + +def decodeRequestHeader(data, pos=0): + """Decode a request header/value pair.""" + try: + if data[pos] == '\xa0': + # Use table + i = ord(data[pos+1]) + name = requestHeaderTable[i] + if name is None: + raise ValueError, 'bad request header code' + pos += 2 + else: + name, pos = decodeString(data, pos) + value, pos = decodeString(data, pos) + return name, value, pos + except Exception, e: + raise ProtocolError, 'decodeRequestHeader: '+str(e) + +def decodeAttribute(data, pos=0): + """Decode a request attribute.""" + try: + i = ord(data[pos]) + pos += 1 + if i == 0xff: + # end + return None, None, pos + elif i == 0x0a: + # name follows + name, pos = decodeString(data, pos) + elif i == 0x0b: + # Special handling of SSL_KEY_SIZE. + name = attributeTable[i] + # Value is an int, not a string. + value = struct.unpack('>H', data[pos:pos+2])[0] + return name, str(value), pos+2 + else: + name = attributeTable[i] + if name is None: + raise ValueError, 'bad attribute code' + value, pos = decodeString(data, pos) + return name, value, pos + except Exception, e: + raise ProtocolError, 'decodeAttribute: '+str(e) + +def encodeString(s): + """Encode a string.""" + return struct.pack('>H', len(s)) + s + '\x00' + +def encodeResponseHeader(name, value): + """Encode a response header/value pair.""" + lname = name.lower() + if lname in responseHeaderTable: + # Use table + i = responseHeaderTable.index(lname) + out = '\xa0' + chr(i) + else: + out = encodeString(name) + out += encodeString(value) + return out + +class Packet(object): + """An AJP message packet.""" + def __init__(self): + self.data = '' + # Don't set this on write, it will be calculated automatically. + self.length = 0 + + def _recvall(sock, length): + """ + Attempts to receive length bytes from a socket, blocking if necessary. + (Socket may be blocking or non-blocking.) + """ + dataList = [] + recvLen = 0 + while length: + try: + data = sock.recv(length) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if not data: # EOF + break + dataList.append(data) + dataLen = len(data) + recvLen += dataLen + length -= dataLen + return ''.join(dataList), recvLen + _recvall = staticmethod(_recvall) + + def read(self, sock): + """Attempt to read a packet from the server.""" + try: + header, length = self._recvall(sock, 4) + except socket.error: + # Treat any sort of socket errors as EOF (close Connection). + raise EOFError + + if length < 4: + raise EOFError + + if header[:2] != SERVER_PREFIX: + raise ProtocolError, 'invalid header' + + self.length = struct.unpack('>H', header[2:4])[0] + if self.length: + try: + self.data, length = self._recvall(sock, self.length) + except socket.error: + raise EOFError + + if length < self.length: + raise EOFError + + def _sendall(sock, data): + """ + Writes data to a socket and does not return until all the data is sent. + """ + length = len(data) + while length: + try: + sent = sock.send(data) + except socket.error, e: + if e[0] == errno.EPIPE: + return # Don't bother raising an exception. Just ignore. + elif e[0] == errno.EAGAIN: + select.select([], [sock], []) + continue + else: + raise + data = data[sent:] + length -= sent + _sendall = staticmethod(_sendall) + + def write(self, sock): + """Send a packet to the server.""" + self.length = len(self.data) + self._sendall(sock, CONTAINER_PREFIX + struct.pack('>H', self.length)) + if self.length: + self._sendall(sock, self.data) + +class InputStream(object): + """ + File-like object that represents the request body (if any). Supports + the bare mininum methods required by the WSGI spec. Thanks to + StringIO for ideas. + """ + def __init__(self, conn): + self._conn = conn + + # See WSGIServer. + self._shrinkThreshold = conn.server.inputStreamShrinkThreshold + + self._buf = '' + self._bufList = [] + self._pos = 0 # Current read position. + self._avail = 0 # Number of bytes currently available. + self._length = 0 # Set to Content-Length in request. + + self.logger = logging.getLogger(LoggerName) + + def bytesAvailForAdd(self): + return self._length - self._avail + + def _shrinkBuffer(self): + """Gets rid of already read data (since we can't rewind).""" + if self._pos >= self._shrinkThreshold: + self._buf = self._buf[self._pos:] + self._avail -= self._pos + self._length -= self._pos + self._pos = 0 + + assert self._avail >= 0 and self._length >= 0 + + def _waitForData(self): + toAdd = min(self.bytesAvailForAdd(), 0xffff) + assert toAdd > 0 + pkt = Packet() + pkt.data = PKTTYPE_GET_BODY + \ + struct.pack('>H', toAdd) + self._conn.writePacket(pkt) + self._conn.processInput() + + def read(self, n=-1): + if self._pos == self._length: + return '' + while True: + if n < 0 or (self._avail - self._pos) < n: + # Not enough data available. + if not self.bytesAvailForAdd(): + # And there's no more coming. + newPos = self._avail + break + else: + # Ask for more data and wait. + self._waitForData() + continue + else: + newPos = self._pos + n + break + # Merge buffer list, if necessary. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readline(self, length=None): + if self._pos == self._length: + return '' + while True: + # Unfortunately, we need to merge the buffer list early. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + # Find newline. + i = self._buf.find('\n', self._pos) + if i < 0: + # Not found? + if not self.bytesAvailForAdd(): + # No more data coming. + newPos = self._avail + break + else: + # Wait for more to come. + self._waitForData() + continue + else: + newPos = i + 1 + break + if length is not None: + if self._pos + length < newPos: + newPos = self._pos + length + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readlines(self, sizehint=0): + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def __iter__(self): + return self + + def next(self): + r = self.readline() + if not r: + raise StopIteration + return r + + def setDataLength(self, length): + """ + Once Content-Length is known, Request calls this method to set it. + """ + self._length = length + + def addData(self, data): + """ + Adds data from the server to this InputStream. Note that we never ask + the server for data beyond the Content-Length, so the server should + never send us an EOF (empty string argument). + """ + if not data: + raise ProtocolError, 'short data' + self._bufList.append(data) + length = len(data) + self._avail += length + if self._avail > self._length: + raise ProtocolError, 'too much data' + +class Request(object): + """ + A Request object. A more fitting name would probably be Transaction, but + it's named Request to mirror my FastCGI driver. :) This object + encapsulates all the data about the HTTP request and allows the handler + to send a response. + + The only attributes/methods that the handler should concern itself + with are: environ, input, startResponse(), and write(). + """ + # Do not ever change the following value. + _maxWrite = 8192 - 4 - 3 # 8k - pkt header - send body header + + def __init__(self, conn): + self._conn = conn + + self.environ = { + 'SCRIPT_NAME': conn.server.scriptName + } + self.input = InputStream(conn) + + self._headersSent = False + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.info('%s %s', + self.environ['REQUEST_METHOD'], + self.environ['REQUEST_URI']) + + start = datetime.datetime.now() + + try: + self._conn.server.handler(self) + except: + self.logger.exception('Exception caught from handler') + if not self._headersSent: + self._conn.server.error(self) + + end = datetime.datetime.now() + + # Notify server of end of response (reuse flag is set to true). + pkt = Packet() + pkt.data = PKTTYPE_END_RESPONSE + '\x01' + self._conn.writePacket(pkt) + + handlerTime = end - start + self.logger.debug('%s %s done (%.3f secs)', + self.environ['REQUEST_METHOD'], + self.environ['REQUEST_URI'], + handlerTime.seconds + + handlerTime.microseconds / 1000000.0) + + # The following methods are called from the Connection to set up this + # Request. + + def setMethod(self, value): + self.environ['REQUEST_METHOD'] = value + + def setProtocol(self, value): + self.environ['SERVER_PROTOCOL'] = value + + def setRequestURI(self, value): + self.environ['REQUEST_URI'] = value + + scriptName = self._conn.server.scriptName + if not value.startswith(scriptName): + self.logger.warning('scriptName does not match request URI') + + self.environ['PATH_INFO'] = value[len(scriptName):] + + def setRemoteAddr(self, value): + self.environ['REMOTE_ADDR'] = value + + def setRemoteHost(self, value): + self.environ['REMOTE_HOST'] = value + + def setServerName(self, value): + self.environ['SERVER_NAME'] = value + + def setServerPort(self, value): + self.environ['SERVER_PORT'] = str(value) + + def setIsSSL(self, value): + if value: + self.environ['HTTPS'] = 'on' + + def addHeader(self, name, value): + name = name.replace('-', '_').upper() + if name in ('CONTENT_TYPE', 'CONTENT_LENGTH'): + self.environ[name] = value + if name == 'CONTENT_LENGTH': + length = int(value) + self.input.setDataLength(length) + else: + self.environ['HTTP_'+name] = value + + def addAttribute(self, name, value): + self.environ[name] = value + + # The only two methods that should be called from the handler. + + def startResponse(self, statusCode, statusMsg, headers): + """ + Begin the HTTP response. This must only be called once and it + must be called before any calls to write(). + + statusCode is the integer status code (e.g. 200). statusMsg + is the associated reason message (e.g.'OK'). headers is a list + of 2-tuples - header name/value pairs. (Both header name and value + must be strings.) + """ + assert not self._headersSent, 'Headers already sent!' + + pkt = Packet() + pkt.data = PKTTYPE_SEND_HEADERS + \ + struct.pack('>H', statusCode) + \ + encodeString(statusMsg) + \ + struct.pack('>H', len(headers)) + \ + ''.join([encodeResponseHeader(name, value) + for name,value in headers]) + + self._conn.writePacket(pkt) + + self._headersSent = True + + def write(self, data): + """ + Write data (which comprises the response body). Note that due to + restrictions on AJP packet size, we limit our writes to 8185 bytes + each packet. + """ + assert self._headersSent, 'Headers must be sent first!' + + bytesLeft = len(data) + while bytesLeft: + toWrite = min(bytesLeft, self._maxWrite) + + pkt = Packet() + pkt.data = PKTTYPE_SEND_BODY + \ + struct.pack('>H', toWrite) + \ + data[:toWrite] + self._conn.writePacket(pkt) + + data = data[toWrite:] + bytesLeft -= toWrite + +class Connection(object): + """ + A single Connection with the server. Requests are not multiplexed over the + same connection, so at any given time, the Connection is either + waiting for a request, or processing a single request. + """ + def __init__(self, sock, addr, server): + self.server = server + self._sock = sock + self._addr = addr + + self._request = None + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.debug('Connection starting up (%s:%d)', + self._addr[0], self._addr[1]) + + # Main loop. Errors will cause the loop to be exited and + # the socket to be closed. + while True: + try: + self.processInput() + except ProtocolError, e: + self.logger.error("Protocol error '%s'", str(e)) + break + except (EOFError, KeyboardInterrupt): + break + except: + self.logger.exception('Exception caught in Connection') + break + + self.logger.debug('Connection shutting down (%s:%d)', + self._addr[0], self._addr[1]) + + self._sock.close() + + def processInput(self): + """Wait for and process a single packet.""" + pkt = Packet() + select.select([self._sock], [], []) + pkt.read(self._sock) + + # Body chunks have no packet type code. + if self._request is not None: + self._processBody(pkt) + return + + if not pkt.length: + raise ProtocolError, 'unexpected empty packet' + + pkttype = pkt.data[0] + if pkttype == PKTTYPE_FWD_REQ: + self._forwardRequest(pkt) + elif pkttype == PKTTYPE_SHUTDOWN: + self._shutdown(pkt) + elif pkttype == PKTTYPE_PING: + self._ping(pkt) + elif pkttype == PKTTYPE_CPING: + self._cping(pkt) + else: + raise ProtocolError, 'unknown packet type' + + def _forwardRequest(self, pkt): + """ + Creates a Request object, fills it in from the packet, then runs it. + """ + assert self._request is None + + req = self.server.requestClass(self) + i = ord(pkt.data[1]) + method = methodTable[i] + if method is None: + raise ValueError, 'bad method field' + req.setMethod(method) + value, pos = decodeString(pkt.data, 2) + req.setProtocol(value) + value, pos = decodeString(pkt.data, pos) + req.setRequestURI(value) + value, pos = decodeString(pkt.data, pos) + req.setRemoteAddr(value) + value, pos = decodeString(pkt.data, pos) + req.setRemoteHost(value) + value, pos = decodeString(pkt.data, pos) + req.setServerName(value) + value = struct.unpack('>H', pkt.data[pos:pos+2])[0] + req.setServerPort(value) + i = ord(pkt.data[pos+2]) + req.setIsSSL(i != 0) + + # Request headers. + numHeaders = struct.unpack('>H', pkt.data[pos+3:pos+5])[0] + pos += 5 + for i in range(numHeaders): + name, value, pos = decodeRequestHeader(pkt.data, pos) + req.addHeader(name, value) + + # Attributes. + while True: + name, value, pos = decodeAttribute(pkt.data, pos) + if name is None: + break + req.addAttribute(name, value) + + self._request = req + + # Read first body chunk, if needed. + if req.input.bytesAvailForAdd(): + self.processInput() + + # Run Request. + req.run() + + self._request = None + + def _shutdown(self, pkt): + """Not sure what to do with this yet.""" + self.logger.info('Received shutdown request from server') + + def _ping(self, pkt): + """I have no idea what this packet means.""" + self.logger.debug('Received ping') + + def _cping(self, pkt): + """Respond to a PING (CPING) packet.""" + self.logger.debug('Received PING, sending PONG') + pkt = Packet() + pkt.data = PKTTYPE_CPONG + self.writePacket(pkt) + + def _processBody(self, pkt): + """ + Handles a body chunk from the server by appending it to the + InputStream. + """ + if pkt.length: + length = struct.unpack('>H', pkt.data[:2])[0] + self._request.input.addData(pkt.data[2:2+length]) + else: + # Shouldn't really ever get here. + self._request.input.addData('') + + def writePacket(self, pkt): + """Sends a Packet to the server.""" + pkt.write(self._sock) + +class WSGIServer(prefork.PreforkServer): + """ + AJP1.3/WSGI server. Runs your WSGI application as a persistant program + that understands AJP1.3. Opens up a TCP socket, binds it, and then + waits for forwarded requests from your webserver. + + Why AJP? Two good reasons are that AJP provides load-balancing and + fail-over support. Personally, I just wanted something new to + implement. :) + + Of course you will need an AJP1.3 connector for your webserver (e.g. + mod_jk) - see <http://jakarta.apache.org/tomcat/connectors-doc/>. + """ + # What Request class to use. + requestClass = Request + + # Limits the size of the InputStream's string buffer to this size + 8k. + # Since the InputStream is not seekable, we throw away already-read + # data once this certain amount has been read. (The 8k is there because + # it is the maximum size of new data added per chunk.) + inputStreamShrinkThreshold = 102400 - 8192 + + def __init__(self, application, scriptName='', environ=None, + bindAddress=('localhost', 8009), allowedServers=None, + loggingLevel=logging.INFO, **kw): + """ + scriptName is the initial portion of the URL path that "belongs" + to your application. It is used to determine PATH_INFO (which doesn't + seem to be passed in). An empty scriptName means your application + is mounted at the root of your virtual host. + + environ, which must be a dictionary, can contain any additional + environment variables you want to pass to your application. + + bindAddress is the address to bind to, which must be a tuple of + length 2. The first element is a string, which is the host name + or IPv4 address of a local interface. The 2nd element is the port + number. + + allowedServers must be None or a list of strings representing the + IPv4 addresses of servers allowed to connect. None means accept + connections from anywhere. + + loggingLevel sets the logging level of the module-level logger. + """ + if kw.has_key('jobClass'): + del kw['jobClass'] + if kw.has_key('jobArgs'): + del kw['jobArgs'] + super(WSGIServer, self).__init__(jobClass=Connection, + jobArgs=(self,), **kw) + + if environ is None: + environ = {} + + self.application = application + self.scriptName = scriptName + self.environ = environ + + self._bindAddress = bindAddress + self._allowedServers = allowedServers + + self.logger = logging.getLogger(LoggerName) + self.logger.setLevel(loggingLevel) + + def _setupSocket(self): + """Creates and binds the socket for communication with the server.""" + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(self._bindAddress) + sock.listen(socket.SOMAXCONN) + return sock + + def _cleanupSocket(self, sock): + """Closes the main socket.""" + sock.close() + + def _isClientAllowed(self, addr): + ret = self._allowedServers is None or addr[0] in self._allowedServers + if not ret: + self.logger.warning('Server connection from %s disallowed', + addr[0]) + return ret + + def run(self): + """ + Main loop. Call this after instantiating WSGIServer. SIGHUP, SIGINT, + SIGQUIT, SIGTERM cause it to cleanup and return. (If a SIGHUP + is caught, this method returns True. Returns False otherwise.) + """ + self.logger.info('%s starting up', self.__class__.__name__) + + try: + sock = self._setupSocket() + except socket.error, e: + self.logger.error('Failed to bind socket (%s), exiting', e[1]) + return False + + ret = super(WSGIServer, self).run(sock) + + self._cleanupSocket(sock) + + self.logger.info('%s shutting down', self.__class__.__name__) + + return ret + + def handler(self, request): + """ + WSGI handler. Sets up WSGI environment, calls the application, + and sends the application's response. + """ + environ = request.environ + environ.update(self.environ) + + environ['wsgi.version'] = (1,0) + environ['wsgi.input'] = request.input + environ['wsgi.errors'] = sys.stderr + environ['wsgi.multithread'] = False + environ['wsgi.multiprocess'] = True + environ['wsgi.run_once'] = False + + if environ.get('HTTPS', 'off') in ('on', '1'): + environ['wsgi.url_scheme'] = 'https' + else: + environ['wsgi.url_scheme'] = 'http' + + headers_set = [] + headers_sent = [] + result = None + + def write(data): + assert type(data) is str, 'write() argument must be string' + assert headers_set, 'write() before start_response()' + + if not headers_sent: + status, responseHeaders = headers_sent[:] = headers_set + statusCode = int(status[:3]) + statusMsg = status[4:] + found = False + for header,value in responseHeaders: + if header.lower() == 'content-length': + found = True + break + if not found and result is not None: + try: + if len(result) == 1: + responseHeaders.append(('Content-Length', + str(len(data)))) + except: + pass + request.startResponse(statusCode, statusMsg, responseHeaders) + + request.write(data) + + def start_response(status, response_headers, exc_info=None): + if exc_info: + try: + if headers_sent: + # Re-raise if too late + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None # avoid dangling circular ref + else: + assert not headers_set, 'Headers already set!' + + assert type(status) is str, 'Status must be a string' + assert len(status) >= 4, 'Status must be at least 4 characters' + assert int(status[:3]), 'Status must begin with 3-digit code' + assert status[3] == ' ', 'Status must have a space after code' + assert type(response_headers) is list, 'Headers must be a list' + if __debug__: + for name,val in response_headers: + assert type(name) is str, 'Header names must be strings' + assert type(val) is str, 'Header values must be strings' + + headers_set[:] = [status, response_headers] + return write + + result = self.application(environ, start_response) + + try: + for data in result: + if data: + write(data) + if not headers_sent: + write('') # in case body was empty + finally: + if hasattr(result, 'close'): + result.close() + + def error(self, request): + """ + Override to provide custom error handling. Ideally, however, + all errors should be caught at the application level. + """ + request.startResponse(200, 'OK', [('Content-Type', 'text/html')]) + import cgitb + request.write(cgitb.html(sys.exc_info())) + +if __name__ == '__main__': + def test_app(environ, start_response): + """Probably not the most efficient example.""" + import cgi + start_response('200 OK', [('Content-Type', 'text/html')]) + yield '<html><head><title>Hello World!</title></head>\n' \ + '<body>\n' \ + '<p>Hello World!</p>\n' \ + '<table border="1">' + names = environ.keys() + names.sort() + for name in names: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + name, cgi.escape(`environ[name]`)) + + form = cgi.FieldStorage(fp=environ['wsgi.input'], environ=environ, + keep_blank_values=1) + if form.list: + yield '<tr><th colspan="2">Form data</th></tr>' + + for field in form.list: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + field.name, field.value) + + yield '</table>\n' \ + '</body></html>\n' + + # Explicitly set bindAddress to *:8009 for testing. + WSGIServer(test_app, + bindAddress=('', 8009), + loggingLevel=logging.DEBUG).run() diff --git a/flup/server/fcgi.py b/flup/server/fcgi.py new file mode 100644 index 0000000..2a536be --- /dev/null +++ b/flup/server/fcgi.py @@ -0,0 +1,1306 @@ +# Copyright (c) 2002, 2003, 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +""" +fcgi - a FastCGI/WSGI gateway. + +For more information about FastCGI, see <http://www.fastcgi.com/>. + +For more information about the Web Server Gateway Interface, see +<http://www.python.org/peps/pep-0333.html>. + +Example usage: + + #!/usr/bin/env python + from myapplication import app # Assume app is your WSGI application object + from fcgi import WSGIServer + WSGIServer(app).run() + +See the documentation for WSGIServer/Server for more information. + +On most platforms, fcgi will fallback to regular CGI behavior if run in a +non-FastCGI context. If you want to force CGI behavior, set the environment +variable FCGI_FORCE_CGI to "Y" or "y". +""" + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import os +import signal +import struct +import cStringIO as StringIO +import select +import socket +import errno +import traceback + +try: + import thread + import threading + thread_available = True +except ImportError: + import dummy_thread as thread + import dummy_threading as threading + thread_available = False + +__all__ = ['WSGIServer'] + +# Constants from the spec. +FCGI_LISTENSOCK_FILENO = 0 + +FCGI_HEADER_LEN = 8 + +FCGI_VERSION_1 = 1 + +FCGI_BEGIN_REQUEST = 1 +FCGI_ABORT_REQUEST = 2 +FCGI_END_REQUEST = 3 +FCGI_PARAMS = 4 +FCGI_STDIN = 5 +FCGI_STDOUT = 6 +FCGI_STDERR = 7 +FCGI_DATA = 8 +FCGI_GET_VALUES = 9 +FCGI_GET_VALUES_RESULT = 10 +FCGI_UNKNOWN_TYPE = 11 +FCGI_MAXTYPE = FCGI_UNKNOWN_TYPE + +FCGI_NULL_REQUEST_ID = 0 + +FCGI_KEEP_CONN = 1 + +FCGI_RESPONDER = 1 +FCGI_AUTHORIZER = 2 +FCGI_FILTER = 3 + +FCGI_REQUEST_COMPLETE = 0 +FCGI_CANT_MPX_CONN = 1 +FCGI_OVERLOADED = 2 +FCGI_UNKNOWN_ROLE = 3 + +FCGI_MAX_CONNS = 'FCGI_MAX_CONNS' +FCGI_MAX_REQS = 'FCGI_MAX_REQS' +FCGI_MPXS_CONNS = 'FCGI_MPXS_CONNS' + +FCGI_Header = '!BBHHBx' +FCGI_BeginRequestBody = '!HB5x' +FCGI_EndRequestBody = '!LB3x' +FCGI_UnknownTypeBody = '!B7x' + +FCGI_EndRequestBody_LEN = struct.calcsize(FCGI_EndRequestBody) +FCGI_UnknownTypeBody_LEN = struct.calcsize(FCGI_UnknownTypeBody) + +if __debug__: + import time + + # Set non-zero to write debug output to a file. + DEBUG = 0 + DEBUGLOG = '/tmp/fcgi.log' + + def _debug(level, msg): + if DEBUG < level: + return + + try: + f = open(DEBUGLOG, 'a') + f.write('%sfcgi: %s\n' % (time.ctime()[4:-4], msg)) + f.close() + except: + pass + +class InputStream(object): + """ + File-like object representing FastCGI input streams (FCGI_STDIN and + FCGI_DATA). Supports the minimum methods required by WSGI spec. + """ + def __init__(self, conn): + self._conn = conn + + # See Server. + self._shrinkThreshold = conn.server.inputStreamShrinkThreshold + + self._buf = '' + self._bufList = [] + self._pos = 0 # Current read position. + self._avail = 0 # Number of bytes currently available. + + self._eof = False # True when server has sent EOF notification. + + def _shrinkBuffer(self): + """Gets rid of already read data (since we can't rewind).""" + if self._pos >= self._shrinkThreshold: + self._buf = self._buf[self._pos:] + self._avail -= self._pos + self._pos = 0 + + assert self._avail >= 0 + + def _waitForData(self): + """Waits for more data to become available.""" + self._conn.process_input() + + def read(self, n=-1): + if self._pos == self._avail and self._eof: + return '' + while True: + if n < 0 or (self._avail - self._pos) < n: + # Not enough data available. + if self._eof: + # And there's no more coming. + newPos = self._avail + break + else: + # Wait for more data. + self._waitForData() + continue + else: + newPos = self._pos + n + break + # Merge buffer list, if necessary. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readline(self, length=None): + if self._pos == self._avail and self._eof: + return '' + while True: + # Unfortunately, we need to merge the buffer list early. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + # Find newline. + i = self._buf.find('\n', self._pos) + if i < 0: + # Not found? + if self._eof: + # No more data coming. + newPos = self._avail + break + else: + # Wait for more to come. + self._waitForData() + continue + else: + newPos = i + 1 + break + if length is not None: + if self._pos + length < newPos: + newPos = self._pos + length + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readlines(self, sizehint=0): + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def __iter__(self): + return self + + def next(self): + r = self.readline() + if not r: + raise StopIteration + return r + + def add_data(self, data): + if not data: + self._eof = True + else: + self._bufList.append(data) + self._avail += len(data) + +class MultiplexedInputStream(InputStream): + """ + A version of InputStream meant to be used with MultiplexedConnections. + Assumes the MultiplexedConnection (the producer) and the Request + (the consumer) are running in different threads. + """ + def __init__(self, conn): + super(MultiplexedInputStream, self).__init__(conn) + + # Arbitrates access to this InputStream (it's used simultaneously + # by a Request and its owning Connection object). + lock = threading.RLock() + + # Notifies Request thread that there is new data available. + self._lock = threading.Condition(lock) + + def _waitForData(self): + # Wait for notification from add_data(). + self._lock.wait() + + def read(self, n=-1): + self._lock.acquire() + try: + return super(MultiplexedInputStream, self).read(n) + finally: + self._lock.release() + + def readline(self, length=None): + self._lock.acquire() + try: + return super(MultiplexedInputStream, self).readline(length) + finally: + self._lock.release() + + def add_data(self, data): + self._lock.acquire() + try: + super(MultiplexedInputStream, self).add_data(data) + self._lock.notify() + finally: + self._lock.release() + +class OutputStream(object): + """ + FastCGI output stream (FCGI_STDOUT/FCGI_STDERR). By default, calls to + write() or writelines() immediately result in Records being sent back + to the server. Buffering should be done in a higher level! + """ + def __init__(self, conn, req, type, buffered=False): + self._conn = conn + self._req = req + self._type = type + self._buffered = buffered + self._bufList = [] # Used if buffered is True + self.dataWritten = False + self.closed = False + + def _write(self, data): + length = len(data) + while length: + toWrite = min(length, self._req.server.maxwrite - FCGI_HEADER_LEN) + + rec = Record(self._type, self._req.requestId) + rec.contentLength = toWrite + rec.contentData = data[:toWrite] + self._conn.writeRecord(rec) + + data = data[toWrite:] + length -= toWrite + + def write(self, data): + assert not self.closed + + if not data: + return + + self.dataWritten = True + + if self._buffered: + self._bufList.append(data) + else: + self._write(data) + + def writelines(self, lines): + assert not self.closed + + for line in lines: + self.write(line) + + def flush(self): + # Only need to flush if this OutputStream is actually buffered. + if self._buffered: + data = ''.join(self._bufList) + self._bufList = [] + self._write(data) + + # Though available, the following should NOT be called by WSGI apps. + def close(self): + """Sends end-of-stream notification, if necessary.""" + if not self.closed and self.dataWritten: + self.flush() + rec = Record(self._type, self._req.requestId) + self._conn.writeRecord(rec) + self.closed = True + +class TeeOutputStream(object): + """ + Simple wrapper around two or more output file-like objects that copies + written data to all streams. + """ + def __init__(self, streamList): + self._streamList = streamList + + def write(self, data): + for f in self._streamList: + f.write(data) + + def writelines(self, lines): + for line in lines: + self.write(line) + + def flush(self): + for f in self._streamList: + f.flush() + +class StdoutWrapper(object): + """ + Wrapper for sys.stdout so we know if data has actually been written. + """ + def __init__(self, stdout): + self._file = stdout + self.dataWritten = False + + def write(self, data): + if data: + self.dataWritten = True + self._file.write(data) + + def writelines(self, lines): + for line in lines: + self.write(line) + + def __getattr__(self, name): + return getattr(self._file, name) + +def decode_pair(s, pos=0): + """ + Decodes a name/value pair. + + The number of bytes decoded as well as the name/value pair + are returned. + """ + nameLength = ord(s[pos]) + if nameLength & 128: + nameLength = struct.unpack('!L', s[pos:pos+4])[0] & 0x7fffffff + pos += 4 + else: + pos += 1 + + valueLength = ord(s[pos]) + if valueLength & 128: + valueLength = struct.unpack('!L', s[pos:pos+4])[0] & 0x7fffffff + pos += 4 + else: + pos += 1 + + name = s[pos:pos+nameLength] + pos += nameLength + value = s[pos:pos+valueLength] + pos += valueLength + + return (pos, (name, value)) + +def encode_pair(name, value): + """ + Encodes a name/value pair. + + The encoded string is returned. + """ + nameLength = len(name) + if nameLength < 128: + s = chr(nameLength) + else: + s = struct.pack('!L', nameLength | 0x80000000L) + + valueLength = len(value) + if valueLength < 128: + s += chr(valueLength) + else: + s += struct.pack('!L', valueLength | 0x80000000L) + + return s + name + value + +class Record(object): + """ + A FastCGI Record. + + Used for encoding/decoding records. + """ + def __init__(self, type=FCGI_UNKNOWN_TYPE, requestId=FCGI_NULL_REQUEST_ID): + self.version = FCGI_VERSION_1 + self.type = type + self.requestId = requestId + self.contentLength = 0 + self.paddingLength = 0 + self.contentData = '' + + def _recvall(sock, length): + """ + Attempts to receive length bytes from a socket, blocking if necessary. + (Socket may be blocking or non-blocking.) + """ + dataList = [] + recvLen = 0 + while length: + try: + data = sock.recv(length) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if not data: # EOF + break + dataList.append(data) + dataLen = len(data) + recvLen += dataLen + length -= dataLen + return ''.join(dataList), recvLen + _recvall = staticmethod(_recvall) + + def read(self, sock): + """Read and decode a Record from a socket.""" + try: + header, length = self._recvall(sock, FCGI_HEADER_LEN) + except: + raise EOFError + + if length < FCGI_HEADER_LEN: + raise EOFError + + self.version, self.type, self.requestId, self.contentLength, \ + self.paddingLength = struct.unpack(FCGI_Header, header) + + if __debug__: _debug(9, 'read: fd = %d, type = %d, requestId = %d, ' + 'contentLength = %d' % + (sock.fileno(), self.type, self.requestId, + self.contentLength)) + + if self.contentLength: + try: + self.contentData, length = self._recvall(sock, + self.contentLength) + except: + raise EOFError + + if length < self.contentLength: + raise EOFError + + if self.paddingLength: + try: + self._recvall(sock, self.paddingLength) + except: + raise EOFError + + def _sendall(sock, data): + """ + Writes data to a socket and does not return until all the data is sent. + """ + length = len(data) + while length: + try: + sent = sock.send(data) + except socket.error, e: + if e[0] == errno.EPIPE: + return # Don't bother raising an exception. Just ignore. + elif e[0] == errno.EAGAIN: + select.select([], [sock], []) + continue + else: + raise + data = data[sent:] + length -= sent + _sendall = staticmethod(_sendall) + + def write(self, sock): + """Encode and write a Record to a socket.""" + self.paddingLength = -self.contentLength & 7 + + if __debug__: _debug(9, 'write: fd = %d, type = %d, requestId = %d, ' + 'contentLength = %d' % + (sock.fileno(), self.type, self.requestId, + self.contentLength)) + + header = struct.pack(FCGI_Header, self.version, self.type, + self.requestId, self.contentLength, + self.paddingLength) + self._sendall(sock, header) + if self.contentLength: + self._sendall(sock, self.contentData) + if self.paddingLength: + self._sendall(sock, '\x00'*self.paddingLength) + +class Request(object): + """ + Represents a single FastCGI request. + + These objects are passed to your handler and is the main interface + between your handler and the fcgi module. The methods should not + be called by your handler. However, server, params, stdin, stdout, + stderr, and data are free for your handler's use. + """ + def __init__(self, conn, inputStreamClass): + self._conn = conn + + self.server = conn.server + self.params = {} + self.stdin = inputStreamClass(conn) + self.stdout = OutputStream(conn, self, FCGI_STDOUT) + self.stderr = OutputStream(conn, self, FCGI_STDERR, buffered=True) + self.data = inputStreamClass(conn) + + def run(self): + """Runs the handler, flushes the streams, and ends the request.""" + try: + protocolStatus, appStatus = self.server.handler(self) + except: + traceback.print_exc(file=self.stderr) + self.stderr.flush() + if not self.stdout.dataWritten: + self.server.error(self) + + protocolStatus, appStatus = FCGI_REQUEST_COMPLETE, 0 + + if __debug__: _debug(1, 'protocolStatus = %d, appStatus = %d' % + (protocolStatus, appStatus)) + + self._flush() + self._end(appStatus, protocolStatus) + + def _end(self, appStatus=0L, protocolStatus=FCGI_REQUEST_COMPLETE): + self._conn.end_request(self, appStatus, protocolStatus) + + def _flush(self): + self.stdout.close() + self.stderr.close() + +class CGIRequest(Request): + """A normal CGI request disguised as a FastCGI request.""" + def __init__(self, server): + # These are normally filled in by Connection. + self.requestId = 1 + self.role = FCGI_RESPONDER + self.flags = 0 + self.aborted = False + + self.server = server + self.params = dict(os.environ) + self.stdin = sys.stdin + self.stdout = StdoutWrapper(sys.stdout) # Oh, the humanity! + self.stderr = sys.stderr + self.data = StringIO.StringIO() + + def _end(self, appStatus=0L, protocolStatus=FCGI_REQUEST_COMPLETE): + sys.exit(appStatus) + + def _flush(self): + # Not buffered, do nothing. + pass + +class Connection(object): + """ + A Connection with the web server. + + Each Connection is associated with a single socket (which is + connected to the web server) and is responsible for handling all + the FastCGI message processing for that socket. + """ + _multiplexed = False + _inputStreamClass = InputStream + + def __init__(self, sock, addr, server): + self._sock = sock + self._addr = addr + self.server = server + + # Active Requests for this Connection, mapped by request ID. + self._requests = {} + + def _cleanupSocket(self): + """Close the Connection's socket.""" + self._sock.close() + + def run(self): + """Begin processing data from the socket.""" + self._keepGoing = True + while self._keepGoing: + try: + self.process_input() + except EOFError: + break + except (select.error, socket.error), e: + if e[0] == errno.EBADF: # Socket was closed by Request. + break + raise + + self._cleanupSocket() + + def process_input(self): + """Attempt to read a single Record from the socket and process it.""" + # Currently, any children Request threads notify this Connection + # that it is no longer needed by closing the Connection's socket. + # We need to put a timeout on select, otherwise we might get + # stuck in it indefinitely... (I don't like this solution.) + while self._keepGoing: + try: + r, w, e = select.select([self._sock], [], [], 1.0) + except ValueError: + # Sigh. ValueError gets thrown sometimes when passing select + # a closed socket. + raise EOFError + if r: break + if not self._keepGoing: + return + rec = Record() + rec.read(self._sock) + + if rec.type == FCGI_GET_VALUES: + self._do_get_values(rec) + elif rec.type == FCGI_BEGIN_REQUEST: + self._do_begin_request(rec) + elif rec.type == FCGI_ABORT_REQUEST: + self._do_abort_request(rec) + elif rec.type == FCGI_PARAMS: + self._do_params(rec) + elif rec.type == FCGI_STDIN: + self._do_stdin(rec) + elif rec.type == FCGI_DATA: + self._do_data(rec) + elif rec.requestId == FCGI_NULL_REQUEST_ID: + self._do_unknown_type(rec) + else: + # Need to complain about this. + pass + + def writeRecord(self, rec): + """ + Write a Record to the socket. + """ + rec.write(self._sock) + + def end_request(self, req, appStatus=0L, + protocolStatus=FCGI_REQUEST_COMPLETE, remove=True): + """ + End a Request. + + Called by Request objects. An FCGI_END_REQUEST Record is + sent to the web server. If the web server no longer requires + the connection, the socket is closed, thereby ending this + Connection (run() returns). + """ + rec = Record(FCGI_END_REQUEST, req.requestId) + rec.contentData = struct.pack(FCGI_EndRequestBody, appStatus, + protocolStatus) + rec.contentLength = FCGI_EndRequestBody_LEN + self.writeRecord(rec) + + if remove: + del self._requests[req.requestId] + + if __debug__: _debug(2, 'end_request: flags = %d' % req.flags) + + if not (req.flags & FCGI_KEEP_CONN) and not self._requests: + self._sock.close() + self._keepGoing = False + + def _do_get_values(self, inrec): + """Handle an FCGI_GET_VALUES request from the web server.""" + outrec = Record(FCGI_GET_VALUES_RESULT) + + pos = 0 + while pos < inrec.contentLength: + pos, (name, value) = decode_pair(inrec.contentData, pos) + cap = self.server.capability.get(name) + if cap is not None: + outrec.contentData += encode_pair(name, str(cap)) + + outrec.contentLength = len(outrec.contentData) + self.writeRecord(rec) + + def _do_begin_request(self, inrec): + """Handle an FCGI_BEGIN_REQUEST from the web server.""" + role, flags = struct.unpack(FCGI_BeginRequestBody, inrec.contentData) + + req = self.server.request_class(self, self._inputStreamClass) + req.requestId, req.role, req.flags = inrec.requestId, role, flags + req.aborted = False + + if not self._multiplexed and self._requests: + # Can't multiplex requests. + self.end_request(req, 0L, FCGI_CANT_MPX_CONN, remove=False) + else: + self._requests[inrec.requestId] = req + + def _do_abort_request(self, inrec): + """ + Handle an FCGI_ABORT_REQUEST from the web server. + + We just mark a flag in the associated Request. + """ + req = self._requests.get(inrec.requestId) + if req is not None: + req.aborted = True + + def _start_request(self, req): + """Run the request.""" + # Not multiplexed, so run it inline. + req.run() + + def _do_params(self, inrec): + """ + Handle an FCGI_PARAMS Record. + + If the last FCGI_PARAMS Record is received, start the request. + """ + req = self._requests.get(inrec.requestId) + if req is not None: + if inrec.contentLength: + pos = 0 + while pos < inrec.contentLength: + pos, (name, value) = decode_pair(inrec.contentData, pos) + req.params[name] = value + else: + self._start_request(req) + + def _do_stdin(self, inrec): + """Handle the FCGI_STDIN stream.""" + req = self._requests.get(inrec.requestId) + if req is not None: + req.stdin.add_data(inrec.contentData) + + def _do_data(self, inrec): + """Handle the FCGI_DATA stream.""" + req = self._requests.get(inrec.requestId) + if req is not None: + req.data.add_data(inrec.contentData) + + def _do_unknown_type(self, inrec): + """Handle an unknown request type. Respond accordingly.""" + outrec = Record(FCGI_UNKNOWN_TYPE) + outrec.contentData = struct.pack(FCGI_UnknownTypeBody, inrec.type) + outrec.contentLength = FCGI_UnknownTypeBody_LEN + self.writeRecord(rec) + +class MultiplexedConnection(Connection): + """ + A version of Connection capable of handling multiple requests + simultaneously. + """ + _multiplexed = True + _inputStreamClass = MultiplexedInputStream + + def __init__(self, sock, addr, server): + super(MultiplexedConnection, self).__init__(sock, addr, server) + + # Used to arbitrate access to self._requests. + lock = threading.RLock() + + # Notification is posted everytime a request completes, allowing us + # to quit cleanly. + self._lock = threading.Condition(lock) + + def _cleanupSocket(self): + # Wait for any outstanding requests before closing the socket. + self._lock.acquire() + while self._requests: + self._lock.wait() + self._lock.release() + + super(MultiplexedConnection, self)._cleanupSocket() + + def writeRecord(self, rec): + # Must use locking to prevent intermingling of Records from different + # threads. + self._lock.acquire() + try: + # Probably faster than calling super. ;) + rec.write(self._sock) + finally: + self._lock.release() + + def end_request(self, req, appStatus=0L, + protocolStatus=FCGI_REQUEST_COMPLETE, remove=True): + self._lock.acquire() + try: + super(MultiplexedConnection, self).end_request(req, appStatus, + protocolStatus, + remove) + self._lock.notify() + finally: + self._lock.release() + + def _do_begin_request(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_begin_request(inrec) + finally: + self._lock.release() + + def _do_abort_request(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_abort_request(inrec) + finally: + self._lock.release() + + def _start_request(self, req): + thread.start_new_thread(req.run, ()) + + def _do_params(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_params(inrec) + finally: + self._lock.release() + + def _do_stdin(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_stdin(inrec) + finally: + self._lock.release() + + def _do_data(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_data(inrec) + finally: + self._lock.release() + +class Server(object): + """ + The FastCGI server. + + Waits for connections from the web server, processing each + request. + + If run in a normal CGI context, it will instead instantiate a + CGIRequest and run the handler through there. + """ + request_class = Request + cgirequest_class = CGIRequest + + # Limits the size of the InputStream's string buffer to this size + the + # server's maximum Record size. Since the InputStream is not seekable, + # we throw away already-read data once this certain amount has been read. + inputStreamShrinkThreshold = 102400 - 8192 + + def __init__(self, handler=None, maxwrite=8192, bindAddress=None, + multiplexed=False): + """ + handler, if present, must reference a function or method that + takes one argument: a Request object. If handler is not + specified at creation time, Server *must* be subclassed. + (The handler method below is abstract.) + + maxwrite is the maximum number of bytes (per Record) to write + to the server. I've noticed mod_fastcgi has a relatively small + receive buffer (8K or so). + + bindAddress, if present, must either be a string or a 2-tuple. If + present, run() will open its own listening socket. You would use + this if you wanted to run your application as an 'external' FastCGI + app. (i.e. the webserver would no longer be responsible for starting + your app) If a string, it will be interpreted as a filename and a UNIX + socket will be opened. If a tuple, the first element, a string, + is the interface name/IP to bind to, and the second element (an int) + is the port number. + + Set multiplexed to True if you want to handle multiple requests + per connection. Some FastCGI backends (namely mod_fastcgi) don't + multiplex requests at all, so by default this is off (which saves + on thread creation/locking overhead). If threads aren't available, + this keyword is ignored; it's not possible to multiplex requests + at all. + """ + if handler is not None: + self.handler = handler + self.maxwrite = maxwrite + if thread_available: + try: + import resource + # Attempt to glean the maximum number of connections + # from the OS. + maxConns = resource.getrlimit(resource.RLIMIT_NOFILE)[0] + except ImportError: + maxConns = 100 # Just some made up number. + maxReqs = maxConns + if multiplexed: + self._connectionClass = MultiplexedConnection + maxReqs *= 5 # Another made up number. + else: + self._connectionClass = Connection + self.capability = { + FCGI_MAX_CONNS: maxConns, + FCGI_MAX_REQS: maxReqs, + FCGI_MPXS_CONNS: multiplexed and 1 or 0 + } + else: + self._connectionClass = Connection + self.capability = { + # If threads aren't available, these are pretty much correct. + FCGI_MAX_CONNS: 1, + FCGI_MAX_REQS: 1, + FCGI_MPXS_CONNS: 0 + } + self._bindAddress = bindAddress + + def _setupSocket(self): + if self._bindAddress is None: # Run as a normal FastCGI? + isFCGI = True + + sock = socket.fromfd(FCGI_LISTENSOCK_FILENO, socket.AF_INET, + socket.SOCK_STREAM) + try: + sock.getpeername() + except socket.error, e: + if e[0] == errno.ENOTSOCK: + # Not a socket, assume CGI context. + isFCGI = False + elif e[0] != errno.ENOTCONN: + raise + + # FastCGI/CGI discrimination is broken on Mac OS X. + # Set the environment variable FCGI_FORCE_CGI to "Y" or "y" + # if you want to run your app as a simple CGI. (You can do + # this with Apache's mod_env [not loaded by default in OS X + # client, ha ha] and the SetEnv directive.) + if not isFCGI or \ + os.environ.get('FCGI_FORCE_CGI', 'N').upper().startswith('Y'): + req = self.cgirequest_class(self) + req.run() + sys.exit(0) + else: + # Run as a server + if type(self._bindAddress) is str: + # Unix socket + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + os.unlink(self._bindAddress) + except OSError: + pass + else: + # INET socket + assert type(self._bindAddress) is tuple + assert len(self._bindAddress) == 2 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + sock.bind(self._bindAddress) + sock.listen(socket.SOMAXCONN) + + return sock + + def _cleanupSocket(self, sock): + """Closes the main socket.""" + sock.close() + + def _installSignalHandlers(self): + self._oldSIGs = [(x,signal.getsignal(x)) for x in + (signal.SIGHUP, signal.SIGINT, signal.SIGTERM)] + signal.signal(signal.SIGHUP, self._hupHandler) + signal.signal(signal.SIGINT, self._intHandler) + signal.signal(signal.SIGTERM, self._intHandler) + + def _restoreSignalHandlers(self): + for signum,handler in self._oldSIGs: + signal.signal(signum, handler) + + def _hupHandler(self, signum, frame): + self._hupReceived = True + self._keepGoing = False + + def _intHandler(self, signum, frame): + self._keepGoing = False + + def run(self, timeout=1.0): + """ + The main loop. Exits on SIGHUP, SIGINT, SIGTERM. Returns True if + SIGHUP was received, False otherwise. + """ + web_server_addrs = os.environ.get('FCGI_WEB_SERVER_ADDRS') + if web_server_addrs is not None: + web_server_addrs = map(lambda x: x.strip(), + web_server_addrs.split(',')) + + sock = self._setupSocket() + + self._keepGoing = True + self._hupReceived = False + + # Install signal handlers. + self._installSignalHandlers() + + while self._keepGoing: + try: + r, w, e = select.select([sock], [], [], timeout) + except select.error, e: + if e[0] == errno.EINTR: + continue + raise + + if r: + try: + clientSock, addr = sock.accept() + except socket.error, e: + if e[0] in (errno.EINTR, errno.EAGAIN): + continue + raise + + if web_server_addrs and \ + (len(addr) != 2 or addr[0] not in web_server_addrs): + clientSock.close() + continue + + # Instantiate a new Connection and begin processing FastCGI + # messages (either in a new thread or this thread). + conn = self._connectionClass(clientSock, addr, self) + thread.start_new_thread(conn.run, ()) + + self._mainloopPeriodic() + + # Restore signal handlers. + self._restoreSignalHandlers() + + self._cleanupSocket(sock) + + return self._hupReceived + + def _mainloopPeriodic(self): + """ + Called with just about each iteration of the main loop. Meant to + be overridden. + """ + pass + + def _exit(self, reload=False): + """ + Protected convenience method for subclasses to force an exit. Not + really thread-safe, which is why it isn't public. + """ + if self._keepGoing: + self._keepGoing = False + self._hupReceived = reload + + def handler(self, req): + """ + Default handler, which just raises an exception. Unless a handler + is passed at initialization time, this must be implemented by + a subclass. + """ + raise NotImplementedError, self.__class__.__name__ + '.handler' + + def error(self, req): + """ + Called by Request if an exception occurs within the handler. May and + should be overridden. + """ + import cgitb + req.stdout.write('Content-Type: text/html\r\n\r\n' + + cgitb.html(sys.exc_info())) + +class WSGIServer(Server): + """ + FastCGI server that supports the Web Server Gateway Interface. See + <http://www.python.org/peps/pep-0333.html>. + """ + def __init__(self, application, environ=None, multithreaded=True, **kw): + """ + environ, if present, must be a dictionary-like object. Its + contents will be copied into application's environ. Useful + for passing application-specific variables. + + Set multithreaded to False if your application is not MT-safe. + """ + if kw.has_key('handler'): + del kw['handler'] # Doesn't make sense to let this through + super(WSGIServer, self).__init__(**kw) + + if environ is None: + environ = {} + + self.application = application + self.environ = environ + self.multithreaded = multithreaded + + # Used to force single-threadedness + self._app_lock = thread.allocate_lock() + + def handler(self, req): + """Special handler for WSGI.""" + if req.role != FCGI_RESPONDER: + return FCGI_UNKNOWN_ROLE, 0 + + # Mostly taken from example CGI gateway. + environ = req.params + environ.update(self.environ) + + environ['wsgi.version'] = (1,0) + environ['wsgi.input'] = req.stdin + if self._bindAddress is None: + stderr = req.stderr + else: + stderr = TeeOutputStream((sys.stderr, req.stderr)) + environ['wsgi.errors'] = stderr + environ['wsgi.multithread'] = not isinstance(req, CGIRequest) and \ + thread_available and self.multithreaded + # Rationale for the following: If started by the web server + # (self._bindAddress is None) in either FastCGI or CGI mode, the + # possibility of being spawned multiple times simultaneously is quite + # real. And, if started as an external server, multiple copies may be + # spawned for load-balancing/redundancy. (Though I don't think + # mod_fastcgi supports this?) + environ['wsgi.multiprocess'] = True + environ['wsgi.run_once'] = isinstance(req, CGIRequest) + + if environ.get('HTTPS', 'off') in ('on', '1'): + environ['wsgi.url_scheme'] = 'https' + else: + environ['wsgi.url_scheme'] = 'http' + + self._sanitizeEnv(environ) + + headers_set = [] + headers_sent = [] + result = None + + def write(data): + assert type(data) is str, 'write() argument must be string' + assert headers_set, 'write() before start_response()' + + if not headers_sent: + status, responseHeaders = headers_sent[:] = headers_set + found = False + for header,value in responseHeaders: + if header.lower() == 'content-length': + found = True + break + if not found and result is not None: + try: + if len(result) == 1: + responseHeaders.append(('Content-Length', + str(len(data)))) + except: + pass + s = 'Status: %s\r\n' % status + for header in responseHeaders: + s += '%s: %s\r\n' % header + s += '\r\n' + req.stdout.write(s) + + req.stdout.write(data) + req.stdout.flush() + + def start_response(status, response_headers, exc_info=None): + if exc_info: + try: + if headers_sent: + # Re-raise if too late + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None # avoid dangling circular ref + else: + assert not headers_set, 'Headers already set!' + + assert type(status) is str, 'Status must be a string' + assert len(status) >= 4, 'Status must be at least 4 characters' + assert int(status[:3]), 'Status must begin with 3-digit code' + assert status[3] == ' ', 'Status must have a space after code' + assert type(response_headers) is list, 'Headers must be a list' + if __debug__: + for name,val in response_headers: + assert type(name) is str, 'Header names must be strings' + assert type(val) is str, 'Header values must be strings' + + headers_set[:] = [status, response_headers] + return write + + if not self.multithreaded: + self._app_lock.acquire() + try: + result = self.application(environ, start_response) + try: + for data in result: + if data: + write(data) + if not headers_sent: + write('') # in case body was empty + finally: + if hasattr(result, 'close'): + result.close() + finally: + if not self.multithreaded: + self._app_lock.release() + + return FCGI_REQUEST_COMPLETE, 0 + + def _sanitizeEnv(self, environ): + """Ensure certain values are present, if required by WSGI.""" + if not environ.has_key('SCRIPT_NAME'): + environ['SCRIPT_NAME'] = '' + if not environ.has_key('PATH_INFO'): + environ['PATH_INFO'] = '' + + # If any of these are missing, it probably signifies a broken + # server... + for name,default in [('REQUEST_METHOD', 'GET'), + ('SERVER_NAME', 'localhost'), + ('SERVER_PORT', '80'), + ('SERVER_PROTOCOL', 'HTTP/1.0')]: + if not environ.has_key(name): + environ['wsgi.errors'].write('%s: missing FastCGI param %s ' + 'required by WSGI!\n' % + (self.__class__.__name__, name)) + environ[name] = default + +if __name__ == '__main__': + def test_app(environ, start_response): + """Probably not the most efficient example.""" + import cgi + start_response('200 OK', [('Content-Type', 'text/html')]) + yield '<html><head><title>Hello World!</title></head>\n' \ + '<body>\n' \ + '<p>Hello World!</p>\n' \ + '<table border="1">' + names = environ.keys() + names.sort() + for name in names: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + name, cgi.escape(`environ[name]`)) + + form = cgi.FieldStorage(fp=environ['wsgi.input'], environ=environ, + keep_blank_values=1) + if form.list: + yield '<tr><th colspan="2">Form data</th></tr>' + + for field in form.list: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + field.name, field.value) + + yield '</table>\n' \ + '</body></html>\n' + + WSGIServer(test_app).run() diff --git a/flup/server/fcgi_fork.py b/flup/server/fcgi_fork.py new file mode 100644 index 0000000..2d3f1b2 --- /dev/null +++ b/flup/server/fcgi_fork.py @@ -0,0 +1,1169 @@ +# Copyright (c) 2002, 2003, 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +""" +fcgi - a FastCGI/WSGI gateway. + +For more information about FastCGI, see <http://www.fastcgi.com/>. + +For more information about the Web Server Gateway Interface, see +<http://www.python.org/peps/pep-0333.html>. + +Example usage: + + #!/usr/bin/env python + from myapplication import app # Assume app is your WSGI application object + from fcgi import WSGIServer + WSGIServer(app).run() + +See the documentation for WSGIServer for more information. + +On most platforms, fcgi will fallback to regular CGI behavior if run in a +non-FastCGI context. If you want to force CGI behavior, set the environment +variable FCGI_FORCE_CGI to "Y" or "y". +""" + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import os +import signal +import struct +import cStringIO as StringIO +import select +import socket +import errno +import traceback +import prefork + +__all__ = ['WSGIServer'] + +# Constants from the spec. +FCGI_LISTENSOCK_FILENO = 0 + +FCGI_HEADER_LEN = 8 + +FCGI_VERSION_1 = 1 + +FCGI_BEGIN_REQUEST = 1 +FCGI_ABORT_REQUEST = 2 +FCGI_END_REQUEST = 3 +FCGI_PARAMS = 4 +FCGI_STDIN = 5 +FCGI_STDOUT = 6 +FCGI_STDERR = 7 +FCGI_DATA = 8 +FCGI_GET_VALUES = 9 +FCGI_GET_VALUES_RESULT = 10 +FCGI_UNKNOWN_TYPE = 11 +FCGI_MAXTYPE = FCGI_UNKNOWN_TYPE + +FCGI_NULL_REQUEST_ID = 0 + +FCGI_KEEP_CONN = 1 + +FCGI_RESPONDER = 1 +FCGI_AUTHORIZER = 2 +FCGI_FILTER = 3 + +FCGI_REQUEST_COMPLETE = 0 +FCGI_CANT_MPX_CONN = 1 +FCGI_OVERLOADED = 2 +FCGI_UNKNOWN_ROLE = 3 + +FCGI_MAX_CONNS = 'FCGI_MAX_CONNS' +FCGI_MAX_REQS = 'FCGI_MAX_REQS' +FCGI_MPXS_CONNS = 'FCGI_MPXS_CONNS' + +FCGI_Header = '!BBHHBx' +FCGI_BeginRequestBody = '!HB5x' +FCGI_EndRequestBody = '!LB3x' +FCGI_UnknownTypeBody = '!B7x' + +FCGI_EndRequestBody_LEN = struct.calcsize(FCGI_EndRequestBody) +FCGI_UnknownTypeBody_LEN = struct.calcsize(FCGI_UnknownTypeBody) + +if __debug__: + import time + + # Set non-zero to write debug output to a file. + DEBUG = 0 + DEBUGLOG = '/tmp/fcgi.log' + + def _debug(level, msg): + if DEBUG < level: + return + + try: + f = open(DEBUGLOG, 'a') + f.write('%sfcgi: %s\n' % (time.ctime()[4:-4], msg)) + f.close() + except: + pass + +class InputStream(object): + """ + File-like object representing FastCGI input streams (FCGI_STDIN and + FCGI_DATA). Supports the minimum methods required by WSGI spec. + """ + def __init__(self, conn): + self._conn = conn + + # See Server. + self._shrinkThreshold = conn.server.inputStreamShrinkThreshold + + self._buf = '' + self._bufList = [] + self._pos = 0 # Current read position. + self._avail = 0 # Number of bytes currently available. + + self._eof = False # True when server has sent EOF notification. + + def _shrinkBuffer(self): + """Gets rid of already read data (since we can't rewind).""" + if self._pos >= self._shrinkThreshold: + self._buf = self._buf[self._pos:] + self._avail -= self._pos + self._pos = 0 + + assert self._avail >= 0 + + def _waitForData(self): + """Waits for more data to become available.""" + self._conn.process_input() + + def read(self, n=-1): + if self._pos == self._avail and self._eof: + return '' + while True: + if n < 0 or (self._avail - self._pos) < n: + # Not enough data available. + if self._eof: + # And there's no more coming. + newPos = self._avail + break + else: + # Wait for more data. + self._waitForData() + continue + else: + newPos = self._pos + n + break + # Merge buffer list, if necessary. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readline(self, length=None): + if self._pos == self._avail and self._eof: + return '' + while True: + # Unfortunately, we need to merge the buffer list early. + if self._bufList: + self._buf += ''.join(self._bufList) + self._bufList = [] + # Find newline. + i = self._buf.find('\n', self._pos) + if i < 0: + # Not found? + if self._eof: + # No more data coming. + newPos = self._avail + break + else: + # Wait for more to come. + self._waitForData() + continue + else: + newPos = i + 1 + break + if length is not None: + if self._pos + length < newPos: + newPos = self._pos + length + r = self._buf[self._pos:newPos] + self._pos = newPos + self._shrinkBuffer() + return r + + def readlines(self, sizehint=0): + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def __iter__(self): + return self + + def next(self): + r = self.readline() + if not r: + raise StopIteration + return r + + def add_data(self, data): + if not data: + self._eof = True + else: + self._bufList.append(data) + self._avail += len(data) + +class MultiplexedInputStream(InputStream): + """ + A version of InputStream meant to be used with MultiplexedConnections. + Assumes the MultiplexedConnection (the producer) and the Request + (the consumer) are running in different threads. + """ + def __init__(self, conn): + super(MultiplexedInputStream, self).__init__(conn) + + # Arbitrates access to this InputStream (it's used simultaneously + # by a Request and its owning Connection object). + lock = threading.RLock() + + # Notifies Request thread that there is new data available. + self._lock = threading.Condition(lock) + + def _waitForData(self): + # Wait for notification from add_data(). + self._lock.wait() + + def read(self, n=-1): + self._lock.acquire() + try: + return super(MultiplexedInputStream, self).read(n) + finally: + self._lock.release() + + def readline(self, length=None): + self._lock.acquire() + try: + return super(MultiplexedInputStream, self).readline(length) + finally: + self._lock.release() + + def add_data(self, data): + self._lock.acquire() + try: + super(MultiplexedInputStream, self).add_data(data) + self._lock.notify() + finally: + self._lock.release() + +class OutputStream(object): + """ + FastCGI output stream (FCGI_STDOUT/FCGI_STDERR). By default, calls to + write() or writelines() immediately result in Records being sent back + to the server. Buffering should be done in a higher level! + """ + def __init__(self, conn, req, type, buffered=False): + self._conn = conn + self._req = req + self._type = type + self._buffered = buffered + self._bufList = [] # Used if buffered is True + self.dataWritten = False + self.closed = False + + def _write(self, data): + length = len(data) + while length: + toWrite = min(length, self._req.server.maxwrite - FCGI_HEADER_LEN) + + rec = Record(self._type, self._req.requestId) + rec.contentLength = toWrite + rec.contentData = data[:toWrite] + self._conn.writeRecord(rec) + + data = data[toWrite:] + length -= toWrite + + def write(self, data): + assert not self.closed + + if not data: + return + + self.dataWritten = True + + if self._buffered: + self._bufList.append(data) + else: + self._write(data) + + def writelines(self, lines): + assert not self.closed + + for line in lines: + self.write(line) + + def flush(self): + # Only need to flush if this OutputStream is actually buffered. + if self._buffered: + data = ''.join(self._bufList) + self._bufList = [] + self._write(data) + + # Though available, the following should NOT be called by WSGI apps. + def close(self): + """Sends end-of-stream notification, if necessary.""" + if not self.closed and self.dataWritten: + self.flush() + rec = Record(self._type, self._req.requestId) + self._conn.writeRecord(rec) + self.closed = True + +class TeeOutputStream(object): + """ + Simple wrapper around two or more output file-like objects that copies + written data to all streams. + """ + def __init__(self, streamList): + self._streamList = streamList + + def write(self, data): + for f in self._streamList: + f.write(data) + + def writelines(self, lines): + for line in lines: + self.write(line) + + def flush(self): + for f in self._streamList: + f.flush() + +class StdoutWrapper(object): + """ + Wrapper for sys.stdout so we know if data has actually been written. + """ + def __init__(self, stdout): + self._file = stdout + self.dataWritten = False + + def write(self, data): + if data: + self.dataWritten = True + self._file.write(data) + + def writelines(self, lines): + for line in lines: + self.write(line) + + def __getattr__(self, name): + return getattr(self._file, name) + +def decode_pair(s, pos=0): + """ + Decodes a name/value pair. + + The number of bytes decoded as well as the name/value pair + are returned. + """ + nameLength = ord(s[pos]) + if nameLength & 128: + nameLength = struct.unpack('!L', s[pos:pos+4])[0] & 0x7fffffff + pos += 4 + else: + pos += 1 + + valueLength = ord(s[pos]) + if valueLength & 128: + valueLength = struct.unpack('!L', s[pos:pos+4])[0] & 0x7fffffff + pos += 4 + else: + pos += 1 + + name = s[pos:pos+nameLength] + pos += nameLength + value = s[pos:pos+valueLength] + pos += valueLength + + return (pos, (name, value)) + +def encode_pair(name, value): + """ + Encodes a name/value pair. + + The encoded string is returned. + """ + nameLength = len(name) + if nameLength < 128: + s = chr(nameLength) + else: + s = struct.pack('!L', nameLength | 0x80000000L) + + valueLength = len(value) + if valueLength < 128: + s += chr(valueLength) + else: + s += struct.pack('!L', valueLength | 0x80000000L) + + return s + name + value + +class Record(object): + """ + A FastCGI Record. + + Used for encoding/decoding records. + """ + def __init__(self, type=FCGI_UNKNOWN_TYPE, requestId=FCGI_NULL_REQUEST_ID): + self.version = FCGI_VERSION_1 + self.type = type + self.requestId = requestId + self.contentLength = 0 + self.paddingLength = 0 + self.contentData = '' + + def _recvall(sock, length): + """ + Attempts to receive length bytes from a socket, blocking if necessary. + (Socket may be blocking or non-blocking.) + """ + dataList = [] + recvLen = 0 + while length: + try: + data = sock.recv(length) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if not data: # EOF + break + dataList.append(data) + dataLen = len(data) + recvLen += dataLen + length -= dataLen + return ''.join(dataList), recvLen + _recvall = staticmethod(_recvall) + + def read(self, sock): + """Read and decode a Record from a socket.""" + try: + header, length = self._recvall(sock, FCGI_HEADER_LEN) + except: + raise EOFError + + if length < FCGI_HEADER_LEN: + raise EOFError + + self.version, self.type, self.requestId, self.contentLength, \ + self.paddingLength = struct.unpack(FCGI_Header, header) + + if __debug__: _debug(9, 'read: fd = %d, type = %d, requestId = %d, ' + 'contentLength = %d' % + (sock.fileno(), self.type, self.requestId, + self.contentLength)) + + if self.contentLength: + try: + self.contentData, length = self._recvall(sock, + self.contentLength) + except: + raise EOFError + + if length < self.contentLength: + raise EOFError + + if self.paddingLength: + try: + self._recvall(sock, self.paddingLength) + except: + raise EOFError + + def _sendall(sock, data): + """ + Writes data to a socket and does not return until all the data is sent. + """ + length = len(data) + while length: + try: + sent = sock.send(data) + except socket.error, e: + if e[0] == errno.EPIPE: + return # Don't bother raising an exception. Just ignore. + elif e[0] == errno.EAGAIN: + select.select([], [sock], []) + continue + else: + raise + data = data[sent:] + length -= sent + _sendall = staticmethod(_sendall) + + def write(self, sock): + """Encode and write a Record to a socket.""" + self.paddingLength = -self.contentLength & 7 + + if __debug__: _debug(9, 'write: fd = %d, type = %d, requestId = %d, ' + 'contentLength = %d' % + (sock.fileno(), self.type, self.requestId, + self.contentLength)) + + header = struct.pack(FCGI_Header, self.version, self.type, + self.requestId, self.contentLength, + self.paddingLength) + self._sendall(sock, header) + if self.contentLength: + self._sendall(sock, self.contentData) + if self.paddingLength: + self._sendall(sock, '\x00'*self.paddingLength) + +class Request(object): + """ + Represents a single FastCGI request. + + These objects are passed to your handler and is the main interface + between your handler and the fcgi module. The methods should not + be called by your handler. However, server, params, stdin, stdout, + stderr, and data are free for your handler's use. + """ + def __init__(self, conn, inputStreamClass): + self._conn = conn + + self.server = conn.server + self.params = {} + self.stdin = inputStreamClass(conn) + self.stdout = OutputStream(conn, self, FCGI_STDOUT) + self.stderr = OutputStream(conn, self, FCGI_STDERR, buffered=True) + self.data = inputStreamClass(conn) + + def run(self): + """Runs the handler, flushes the streams, and ends the request.""" + try: + protocolStatus, appStatus = self.server.handler(self) + except: + traceback.print_exc(file=self.stderr) + self.stderr.flush() + if not self.stdout.dataWritten: + self.server.error(self) + + protocolStatus, appStatus = FCGI_REQUEST_COMPLETE, 0 + + if __debug__: _debug(1, 'protocolStatus = %d, appStatus = %d' % + (protocolStatus, appStatus)) + + self._flush() + self._end(appStatus, protocolStatus) + + def _end(self, appStatus=0L, protocolStatus=FCGI_REQUEST_COMPLETE): + self._conn.end_request(self, appStatus, protocolStatus) + + def _flush(self): + self.stdout.close() + self.stderr.close() + +class CGIRequest(Request): + """A normal CGI request disguised as a FastCGI request.""" + def __init__(self, server): + # These are normally filled in by Connection. + self.requestId = 1 + self.role = FCGI_RESPONDER + self.flags = 0 + self.aborted = False + + self.server = server + self.params = dict(os.environ) + self.stdin = sys.stdin + self.stdout = StdoutWrapper(sys.stdout) # Oh, the humanity! + self.stderr = sys.stderr + self.data = StringIO.StringIO() + + def _end(self, appStatus=0L, protocolStatus=FCGI_REQUEST_COMPLETE): + sys.exit(appStatus) + + def _flush(self): + # Not buffered, do nothing. + pass + +class Connection(object): + """ + A Connection with the web server. + + Each Connection is associated with a single socket (which is + connected to the web server) and is responsible for handling all + the FastCGI message processing for that socket. + """ + _multiplexed = False + _inputStreamClass = InputStream + + def __init__(self, sock, addr, server): + self._sock = sock + self._addr = addr + self.server = server + + # Active Requests for this Connection, mapped by request ID. + self._requests = {} + + def _cleanupSocket(self): + """Close the Connection's socket.""" + self._sock.close() + + def run(self): + """Begin processing data from the socket.""" + self._keepGoing = True + while self._keepGoing: + try: + self.process_input() + except (EOFError, KeyboardInterrupt): + break + except (select.error, socket.error), e: + if e[0] == errno.EBADF: # Socket was closed by Request. + break + raise + + self._cleanupSocket() + + def process_input(self): + """Attempt to read a single Record from the socket and process it.""" + # Currently, any children Request threads notify this Connection + # that it is no longer needed by closing the Connection's socket. + # We need to put a timeout on select, otherwise we might get + # stuck in it indefinitely... (I don't like this solution.) + while self._keepGoing: + try: + r, w, e = select.select([self._sock], [], [], 1.0) + except ValueError: + # Sigh. ValueError gets thrown sometimes when passing select + # a closed socket. + raise EOFError + if r: break + if not self._keepGoing: + return + rec = Record() + rec.read(self._sock) + + if rec.type == FCGI_GET_VALUES: + self._do_get_values(rec) + elif rec.type == FCGI_BEGIN_REQUEST: + self._do_begin_request(rec) + elif rec.type == FCGI_ABORT_REQUEST: + self._do_abort_request(rec) + elif rec.type == FCGI_PARAMS: + self._do_params(rec) + elif rec.type == FCGI_STDIN: + self._do_stdin(rec) + elif rec.type == FCGI_DATA: + self._do_data(rec) + elif rec.requestId == FCGI_NULL_REQUEST_ID: + self._do_unknown_type(rec) + else: + # Need to complain about this. + pass + + def writeRecord(self, rec): + """ + Write a Record to the socket. + """ + rec.write(self._sock) + + def end_request(self, req, appStatus=0L, + protocolStatus=FCGI_REQUEST_COMPLETE, remove=True): + """ + End a Request. + + Called by Request objects. An FCGI_END_REQUEST Record is + sent to the web server. If the web server no longer requires + the connection, the socket is closed, thereby ending this + Connection (run() returns). + """ + rec = Record(FCGI_END_REQUEST, req.requestId) + rec.contentData = struct.pack(FCGI_EndRequestBody, appStatus, + protocolStatus) + rec.contentLength = FCGI_EndRequestBody_LEN + self.writeRecord(rec) + + if remove: + del self._requests[req.requestId] + + if __debug__: _debug(2, 'end_request: flags = %d' % req.flags) + + if not (req.flags & FCGI_KEEP_CONN) and not self._requests: + self._sock.close() + self._keepGoing = False + + def _do_get_values(self, inrec): + """Handle an FCGI_GET_VALUES request from the web server.""" + outrec = Record(FCGI_GET_VALUES_RESULT) + + pos = 0 + while pos < inrec.contentLength: + pos, (name, value) = decode_pair(inrec.contentData, pos) + cap = self.server.capability.get(name) + if cap is not None: + outrec.contentData += encode_pair(name, str(cap)) + + outrec.contentLength = len(outrec.contentData) + self.writeRecord(rec) + + def _do_begin_request(self, inrec): + """Handle an FCGI_BEGIN_REQUEST from the web server.""" + role, flags = struct.unpack(FCGI_BeginRequestBody, inrec.contentData) + + req = self.server.request_class(self, self._inputStreamClass) + req.requestId, req.role, req.flags = inrec.requestId, role, flags + req.aborted = False + + if not self._multiplexed and self._requests: + # Can't multiplex requests. + self.end_request(req, 0L, FCGI_CANT_MPX_CONN, remove=False) + else: + self._requests[inrec.requestId] = req + + def _do_abort_request(self, inrec): + """ + Handle an FCGI_ABORT_REQUEST from the web server. + + We just mark a flag in the associated Request. + """ + req = self._requests.get(inrec.requestId) + if req is not None: + req.aborted = True + + def _start_request(self, req): + """Run the request.""" + # Not multiplexed, so run it inline. + req.run() + + def _do_params(self, inrec): + """ + Handle an FCGI_PARAMS Record. + + If the last FCGI_PARAMS Record is received, start the request. + """ + req = self._requests.get(inrec.requestId) + if req is not None: + if inrec.contentLength: + pos = 0 + while pos < inrec.contentLength: + pos, (name, value) = decode_pair(inrec.contentData, pos) + req.params[name] = value + else: + self._start_request(req) + + def _do_stdin(self, inrec): + """Handle the FCGI_STDIN stream.""" + req = self._requests.get(inrec.requestId) + if req is not None: + req.stdin.add_data(inrec.contentData) + + def _do_data(self, inrec): + """Handle the FCGI_DATA stream.""" + req = self._requests.get(inrec.requestId) + if req is not None: + req.data.add_data(inrec.contentData) + + def _do_unknown_type(self, inrec): + """Handle an unknown request type. Respond accordingly.""" + outrec = Record(FCGI_UNKNOWN_TYPE) + outrec.contentData = struct.pack(FCGI_UnknownTypeBody, inrec.type) + outrec.contentLength = FCGI_UnknownTypeBody_LEN + self.writeRecord(rec) + +class MultiplexedConnection(Connection): + """ + A version of Connection capable of handling multiple requests + simultaneously. + """ + _multiplexed = True + _inputStreamClass = MultiplexedInputStream + + def __init__(self, sock, addr, server): + super(MultiplexedConnection, self).__init__(sock, addr, server) + + # Used to arbitrate access to self._requests. + lock = threading.RLock() + + # Notification is posted everytime a request completes, allowing us + # to quit cleanly. + self._lock = threading.Condition(lock) + + def _cleanupSocket(self): + # Wait for any outstanding requests before closing the socket. + self._lock.acquire() + while self._requests: + self._lock.wait() + self._lock.release() + + super(MultiplexedConnection, self)._cleanupSocket() + + def writeRecord(self, rec): + # Must use locking to prevent intermingling of Records from different + # threads. + self._lock.acquire() + try: + # Probably faster than calling super. ;) + rec.write(self._sock) + finally: + self._lock.release() + + def end_request(self, req, appStatus=0L, + protocolStatus=FCGI_REQUEST_COMPLETE, remove=True): + self._lock.acquire() + try: + super(MultiplexedConnection, self).end_request(req, appStatus, + protocolStatus, + remove) + self._lock.notify() + finally: + self._lock.release() + + def _do_begin_request(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_begin_request(inrec) + finally: + self._lock.release() + + def _do_abort_request(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_abort_request(inrec) + finally: + self._lock.release() + + def _start_request(self, req): + thread.start_new_thread(req.run, ()) + + def _do_params(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_params(inrec) + finally: + self._lock.release() + + def _do_stdin(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_stdin(inrec) + finally: + self._lock.release() + + def _do_data(self, inrec): + self._lock.acquire() + try: + super(MultiplexedConnection, self)._do_data(inrec) + finally: + self._lock.release() + +class WSGIServer(prefork.PreforkServer): + """ + FastCGI server that supports the Web Server Gateway Interface. See + <http://www.python.org/peps/pep-0333.html>. + """ + request_class = Request + cgirequest_class = CGIRequest + + # Limits the size of the InputStream's string buffer to this size + the + # server's maximum Record size. Since the InputStream is not seekable, + # we throw away already-read data once this certain amount has been read. + inputStreamShrinkThreshold = 102400 - 8192 + + def __init__(self, application, environ=None, + maxwrite=8192, bindAddress=None, **kw): + """ + environ, if present, must be a dictionary-like object. Its + contents will be copied into application's environ. Useful + for passing application-specific variables. + + maxwrite is the maximum number of bytes (per Record) to write + to the server. I've noticed mod_fastcgi has a relatively small + receive buffer (8K or so). + + bindAddress, if present, must either be a string or a 2-tuple. If + present, run() will open its own listening socket. You would use + this if you wanted to run your application as an 'external' FastCGI + app. (i.e. the webserver would no longer be responsible for starting + your app) If a string, it will be interpreted as a filename and a UNIX + socket will be opened. If a tuple, the first element, a string, + is the interface name/IP to bind to, and the second element (an int) + is the port number. + """ + if kw.has_key('jobClass'): + del kw['jobClass'] + if kw.has_key('jobArgs'): + del kw['jobArgs'] + super(WSGIServer, self).__init__(jobClass=Connection, + jobArgs=(self,), **kw) + + if environ is None: + environ = {} + + self.application = application + self.environ = environ + + self.maxwrite = maxwrite + try: + import resource + # Attempt to glean the maximum number of connections + # from the OS. + maxConns = resource.getrlimit(resource.RLIMIT_NPROC)[0] + except ImportError: + maxConns = 100 # Just some made up number. + maxReqs = maxConns + self.capability = { + FCGI_MAX_CONNS: maxConns, + FCGI_MAX_REQS: maxReqs, + FCGI_MPXS_CONNS: 0 + } + self._bindAddress = bindAddress + + def _setupSocket(self): + if self._bindAddress is None: # Run as a normal FastCGI? + isFCGI = True + + sock = socket.fromfd(FCGI_LISTENSOCK_FILENO, socket.AF_INET, + socket.SOCK_STREAM) + try: + sock.getpeername() + except socket.error, e: + if e[0] == errno.ENOTSOCK: + # Not a socket, assume CGI context. + isFCGI = False + elif e[0] != errno.ENOTCONN: + raise + + # FastCGI/CGI discrimination is broken on Mac OS X. + # Set the environment variable FCGI_FORCE_CGI to "Y" or "y" + # if you want to run your app as a simple CGI. (You can do + # this with Apache's mod_env [not loaded by default in OS X + # client, ha ha] and the SetEnv directive.) + if not isFCGI or \ + os.environ.get('FCGI_FORCE_CGI', 'N').upper().startswith('Y'): + req = self.cgirequest_class(self) + req.run() + sys.exit(0) + else: + # Run as a server + if type(self._bindAddress) is str: + # Unix socket + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + os.unlink(self._bindAddress) + except OSError: + pass + else: + # INET socket + assert type(self._bindAddress) is tuple + assert len(self._bindAddress) == 2 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + sock.bind(self._bindAddress) + sock.listen(socket.SOMAXCONN) + + return sock + + def _cleanupSocket(self, sock): + """Closes the main socket.""" + sock.close() + + def _isClientAllowed(self, addr): + return self._web_server_addrs is None or \ + (len(addr) == 2 and addr[0] in self._web_server_addrs) + + def run(self): + """ + The main loop. Exits on SIGHUP, SIGINT, SIGTERM. Returns True if + SIGHUP was received, False otherwise. + """ + self._web_server_addrs = os.environ.get('FCGI_WEB_SERVER_ADDRS') + if self._web_server_addrs is not None: + self._web_server_addrs = map(lambda x: x.strip(), + self._web_server_addrs.split(',')) + + sock = self._setupSocket() + + ret = super(WSGIServer, self).run(sock) + + self._cleanupSocket(sock) + + return ret + + def handler(self, req): + """Special handler for WSGI.""" + if req.role != FCGI_RESPONDER: + return FCGI_UNKNOWN_ROLE, 0 + + # Mostly taken from example CGI gateway. + environ = req.params + environ.update(self.environ) + + environ['wsgi.version'] = (1,0) + environ['wsgi.input'] = req.stdin + if self._bindAddress is None: + stderr = req.stderr + else: + stderr = TeeOutputStream((sys.stderr, req.stderr)) + environ['wsgi.errors'] = stderr + environ['wsgi.multithread'] = False + environ['wsgi.multiprocess'] = True + environ['wsgi.run_once'] = isinstance(req, CGIRequest) + + if environ.get('HTTPS', 'off') in ('on', '1'): + environ['wsgi.url_scheme'] = 'https' + else: + environ['wsgi.url_scheme'] = 'http' + + self._sanitizeEnv(environ) + + headers_set = [] + headers_sent = [] + result = None + + def write(data): + assert type(data) is str, 'write() argument must be string' + assert headers_set, 'write() before start_response()' + + if not headers_sent: + status, responseHeaders = headers_sent[:] = headers_set + found = False + for header,value in responseHeaders: + if header.lower() == 'content-length': + found = True + break + if not found and result is not None: + try: + if len(result) == 1: + responseHeaders.append(('Content-Length', + str(len(data)))) + except: + pass + s = 'Status: %s\r\n' % status + for header in responseHeaders: + s += '%s: %s\r\n' % header + s += '\r\n' + req.stdout.write(s) + + req.stdout.write(data) + req.stdout.flush() + + def start_response(status, response_headers, exc_info=None): + if exc_info: + try: + if headers_sent: + # Re-raise if too late + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None # avoid dangling circular ref + else: + assert not headers_set, 'Headers already set!' + + assert type(status) is str, 'Status must be a string' + assert len(status) >= 4, 'Status must be at least 4 characters' + assert int(status[:3]), 'Status must begin with 3-digit code' + assert status[3] == ' ', 'Status must have a space after code' + assert type(response_headers) is list, 'Headers must be a list' + if __debug__: + for name,val in response_headers: + assert type(name) is str, 'Header names must be strings' + assert type(val) is str, 'Header values must be strings' + + headers_set[:] = [status, response_headers] + return write + + result = self.application(environ, start_response) + try: + for data in result: + if data: + write(data) + if not headers_sent: + write('') # in case body was empty + finally: + if hasattr(result, 'close'): + result.close() + + return FCGI_REQUEST_COMPLETE, 0 + + def _sanitizeEnv(self, environ): + """Ensure certain values are present, if required by WSGI.""" + if not environ.has_key('SCRIPT_NAME'): + environ['SCRIPT_NAME'] = '' + if not environ.has_key('PATH_INFO'): + environ['PATH_INFO'] = '' + + # If any of these are missing, it probably signifies a broken + # server... + for name,default in [('REQUEST_METHOD', 'GET'), + ('SERVER_NAME', 'localhost'), + ('SERVER_PORT', '80'), + ('SERVER_PROTOCOL', 'HTTP/1.0')]: + if not environ.has_key(name): + environ['wsgi.errors'].write('%s: missing FastCGI param %s ' + 'required by WSGI!\n' % + (self.__class__.__name__, name)) + environ[name] = default + + def error(self, req): + """ + Called by Request if an exception occurs within the handler. May and + should be overridden. + """ + import cgitb + req.stdout.write('Content-Type: text/html\r\n\r\n' + + cgitb.html(sys.exc_info())) + +if __name__ == '__main__': + def test_app(environ, start_response): + """Probably not the most efficient example.""" + import cgi + start_response('200 OK', [('Content-Type', 'text/html')]) + yield '<html><head><title>Hello World!</title></head>\n' \ + '<body>\n' \ + '<p>Hello World!</p>\n' \ + '<table border="1">' + names = environ.keys() + names.sort() + for name in names: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + name, cgi.escape(`environ[name]`)) + + form = cgi.FieldStorage(fp=environ['wsgi.input'], environ=environ, + keep_blank_values=1) + if form.list: + yield '<tr><th colspan="2">Form data</th></tr>' + + for field in form.list: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + field.name, field.value) + + yield '</table>\n' \ + '</body></html>\n' + + WSGIServer(test_app).run() diff --git a/flup/server/prefork.py b/flup/server/prefork.py new file mode 100644 index 0000000..191a651 --- /dev/null +++ b/flup/server/prefork.py @@ -0,0 +1,364 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import os +import socket +import select +import errno +import signal + +class PreforkServer(object): + """ + A preforked server model conceptually similar to Apache httpd(2). At + any given time, ensures there are at least minSpare children ready to + process new requests (up to a maximum of maxChildren children total). + If the number of idle children is ever above maxSpare, the extra + children are killed. + + jobClass should be a class whose constructor takes at least two + arguments: the client socket and client address. jobArgs, which + must be a list or tuple, is any additional (static) arguments you + wish to pass to the constructor. + + jobClass should have a run() method (taking no arguments) that does + the actual work. When run() returns, the request is considered + complete and the child process moves to idle state. + """ + def __init__(self, minSpare=1, maxSpare=5, maxChildren=50, + jobClass=None, jobArgs=()): + self._minSpare = minSpare + self._maxSpare = maxSpare + self._maxChildren = max(maxSpare, maxChildren) + self._jobClass = jobClass + self._jobArgs = jobArgs + + # Internal state of children. Maps pids to dictionaries with two + # members: 'file' and 'avail'. 'file' is the socket to that + # individidual child and 'avail' is whether or not the child is + # free to process requests. + self._children = {} + + def run(self, sock): + """ + The main loop. Pass a socket that is ready to accept() client + connections. Return value will be True or False indiciating whether + or not the loop was exited due to SIGHUP. + """ + # Set up signal handlers. + self._keepGoing = True + self._hupReceived = False + self._installSignalHandlers() + + # Don't want operations on main socket to block. + sock.setblocking(0) + + # Main loop. + while self._keepGoing: + # Maintain minimum number of children. + while len(self._children) < self._maxSpare: + if not self._spawnChild(sock): break + + # Wait on any socket activity from live children. + r = [x['file'] for x in self._children.values() + if x['file'] is not None] + + if len(r) == len(self._children): + timeout = None + else: + # There are dead children that need to be reaped, ensure + # that they are by timing out, if necessary. + timeout = 2 + + try: + r, w, e = select.select(r, [], [], timeout) + except select.error, e: + if e[0] != errno.EINTR: + raise + + # Scan child sockets and tend to those that need attention. + for child in r: + # Receive status byte. + try: + state = child.recv(1) + except socket.error, e: + if e[0] in (errno.EAGAIN, errno.EINTR): + # Guess it really didn't need attention? + continue + raise + # Try to match it with a child. (Do we need a reverse map?) + for pid,d in self._children.items(): + if child is d['file']: + if state: + # Set availability status accordingly. + self._children[pid]['avail'] = state != '\x00' + else: + # Didn't receive anything. Child is most likely + # dead. + d = self._children[pid] + d['file'].close() + d['file'] = None + d['avail'] = False + + # Reap children. + self._reapChildren() + + # See who and how many children are available. + availList = filter(lambda x: x[1]['avail'], self._children.items()) + avail = len(availList) + + if avail < self._minSpare: + # Need to spawn more children. + while avail < self._minSpare and \ + len(self._children) < self._maxChildren: + if not self._spawnChild(sock): break + avail += 1 + elif avail > self._maxSpare: + # Too many spares, kill off the extras. + pids = [x[0] for x in availList] + pids.sort() + pids = pids[self._maxSpare:] + for pid in pids: + d = self._children[pid] + d['file'].close() + d['file'] = None + d['avail'] = False + + # Clean up all child processes. + self._cleanupChildren() + + # Restore signal handlers. + self._restoreSignalHandlers() + + # Return bool based on whether or not SIGHUP was received. + return self._hupReceived + + def _cleanupChildren(self): + """ + Closes all child sockets (letting those that are available know + that it's time to exit). Sends SIGINT to those that are currently + processing (and hopes that it finishses ASAP). + + Any children remaining after 10 seconds is SIGKILLed. + """ + # Let all children know it's time to go. + for pid,d in self._children.items(): + if d['file'] is not None: + d['file'].close() + d['file'] = None + if not d['avail']: + # Child is unavailable. SIGINT it. + try: + os.kill(pid, signal.SIGINT) + except OSError, e: + if e[0] != errno.ESRCH: + raise + + def alrmHandler(signum, frame): + pass + + # Set up alarm to wake us up after 10 seconds. + oldSIGALRM = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, alrmHandler) + signal.alarm(10) + + # Wait for all children to die. + while len(self._children): + try: + pid, status = os.wait() + except OSError, e: + if e[0] in (errno.ECHILD, errno.EINTR): + break + if self._children.has_key(pid): + del self._children[pid] + + signal.signal(signal.SIGALRM, oldSIGALRM) + + # Forcefully kill any remaining children. + for pid in self._children.keys(): + try: + os.kill(pid, signal.SIGKILL) + except OSError, e: + if e[0] != errno.ESRCH: + raise + + def _reapChildren(self): + """Cleans up self._children whenever children die.""" + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except OSError, e: + if e[0] == errno.ECHILD: + break + raise + if pid <= 0: + break + if self._children.has_key(pid): # Sanity check. + if self._children[pid]['file'] is not None: + self._children[pid]['file'].close() + del self._children[pid] + + def _spawnChild(self, sock): + """ + Spawn a single child. Returns True if successful, False otherwise. + """ + # This socket pair is used for very simple communication between + # the parent and its children. + parent, child = socket.socketpair() + parent.setblocking(0) + child.setblocking(0) + try: + pid = os.fork() + except OSError, e: + if e[0] in (errno.EAGAIN, errno.ENOMEM): + return False # Can't fork anymore. + raise + if not pid: + # Child + child.close() + # Put child into its own process group. + pid = os.getpid() + os.setpgid(pid, pid) + # Restore signal handlers. + self._restoreSignalHandlers() + # Close copies of child sockets. + for f in [x['file'] for x in self._children.values() + if x['file'] is not None]: + f.close() + self._children = {} + try: + # Enter main loop. + self._child(sock, parent) + except KeyboardInterrupt: + pass + sys.exit(0) + else: + # Parent + parent.close() + d = self._children[pid] = {} + d['file'] = child + d['avail'] = True + return True + + def _isClientAllowed(self, addr): + """Override to provide access control.""" + return True + + def _child(self, sock, parent): + """Main loop for children.""" + while True: + # Wait for any activity on the main socket or parent socket. + r, w, e = select.select([sock, parent], [], []) + + for f in r: + # If there's any activity on the parent socket, it + # means the parent wants us to die or has died itself. + # Either way, exit. + if f is parent: + return + + # Otherwise, there's activity on the main socket... + try: + clientSock, addr = sock.accept() + except socket.error, e: + if e[0] == errno.EAGAIN: + # Or maybe not. + continue + raise + + # Check if this client is allowed. + if not self._isClientAllowed(addr): + clientSock.close() + continue + + # Notify parent we're no longer available. + try: + parent.send('\x00') + except socket.error, e: + # If parent is gone, finish up this request. + if e[0] != errno.EPIPE: + raise + + # Do the job. + self._jobClass(clientSock, addr, *self._jobArgs).run() + + # Tell parent we're free again. + try: + parent.send('\xff') + except socket.error, e: + if e[0] == errno.EPIPE: + # Parent is gone. + return + raise + + # Signal handlers + + def _hupHandler(self, signum, frame): + self._keepGoing = False + self._hupReceived = True + + def _intHandler(self, signum, frame): + self._keepGoing = False + + def _chldHandler(self, signum, frame): + # Do nothing (breaks us out of select and allows us to reap children). + pass + + def _installSignalHandlers(self): + """Installs signal handlers.""" + self._oldSIGs = [(x,signal.getsignal(x)) for x in + (signal.SIGHUP, signal.SIGINT, signal.SIGQUIT, + signal.SIGTERM, signal.SIGCHLD)] + signal.signal(signal.SIGHUP, self._hupHandler) + signal.signal(signal.SIGINT, self._intHandler) + signal.signal(signal.SIGQUIT, self._intHandler) + signal.signal(signal.SIGTERM, self._intHandler) + + def _restoreSignalHandlers(self): + """Restores previous signal handlers.""" + for signum,handler in self._oldSIGs: + signal.signal(signum, handler) + +if __name__ == '__main__': + class TestJob(object): + def __init__(self, sock, addr): + self._sock = sock + self._addr = addr + def run(self): + print "Client connection opened from %s:%d" % self._addr + self._sock.send('Hello World!\n') + self._sock.setblocking(1) + self._sock.recv(1) + self._sock.close() + print "Client connection closed from %s:%d" % self._addr + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(('', 8080)) + sock.listen(socket.SOMAXCONN) + PreforkServer(maxChildren=10, jobClass=TestJob).run(sock) diff --git a/flup/server/scgi.py b/flup/server/scgi.py new file mode 100644 index 0000000..c4bb20a --- /dev/null +++ b/flup/server/scgi.py @@ -0,0 +1,699 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +""" +scgi - an SCGI/WSGI gateway. (I might have to rename this module.) + +For more information about SCGI and mod_scgi for Apache1/Apache2, see +<http://www.mems-exchange.org/software/scgi/>. + +For more information about the Web Server Gateway Interface, see +<http://www.python.org/peps/pep-0333.html>. + +Example usage: + + #!/usr/bin/env python + import sys + from myapplication import app # Assume app is your WSGI application object + from scgi import WSGIServer + ret = WSGIServer(app).run() + sys.exit(ret and 42 or 0) + +See the documentation for WSGIServer for more information. + +About the bit of logic at the end: +Upon receiving SIGHUP, the python script will exit with status code 42. This +can be used by a wrapper script to determine if the python script should be +re-run. When a SIGINT or SIGTERM is received, the script exits with status +code 0, possibly indicating a normal exit. + +Example wrapper script: + + #!/bin/sh + STATUS=42 + while test $STATUS -eq 42; do + python "$@" that_script_above.py + STATUS=$? + done +""" + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import logging +import socket +import select +import errno +import cStringIO as StringIO +import signal +import datetime + +# Threads are required. If you want a non-threaded (forking) version, look at +# SWAP <http://www.idyll.org/~t/www-tools/wsgi/>. +import thread +import threading + +__all__ = ['WSGIServer'] + +# The main classes use this name for logging. +LoggerName = 'scgi-wsgi' + +# Set up module-level logger. +console = logging.StreamHandler() +console.setLevel(logging.DEBUG) +console.setFormatter(logging.Formatter('%(asctime)s : %(message)s', + '%Y-%m-%d %H:%M:%S')) +logging.getLogger(LoggerName).addHandler(console) +del console + +class ProtocolError(Exception): + """ + Exception raised when the server does something unexpected or + sends garbled data. Usually leads to a Connection closing. + """ + pass + +def recvall(sock, length): + """ + Attempts to receive length bytes from a socket, blocking if necessary. + (Socket may be blocking or non-blocking.) + """ + dataList = [] + recvLen = 0 + while length: + try: + data = sock.recv(length) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if not data: # EOF + break + dataList.append(data) + dataLen = len(data) + recvLen += dataLen + length -= dataLen + return ''.join(dataList), recvLen + +def readNetstring(sock): + """ + Attempt to read a netstring from a socket. + """ + # First attempt to read the length. + size = '' + while True: + try: + c = sock.recv(1) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if c == ':': + break + if not c: + raise EOFError + size += c + + # Try to decode the length. + try: + size = int(size) + if size < 0: + raise ValueError + except ValueError: + raise ProtocolError, 'invalid netstring length' + + # Now read the string. + s, length = recvall(sock, size) + + if length < size: + raise EOFError + + # Lastly, the trailer. + trailer, length = recvall(sock, 1) + + if length < 1: + raise EOFError + + if trailer != ',': + raise ProtocolError, 'invalid netstring trailer' + + return s + +class StdoutWrapper(object): + """ + Wrapper for sys.stdout so we know if data has actually been written. + """ + def __init__(self, stdout): + self._file = stdout + self.dataWritten = False + + def write(self, data): + if data: + self.dataWritten = True + self._file.write(data) + + def writelines(self, lines): + for line in lines: + self.write(line) + + def __getattr__(self, name): + return getattr(self._file, name) + +class Request(object): + """ + Encapsulates data related to a single request. + + Public attributes: + environ - Environment variables from web server. + stdin - File-like object representing the request body. + stdout - File-like object for writing the response. + """ + def __init__(self, conn, environ, input, output): + self._conn = conn + self.environ = environ + self.stdin = input + self.stdout = StdoutWrapper(output) + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.info('%s %s%s', + self.environ['REQUEST_METHOD'], + self.environ.get('SCRIPT_NAME', ''), + self.environ.get('PATH_INFO', '')) + + start = datetime.datetime.now() + + try: + self._conn.server.handler(self) + except: + self.logger.exception('Exception caught from handler') + if not self.stdout.dataWritten: + self._conn.server.error(self) + + end = datetime.datetime.now() + + handlerTime = end - start + self.logger.debug('%s %s%s done (%.3f secs)', + self.environ['REQUEST_METHOD'], + self.environ.get('SCRIPT_NAME', ''), + self.environ.get('PATH_INFO', ''), + handlerTime.seconds + + handlerTime.microseconds / 1000000.0) + +class Connection(object): + """ + Represents a single client (web server) connection. A single request + is handled, after which the socket is closed. + """ + def __init__(self, sock, addr, server): + self._sock = sock + self._addr = addr + self.server = server + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.debug('Connection starting up (%s:%d)', + self._addr[0], self._addr[1]) + + try: + self.processInput() + except EOFError: + pass + except ProtocolError, e: + self.logger.error("Protocol error '%s'", str(e)) + except: + self.logger.exception('Exception caught in Connection') + + self.logger.debug('Connection shutting down (%s:%d)', + self._addr[0], self._addr[1]) + + # All done! + self._sock.close() + + def processInput(self): + # Read headers + headers = readNetstring(self._sock) + headers = headers.split('\x00')[:-1] + if len(headers) % 2 != 0: + raise ProtocolError, 'invalid headers' + environ = {} + for i in range(len(headers) / 2): + environ[headers[2*i]] = headers[2*i+1] + + clen = environ.get('CONTENT_LENGTH') + if clen is None: + raise ProtocolError, 'missing CONTENT_LENGTH' + try: + clen = int(clen) + if clen < 0: + raise ValueError + except ValueError: + raise ProtocolError, 'invalid CONTENT_LENGTH' + + self._sock.setblocking(1) + if clen: + input = self._sock.makefile('r') + else: + # Empty input. + input = StringIO.StringIO() + + # stdout + output = self._sock.makefile('w') + + # Allocate Request + req = Request(self, environ, input, output) + + # Run it. + req.run() + + output.close() + input.close() + +class ThreadPool(object): + """ + Thread pool that maintains the number of idle threads between + minSpare and maxSpare inclusive. By default, there is no limit on + the number of threads that can be started, but this can be controlled + by maxThreads. + """ + def __init__(self, minSpare=1, maxSpare=5, maxThreads=sys.maxint): + self._minSpare = minSpare + self._maxSpare = maxSpare + self._maxThreads = max(minSpare, maxThreads) + + self._lock = threading.Condition() + self._workQueue = [] + self._idleCount = self._workerCount = maxSpare + + # Start the minimum number of worker threads. + for i in range(maxSpare): + thread.start_new_thread(self._worker, ()) + + def addJob(self, job, allowQueuing=True): + """ + Adds a job to the work queue. The job object should have a run() + method. If allowQueuing is True (the default), the job will be + added to the work queue regardless if there are any idle threads + ready. (The only way for there to be no idle threads is if maxThreads + is some reasonable, finite limit.) + + Otherwise, if allowQueuing is False, and there are no more idle + threads, the job will not be queued. + + Returns True if the job was queued, False otherwise. + """ + self._lock.acquire() + try: + # Maintain minimum number of spares. + while self._idleCount < self._minSpare and \ + self._workerCount < self._maxThreads: + self._workerCount += 1 + self._idleCount += 1 + thread.start_new_thread(self._worker, ()) + + # Hand off the job. + if self._idleCount or allowQueuing: + self._workQueue.append(job) + self._lock.notify() + return True + else: + return False + finally: + self._lock.release() + + def _worker(self): + """ + Worker thread routine. Waits for a job, executes it, repeat. + """ + self._lock.acquire() + while True: + while not self._workQueue: + self._lock.wait() + + # We have a job to do... + job = self._workQueue.pop(0) + + assert self._idleCount > 0 + self._idleCount -= 1 + + self._lock.release() + + job.run() + + self._lock.acquire() + + if self._idleCount == self._maxSpare: + break # NB: lock still held + self._idleCount += 1 + assert self._idleCount <= self._maxSpare + + # Die off... + assert self._workerCount > self._maxSpare + self._workerCount -= 1 + + self._lock.release() + +class WSGIServer(object): + """ + SCGI/WSGI server. For information about SCGI (Simple Common Gateway + Interface), see <http://www.mems-exchange.org/software/scgi/>. + + This server is similar to SWAP <http://www.idyll.org/~t/www-tools/wsgi/>, + another SCGI/WSGI server. + + It differs from SWAP in that it isn't based on scgi.scgi_server and + therefore, it allows me to implement concurrency using threads. (Also, + this server was written from scratch and really has no other depedencies.) + Which server to use really boils down to whether you want multithreading + or forking. (But as an aside, I've found scgi.scgi_server's implementation + of preforking to be quite superior. So if your application really doesn't + mind running in multiple processes, go use SWAP. ;) + """ + # What Request class to use. + requestClass = Request + + def __init__(self, application, environ=None, + multithreaded=True, + bindAddress=('localhost', 4000), allowedServers=None, + loggingLevel=logging.INFO, **kw): + """ + environ, which must be a dictionary, can contain any additional + environment variables you want to pass to your application. + + Set multithreaded to False if your application is not thread-safe. + + bindAddress is the address to bind to, which must be a tuple of + length 2. The first element is a string, which is the host name + or IPv4 address of a local interface. The 2nd element is the port + number. + + allowedServers must be None or a list of strings representing the + IPv4 addresses of servers allowed to connect. None means accept + connections from anywhere. + + loggingLevel sets the logging level of the module-level logger. + + Any additional keyword arguments are passed to the underlying + ThreadPool. + """ + if environ is None: + environ = {} + + self.application = application + self.environ = environ + self.multithreaded = multithreaded + self._bindAddress = bindAddress + self._allowedServers = allowedServers + + # Used to force single-threadedness. + self._appLock = thread.allocate_lock() + + self._threadPool = ThreadPool(**kw) + + self.logger = logging.getLogger(LoggerName) + self.logger.setLevel(loggingLevel) + + def _setupSocket(self): + """Creates and binds the socket for communication with the server.""" + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(self._bindAddress) + sock.listen(socket.SOMAXCONN) + return sock + + def _cleanupSocket(self, sock): + """Closes the main socket.""" + sock.close() + + def _isServerAllowed(self, addr): + return self._allowedServers is None or \ + addr[0] in self._allowedServers + + def _installSignalHandlers(self): + self._oldSIGs = [(x,signal.getsignal(x)) for x in + (signal.SIGHUP, signal.SIGINT, signal.SIGTERM)] + signal.signal(signal.SIGHUP, self._hupHandler) + signal.signal(signal.SIGINT, self._intHandler) + signal.signal(signal.SIGTERM, self._intHandler) + + def _restoreSignalHandlers(self): + for signum,handler in self._oldSIGs: + signal.signal(signum, handler) + + def _hupHandler(self, signum, frame): + self._hupReceived = True + self._keepGoing = False + + def _intHandler(self, signum, frame): + self._keepGoing = False + + def run(self, timeout=1.0): + """ + Main loop. Call this after instantiating WSGIServer. SIGHUP, SIGINT, + SIGTERM cause it to cleanup and return. (If a SIGHUP is caught, this + method returns True. Returns False otherwise.) + """ + self.logger.info('%s starting up', self.__class__.__name__) + + try: + sock = self._setupSocket() + except socket.error, e: + self.logger.error('Failed to bind socket (%s), exiting', e[1]) + return False + + self._keepGoing = True + self._hupReceived = False + + # Install signal handlers. + self._installSignalHandlers() + + while self._keepGoing: + try: + r, w, e = select.select([sock], [], [], timeout) + except select.error, e: + if e[0] == errno.EINTR: + continue + raise + + if r: + try: + clientSock, addr = sock.accept() + except socket.error, e: + if e[0] in (errno.EINTR, errno.EAGAIN): + continue + raise + + if not self._isServerAllowed(addr): + self.logger.warning('Server connection from %s disallowed', + addr[0]) + clientSock.close() + continue + + # Hand off to Connection. + conn = Connection(clientSock, addr, self) + if not self._threadPool.addJob(conn, allowQueuing=False): + # No thread left, immediately close the socket to hopefully + # indicate to the web server that we're at our limit... + # and to prevent having too many opened (and useless) + # files. + clientSock.close() + + self._mainloopPeriodic() + + # Restore old signal handlers. + self._restoreSignalHandlers() + + self._cleanupSocket(sock) + + self.logger.info('%s shutting down%s', self.__class__.__name__, + self._hupReceived and ' (reload requested)' or '') + + return self._hupReceived + + def _mainloopPeriodic(self): + """ + Called with just about each iteration of the main loop. Meant to + be overridden. + """ + pass + + def _exit(self, reload=False): + """ + Protected convenience method for subclasses to force an exit. Not + really thread-safe, which is why it isn't public. + """ + if self._keepGoing: + self._keepGoing = False + self._hupReceived = reload + + def handler(self, request): + """ + WSGI handler. Sets up WSGI environment, calls the application, + and sends the application's response. + """ + environ = request.environ + environ.update(self.environ) + + environ['wsgi.version'] = (1,0) + environ['wsgi.input'] = request.stdin + environ['wsgi.errors'] = sys.stderr + environ['wsgi.multithread'] = self.multithreaded + # AFAIK, the current mod_scgi does not do load-balancing/fail-over. + # So a single application deployment will only run in one process + # at a time, on this server. + environ['wsgi.multiprocess'] = False + environ['wsgi.run_once'] = False + + if environ.get('HTTPS', 'off') in ('on', '1'): + environ['wsgi.url_scheme'] = 'https' + else: + environ['wsgi.url_scheme'] = 'http' + + headers_set = [] + headers_sent = [] + result = None + + def write(data): + assert type(data) is str, 'write() argument must be string' + assert headers_set, 'write() before start_response()' + + if not headers_sent: + status, responseHeaders = headers_sent[:] = headers_set + found = False + for header,value in responseHeaders: + if header.lower() == 'content-length': + found = True + break + if not found and result is not None: + try: + if len(result) == 1: + responseHeaders.append(('Content-Length', + str(len(data)))) + except: + pass + s = 'Status: %s\r\n' % status + for header in responseHeaders: + s += '%s: %s\r\n' % header + s += '\r\n' + try: + request.stdout.write(s) + except socket.error, e: + if e[0] != errno.EPIPE: + raise + + try: + request.stdout.write(data) + request.stdout.flush() + except socket.error, e: + if e[0] != errno.EPIPE: + raise + + def start_response(status, response_headers, exc_info=None): + if exc_info: + try: + if headers_sent: + # Re-raise if too late + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None # avoid dangling circular ref + else: + assert not headers_set, 'Headers already set!' + + assert type(status) is str, 'Status must be a string' + assert len(status) >= 4, 'Status must be at least 4 characters' + assert int(status[:3]), 'Status must begin with 3-digit code' + assert status[3] == ' ', 'Status must have a space after code' + assert type(response_headers) is list, 'Headers must be a list' + if __debug__: + for name,val in response_headers: + assert type(name) is str, 'Header names must be strings' + assert type(val) is str, 'Header values must be strings' + + headers_set[:] = [status, response_headers] + return write + + if not self.multithreaded: + self._appLock.acquire() + try: + result = self.application(environ, start_response) + try: + for data in result: + if data: + write(data) + if not headers_sent: + write('') # in case body was empty + finally: + if hasattr(result, 'close'): + result.close() + finally: + if not self.multithreaded: + self._appLock.release() + + def error(self, request): + """ + Override to provide custom error handling. Ideally, however, + all errors should be caught at the application level. + """ + import cgitb + request.stdout.write('Content-Type: text/html\r\n\r\n' + + cgitb.html(sys.exc_info())) + +if __name__ == '__main__': + def test_app(environ, start_response): + """Probably not the most efficient example.""" + import cgi + start_response('200 OK', [('Content-Type', 'text/html')]) + yield '<html><head><title>Hello World!</title></head>\n' \ + '<body>\n' \ + '<p>Hello World!</p>\n' \ + '<table border="1">' + names = environ.keys() + names.sort() + for name in names: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + name, cgi.escape(`environ[name]`)) + + form = cgi.FieldStorage(fp=environ['wsgi.input'], environ=environ, + keep_blank_values=1) + if form.list: + yield '<tr><th colspan="2">Form data</th></tr>' + + for field in form.list: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + field.name, field.value) + + yield '</table>\n' \ + '</body></html>\n' + + WSGIServer(test_app, + loggingLevel=logging.DEBUG).run() diff --git a/flup/server/scgi_fork.py b/flup/server/scgi_fork.py new file mode 100644 index 0000000..05a527c --- /dev/null +++ b/flup/server/scgi_fork.py @@ -0,0 +1,528 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +""" +scgi - an SCGI/WSGI gateway. (I might have to rename this module.) + +For more information about SCGI and mod_scgi for Apache1/Apache2, see +<http://www.mems-exchange.org/software/scgi/>. + +For more information about the Web Server Gateway Interface, see +<http://www.python.org/peps/pep-0333.html>. + +Example usage: + + #!/usr/bin/env python + import sys + from myapplication import app # Assume app is your WSGI application object + from scgi import WSGIServer + ret = WSGIServer(app).run() + sys.exit(ret and 42 or 0) + +See the documentation for WSGIServer for more information. + +About the bit of logic at the end: +Upon receiving SIGHUP, the python script will exit with status code 42. This +can be used by a wrapper script to determine if the python script should be +re-run. When a SIGINT or SIGTERM is received, the script exits with status +code 0, possibly indicating a normal exit. + +Example wrapper script: + + #!/bin/sh + STATUS=42 + while test $STATUS -eq 42; do + python "$@" that_script_above.py + STATUS=$? + done +""" + +__author__ = 'Allan Saddi <allan@saddi.com>' +__version__ = '$Revision$' + +import sys +import logging +import socket +import select +import errno +import cStringIO as StringIO +import signal +import datetime +import prefork + +__all__ = ['WSGIServer'] + +# The main classes use this name for logging. +LoggerName = 'scgi-wsgi' + +# Set up module-level logger. +console = logging.StreamHandler() +console.setLevel(logging.DEBUG) +console.setFormatter(logging.Formatter('%(asctime)s : %(message)s', + '%Y-%m-%d %H:%M:%S')) +logging.getLogger(LoggerName).addHandler(console) +del console + +class ProtocolError(Exception): + """ + Exception raised when the server does something unexpected or + sends garbled data. Usually leads to a Connection closing. + """ + pass + +def recvall(sock, length): + """ + Attempts to receive length bytes from a socket, blocking if necessary. + (Socket may be blocking or non-blocking.) + """ + dataList = [] + recvLen = 0 + while length: + try: + data = sock.recv(length) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if not data: # EOF + break + dataList.append(data) + dataLen = len(data) + recvLen += dataLen + length -= dataLen + return ''.join(dataList), recvLen + +def readNetstring(sock): + """ + Attempt to read a netstring from a socket. + """ + # First attempt to read the length. + size = '' + while True: + try: + c = sock.recv(1) + except socket.error, e: + if e[0] == errno.EAGAIN: + select.select([sock], [], []) + continue + else: + raise + if c == ':': + break + if not c: + raise EOFError + size += c + + # Try to decode the length. + try: + size = int(size) + if size < 0: + raise ValueError + except ValueError: + raise ProtocolError, 'invalid netstring length' + + # Now read the string. + s, length = recvall(sock, size) + + if length < size: + raise EOFError + + # Lastly, the trailer. + trailer, length = recvall(sock, 1) + + if length < 1: + raise EOFError + + if trailer != ',': + raise ProtocolError, 'invalid netstring trailer' + + return s + +class StdoutWrapper(object): + """ + Wrapper for sys.stdout so we know if data has actually been written. + """ + def __init__(self, stdout): + self._file = stdout + self.dataWritten = False + + def write(self, data): + if data: + self.dataWritten = True + self._file.write(data) + + def writelines(self, lines): + for line in lines: + self.write(line) + + def __getattr__(self, name): + return getattr(self._file, name) + +class Request(object): + """ + Encapsulates data related to a single request. + + Public attributes: + environ - Environment variables from web server. + stdin - File-like object representing the request body. + stdout - File-like object for writing the response. + """ + def __init__(self, conn, environ, input, output): + self._conn = conn + self.environ = environ + self.stdin = input + self.stdout = StdoutWrapper(output) + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.info('%s %s%s', + self.environ['REQUEST_METHOD'], + self.environ.get('SCRIPT_NAME', ''), + self.environ.get('PATH_INFO', '')) + + start = datetime.datetime.now() + + try: + self._conn.server.handler(self) + except: + self.logger.exception('Exception caught from handler') + if not self.stdout.dataWritten: + self._conn.server.error(self) + + end = datetime.datetime.now() + + handlerTime = end - start + self.logger.debug('%s %s%s done (%.3f secs)', + self.environ['REQUEST_METHOD'], + self.environ.get('SCRIPT_NAME', ''), + self.environ.get('PATH_INFO', ''), + handlerTime.seconds + + handlerTime.microseconds / 1000000.0) + +class Connection(object): + """ + Represents a single client (web server) connection. A single request + is handled, after which the socket is closed. + """ + def __init__(self, sock, addr, server): + self._sock = sock + self._addr = addr + self.server = server + + self.logger = logging.getLogger(LoggerName) + + def run(self): + self.logger.debug('Connection starting up (%s:%d)', + self._addr[0], self._addr[1]) + + try: + self.processInput() + except EOFError: + pass + except ProtocolError, e: + self.logger.error("Protocol error '%s'", str(e)) + except: + self.logger.exception('Exception caught in Connection') + + self.logger.debug('Connection shutting down (%s:%d)', + self._addr[0], self._addr[1]) + + # All done! + self._sock.close() + + def processInput(self): + # Read headers + headers = readNetstring(self._sock) + headers = headers.split('\x00')[:-1] + if len(headers) % 2 != 0: + raise ProtocolError, 'invalid headers' + environ = {} + for i in range(len(headers) / 2): + environ[headers[2*i]] = headers[2*i+1] + + clen = environ.get('CONTENT_LENGTH') + if clen is None: + raise ProtocolError, 'missing CONTENT_LENGTH' + try: + clen = int(clen) + if clen < 0: + raise ValueError + except ValueError: + raise ProtocolError, 'invalid CONTENT_LENGTH' + + self._sock.setblocking(1) + if clen: + input = self._sock.makefile('r') + else: + # Empty input. + input = StringIO.StringIO() + + # stdout + output = self._sock.makefile('w') + + # Allocate Request + req = Request(self, environ, input, output) + + # Run it. + req.run() + + output.close() + input.close() + +class WSGIServer(prefork.PreforkServer): + """ + SCGI/WSGI server. For information about SCGI (Simple Common Gateway + Interface), see <http://www.mems-exchange.org/software/scgi/>. + + This server is similar to SWAP <http://www.idyll.org/~t/www-tools/wsgi/>, + another SCGI/WSGI server. + + It differs from SWAP in that it isn't based on scgi.scgi_server and + therefore, it allows me to implement concurrency using threads. (Also, + this server was written from scratch and really has no other depedencies.) + Which server to use really boils down to whether you want multithreading + or forking. (But as an aside, I've found scgi.scgi_server's implementation + of preforking to be quite superior. So if your application really doesn't + mind running in multiple processes, go use SWAP. ;) + """ + # What Request class to use. + requestClass = Request + + def __init__(self, application, environ=None, + bindAddress=('localhost', 4000), allowedServers=None, + loggingLevel=logging.INFO, **kw): + """ + environ, which must be a dictionary, can contain any additional + environment variables you want to pass to your application. + + bindAddress is the address to bind to, which must be a tuple of + length 2. The first element is a string, which is the host name + or IPv4 address of a local interface. The 2nd element is the port + number. + + allowedServers must be None or a list of strings representing the + IPv4 addresses of servers allowed to connect. None means accept + connections from anywhere. + + loggingLevel sets the logging level of the module-level logger. + + Any additional keyword arguments are passed to the underlying + ThreadPool. + """ + if kw.has_key('jobClass'): + del kw['jobClass'] + if kw.has_key('jobArgs'): + del kw['jobArgs'] + super(WSGIServer, self).__init__(jobClass=Connection, + jobArgs=(self,), **kw) + + if environ is None: + environ = {} + + self.application = application + self.environ = environ + self._bindAddress = bindAddress + self._allowedServers = allowedServers + + self.logger = logging.getLogger(LoggerName) + self.logger.setLevel(loggingLevel) + + def _setupSocket(self): + """Creates and binds the socket for communication with the server.""" + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(self._bindAddress) + sock.listen(socket.SOMAXCONN) + return sock + + def _cleanupSocket(self, sock): + """Closes the main socket.""" + sock.close() + + def _isClientAllowed(self, addr): + ret = self._allowedServers is None or addr[0] in self._allowedServers + if not ret: + self.logger.warning('Server connection from %s disallowed', + addr[0]) + return ret + + def run(self): + """ + Main loop. Call this after instantiating WSGIServer. SIGHUP, SIGINT, + SIGQUIT, SIGTERM cause it to cleanup and return. (If a SIGHUP + is caught, this method returns True. Returns False otherwise.) + """ + self.logger.info('%s starting up', self.__class__.__name__) + + try: + sock = self._setupSocket() + except socket.error, e: + self.logger.error('Failed to bind socket (%s), exiting', e[1]) + return False + + ret = super(WSGIServer, self).run(sock) + + self._cleanupSocket(sock) + + self.logger.info('%s shutting down', self.__class__.__name__) + + return ret + + def handler(self, request): + """ + WSGI handler. Sets up WSGI environment, calls the application, + and sends the application's response. + """ + environ = request.environ + environ.update(self.environ) + + environ['wsgi.version'] = (1,0) + environ['wsgi.input'] = request.stdin + environ['wsgi.errors'] = sys.stderr + environ['wsgi.multithread'] = False + environ['wsgi.multiprocess'] = True + environ['wsgi.run_once'] = False + + if environ.get('HTTPS', 'off') in ('on', '1'): + environ['wsgi.url_scheme'] = 'https' + else: + environ['wsgi.url_scheme'] = 'http' + + headers_set = [] + headers_sent = [] + result = None + + def write(data): + assert type(data) is str, 'write() argument must be string' + assert headers_set, 'write() before start_response()' + + if not headers_sent: + status, responseHeaders = headers_sent[:] = headers_set + found = False + for header,value in responseHeaders: + if header.lower() == 'content-length': + found = True + break + if not found and result is not None: + try: + if len(result) == 1: + responseHeaders.append(('Content-Length', + str(len(data)))) + except: + pass + s = 'Status: %s\r\n' % status + for header in responseHeaders: + s += '%s: %s\r\n' % header + s += '\r\n' + try: + request.stdout.write(s) + except socket.error, e: + if e[0] != errno.EPIPE: + raise + + try: + request.stdout.write(data) + request.stdout.flush() + except socket.error, e: + if e[0] != errno.EPIPE: + raise + + def start_response(status, response_headers, exc_info=None): + if exc_info: + try: + if headers_sent: + # Re-raise if too late + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None # avoid dangling circular ref + else: + assert not headers_set, 'Headers already set!' + + assert type(status) is str, 'Status must be a string' + assert len(status) >= 4, 'Status must be at least 4 characters' + assert int(status[:3]), 'Status must begin with 3-digit code' + assert status[3] == ' ', 'Status must have a space after code' + assert type(response_headers) is list, 'Headers must be a list' + if __debug__: + for name,val in response_headers: + assert type(name) is str, 'Header names must be strings' + assert type(val) is str, 'Header values must be strings' + + headers_set[:] = [status, response_headers] + return write + + result = self.application(environ, start_response) + try: + for data in result: + if data: + write(data) + if not headers_sent: + write('') # in case body was empty + finally: + if hasattr(result, 'close'): + result.close() + + def error(self, request): + """ + Override to provide custom error handling. Ideally, however, + all errors should be caught at the application level. + """ + import cgitb + request.stdout.write('Content-Type: text/html\r\n\r\n' + + cgitb.html(sys.exc_info())) + +if __name__ == '__main__': + def test_app(environ, start_response): + """Probably not the most efficient example.""" + import cgi + start_response('200 OK', [('Content-Type', 'text/html')]) + yield '<html><head><title>Hello World!</title></head>\n' \ + '<body>\n' \ + '<p>Hello World!</p>\n' \ + '<table border="1">' + names = environ.keys() + names.sort() + for name in names: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + name, cgi.escape(`environ[name]`)) + + form = cgi.FieldStorage(fp=environ['wsgi.input'], environ=environ, + keep_blank_values=1) + if form.list: + yield '<tr><th colspan="2">Form data</th></tr>' + + for field in form.list: + yield '<tr><td>%s</td><td>%s</td></tr>\n' % ( + field.name, field.value) + + yield '</table>\n' \ + '</body></html>\n' + + WSGIServer(test_app, + loggingLevel=logging.DEBUG).run() diff --git a/flup/server/threadpool.py b/flup/server/threadpool.py new file mode 100644 index 0000000..197433f --- /dev/null +++ b/flup/server/threadpool.py @@ -0,0 +1,113 @@ +# Copyright (c) 2005 Allan Saddi <allan@saddi.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# $Id$ + +import sys +import thread +import threading + +class ThreadPool(object): + """ + Thread pool that maintains the number of idle threads between + minSpare and maxSpare inclusive. By default, there is no limit on + the number of threads that can be started, but this can be controlled + by maxThreads. + """ + def __init__(self, minSpare=1, maxSpare=5, maxThreads=sys.maxint): + self._minSpare = minSpare + self._maxSpare = maxSpare + self._maxThreads = max(minSpare, maxThreads) + + self._lock = threading.Condition() + self._workQueue = [] + self._idleCount = self._workerCount = maxSpare + + # Start the minimum number of worker threads. + for i in range(maxSpare): + thread.start_new_thread(self._worker, ()) + + def addJob(self, job, allowQueuing=True): + """ + Adds a job to the work queue. The job object should have a run() + method. If allowQueuing is True (the default), the job will be + added to the work queue regardless if there are any idle threads + ready. (The only way for there to be no idle threads is if maxThreads + is some reasonable, finite limit.) + + Otherwise, if allowQueuing is False, and there are no more idle + threads, the job will not be queued. + + Returns True if the job was queued, False otherwise. + """ + self._lock.acquire() + try: + # Maintain minimum number of spares. + while self._idleCount < self._minSpare and \ + self._workerCount < self._maxThreads: + self._workerCount += 1 + self._idleCount += 1 + thread.start_new_thread(self._worker, ()) + + # Hand off the job. + if self._idleCount or allowQueuing: + self._workQueue.append(job) + self._lock.notify() + return True + else: + return False + finally: + self._lock.release() + + def _worker(self): + """ + Worker thread routine. Waits for a job, executes it, repeat. + """ + self._lock.acquire() + while True: + while not self._workQueue: + self._lock.wait() + + # We have a job to do... + job = self._workQueue.pop(0) + + assert self._idleCount > 0 + self._idleCount -= 1 + + self._lock.release() + + job.run() + + self._lock.acquire() + + if self._idleCount == self._maxSpare: + break # NB: lock still held + self._idleCount += 1 + assert self._idleCount <= self._maxSpare + + # Die off... + assert self._workerCount > self._maxSpare + self._workerCount -= 1 + + self._lock.release() |