diff options
author | pje <pje@571e12c6-e1fa-0310-aee7-ff1267fa46bd> | 2004-10-05 23:11:22 +0000 |
---|---|---|
committer | pje <pje@571e12c6-e1fa-0310-aee7-ff1267fa46bd> | 2004-10-05 23:11:22 +0000 |
commit | 6d514a124dfd65e805c4fe61273550e983101574 (patch) | |
tree | 32d1387a1a9a9a806f08fba611d7aa4c7dcb063b | |
parent | 7a2c9a1c49e81c5658e125996dad584318bf55f7 (diff) | |
download | wsgiref-6d514a124dfd65e805c4fe61273550e983101574.tar.gz |
Add a wsgiref.headers.Headers class to allow easy HTTP response header
manipulation.
git-svn-id: svn://svn.eby-sarna.com/svnroot/wsgiref@248 571e12c6-e1fa-0310-aee7-ff1267fa46bd
-rw-r--r-- | src/wsgiref/headers.py | 205 | ||||
-rw-r--r-- | src/wsgiref/tests/__init__.py | 64 | ||||
-rw-r--r-- | src/wsgiref/tests/test_headers.py | 90 | ||||
-rw-r--r-- | src/wsgiref/tests/test_util.py | 76 |
4 files changed, 397 insertions, 38 deletions
diff --git a/src/wsgiref/headers.py b/src/wsgiref/headers.py new file mode 100644 index 0000000..9435738 --- /dev/null +++ b/src/wsgiref/headers.py @@ -0,0 +1,205 @@ +"""Manage HTTP Response Headers + +Much of this module is red-handedly pilfered from email.Message in the stdlib, +so portions are Copyright (C) 2001,2002 Python Software Foundation, and were +written by Barry Warsaw. +""" + +from types import ListType, TupleType + +# Regular expression that matches `special' characters in parameters, the +# existance of which force quoting of the parameter value. +import re +tspecials = re.compile(r'[ \(\)<>@,;:\\"/\[\]\?=]') + +def _formatparam(param, value=None, quote=1): + """Convenience function to format and return a key=value pair. + + This will quote the value if needed or if quote is true. + """ + if value is not None and len(value) > 0: + if quote or tspecials.search(value): + value = value.replace('\\', '\\\\').replace('"', r'\"') + return '%s="%s"' % (param, value) + else: + return '%s=%s' % (param, value) + else: + return param + + + + + + + + + + + + + + +class Headers: + + """Manage a collection of HTTP response headers""" + + def __init__(self,headers): + if type(headers) is not ListType: + raise TypeError("Headers must be a list of name/value tuples") + self._headers = headers + + def __len__(self): + """Return the total number of headers, including duplicates.""" + return len(self._headers) + + def __setitem__(self, name, val): + """Set the value of a header.""" + del self[name] + self._headers.append((name, val)) + + def __delitem__(self,name): + """Delete all occurrences of a header, if present. + + Does *not* raise an exception if the header is missing. + """ + name = name.lower() + self._headers[:] = [kv for kv in self._headers if kv[0].lower()<>name] + + def __getitem__(self,name): + """Get the first header value for 'name' + + Return None if the header is missing instead of raising an exception. + + Note that if the header appeared multiple times, the first exactly which + occurrance gets returned is undefined. Use getall() to get all + the values matching a header field name. + """ + return self.get(name) + + + + + + def has_key(self, name): + """Return true if the message contains the header.""" + return self.get(name) is not None + + __contains__ = has_key + + + def get_all(self, name): + """Return a list of all the values for the named field. + + These will be sorted in the order they appeared in the original header + list or were added to this instance, and may contain duplicates. Any + fields deleted and re-inserted are always appended to the header list. + If no fields exist with the given name, returns an empty list. + """ + name = name.lower() + return [kv[1] for kv in self._headers if kv[0].lower()==name] + + + def get(self,name,default=None): + """Get the first header value for 'name', or return 'default'""" + name = name.lower() + for k,v in self._headers: + if k.lower()==name: + return v + return default + + + def keys(self): + """Return a list of all the header field names. + + These will be sorted in the order they appeared in the original header + list, or were added to this instance, and may contain duplicates. + Any fields deleted and re-inserted are always appended to the header + list. + """ + return [k for k, v in self._headers] + + + + + def values(self): + """Return a list of all header values. + + These will be sorted in the order they appeared in the original header + list, or were added to this instance, and may contain duplicates. + Any fields deleted and re-inserted are always appended to the header + list. + """ + return [v for k, v in self._headers] + + def items(self): + """Get all the header fields and values. + + These will be sorted in the order they were in the original header + list, or were added to this instance, and may contain duplicates. + Any fields deleted and re-inserted are always appended to the header + list. + """ + return self._headers[:] + + def __repr__(self): + return "Headers(%s)" % `self._headers` + + def __str__(self): + """str() returns the formatted headers, complete with end line, + suitable for direct HTTP transmission.""" + return '\r\n'.join(["%s: %s" % kv for kv in self._headers])+'\r\n'*2 + + + + + + + + + + + + + + + def add_header(self, _name, _value, **_params): + """Extended header setting. + + _name is the header field to add. keyword arguments can be used to set + additional parameters for the header field, with underscores converted + to dashes. Normally the parameter will be added as key="value" unless + value is None, in which case only the key will be added. + + Example: + + h.add_header('content-disposition', 'attachment', filename='bud.gif') + + Note that unlike the corresponding 'email.Message' method, this does + *not* handle '(charset, language, value)' tuples: all values must be + strings or None. + """ + parts = [] + if _value is not None: + parts.append(_value) + for k, v in _params.items(): + if v is None: + parts.append(k.replace('_', '-')) + else: + parts.append(_formatparam(k.replace('_', '-'), v)) + self._headers.append((_name, "; ".join(parts))) + + + + + + + + + + + + + + + + diff --git a/src/wsgiref/tests/__init__.py b/src/wsgiref/tests/__init__.py index a4eb4b5..6b61776 100644 --- a/src/wsgiref/tests/__init__.py +++ b/src/wsgiref/tests/__init__.py @@ -1,11 +1,52 @@ from unittest import TestSuite, TestCase, makeSuite +def compare_generic_iter(make_it,match): + """Utility to compare a generic 2.1/2.2+ iterator with an iterable + + If running under Python 2.2+, this tests the iterator using iter()/next(), + as well as __getitem__. 'make_it' must be a function returning a fresh + iterator to be tested (since this may test the iterator twice).""" + + it = make_it() + n = 0 + for item in match: + assert it[n]==item + n+=1 + try: + it[n] + except IndexError: + pass + else: + raise AssertionError("Too many items from __getitem__",it) + + try: + iter, StopIteration + except NameError: + pass + else: + # Only test iter mode under 2.2+ + it = make_it() + assert iter(it) is it + for item in match: + assert it.next()==item + try: + it.next() + except StopIteration: + pass + else: + raise AssertionError("Too many items from .next()",it) + + + + def test_suite(): from wsgiref.tests import test_util + from wsgiref.tests import test_headers tests = [ test_util.test_suite(), + test_headers.test_suite(), ] return TestSuite(tests) @@ -16,3 +57,26 @@ def test_suite(): + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/wsgiref/tests/test_headers.py b/src/wsgiref/tests/test_headers.py new file mode 100644 index 0000000..0ce53fe --- /dev/null +++ b/src/wsgiref/tests/test_headers.py @@ -0,0 +1,90 @@ +from unittest import TestCase, TestSuite, makeSuite +from wsgiref.headers import Headers +from wsgiref.tests import compare_generic_iter + + +class HeaderTests(TestCase): + + def testMappingInterface(self): + test = [('x','y')] + self.assertEqual(len(Headers([])),0) + self.assertEqual(len(Headers(test[:])),1) + self.assertEqual(Headers(test[:]).keys(), ['x']) + self.assertEqual(Headers(test[:]).values(), ['y']) + self.assertEqual(Headers(test[:]).items(), test) + self.failIf(Headers(test).items() is test) # must be copy! + + h=Headers([]) + del h['foo'] # should not raise an error + + h['Foo'] = 'bar' + for m in h.has_key, h.__contains__, h.get, h.get_all, h.__getitem__: + self.failUnless(m('foo')) + self.failUnless(m('Foo')) + self.failUnless(m('FOO')) + self.failIf(m('bar')) + + self.assertEqual(h['foo'],'bar') + h['foo'] = 'baz' + self.assertEqual(h['FOO'],'baz') + self.assertEqual(h.get_all('foo'),['baz']) + + self.assertEqual(h.get("foo","whee"), "baz") + self.assertEqual(h.get("zoo","whee"), "whee") + + def testRequireList(self): + self.assertRaises(TypeError, Headers, "foo") + + + + + + def testExtras(self): + h = Headers([]) + h.add_header('foo','bar',baz="spam") + self.assertEqual(h['foo'], 'bar; baz="spam"') + h.add_header('Foo','bar',cheese=None) + self.assertEqual(h.get_all('foo'), + ['bar; baz="spam"', 'bar; cheese']) + self.assertEqual(str(h), + 'foo: bar; baz="spam"\r\n' + 'Foo: bar; cheese\r\n' + '\r\n' + ) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +TestClasses = ( + HeaderTests, +) + +def test_suite(): + return TestSuite([makeSuite(t,'test') for t in TestClasses]) + + diff --git a/src/wsgiref/tests/test_util.py b/src/wsgiref/tests/test_util.py index b43f28e..91caac6 100644 --- a/src/wsgiref/tests/test_util.py +++ b/src/wsgiref/tests/test_util.py @@ -1,8 +1,8 @@ from unittest import TestCase, TestSuite, makeSuite from wsgiref import util +from wsgiref.tests import compare_generic_iter from StringIO import StringIO - class UtilityTests(TestCase): def checkShift(self,sn_in,pi_in,part,sn_out,pi_out): @@ -13,7 +13,6 @@ class UtilityTests(TestCase): self.assertEqual(env['SCRIPT_NAME'],sn_out) return env - def checkDefault(self, key, value, alt=None): # Check defaulting when empty env = {} @@ -28,17 +27,6 @@ class UtilityTests(TestCase): util.setup_testing_defaults(env) self.failUnless(env[key] is alt) - - - - - - - - - - - def checkCrossDefault(self,key,value,**kw): util.setup_testing_defaults(kw) self.assertEqual(kw[key],value) @@ -52,32 +40,22 @@ class UtilityTests(TestCase): self.assertEqual(util.request_uri(kw,query),uri) def checkFW(self,text,size,match): - sio = StringIO(text) - fw = util.FileWrapper(sio,size) - n = 0 - for item in match: - self.assertEqual(fw[n],item) - n+=1 - self.assertRaises(IndexError, fw.__getitem__, n) - - try: - iter, StopIteration - except NameError: - pass - else: - # Only test iter mode under 2.2+ - sio = StringIO(text) - fw = util.FileWrapper(sio,size) - self.failUnless(iter(fw) is fw) - for item in match: - self.assertEqual(fw.next(),item) - self.assertRaises(StopIteration, fw.next) - self.failIf(sio.closed) - fw.close() - self.failUnless(sio.closed) + def make_it(text=text,size=size): + return util.FileWrapper(StringIO(text),size) + + compare_generic_iter(make_it,match) + + it = make_it() + self.failIf(it.filelike.closed) + + for item in it: + pass + self.failIf(it.filelike.closed) + it.close() + self.failUnless(it.filelike.closed) def testSimpleShifts(self): @@ -87,6 +65,7 @@ class UtilityTests(TestCase): self.checkShift('/a','/x/y', 'x', '/a/x', '/y') self.checkShift('/a','/x/', 'x', '/a/x', '/') + def testNormalizedShifts(self): self.checkShift('/a/b', '/../y', '..', '/a', '/y') self.checkShift('', '/../y', '..', '', '/y') @@ -100,6 +79,7 @@ class UtilityTests(TestCase): self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/') self.checkShift('/a/b', '/.', None, '/a/b', '') + def testDefaults(self): for key, value in [ ('SERVER_NAME','127.0.0.1'), @@ -119,8 +99,6 @@ class UtilityTests(TestCase): self.checkDefault(key,value) - - def testCrossDefaults(self): self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar") self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on") @@ -130,6 +108,7 @@ class UtilityTests(TestCase): self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo") self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on") + def testGuessScheme(self): self.assertEqual(util.guess_scheme({}), "http") self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http") @@ -137,6 +116,11 @@ class UtilityTests(TestCase): self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https") self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https") + + + + + def testAppURIs(self): self.checkAppURI("http://127.0.0.1/") self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam") @@ -162,6 +146,22 @@ class UtilityTests(TestCase): def testFileWrapper(self): self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10]) + + + + + + + + + + + + + + + + TestClasses = ( UtilityTests, ) |