summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpje <pje@571e12c6-e1fa-0310-aee7-ff1267fa46bd>2004-10-05 23:11:22 +0000
committerpje <pje@571e12c6-e1fa-0310-aee7-ff1267fa46bd>2004-10-05 23:11:22 +0000
commit6d514a124dfd65e805c4fe61273550e983101574 (patch)
tree32d1387a1a9a9a806f08fba611d7aa4c7dcb063b
parent7a2c9a1c49e81c5658e125996dad584318bf55f7 (diff)
downloadwsgiref-6d514a124dfd65e805c4fe61273550e983101574.tar.gz
Add a wsgiref.headers.Headers class to allow easy HTTP response header
manipulation. git-svn-id: svn://svn.eby-sarna.com/svnroot/wsgiref@248 571e12c6-e1fa-0310-aee7-ff1267fa46bd
-rw-r--r--src/wsgiref/headers.py205
-rw-r--r--src/wsgiref/tests/__init__.py64
-rw-r--r--src/wsgiref/tests/test_headers.py90
-rw-r--r--src/wsgiref/tests/test_util.py76
4 files changed, 397 insertions, 38 deletions
diff --git a/src/wsgiref/headers.py b/src/wsgiref/headers.py
new file mode 100644
index 0000000..9435738
--- /dev/null
+++ b/src/wsgiref/headers.py
@@ -0,0 +1,205 @@
+"""Manage HTTP Response Headers
+
+Much of this module is red-handedly pilfered from email.Message in the stdlib,
+so portions are Copyright (C) 2001,2002 Python Software Foundation, and were
+written by Barry Warsaw.
+"""
+
+from types import ListType, TupleType
+
+# Regular expression that matches `special' characters in parameters, the
+# existance of which force quoting of the parameter value.
+import re
+tspecials = re.compile(r'[ \(\)<>@,;:\\"/\[\]\?=]')
+
+def _formatparam(param, value=None, quote=1):
+ """Convenience function to format and return a key=value pair.
+
+ This will quote the value if needed or if quote is true.
+ """
+ if value is not None and len(value) > 0:
+ if quote or tspecials.search(value):
+ value = value.replace('\\', '\\\\').replace('"', r'\"')
+ return '%s="%s"' % (param, value)
+ else:
+ return '%s=%s' % (param, value)
+ else:
+ return param
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+class Headers:
+
+ """Manage a collection of HTTP response headers"""
+
+ def __init__(self,headers):
+ if type(headers) is not ListType:
+ raise TypeError("Headers must be a list of name/value tuples")
+ self._headers = headers
+
+ def __len__(self):
+ """Return the total number of headers, including duplicates."""
+ return len(self._headers)
+
+ def __setitem__(self, name, val):
+ """Set the value of a header."""
+ del self[name]
+ self._headers.append((name, val))
+
+ def __delitem__(self,name):
+ """Delete all occurrences of a header, if present.
+
+ Does *not* raise an exception if the header is missing.
+ """
+ name = name.lower()
+ self._headers[:] = [kv for kv in self._headers if kv[0].lower()<>name]
+
+ def __getitem__(self,name):
+ """Get the first header value for 'name'
+
+ Return None if the header is missing instead of raising an exception.
+
+ Note that if the header appeared multiple times, the first exactly which
+ occurrance gets returned is undefined. Use getall() to get all
+ the values matching a header field name.
+ """
+ return self.get(name)
+
+
+
+
+
+ def has_key(self, name):
+ """Return true if the message contains the header."""
+ return self.get(name) is not None
+
+ __contains__ = has_key
+
+
+ def get_all(self, name):
+ """Return a list of all the values for the named field.
+
+ These will be sorted in the order they appeared in the original header
+ list or were added to this instance, and may contain duplicates. Any
+ fields deleted and re-inserted are always appended to the header list.
+ If no fields exist with the given name, returns an empty list.
+ """
+ name = name.lower()
+ return [kv[1] for kv in self._headers if kv[0].lower()==name]
+
+
+ def get(self,name,default=None):
+ """Get the first header value for 'name', or return 'default'"""
+ name = name.lower()
+ for k,v in self._headers:
+ if k.lower()==name:
+ return v
+ return default
+
+
+ def keys(self):
+ """Return a list of all the header field names.
+
+ These will be sorted in the order they appeared in the original header
+ list, or were added to this instance, and may contain duplicates.
+ Any fields deleted and re-inserted are always appended to the header
+ list.
+ """
+ return [k for k, v in self._headers]
+
+
+
+
+ def values(self):
+ """Return a list of all header values.
+
+ These will be sorted in the order they appeared in the original header
+ list, or were added to this instance, and may contain duplicates.
+ Any fields deleted and re-inserted are always appended to the header
+ list.
+ """
+ return [v for k, v in self._headers]
+
+ def items(self):
+ """Get all the header fields and values.
+
+ These will be sorted in the order they were in the original header
+ list, or were added to this instance, and may contain duplicates.
+ Any fields deleted and re-inserted are always appended to the header
+ list.
+ """
+ return self._headers[:]
+
+ def __repr__(self):
+ return "Headers(%s)" % `self._headers`
+
+ def __str__(self):
+ """str() returns the formatted headers, complete with end line,
+ suitable for direct HTTP transmission."""
+ return '\r\n'.join(["%s: %s" % kv for kv in self._headers])+'\r\n'*2
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ def add_header(self, _name, _value, **_params):
+ """Extended header setting.
+
+ _name is the header field to add. keyword arguments can be used to set
+ additional parameters for the header field, with underscores converted
+ to dashes. Normally the parameter will be added as key="value" unless
+ value is None, in which case only the key will be added.
+
+ Example:
+
+ h.add_header('content-disposition', 'attachment', filename='bud.gif')
+
+ Note that unlike the corresponding 'email.Message' method, this does
+ *not* handle '(charset, language, value)' tuples: all values must be
+ strings or None.
+ """
+ parts = []
+ if _value is not None:
+ parts.append(_value)
+ for k, v in _params.items():
+ if v is None:
+ parts.append(k.replace('_', '-'))
+ else:
+ parts.append(_formatparam(k.replace('_', '-'), v))
+ self._headers.append((_name, "; ".join(parts)))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/wsgiref/tests/__init__.py b/src/wsgiref/tests/__init__.py
index a4eb4b5..6b61776 100644
--- a/src/wsgiref/tests/__init__.py
+++ b/src/wsgiref/tests/__init__.py
@@ -1,11 +1,52 @@
from unittest import TestSuite, TestCase, makeSuite
+def compare_generic_iter(make_it,match):
+ """Utility to compare a generic 2.1/2.2+ iterator with an iterable
+
+ If running under Python 2.2+, this tests the iterator using iter()/next(),
+ as well as __getitem__. 'make_it' must be a function returning a fresh
+ iterator to be tested (since this may test the iterator twice)."""
+
+ it = make_it()
+ n = 0
+ for item in match:
+ assert it[n]==item
+ n+=1
+ try:
+ it[n]
+ except IndexError:
+ pass
+ else:
+ raise AssertionError("Too many items from __getitem__",it)
+
+ try:
+ iter, StopIteration
+ except NameError:
+ pass
+ else:
+ # Only test iter mode under 2.2+
+ it = make_it()
+ assert iter(it) is it
+ for item in match:
+ assert it.next()==item
+ try:
+ it.next()
+ except StopIteration:
+ pass
+ else:
+ raise AssertionError("Too many items from .next()",it)
+
+
+
+
def test_suite():
from wsgiref.tests import test_util
+ from wsgiref.tests import test_headers
tests = [
test_util.test_suite(),
+ test_headers.test_suite(),
]
return TestSuite(tests)
@@ -16,3 +57,26 @@ def test_suite():
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/wsgiref/tests/test_headers.py b/src/wsgiref/tests/test_headers.py
new file mode 100644
index 0000000..0ce53fe
--- /dev/null
+++ b/src/wsgiref/tests/test_headers.py
@@ -0,0 +1,90 @@
+from unittest import TestCase, TestSuite, makeSuite
+from wsgiref.headers import Headers
+from wsgiref.tests import compare_generic_iter
+
+
+class HeaderTests(TestCase):
+
+ def testMappingInterface(self):
+ test = [('x','y')]
+ self.assertEqual(len(Headers([])),0)
+ self.assertEqual(len(Headers(test[:])),1)
+ self.assertEqual(Headers(test[:]).keys(), ['x'])
+ self.assertEqual(Headers(test[:]).values(), ['y'])
+ self.assertEqual(Headers(test[:]).items(), test)
+ self.failIf(Headers(test).items() is test) # must be copy!
+
+ h=Headers([])
+ del h['foo'] # should not raise an error
+
+ h['Foo'] = 'bar'
+ for m in h.has_key, h.__contains__, h.get, h.get_all, h.__getitem__:
+ self.failUnless(m('foo'))
+ self.failUnless(m('Foo'))
+ self.failUnless(m('FOO'))
+ self.failIf(m('bar'))
+
+ self.assertEqual(h['foo'],'bar')
+ h['foo'] = 'baz'
+ self.assertEqual(h['FOO'],'baz')
+ self.assertEqual(h.get_all('foo'),['baz'])
+
+ self.assertEqual(h.get("foo","whee"), "baz")
+ self.assertEqual(h.get("zoo","whee"), "whee")
+
+ def testRequireList(self):
+ self.assertRaises(TypeError, Headers, "foo")
+
+
+
+
+
+ def testExtras(self):
+ h = Headers([])
+ h.add_header('foo','bar',baz="spam")
+ self.assertEqual(h['foo'], 'bar; baz="spam"')
+ h.add_header('Foo','bar',cheese=None)
+ self.assertEqual(h.get_all('foo'),
+ ['bar; baz="spam"', 'bar; cheese'])
+ self.assertEqual(str(h),
+ 'foo: bar; baz="spam"\r\n'
+ 'Foo: bar; cheese\r\n'
+ '\r\n'
+ )
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+TestClasses = (
+ HeaderTests,
+)
+
+def test_suite():
+ return TestSuite([makeSuite(t,'test') for t in TestClasses])
+
+
diff --git a/src/wsgiref/tests/test_util.py b/src/wsgiref/tests/test_util.py
index b43f28e..91caac6 100644
--- a/src/wsgiref/tests/test_util.py
+++ b/src/wsgiref/tests/test_util.py
@@ -1,8 +1,8 @@
from unittest import TestCase, TestSuite, makeSuite
from wsgiref import util
+from wsgiref.tests import compare_generic_iter
from StringIO import StringIO
-
class UtilityTests(TestCase):
def checkShift(self,sn_in,pi_in,part,sn_out,pi_out):
@@ -13,7 +13,6 @@ class UtilityTests(TestCase):
self.assertEqual(env['SCRIPT_NAME'],sn_out)
return env
-
def checkDefault(self, key, value, alt=None):
# Check defaulting when empty
env = {}
@@ -28,17 +27,6 @@ class UtilityTests(TestCase):
util.setup_testing_defaults(env)
self.failUnless(env[key] is alt)
-
-
-
-
-
-
-
-
-
-
-
def checkCrossDefault(self,key,value,**kw):
util.setup_testing_defaults(kw)
self.assertEqual(kw[key],value)
@@ -52,32 +40,22 @@ class UtilityTests(TestCase):
self.assertEqual(util.request_uri(kw,query),uri)
def checkFW(self,text,size,match):
- sio = StringIO(text)
- fw = util.FileWrapper(sio,size)
- n = 0
- for item in match:
- self.assertEqual(fw[n],item)
- n+=1
- self.assertRaises(IndexError, fw.__getitem__, n)
-
- try:
- iter, StopIteration
- except NameError:
- pass
- else:
- # Only test iter mode under 2.2+
- sio = StringIO(text)
- fw = util.FileWrapper(sio,size)
- self.failUnless(iter(fw) is fw)
- for item in match:
- self.assertEqual(fw.next(),item)
- self.assertRaises(StopIteration, fw.next)
- self.failIf(sio.closed)
- fw.close()
- self.failUnless(sio.closed)
+ def make_it(text=text,size=size):
+ return util.FileWrapper(StringIO(text),size)
+
+ compare_generic_iter(make_it,match)
+
+ it = make_it()
+ self.failIf(it.filelike.closed)
+
+ for item in it:
+ pass
+ self.failIf(it.filelike.closed)
+ it.close()
+ self.failUnless(it.filelike.closed)
def testSimpleShifts(self):
@@ -87,6 +65,7 @@ class UtilityTests(TestCase):
self.checkShift('/a','/x/y', 'x', '/a/x', '/y')
self.checkShift('/a','/x/', 'x', '/a/x', '/')
+
def testNormalizedShifts(self):
self.checkShift('/a/b', '/../y', '..', '/a', '/y')
self.checkShift('', '/../y', '..', '', '/y')
@@ -100,6 +79,7 @@ class UtilityTests(TestCase):
self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/')
self.checkShift('/a/b', '/.', None, '/a/b', '')
+
def testDefaults(self):
for key, value in [
('SERVER_NAME','127.0.0.1'),
@@ -119,8 +99,6 @@ class UtilityTests(TestCase):
self.checkDefault(key,value)
-
-
def testCrossDefaults(self):
self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar")
self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on")
@@ -130,6 +108,7 @@ class UtilityTests(TestCase):
self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo")
self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on")
+
def testGuessScheme(self):
self.assertEqual(util.guess_scheme({}), "http")
self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http")
@@ -137,6 +116,11 @@ class UtilityTests(TestCase):
self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https")
self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https")
+
+
+
+
+
def testAppURIs(self):
self.checkAppURI("http://127.0.0.1/")
self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
@@ -162,6 +146,22 @@ class UtilityTests(TestCase):
def testFileWrapper(self):
self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10])
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
TestClasses = (
UtilityTests,
)