diff options
author | Arthur Koziel <arthur@arthurkoziel.com> | 2010-09-13 00:04:27 +0000 |
---|---|---|
committer | Arthur Koziel <arthur@arthurkoziel.com> | 2010-09-13 00:04:27 +0000 |
commit | dd49269c7db008b2567f50cb03c4d3d9b321daa1 (patch) | |
tree | 326dd25bb045ac016cda7966b43cbdfe1f67d699 /django/test | |
parent | c9b188c4ec939abbe48dae5a371276742e64b6b8 (diff) | |
download | django-soc2010/app-loading.tar.gz |
[soc2010/app-loading] merged trunkarchive/soc2010/app-loadingsoc2010/app-loading
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/app-loading@13818 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/test')
-rw-r--r-- | django/test/__init__.py | 1 | ||||
-rw-r--r-- | django/test/client.py | 31 | ||||
-rw-r--r-- | django/test/testcases.py | 11 | ||||
-rw-r--r-- | django/test/utils.py | 25 |
4 files changed, 51 insertions, 17 deletions
diff --git a/django/test/__init__.py b/django/test/__init__.py index 957b293e12..c996ed49d6 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -4,3 +4,4 @@ Django Unit Test and Doctest framework. from django.test.client import Client from django.test.testcases import TestCase, TransactionTestCase +from django.test.utils import Approximate diff --git a/django/test/client.py b/django/test/client.py index e5a16b6e79..08e3ff6b71 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -3,6 +3,7 @@ from urlparse import urlparse, urlunparse, urlsplit import sys import os import re +import mimetypes try: from cStringIO import StringIO except ImportError: @@ -54,6 +55,10 @@ class ClientHandler(BaseHandler): Uses the WSGI interface to compose requests, but returns the raw HttpResponse object """ + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): + self.enforce_csrf_checks = enforce_csrf_checks + super(ClientHandler, self).__init__(*args, **kwargs) + def __call__(self, environ): from django.conf import settings from django.core import signals @@ -70,7 +75,7 @@ class ClientHandler(BaseHandler): # CsrfViewMiddleware. This makes life easier, and is probably # required for backwards compatibility with external tests against # admin views. - request._dont_enforce_csrf_checks = True + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks response = self.get_response(request) # Apply response middleware. @@ -138,11 +143,14 @@ def encode_multipart(boundary, data): def encode_file(boundary, key, file): to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET) + content_type = mimetypes.guess_type(file.name)[0] + if content_type is None: + content_type = 'application/octet-stream' return [ '--' + boundary, 'Content-Disposition: form-data; name="%s"; filename="%s"' \ % (to_str(key), to_str(os.path.basename(file.name))), - 'Content-Type: application/octet-stream', + 'Content-Type: %s' % content_type, '', file.read() ] @@ -165,8 +173,8 @@ class Client(object): contexts and templates produced by a view, rather than the HTML rendered to the end-user. """ - def __init__(self, **defaults): - self.handler = ClientHandler() + def __init__(self, enforce_csrf_checks=False, **defaults): + self.handler = ClientHandler(enforce_csrf_checks) self.defaults = defaults self.cookies = SimpleCookie() self.exc_info = None @@ -289,7 +297,7 @@ class Client(object): response = self.request(**r) if follow: - response = self._handle_redirects(response) + response = self._handle_redirects(response, **extra) return response def post(self, path, data={}, content_type=MULTIPART_CONTENT, @@ -321,7 +329,7 @@ class Client(object): response = self.request(**r) if follow: - response = self._handle_redirects(response) + response = self._handle_redirects(response, **extra) return response def head(self, path, data={}, follow=False, **extra): @@ -340,7 +348,7 @@ class Client(object): response = self.request(**r) if follow: - response = self._handle_redirects(response) + response = self._handle_redirects(response, **extra) return response def options(self, path, data={}, follow=False, **extra): @@ -358,7 +366,7 @@ class Client(object): response = self.request(**r) if follow: - response = self._handle_redirects(response) + response = self._handle_redirects(response, **extra) return response def put(self, path, data={}, content_type=MULTIPART_CONTENT, @@ -390,7 +398,7 @@ class Client(object): response = self.request(**r) if follow: - response = self._handle_redirects(response) + response = self._handle_redirects(response, **extra) return response def delete(self, path, data={}, follow=False, **extra): @@ -408,7 +416,7 @@ class Client(object): response = self.request(**r) if follow: - response = self._handle_redirects(response) + response = self._handle_redirects(response, **extra) return response def login(self, **credentials): @@ -463,7 +471,7 @@ class Client(object): session.delete(session_key=session_cookie.value) self.cookies = SimpleCookie() - def _handle_redirects(self, response): + def _handle_redirects(self, response, **extra): "Follows any redirects by requesting responses from the server using GET." response.redirect_chain = [] @@ -474,7 +482,6 @@ class Client(object): redirect_chain = response.redirect_chain redirect_chain.append((url, response.status_code)) - extra = {} if scheme: extra['wsgi.url_scheme'] = scheme diff --git a/django/test/testcases.py b/django/test/testcases.py index 2f8acad68c..10bd6c6c9f 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -347,7 +347,7 @@ class TransactionTestCase(unittest.TestCase): def assertContains(self, response, text, count=None, status_code=200, msg_prefix=''): """ - Asserts that a response indicates that a page was retrieved + Asserts that a response indicates that some content was retrieved successfully, (i.e., the HTTP status code was as expected), and that ``text`` occurs ``count`` times in the content of the response. If ``count`` is None, the count doesn't matter - the assertion is true @@ -357,7 +357,7 @@ class TransactionTestCase(unittest.TestCase): msg_prefix += ": " self.assertEqual(response.status_code, status_code, - msg_prefix + "Couldn't retrieve page: Response code was %d" + msg_prefix + "Couldn't retrieve content: Response code was %d" " (expected %d)" % (response.status_code, status_code)) text = smart_str(text, response._charset) real_count = response.content.count(text) @@ -372,7 +372,7 @@ class TransactionTestCase(unittest.TestCase): def assertNotContains(self, response, text, status_code=200, msg_prefix=''): """ - Asserts that a response indicates that a page was retrieved + Asserts that a response indicates that some content was retrieved successfully, (i.e., the HTTP status code was as expected), and that ``text`` doesn't occurs in the content of the response. """ @@ -380,7 +380,7 @@ class TransactionTestCase(unittest.TestCase): msg_prefix += ": " self.assertEqual(response.status_code, status_code, - msg_prefix + "Couldn't retrieve page: Response code was %d" + msg_prefix + "Couldn't retrieve content: Response code was %d" " (expected %d)" % (response.status_code, status_code)) text = smart_str(text, response._charset) self.assertEqual(response.content.count(text), 0, @@ -466,6 +466,9 @@ class TransactionTestCase(unittest.TestCase): msg_prefix + "Template '%s' was used unexpectedly in rendering" " the response" % template_name) + def assertQuerysetEqual(self, qs, values, transform=repr): + return self.assertEqual(map(transform, qs), values) + def connections_support_transactions(): """ Returns True if all connections support transactions. This is messy diff --git a/django/test/utils.py b/django/test/utils.py index b6ab39901b..8ecb5a0e60 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -1,4 +1,6 @@ -import sys, time, os +import sys +import time +import os from django.conf import settings from django.core import mail from django.core.mail.backends import locmem @@ -6,6 +8,21 @@ from django.test import signals from django.template import Template from django.utils.translation import deactivate + +class Approximate(object): + def __init__(self, val, places=7): + self.val = val + self.places = places + + def __repr__(self): + return repr(self.val) + + def __eq__(self, other): + if self.val == other: + return True + return round(abs(self.val-other), self.places) == 0 + + class ContextList(list): """A wrapper that provides direct key access to context items contained in a list of context objects. @@ -19,6 +36,12 @@ class ContextList(list): else: return super(ContextList, self).__getitem__(key) + def __contains__(self, key): + try: + value = self[key] + except KeyError: + return False + return True def instrumented_test_render(self, context): """ |