summaryrefslogtreecommitdiff
path: root/django/test
diff options
context:
space:
mode:
authorArthur Koziel <arthur@arthurkoziel.com>2010-09-13 00:04:27 +0000
committerArthur Koziel <arthur@arthurkoziel.com>2010-09-13 00:04:27 +0000
commitdd49269c7db008b2567f50cb03c4d3d9b321daa1 (patch)
tree326dd25bb045ac016cda7966b43cbdfe1f67d699 /django/test
parentc9b188c4ec939abbe48dae5a371276742e64b6b8 (diff)
downloaddjango-soc2010/app-loading.tar.gz
[soc2010/app-loading] merged trunkarchive/soc2010/app-loadingsoc2010/app-loading
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/app-loading@13818 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/test')
-rw-r--r--django/test/__init__.py1
-rw-r--r--django/test/client.py31
-rw-r--r--django/test/testcases.py11
-rw-r--r--django/test/utils.py25
4 files changed, 51 insertions, 17 deletions
diff --git a/django/test/__init__.py b/django/test/__init__.py
index 957b293e12..c996ed49d6 100644
--- a/django/test/__init__.py
+++ b/django/test/__init__.py
@@ -4,3 +4,4 @@ Django Unit Test and Doctest framework.
from django.test.client import Client
from django.test.testcases import TestCase, TransactionTestCase
+from django.test.utils import Approximate
diff --git a/django/test/client.py b/django/test/client.py
index e5a16b6e79..08e3ff6b71 100644
--- a/django/test/client.py
+++ b/django/test/client.py
@@ -3,6 +3,7 @@ from urlparse import urlparse, urlunparse, urlsplit
import sys
import os
import re
+import mimetypes
try:
from cStringIO import StringIO
except ImportError:
@@ -54,6 +55,10 @@ class ClientHandler(BaseHandler):
Uses the WSGI interface to compose requests, but returns
the raw HttpResponse object
"""
+ def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
+ self.enforce_csrf_checks = enforce_csrf_checks
+ super(ClientHandler, self).__init__(*args, **kwargs)
+
def __call__(self, environ):
from django.conf import settings
from django.core import signals
@@ -70,7 +75,7 @@ class ClientHandler(BaseHandler):
# CsrfViewMiddleware. This makes life easier, and is probably
# required for backwards compatibility with external tests against
# admin views.
- request._dont_enforce_csrf_checks = True
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
response = self.get_response(request)
# Apply response middleware.
@@ -138,11 +143,14 @@ def encode_multipart(boundary, data):
def encode_file(boundary, key, file):
to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET)
+ content_type = mimetypes.guess_type(file.name)[0]
+ if content_type is None:
+ content_type = 'application/octet-stream'
return [
'--' + boundary,
'Content-Disposition: form-data; name="%s"; filename="%s"' \
% (to_str(key), to_str(os.path.basename(file.name))),
- 'Content-Type: application/octet-stream',
+ 'Content-Type: %s' % content_type,
'',
file.read()
]
@@ -165,8 +173,8 @@ class Client(object):
contexts and templates produced by a view, rather than the
HTML rendered to the end-user.
"""
- def __init__(self, **defaults):
- self.handler = ClientHandler()
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ self.handler = ClientHandler(enforce_csrf_checks)
self.defaults = defaults
self.cookies = SimpleCookie()
self.exc_info = None
@@ -289,7 +297,7 @@ class Client(object):
response = self.request(**r)
if follow:
- response = self._handle_redirects(response)
+ response = self._handle_redirects(response, **extra)
return response
def post(self, path, data={}, content_type=MULTIPART_CONTENT,
@@ -321,7 +329,7 @@ class Client(object):
response = self.request(**r)
if follow:
- response = self._handle_redirects(response)
+ response = self._handle_redirects(response, **extra)
return response
def head(self, path, data={}, follow=False, **extra):
@@ -340,7 +348,7 @@ class Client(object):
response = self.request(**r)
if follow:
- response = self._handle_redirects(response)
+ response = self._handle_redirects(response, **extra)
return response
def options(self, path, data={}, follow=False, **extra):
@@ -358,7 +366,7 @@ class Client(object):
response = self.request(**r)
if follow:
- response = self._handle_redirects(response)
+ response = self._handle_redirects(response, **extra)
return response
def put(self, path, data={}, content_type=MULTIPART_CONTENT,
@@ -390,7 +398,7 @@ class Client(object):
response = self.request(**r)
if follow:
- response = self._handle_redirects(response)
+ response = self._handle_redirects(response, **extra)
return response
def delete(self, path, data={}, follow=False, **extra):
@@ -408,7 +416,7 @@ class Client(object):
response = self.request(**r)
if follow:
- response = self._handle_redirects(response)
+ response = self._handle_redirects(response, **extra)
return response
def login(self, **credentials):
@@ -463,7 +471,7 @@ class Client(object):
session.delete(session_key=session_cookie.value)
self.cookies = SimpleCookie()
- def _handle_redirects(self, response):
+ def _handle_redirects(self, response, **extra):
"Follows any redirects by requesting responses from the server using GET."
response.redirect_chain = []
@@ -474,7 +482,6 @@ class Client(object):
redirect_chain = response.redirect_chain
redirect_chain.append((url, response.status_code))
- extra = {}
if scheme:
extra['wsgi.url_scheme'] = scheme
diff --git a/django/test/testcases.py b/django/test/testcases.py
index 2f8acad68c..10bd6c6c9f 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -347,7 +347,7 @@ class TransactionTestCase(unittest.TestCase):
def assertContains(self, response, text, count=None, status_code=200,
msg_prefix=''):
"""
- Asserts that a response indicates that a page was retrieved
+ Asserts that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected), and that
``text`` occurs ``count`` times in the content of the response.
If ``count`` is None, the count doesn't matter - the assertion is true
@@ -357,7 +357,7 @@ class TransactionTestCase(unittest.TestCase):
msg_prefix += ": "
self.assertEqual(response.status_code, status_code,
- msg_prefix + "Couldn't retrieve page: Response code was %d"
+ msg_prefix + "Couldn't retrieve content: Response code was %d"
" (expected %d)" % (response.status_code, status_code))
text = smart_str(text, response._charset)
real_count = response.content.count(text)
@@ -372,7 +372,7 @@ class TransactionTestCase(unittest.TestCase):
def assertNotContains(self, response, text, status_code=200,
msg_prefix=''):
"""
- Asserts that a response indicates that a page was retrieved
+ Asserts that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected), and that
``text`` doesn't occurs in the content of the response.
"""
@@ -380,7 +380,7 @@ class TransactionTestCase(unittest.TestCase):
msg_prefix += ": "
self.assertEqual(response.status_code, status_code,
- msg_prefix + "Couldn't retrieve page: Response code was %d"
+ msg_prefix + "Couldn't retrieve content: Response code was %d"
" (expected %d)" % (response.status_code, status_code))
text = smart_str(text, response._charset)
self.assertEqual(response.content.count(text), 0,
@@ -466,6 +466,9 @@ class TransactionTestCase(unittest.TestCase):
msg_prefix + "Template '%s' was used unexpectedly in rendering"
" the response" % template_name)
+ def assertQuerysetEqual(self, qs, values, transform=repr):
+ return self.assertEqual(map(transform, qs), values)
+
def connections_support_transactions():
"""
Returns True if all connections support transactions. This is messy
diff --git a/django/test/utils.py b/django/test/utils.py
index b6ab39901b..8ecb5a0e60 100644
--- a/django/test/utils.py
+++ b/django/test/utils.py
@@ -1,4 +1,6 @@
-import sys, time, os
+import sys
+import time
+import os
from django.conf import settings
from django.core import mail
from django.core.mail.backends import locmem
@@ -6,6 +8,21 @@ from django.test import signals
from django.template import Template
from django.utils.translation import deactivate
+
+class Approximate(object):
+ def __init__(self, val, places=7):
+ self.val = val
+ self.places = places
+
+ def __repr__(self):
+ return repr(self.val)
+
+ def __eq__(self, other):
+ if self.val == other:
+ return True
+ return round(abs(self.val-other), self.places) == 0
+
+
class ContextList(list):
"""A wrapper that provides direct key access to context items contained
in a list of context objects.
@@ -19,6 +36,12 @@ class ContextList(list):
else:
return super(ContextList, self).__getitem__(key)
+ def __contains__(self, key):
+ try:
+ value = self[key]
+ except KeyError:
+ return False
+ return True
def instrumented_test_render(self, context):
"""