summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/backends/tests.py107
-rw-r--r--tests/servers/test_liveserverthread.py5
-rw-r--r--tests/staticfiles_tests/test_liveserver.py3
3 files changed, 69 insertions, 46 deletions
diff --git a/tests/backends/tests.py b/tests/backends/tests.py
index 7e4e665758..6138a3626c 100644
--- a/tests/backends/tests.py
+++ b/tests/backends/tests.py
@@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase):
connection = connections[DEFAULT_DB_ALIAS]
# Allow thread sharing so the connection can be closed by the
# main thread.
- connection.allow_thread_sharing = True
+ connection.inc_thread_sharing()
connection.cursor()
connections_dict[id(connection)] = connection
- for x in range(2):
- t = threading.Thread(target=runner)
- t.start()
- t.join()
- # Each created connection got different inner connection.
- self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
- # Finish by closing the connections opened by the other threads (the
- # connection opened in the main thread will automatically be closed on
- # teardown).
- for conn in connections_dict.values():
- if conn is not connection:
- conn.close()
+ try:
+ for x in range(2):
+ t = threading.Thread(target=runner)
+ t.start()
+ t.join()
+ # Each created connection got different inner connection.
+ self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
+ finally:
+ # Finish by closing the connections opened by the other threads
+ # (the connection opened in the main thread will automatically be
+ # closed on teardown).
+ for conn in connections_dict.values():
+ if conn is not connection:
+ if conn.allow_thread_sharing:
+ conn.close()
+ conn.dec_thread_sharing()
def test_connections_thread_local(self):
"""
@@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase):
for conn in connections.all():
# Allow thread sharing so the connection can be closed by the
# main thread.
- conn.allow_thread_sharing = True
+ conn.inc_thread_sharing()
connections_dict[id(conn)] = conn
- for x in range(2):
- t = threading.Thread(target=runner)
- t.start()
- t.join()
- self.assertEqual(len(connections_dict), 6)
- # Finish by closing the connections opened by the other threads (the
- # connection opened in the main thread will automatically be closed on
- # teardown).
- for conn in connections_dict.values():
- if conn is not connection:
- conn.close()
+ try:
+ for x in range(2):
+ t = threading.Thread(target=runner)
+ t.start()
+ t.join()
+ self.assertEqual(len(connections_dict), 6)
+ finally:
+ # Finish by closing the connections opened by the other threads
+ # (the connection opened in the main thread will automatically be
+ # closed on teardown).
+ for conn in connections_dict.values():
+ if conn is not connection:
+ if conn.allow_thread_sharing:
+ conn.close()
+ conn.dec_thread_sharing()
def test_pass_connection_between_threads(self):
"""
@@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase):
t.start()
t.join()
- # Without touching allow_thread_sharing, which should be False by default.
- exceptions = []
- do_thread()
- # Forbidden!
- self.assertIsInstance(exceptions[0], DatabaseError)
-
- # If explicitly setting allow_thread_sharing to False
- connections['default'].allow_thread_sharing = False
+ # Without touching thread sharing, which should be False by default.
exceptions = []
do_thread()
# Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError)
- # If explicitly setting allow_thread_sharing to True
- connections['default'].allow_thread_sharing = True
- exceptions = []
- do_thread()
- # All good
- self.assertEqual(exceptions, [])
+ # After calling inc_thread_sharing() on the connection.
+ connections['default'].inc_thread_sharing()
+ try:
+ exceptions = []
+ do_thread()
+ # All good
+ self.assertEqual(exceptions, [])
+ finally:
+ connections['default'].dec_thread_sharing()
def test_closing_non_shared_connections(self):
"""
@@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase):
except DatabaseError as e:
exceptions.add(e)
# Enable thread sharing
- connections['default'].allow_thread_sharing = True
- t2 = threading.Thread(target=runner2, args=[connections['default']])
- t2.start()
- t2.join()
+ connections['default'].inc_thread_sharing()
+ try:
+ t2 = threading.Thread(target=runner2, args=[connections['default']])
+ t2.start()
+ t2.join()
+ finally:
+ connections['default'].dec_thread_sharing()
t1 = threading.Thread(target=runner1)
t1.start()
t1.join()
# No exception was raised
self.assertEqual(len(exceptions), 0)
+ def test_thread_sharing_count(self):
+ self.assertIs(connection.allow_thread_sharing, False)
+ connection.inc_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, True)
+ connection.inc_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, True)
+ connection.dec_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, True)
+ connection.dec_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, False)
+ msg = 'Cannot decrement the thread sharing count below zero.'
+ with self.assertRaisesMessage(RuntimeError, msg):
+ connection.dec_thread_sharing()
+
class MySQLPKZeroTests(TestCase):
"""
diff --git a/tests/servers/test_liveserverthread.py b/tests/servers/test_liveserverthread.py
index d39aac8183..9762b53791 100644
--- a/tests/servers/test_liveserverthread.py
+++ b/tests/servers/test_liveserverthread.py
@@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase):
# Pass a connection to the thread to check they are being closed.
connections_override = {DEFAULT_DB_ALIAS: conn}
- saved_sharing = conn.allow_thread_sharing
+ conn.inc_thread_sharing()
try:
- conn.allow_thread_sharing = True
self.assertTrue(conn.is_usable())
self.run_live_server_thread(connections_override)
self.assertFalse(conn.is_usable())
finally:
- conn.allow_thread_sharing = saved_sharing
+ conn.dec_thread_sharing()
diff --git a/tests/staticfiles_tests/test_liveserver.py b/tests/staticfiles_tests/test_liveserver.py
index 264242bbae..820fa5bc89 100644
--- a/tests/staticfiles_tests/test_liveserver.py
+++ b/tests/staticfiles_tests/test_liveserver.py
@@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase):
# app without having set the required STATIC_URL setting.")
pass
finally:
+ # Use del to avoid decrementing the database thread sharing count a
+ # second time.
+ del cls.server_thread
super().tearDownClass()
def test_test_test(self):