diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/backends/tests.py | 107 | ||||
-rw-r--r-- | tests/servers/test_liveserverthread.py | 5 | ||||
-rw-r--r-- | tests/staticfiles_tests/test_liveserver.py | 3 |
3 files changed, 69 insertions, 46 deletions
diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 7e4e665758..6138a3626c 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase): connection = connections[DEFAULT_DB_ALIAS] # Allow thread sharing so the connection can be closed by the # main thread. - connection.allow_thread_sharing = True + connection.inc_thread_sharing() connection.cursor() connections_dict[id(connection)] = connection - for x in range(2): - t = threading.Thread(target=runner) - t.start() - t.join() - # Each created connection got different inner connection. - self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) - # Finish by closing the connections opened by the other threads (the - # connection opened in the main thread will automatically be closed on - # teardown). - for conn in connections_dict.values(): - if conn is not connection: - conn.close() + try: + for x in range(2): + t = threading.Thread(target=runner) + t.start() + t.join() + # Each created connection got different inner connection. + self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) + finally: + # Finish by closing the connections opened by the other threads + # (the connection opened in the main thread will automatically be + # closed on teardown). + for conn in connections_dict.values(): + if conn is not connection: + if conn.allow_thread_sharing: + conn.close() + conn.dec_thread_sharing() def test_connections_thread_local(self): """ @@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase): for conn in connections.all(): # Allow thread sharing so the connection can be closed by the # main thread. - conn.allow_thread_sharing = True + conn.inc_thread_sharing() connections_dict[id(conn)] = conn - for x in range(2): - t = threading.Thread(target=runner) - t.start() - t.join() - self.assertEqual(len(connections_dict), 6) - # Finish by closing the connections opened by the other threads (the - # connection opened in the main thread will automatically be closed on - # teardown). - for conn in connections_dict.values(): - if conn is not connection: - conn.close() + try: + for x in range(2): + t = threading.Thread(target=runner) + t.start() + t.join() + self.assertEqual(len(connections_dict), 6) + finally: + # Finish by closing the connections opened by the other threads + # (the connection opened in the main thread will automatically be + # closed on teardown). + for conn in connections_dict.values(): + if conn is not connection: + if conn.allow_thread_sharing: + conn.close() + conn.dec_thread_sharing() def test_pass_connection_between_threads(self): """ @@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase): t.start() t.join() - # Without touching allow_thread_sharing, which should be False by default. - exceptions = [] - do_thread() - # Forbidden! - self.assertIsInstance(exceptions[0], DatabaseError) - - # If explicitly setting allow_thread_sharing to False - connections['default'].allow_thread_sharing = False + # Without touching thread sharing, which should be False by default. exceptions = [] do_thread() # Forbidden! self.assertIsInstance(exceptions[0], DatabaseError) - # If explicitly setting allow_thread_sharing to True - connections['default'].allow_thread_sharing = True - exceptions = [] - do_thread() - # All good - self.assertEqual(exceptions, []) + # After calling inc_thread_sharing() on the connection. + connections['default'].inc_thread_sharing() + try: + exceptions = [] + do_thread() + # All good + self.assertEqual(exceptions, []) + finally: + connections['default'].dec_thread_sharing() def test_closing_non_shared_connections(self): """ @@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase): except DatabaseError as e: exceptions.add(e) # Enable thread sharing - connections['default'].allow_thread_sharing = True - t2 = threading.Thread(target=runner2, args=[connections['default']]) - t2.start() - t2.join() + connections['default'].inc_thread_sharing() + try: + t2 = threading.Thread(target=runner2, args=[connections['default']]) + t2.start() + t2.join() + finally: + connections['default'].dec_thread_sharing() t1 = threading.Thread(target=runner1) t1.start() t1.join() # No exception was raised self.assertEqual(len(exceptions), 0) + def test_thread_sharing_count(self): + self.assertIs(connection.allow_thread_sharing, False) + connection.inc_thread_sharing() + self.assertIs(connection.allow_thread_sharing, True) + connection.inc_thread_sharing() + self.assertIs(connection.allow_thread_sharing, True) + connection.dec_thread_sharing() + self.assertIs(connection.allow_thread_sharing, True) + connection.dec_thread_sharing() + self.assertIs(connection.allow_thread_sharing, False) + msg = 'Cannot decrement the thread sharing count below zero.' + with self.assertRaisesMessage(RuntimeError, msg): + connection.dec_thread_sharing() + class MySQLPKZeroTests(TestCase): """ diff --git a/tests/servers/test_liveserverthread.py b/tests/servers/test_liveserverthread.py index d39aac8183..9762b53791 100644 --- a/tests/servers/test_liveserverthread.py +++ b/tests/servers/test_liveserverthread.py @@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase): # Pass a connection to the thread to check they are being closed. connections_override = {DEFAULT_DB_ALIAS: conn} - saved_sharing = conn.allow_thread_sharing + conn.inc_thread_sharing() try: - conn.allow_thread_sharing = True self.assertTrue(conn.is_usable()) self.run_live_server_thread(connections_override) self.assertFalse(conn.is_usable()) finally: - conn.allow_thread_sharing = saved_sharing + conn.dec_thread_sharing() diff --git a/tests/staticfiles_tests/test_liveserver.py b/tests/staticfiles_tests/test_liveserver.py index 264242bbae..820fa5bc89 100644 --- a/tests/staticfiles_tests/test_liveserver.py +++ b/tests/staticfiles_tests/test_liveserver.py @@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase): # app without having set the required STATIC_URL setting.") pass finally: + # Use del to avoid decrementing the database thread sharing count a + # second time. + del cls.server_thread super().tearDownClass() def test_test_test(self): |