diff options
author | django-bot <ops@djangoproject.com> | 2022-02-03 20:24:19 +0100 |
---|---|---|
committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-02-07 20:37:05 +0100 |
commit | 9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch) | |
tree | f0506b668a013d0063e5fba3dbf4863b466713ba /tests/backends | |
parent | f68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff) | |
download | django-9c19aff7c7561e3a82978a272ecdaad40dda5c00.tar.gz |
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'tests/backends')
34 files changed, 1202 insertions, 831 deletions
diff --git a/tests/backends/base/app_unmigrated/migrations/0001_initial.py b/tests/backends/base/app_unmigrated/migrations/0001_initial.py index 2481756a5c..38ecc948e4 100644 --- a/tests/backends/base/app_unmigrated/migrations/0001_initial.py +++ b/tests/backends/base/app_unmigrated/migrations/0001_initial.py @@ -8,10 +8,18 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name='Foo', + name="Foo", fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('name', models.CharField(max_length=255)), + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255)), ], ), ] diff --git a/tests/backends/base/app_unmigrated/models.py b/tests/backends/base/app_unmigrated/models.py index 0c1f64f61d..4bb6074a84 100644 --- a/tests/backends/base/app_unmigrated/models.py +++ b/tests/backends/base/app_unmigrated/models.py @@ -5,4 +5,4 @@ class Foo(models.Model): name = models.CharField(max_length=255) class Meta: - app_label = 'app_unmigrated' + app_label = "app_unmigrated" diff --git a/tests/backends/base/test_base.py b/tests/backends/base/test_base.py index b2ebea5829..00ef766d5d 100644 --- a/tests/backends/base/test_base.py +++ b/tests/backends/base/test_base.py @@ -8,7 +8,6 @@ from ..models import Square class DatabaseWrapperTests(SimpleTestCase): - def test_repr(self): conn = connections[DEFAULT_DB_ALIAS] self.assertEqual( @@ -25,12 +24,12 @@ class DatabaseWrapperTests(SimpleTestCase): conn = connections[DEFAULT_DB_ALIAS] conn_class = type(conn) attr_names = [ - ('client_class', 'client'), - ('creation_class', 'creation'), - ('features_class', 'features'), - ('introspection_class', 'introspection'), - ('ops_class', 'ops'), - ('validation_class', 'validation'), + ("client_class", "client"), + ("creation_class", "creation"), + ("features_class", "features"), + ("introspection_class", "introspection"), + ("ops_class", "ops"), + ("validation_class", "validation"), ] for class_attr_name, instance_attr_name in attr_names: class_attr_value = getattr(conn_class, class_attr_name) @@ -39,23 +38,22 @@ class DatabaseWrapperTests(SimpleTestCase): self.assertIsInstance(instance_attr_value, class_attr_value) def test_initialization_display_name(self): - self.assertEqual(BaseDatabaseWrapper.display_name, 'unknown') - self.assertNotEqual(connection.display_name, 'unknown') + self.assertEqual(BaseDatabaseWrapper.display_name, "unknown") + self.assertNotEqual(connection.display_name, "unknown") class ExecuteWrapperTests(TestCase): - @staticmethod def call_execute(connection, params=None): - ret_val = '1' if params is None else '%s' - sql = 'SELECT ' + ret_val + connection.features.bare_select_suffix + ret_val = "1" if params is None else "%s" + sql = "SELECT " + ret_val + connection.features.bare_select_suffix with connection.cursor() as cursor: cursor.execute(sql, params) def call_executemany(self, connection, params=None): # executemany() must use an update query. Make sure it does nothing # by putting a false condition in the WHERE clause. - sql = 'DELETE FROM {} WHERE 0=1 AND 0=%s'.format(Square._meta.db_table) + sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table) if params is None: params = [(i,) for i in range(3)] with connection.cursor() as cursor: @@ -71,10 +69,10 @@ class ExecuteWrapperTests(TestCase): self.call_execute(connection) self.assertTrue(wrapper.called) (_, sql, params, many, context), _ = wrapper.call_args - self.assertIn('SELECT', sql) + self.assertIn("SELECT", sql) self.assertIsNone(params) self.assertIs(many, False) - self.assertEqual(context['connection'], connection) + self.assertEqual(context["connection"], connection) def test_wrapper_invoked_many(self): wrapper = self.mock_wrapper() @@ -82,16 +80,16 @@ class ExecuteWrapperTests(TestCase): self.call_executemany(connection) self.assertTrue(wrapper.called) (_, sql, param_list, many, context), _ = wrapper.call_args - self.assertIn('DELETE', sql) + self.assertIn("DELETE", sql) self.assertIsInstance(param_list, (list, tuple)) self.assertIs(many, True) - self.assertEqual(context['connection'], connection) + self.assertEqual(context["connection"], connection) def test_database_queried(self): wrapper = self.mock_wrapper() with connection.execute_wrapper(wrapper): with connection.cursor() as cursor: - sql = 'SELECT 17' + connection.features.bare_select_suffix + sql = "SELECT 17" + connection.features.bare_select_suffix cursor.execute(sql) seventeen = cursor.fetchall() self.assertEqual(list(seventeen), [(17,)]) @@ -100,7 +98,9 @@ class ExecuteWrapperTests(TestCase): def test_nested_wrapper_invoked(self): outer_wrapper = self.mock_wrapper() inner_wrapper = self.mock_wrapper() - with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(inner_wrapper): + with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper( + inner_wrapper + ): self.call_execute(connection) self.assertEqual(inner_wrapper.call_count, 1) self.call_executemany(connection) @@ -109,9 +109,12 @@ class ExecuteWrapperTests(TestCase): def test_outer_wrapper_blocks(self): def blocker(*args): pass + wrapper = self.mock_wrapper() c = connection # This alias shortens the next line. - with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(wrapper): + with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper( + wrapper + ): with c.cursor() as cursor: cursor.execute("The database never sees this") self.assertEqual(wrapper.call_count, 1) @@ -128,16 +131,16 @@ class ExecuteWrapperTests(TestCase): def test_wrapper_connection_specific(self): wrapper = self.mock_wrapper() - with connections['other'].execute_wrapper(wrapper): - self.assertEqual(connections['other'].execute_wrappers, [wrapper]) + with connections["other"].execute_wrapper(wrapper): + self.assertEqual(connections["other"].execute_wrappers, [wrapper]) self.call_execute(connection) self.assertFalse(wrapper.called) self.assertEqual(connection.execute_wrappers, []) - self.assertEqual(connections['other'].execute_wrappers, []) + self.assertEqual(connections["other"].execute_wrappers, []) class ConnectionHealthChecksTests(SimpleTestCase): - databases = {'default'} + databases = {"default"} def setUp(self): # All test cases here need newly configured and created connections. @@ -146,25 +149,28 @@ class ConnectionHealthChecksTests(SimpleTestCase): self.addCleanup(connection.close) def patch_settings_dict(self, conn_health_checks): - self.settings_dict_patcher = patch.dict(connection.settings_dict, { - **connection.settings_dict, - 'CONN_MAX_AGE': None, - 'CONN_HEALTH_CHECKS': conn_health_checks, - }) + self.settings_dict_patcher = patch.dict( + connection.settings_dict, + { + **connection.settings_dict, + "CONN_MAX_AGE": None, + "CONN_HEALTH_CHECKS": conn_health_checks, + }, + ) self.settings_dict_patcher.start() self.addCleanup(self.settings_dict_patcher.stop) def run_query(self): with connection.cursor() as cursor: - cursor.execute('SELECT 42' + connection.features.bare_select_suffix) + cursor.execute("SELECT 42" + connection.features.bare_select_suffix) - @skipUnlessDBFeature('test_db_allows_multiple_connections') + @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_health_checks_enabled(self): self.patch_settings_dict(conn_health_checks=True) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. - with patch.object(connection, 'is_usable', side_effect=AssertionError): + with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() old_connection = connection.connection @@ -173,7 +179,9 @@ class ConnectionHealthChecksTests(SimpleTestCase): self.assertIs(old_connection, connection.connection) # Simulate connection health check failing. - with patch.object(connection, 'is_usable', return_value=False) as mocked_is_usable: + with patch.object( + connection, "is_usable", return_value=False + ) as mocked_is_usable: self.run_query() new_connection = connection.connection # A new connection is established. @@ -194,13 +202,13 @@ class ConnectionHealthChecksTests(SimpleTestCase): self.run_query() self.assertIs(new_connection, connection.connection) - @skipUnlessDBFeature('test_db_allows_multiple_connections') + @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_health_checks_enabled_errors_occurred(self): self.patch_settings_dict(conn_health_checks=True) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. - with patch.object(connection, 'is_usable', side_effect=AssertionError): + with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() old_connection = connection.connection @@ -213,16 +221,16 @@ class ConnectionHealthChecksTests(SimpleTestCase): # No additional health checks after the one in # close_if_unusable_or_obsolete() are executed during this "request" # when running queries. - with patch.object(connection, 'is_usable', side_effect=AssertionError): + with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() - @skipUnlessDBFeature('test_db_allows_multiple_connections') + @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_health_checks_disabled(self): self.patch_settings_dict(conn_health_checks=False) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. - with patch.object(connection, 'is_usable', side_effect=AssertionError): + with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() old_connection = connection.connection @@ -231,7 +239,7 @@ class ConnectionHealthChecksTests(SimpleTestCase): # Persistent connections are enabled (connection is not). self.assertIs(old_connection, connection.connection) # Health checks are not performed. - with patch.object(connection, 'is_usable', side_effect=AssertionError): + with patch.object(connection, "is_usable", side_effect=AssertionError): self.run_query() # Health check wasn't performed and the connection is unchanged. self.assertIs(old_connection, connection.connection) @@ -240,13 +248,13 @@ class ConnectionHealthChecksTests(SimpleTestCase): # the current "request". self.assertIs(old_connection, connection.connection) - @skipUnlessDBFeature('test_db_allows_multiple_connections') + @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_set_autocommit_health_checks_enabled(self): self.patch_settings_dict(conn_health_checks=True) self.assertIsNone(connection.connection) # Newly created connections are considered healthy without performing # the health check. - with patch.object(connection, 'is_usable', side_effect=AssertionError): + with patch.object(connection, "is_usable", side_effect=AssertionError): # Simulate outermost atomic block: changing autocommit for # a connection. connection.set_autocommit(False) @@ -261,7 +269,9 @@ class ConnectionHealthChecksTests(SimpleTestCase): self.assertIs(old_connection, connection.connection) # Simulate connection health check failing. - with patch.object(connection, 'is_usable', return_value=False) as mocked_is_usable: + with patch.object( + connection, "is_usable", return_value=False + ) as mocked_is_usable: # Simulate outermost atomic block: changing autocommit for # a connection. connection.set_autocommit(False) diff --git a/tests/backends/base/test_client.py b/tests/backends/base/test_client.py index d9e5cc8883..7ed0604c32 100644 --- a/tests/backends/base/test_client.py +++ b/tests/backends/base/test_client.py @@ -11,8 +11,8 @@ class SimpleDatabaseClientTests(SimpleTestCase): def test_settings_to_cmd_args_env(self): msg = ( - 'subclasses of BaseDatabaseClient must provide a ' - 'settings_to_cmd_args_env() method or override a runshell().' + "subclasses of BaseDatabaseClient must provide a " + "settings_to_cmd_args_env() method or override a runshell()." ) with self.assertRaisesMessage(NotImplementedError, msg): self.client.settings_to_cmd_args_env(None, None) @@ -20,10 +20,10 @@ class SimpleDatabaseClientTests(SimpleTestCase): def test_runshell_use_environ(self): for env in [None, {}]: with self.subTest(env=env): - with mock.patch('subprocess.run') as run: + with mock.patch("subprocess.run") as run: with mock.patch.object( BaseDatabaseClient, - 'settings_to_cmd_args_env', + "settings_to_cmd_args_env", return_value=([], env), ): self.client.runshell(None) diff --git a/tests/backends/base/test_creation.py b/tests/backends/base/test_creation.py index 825fb872ed..9593e13462 100644 --- a/tests/backends/base/test_creation.py +++ b/tests/backends/base/test_creation.py @@ -4,14 +4,16 @@ import os from unittest import mock from django.db import DEFAULT_DB_ALIAS, connection, connections -from django.db.backends.base.creation import ( - TEST_DATABASE_PREFIX, BaseDatabaseCreation, -) +from django.db.backends.base.creation import TEST_DATABASE_PREFIX, BaseDatabaseCreation from django.test import SimpleTestCase, TransactionTestCase from django.test.utils import override_settings from ..models import ( - CircularA, CircularB, Object, ObjectReference, ObjectSelfReference, + CircularA, + CircularB, + Object, + ObjectReference, + ObjectSelfReference, SchoolClass, ) @@ -29,135 +31,146 @@ def get_connection_copy(): class TestDbSignatureTests(SimpleTestCase): def test_default_name(self): # A test db name isn't set. - prod_name = 'hodor' + prod_name = "hodor" test_connection = get_connection_copy() - test_connection.settings_dict['NAME'] = prod_name - test_connection.settings_dict['TEST'] = {'NAME': None} + test_connection.settings_dict["NAME"] = prod_name + test_connection.settings_dict["TEST"] = {"NAME": None} signature = BaseDatabaseCreation(test_connection).test_db_signature() self.assertEqual(signature[3], TEST_DATABASE_PREFIX + prod_name) def test_custom_test_name(self): # A regular test db name is set. - test_name = 'hodor' + test_name = "hodor" test_connection = get_connection_copy() - test_connection.settings_dict['TEST'] = {'NAME': test_name} + test_connection.settings_dict["TEST"] = {"NAME": test_name} signature = BaseDatabaseCreation(test_connection).test_db_signature() self.assertEqual(signature[3], test_name) def test_custom_test_name_with_test_prefix(self): # A test db name prefixed with TEST_DATABASE_PREFIX is set. - test_name = TEST_DATABASE_PREFIX + 'hodor' + test_name = TEST_DATABASE_PREFIX + "hodor" test_connection = get_connection_copy() - test_connection.settings_dict['TEST'] = {'NAME': test_name} + test_connection.settings_dict["TEST"] = {"NAME": test_name} signature = BaseDatabaseCreation(test_connection).test_db_signature() self.assertEqual(signature[3], test_name) -@override_settings(INSTALLED_APPS=['backends.base.app_unmigrated']) -@mock.patch.object(connection, 'ensure_connection') -@mock.patch.object(connection, 'prepare_database') -@mock.patch('django.db.migrations.recorder.MigrationRecorder.has_table', return_value=False) -@mock.patch('django.core.management.commands.migrate.Command.sync_apps') +@override_settings(INSTALLED_APPS=["backends.base.app_unmigrated"]) +@mock.patch.object(connection, "ensure_connection") +@mock.patch.object(connection, "prepare_database") +@mock.patch( + "django.db.migrations.recorder.MigrationRecorder.has_table", return_value=False +) +@mock.patch("django.core.management.commands.migrate.Command.sync_apps") class TestDbCreationTests(SimpleTestCase): - available_apps = ['backends.base.app_unmigrated'] + available_apps = ["backends.base.app_unmigrated"] - @mock.patch('django.db.migrations.executor.MigrationExecutor.migrate') - def test_migrate_test_setting_false(self, mocked_migrate, mocked_sync_apps, *mocked_objects): + @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate") + def test_migrate_test_setting_false( + self, mocked_migrate, mocked_sync_apps, *mocked_objects + ): test_connection = get_connection_copy() - test_connection.settings_dict['TEST']['MIGRATE'] = False + test_connection.settings_dict["TEST"]["MIGRATE"] = False creation = test_connection.creation_class(test_connection) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": # Don't close connection on Oracle. creation.connection.close = mock.Mock() - old_database_name = test_connection.settings_dict['NAME'] + old_database_name = test_connection.settings_dict["NAME"] try: - with mock.patch.object(creation, '_create_test_db'): + with mock.patch.object(creation, "_create_test_db"): creation.create_test_db(verbosity=0, autoclobber=True, serialize=False) # Migrations don't run. mocked_migrate.assert_called() args, kwargs = mocked_migrate.call_args self.assertEqual(args, ([],)) - self.assertEqual(kwargs['plan'], []) + self.assertEqual(kwargs["plan"], []) # App is synced. mocked_sync_apps.assert_called() mocked_args, _ = mocked_sync_apps.call_args - self.assertEqual(mocked_args[1], {'app_unmigrated'}) + self.assertEqual(mocked_args[1], {"app_unmigrated"}) finally: - with mock.patch.object(creation, '_destroy_test_db'): + with mock.patch.object(creation, "_destroy_test_db"): creation.destroy_test_db(old_database_name, verbosity=0) - @mock.patch('django.db.migrations.executor.MigrationRecorder.ensure_schema') + @mock.patch("django.db.migrations.executor.MigrationRecorder.ensure_schema") def test_migrate_test_setting_false_ensure_schema( - self, mocked_ensure_schema, mocked_sync_apps, *mocked_objects, + self, + mocked_ensure_schema, + mocked_sync_apps, + *mocked_objects, ): test_connection = get_connection_copy() - test_connection.settings_dict['TEST']['MIGRATE'] = False + test_connection.settings_dict["TEST"]["MIGRATE"] = False creation = test_connection.creation_class(test_connection) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": # Don't close connection on Oracle. creation.connection.close = mock.Mock() - old_database_name = test_connection.settings_dict['NAME'] + old_database_name = test_connection.settings_dict["NAME"] try: - with mock.patch.object(creation, '_create_test_db'): + with mock.patch.object(creation, "_create_test_db"): creation.create_test_db(verbosity=0, autoclobber=True, serialize=False) # The django_migrations table is not created. mocked_ensure_schema.assert_not_called() # App is synced. mocked_sync_apps.assert_called() mocked_args, _ = mocked_sync_apps.call_args - self.assertEqual(mocked_args[1], {'app_unmigrated'}) + self.assertEqual(mocked_args[1], {"app_unmigrated"}) finally: - with mock.patch.object(creation, '_destroy_test_db'): + with mock.patch.object(creation, "_destroy_test_db"): creation.destroy_test_db(old_database_name, verbosity=0) - @mock.patch('django.db.migrations.executor.MigrationExecutor.migrate') - def test_migrate_test_setting_true(self, mocked_migrate, mocked_sync_apps, *mocked_objects): + @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate") + def test_migrate_test_setting_true( + self, mocked_migrate, mocked_sync_apps, *mocked_objects + ): test_connection = get_connection_copy() - test_connection.settings_dict['TEST']['MIGRATE'] = True + test_connection.settings_dict["TEST"]["MIGRATE"] = True creation = test_connection.creation_class(test_connection) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": # Don't close connection on Oracle. creation.connection.close = mock.Mock() - old_database_name = test_connection.settings_dict['NAME'] + old_database_name = test_connection.settings_dict["NAME"] try: - with mock.patch.object(creation, '_create_test_db'): + with mock.patch.object(creation, "_create_test_db"): creation.create_test_db(verbosity=0, autoclobber=True, serialize=False) # Migrations run. mocked_migrate.assert_called() args, kwargs = mocked_migrate.call_args - self.assertEqual(args, ([('app_unmigrated', '0001_initial')],)) - self.assertEqual(len(kwargs['plan']), 1) + self.assertEqual(args, ([("app_unmigrated", "0001_initial")],)) + self.assertEqual(len(kwargs["plan"]), 1) # App is not synced. mocked_sync_apps.assert_not_called() finally: - with mock.patch.object(creation, '_destroy_test_db'): + with mock.patch.object(creation, "_destroy_test_db"): creation.destroy_test_db(old_database_name, verbosity=0) - @mock.patch.dict(os.environ, {'RUNNING_DJANGOS_TEST_SUITE': ''}) - @mock.patch('django.db.migrations.executor.MigrationExecutor.migrate') - @mock.patch.object(BaseDatabaseCreation, 'mark_expected_failures_and_skips') - def test_mark_expected_failures_and_skips_call(self, mark_expected_failures_and_skips, *mocked_objects): + @mock.patch.dict(os.environ, {"RUNNING_DJANGOS_TEST_SUITE": ""}) + @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate") + @mock.patch.object(BaseDatabaseCreation, "mark_expected_failures_and_skips") + def test_mark_expected_failures_and_skips_call( + self, mark_expected_failures_and_skips, *mocked_objects + ): """ mark_expected_failures_and_skips() isn't called unless RUNNING_DJANGOS_TEST_SUITE is 'true'. """ test_connection = get_connection_copy() creation = test_connection.creation_class(test_connection) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": # Don't close connection on Oracle. creation.connection.close = mock.Mock() - old_database_name = test_connection.settings_dict['NAME'] + old_database_name = test_connection.settings_dict["NAME"] try: - with mock.patch.object(creation, '_create_test_db'): + with mock.patch.object(creation, "_create_test_db"): creation.create_test_db(verbosity=0, autoclobber=True, serialize=False) self.assertIs(mark_expected_failures_and_skips.called, False) finally: - with mock.patch.object(creation, '_destroy_test_db'): + with mock.patch.object(creation, "_destroy_test_db"): creation.destroy_test_db(old_database_name, verbosity=0) class TestDeserializeDbFromString(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def test_circular_reference(self): # deserialize_db_from_string() handles circular references. @@ -184,38 +197,38 @@ class TestDeserializeDbFromString(TransactionTestCase): def test_self_reference(self): # serialize_db_to_string() and deserialize_db_from_string() handles # self references. - obj_1 = ObjectSelfReference.objects.create(key='X') - obj_2 = ObjectSelfReference.objects.create(key='Y', obj=obj_1) + obj_1 = ObjectSelfReference.objects.create(key="X") + obj_2 = ObjectSelfReference.objects.create(key="Y", obj=obj_1) obj_1.obj = obj_2 obj_1.save() # Serialize objects. - with mock.patch('django.db.migrations.loader.MigrationLoader') as loader: + with mock.patch("django.db.migrations.loader.MigrationLoader") as loader: # serialize_db_to_string() serializes only migrated apps, so mark # the backends app as migrated. loader_instance = loader.return_value - loader_instance.migrated_apps = {'backends'} + loader_instance.migrated_apps = {"backends"} data = connection.creation.serialize_db_to_string() ObjectSelfReference.objects.all().delete() # Deserialize objects. connection.creation.deserialize_db_from_string(data) - obj_1 = ObjectSelfReference.objects.get(key='X') - obj_2 = ObjectSelfReference.objects.get(key='Y') + obj_1 = ObjectSelfReference.objects.get(key="X") + obj_2 = ObjectSelfReference.objects.get(key="Y") self.assertEqual(obj_1.obj, obj_2) self.assertEqual(obj_2.obj, obj_1) def test_circular_reference_with_natural_key(self): # serialize_db_to_string() and deserialize_db_from_string() handles # circular references for models with natural keys. - obj_a = CircularA.objects.create(key='A') - obj_b = CircularB.objects.create(key='B', obj=obj_a) + obj_a = CircularA.objects.create(key="A") + obj_b = CircularB.objects.create(key="B", obj=obj_a) obj_a.obj = obj_b obj_a.save() # Serialize objects. - with mock.patch('django.db.migrations.loader.MigrationLoader') as loader: + with mock.patch("django.db.migrations.loader.MigrationLoader") as loader: # serialize_db_to_string() serializes only migrated apps, so mark # the backends app as migrated. loader_instance = loader.return_value - loader_instance.migrated_apps = {'backends'} + loader_instance.migrated_apps = {"backends"} data = connection.creation.serialize_db_to_string() CircularA.objects.all().delete() CircularB.objects.all().delete() @@ -228,11 +241,11 @@ class TestDeserializeDbFromString(TransactionTestCase): def test_serialize_db_to_string_base_manager(self): SchoolClass.objects.create(year=1000, last_updated=datetime.datetime.now()) - with mock.patch('django.db.migrations.loader.MigrationLoader') as loader: + with mock.patch("django.db.migrations.loader.MigrationLoader") as loader: # serialize_db_to_string() serializes only migrated apps, so mark # the backends app as migrated. loader_instance = loader.return_value - loader_instance.migrated_apps = {'backends'} + loader_instance.migrated_apps = {"backends"} data = connection.creation.serialize_db_to_string() self.assertIn('"model": "backends.schoolclass"', data) self.assertIn('"year": 1000', data) @@ -256,14 +269,14 @@ class TestMarkTests(SimpleTestCase): test_connection = get_connection_copy() creation = BaseDatabaseCreation(test_connection) creation.connection.features.django_test_expected_failures = { - 'backends.base.test_creation.expected_failure_test_function', + "backends.base.test_creation.expected_failure_test_function", } creation.connection.features.django_test_skips = { - 'skip test class': { - 'backends.base.test_creation.SkipTestClass', + "skip test class": { + "backends.base.test_creation.SkipTestClass", }, - 'skip test function': { - 'backends.base.test_creation.skip_test_function', + "skip test function": { + "backends.base.test_creation.skip_test_function", }, } creation.mark_expected_failures_and_skips() @@ -274,10 +287,10 @@ class TestMarkTests(SimpleTestCase): self.assertIs(SkipTestClass.__unittest_skip__, True) self.assertEqual( SkipTestClass.__unittest_skip_why__, - 'skip test class', + "skip test class", ) self.assertIs(skip_test_function.__unittest_skip__, True) self.assertEqual( skip_test_function.__unittest_skip_why__, - 'skip test function', + "skip test function", ) diff --git a/tests/backends/base/test_features.py b/tests/backends/base/test_features.py index 9b67cfec47..f8a8e945e0 100644 --- a/tests/backends/base/test_features.py +++ b/tests/backends/base/test_features.py @@ -3,6 +3,5 @@ from django.test import SimpleTestCase class TestDatabaseFeatures(SimpleTestCase): - def test_nonexistent_feature(self): - self.assertFalse(hasattr(connection.features, 'nonexistent')) + self.assertFalse(hasattr(connection.features, "nonexistent")) diff --git a/tests/backends/base/test_introspection.py b/tests/backends/base/test_introspection.py index 636d6683c0..284e98f2ae 100644 --- a/tests/backends/base/test_introspection.py +++ b/tests/backends/base/test_introspection.py @@ -5,33 +5,33 @@ from django.test import SimpleTestCase class SimpleDatabaseIntrospectionTests(SimpleTestCase): may_require_msg = ( - 'subclasses of BaseDatabaseIntrospection may require a %s() method' + "subclasses of BaseDatabaseIntrospection may require a %s() method" ) def setUp(self): self.introspection = BaseDatabaseIntrospection(connection=connection) def test_get_table_list(self): - msg = self.may_require_msg % 'get_table_list' + msg = self.may_require_msg % "get_table_list" with self.assertRaisesMessage(NotImplementedError, msg): self.introspection.get_table_list(None) def test_get_table_description(self): - msg = self.may_require_msg % 'get_table_description' + msg = self.may_require_msg % "get_table_description" with self.assertRaisesMessage(NotImplementedError, msg): self.introspection.get_table_description(None, None) def test_get_sequences(self): - msg = self.may_require_msg % 'get_sequences' + msg = self.may_require_msg % "get_sequences" with self.assertRaisesMessage(NotImplementedError, msg): self.introspection.get_sequences(None, None) def test_get_relations(self): - msg = self.may_require_msg % 'get_relations' + msg = self.may_require_msg % "get_relations" with self.assertRaisesMessage(NotImplementedError, msg): self.introspection.get_relations(None, None) def test_get_constraints(self): - msg = self.may_require_msg % 'get_constraints' + msg = self.may_require_msg % "get_constraints" with self.assertRaisesMessage(NotImplementedError, msg): self.introspection.get_constraints(None, None) diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py index 535cb20f41..b19b7ee558 100644 --- a/tests/backends/base/test_operations.py +++ b/tests/backends/base/test_operations.py @@ -5,7 +5,10 @@ from django.db import NotSupportedError, connection, transaction from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models import DurationField, Value from django.test import ( - SimpleTestCase, TestCase, TransactionTestCase, override_settings, + SimpleTestCase, + TestCase, + TransactionTestCase, + override_settings, skipIfDBFeature, ) from django.utils import timezone @@ -14,60 +17,70 @@ from ..models import Author, Book class SimpleDatabaseOperationTests(SimpleTestCase): - may_require_msg = 'subclasses of BaseDatabaseOperations may require a %s() method' + may_require_msg = "subclasses of BaseDatabaseOperations may require a %s() method" def setUp(self): self.ops = BaseDatabaseOperations(connection=connection) def test_deferrable_sql(self): - self.assertEqual(self.ops.deferrable_sql(), '') + self.assertEqual(self.ops.deferrable_sql(), "") def test_end_transaction_rollback(self): - self.assertEqual(self.ops.end_transaction_sql(success=False), 'ROLLBACK;') + self.assertEqual(self.ops.end_transaction_sql(success=False), "ROLLBACK;") def test_no_limit_value(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'no_limit_value'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "no_limit_value" + ): self.ops.no_limit_value() def test_quote_name(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'quote_name'): - self.ops.quote_name('a') + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "quote_name" + ): + self.ops.quote_name("a") def test_regex_lookup(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'regex_lookup'): - self.ops.regex_lookup(lookup_type='regex') + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "regex_lookup" + ): + self.ops.regex_lookup(lookup_type="regex") def test_set_time_zone_sql(self): - self.assertEqual(self.ops.set_time_zone_sql(), '') + self.assertEqual(self.ops.set_time_zone_sql(), "") def test_sql_flush(self): - msg = 'subclasses of BaseDatabaseOperations must provide an sql_flush() method' + msg = "subclasses of BaseDatabaseOperations must provide an sql_flush() method" with self.assertRaisesMessage(NotImplementedError, msg): self.ops.sql_flush(None, None) def test_pk_default_value(self): - self.assertEqual(self.ops.pk_default_value(), 'DEFAULT') + self.assertEqual(self.ops.pk_default_value(), "DEFAULT") def test_tablespace_sql(self): - self.assertEqual(self.ops.tablespace_sql(None), '') + self.assertEqual(self.ops.tablespace_sql(None), "") def test_sequence_reset_by_name_sql(self): self.assertEqual(self.ops.sequence_reset_by_name_sql(None, []), []) def test_adapt_unknown_value_decimal(self): - value = decimal.Decimal('3.14') + value = decimal.Decimal("3.14") self.assertEqual( self.ops.adapt_unknown_value(value), - self.ops.adapt_decimalfield_value(value) + self.ops.adapt_decimalfield_value(value), ) def test_adapt_unknown_value_date(self): value = timezone.now().date() - self.assertEqual(self.ops.adapt_unknown_value(value), self.ops.adapt_datefield_value(value)) + self.assertEqual( + self.ops.adapt_unknown_value(value), self.ops.adapt_datefield_value(value) + ) def test_adapt_unknown_value_time(self): value = timezone.now().time() - self.assertEqual(self.ops.adapt_unknown_value(value), self.ops.adapt_timefield_value(value)) + self.assertEqual( + self.ops.adapt_unknown_value(value), self.ops.adapt_timefield_value(value) + ) def test_adapt_timefield_value_none(self): self.assertIsNone(self.ops.adapt_timefield_value(None)) @@ -84,7 +97,7 @@ class SimpleDatabaseOperationTests(SimpleTestCase): self.assertEqual(self.ops.adapt_datetimefield_value(value), value) def test_adapt_timefield_value(self): - msg = 'Django does not support timezone-aware times.' + msg = "Django does not support timezone-aware times." with self.assertRaisesMessage(ValueError, msg): self.ops.adapt_timefield_value(timezone.make_aware(timezone.now())) @@ -94,40 +107,56 @@ class SimpleDatabaseOperationTests(SimpleTestCase): self.assertEqual(self.ops.adapt_timefield_value(now), str(now)) def test_format_for_duration_arithmetic(self): - msg = self.may_require_msg % 'format_for_duration_arithmetic' + msg = self.may_require_msg % "format_for_duration_arithmetic" with self.assertRaisesMessage(NotImplementedError, msg): self.ops.format_for_duration_arithmetic(None) def test_date_extract_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'date_extract_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "date_extract_sql" + ): self.ops.date_extract_sql(None, None) def test_time_extract_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'date_extract_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "date_extract_sql" + ): self.ops.time_extract_sql(None, None) def test_date_trunc_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'date_trunc_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "date_trunc_sql" + ): self.ops.date_trunc_sql(None, None) def test_time_trunc_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'time_trunc_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "time_trunc_sql" + ): self.ops.time_trunc_sql(None, None) def test_datetime_trunc_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'datetime_trunc_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "datetime_trunc_sql" + ): self.ops.datetime_trunc_sql(None, None, None) def test_datetime_cast_date_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'datetime_cast_date_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "datetime_cast_date_sql" + ): self.ops.datetime_cast_date_sql(None, None) def test_datetime_cast_time_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'datetime_cast_time_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "datetime_cast_time_sql" + ): self.ops.datetime_cast_time_sql(None, None) def test_datetime_extract_sql(self): - with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'datetime_extract_sql'): + with self.assertRaisesMessage( + NotImplementedError, self.may_require_msg % "datetime_extract_sql" + ): self.ops.datetime_extract_sql(None, None, None) @@ -135,41 +164,41 @@ class DatabaseOperationTests(TestCase): def setUp(self): self.ops = BaseDatabaseOperations(connection=connection) - @skipIfDBFeature('supports_over_clause') + @skipIfDBFeature("supports_over_clause") def test_window_frame_raise_not_supported_error(self): - msg = 'This backend does not support window expressions.' + msg = "This backend does not support window expressions." with self.assertRaisesMessage(NotSupportedError, msg): self.ops.window_frame_rows_start_end() - @skipIfDBFeature('can_distinct_on_fields') + @skipIfDBFeature("can_distinct_on_fields") def test_distinct_on_fields(self): - msg = 'DISTINCT ON fields is not supported by this database backend' + msg = "DISTINCT ON fields is not supported by this database backend" with self.assertRaisesMessage(NotSupportedError, msg): - self.ops.distinct_sql(['a', 'b'], None) + self.ops.distinct_sql(["a", "b"], None) - @skipIfDBFeature('supports_temporal_subtraction') + @skipIfDBFeature("supports_temporal_subtraction") def test_subtract_temporals(self): duration_field = DurationField() duration_field_internal_type = duration_field.get_internal_type() msg = ( - 'This backend does not support %s subtraction.' % - duration_field_internal_type + "This backend does not support %s subtraction." + % duration_field_internal_type ) with self.assertRaisesMessage(NotSupportedError, msg): self.ops.subtract_temporals(duration_field_internal_type, None, None) class SqlFlushTests(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def test_sql_flush_no_tables(self): self.assertEqual(connection.ops.sql_flush(no_style(), []), []) def test_execute_sql_flush_statements(self): with transaction.atomic(): - author = Author.objects.create(name='George Orwell') + author = Author.objects.create(name="George Orwell") Book.objects.create(author=author) - author = Author.objects.create(name='Harper Lee') + author = Author.objects.create(name="Harper Lee") Book.objects.create(author=author) Book.objects.create(author=author) self.assertIs(Author.objects.exists(), True) @@ -187,7 +216,7 @@ class SqlFlushTests(TransactionTestCase): self.assertIs(Author.objects.exists(), False) self.assertIs(Book.objects.exists(), False) if connection.features.supports_sequence_reset: - author = Author.objects.create(name='F. Scott Fitzgerald') + author = Author.objects.create(name="F. Scott Fitzgerald") self.assertEqual(author.pk, 1) book = Book.objects.create(author=author) self.assertEqual(book.pk, 1) diff --git a/tests/backends/base/test_schema.py b/tests/backends/base/test_schema.py index 2ecad098a6..5409789b13 100644 --- a/tests/backends/base/test_schema.py +++ b/tests/backends/base/test_schema.py @@ -4,9 +4,9 @@ from django.test import SimpleTestCase class SchemaEditorTests(SimpleTestCase): - def test_effective_default_callable(self): """SchemaEditor.effective_default() shouldn't call callable defaults.""" + class MyStr(str): def __call__(self): return self diff --git a/tests/backends/models.py b/tests/backends/models.py index 15c81052ea..99e9e86f44 100644 --- a/tests/backends/models.py +++ b/tests/backends/models.py @@ -1,6 +1,4 @@ -from django.contrib.contenttypes.fields import ( - GenericForeignKey, GenericRelation, -) +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.models import ContentType from django.db import models @@ -18,7 +16,7 @@ class Person(models.Model): last_name = models.CharField(max_length=20) def __str__(self): - return '%s %s' % (self.first_name, self.last_name) + return "%s %s" % (self.first_name, self.last_name) class SchoolClassManager(models.Manager): @@ -35,25 +33,33 @@ class SchoolClass(models.Model): class VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ(models.Model): - primary_key_is_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz = models.AutoField(primary_key=True) - charfield_is_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz = models.CharField(max_length=100) - m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz = models.ManyToManyField(Person, blank=True) + primary_key_is_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz = models.AutoField( + primary_key=True + ) + charfield_is_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz = models.CharField( + max_length=100 + ) + m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz = ( + models.ManyToManyField(Person, blank=True) + ) class Tag(models.Model): name = models.CharField(max_length=30) - content_type = models.ForeignKey(ContentType, models.CASCADE, related_name='backend_tags') + content_type = models.ForeignKey( + ContentType, models.CASCADE, related_name="backend_tags" + ) object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') + content_object = GenericForeignKey("content_type", "object_id") class Post(models.Model): name = models.CharField(max_length=30) text = models.TextField() - tags = GenericRelation('Tag') + tags = GenericRelation("Tag") class Meta: - db_table = 'CaseSensitive_Post' + db_table = "CaseSensitive_Post" class Reporter(models.Model): @@ -77,7 +83,7 @@ class Article(models.Model): ReporterProxy, models.SET_NULL, null=True, - related_name='reporter_proxy', + related_name="reporter_proxy", ) def __str__(self): @@ -95,8 +101,10 @@ class Item(models.Model): class Object(models.Model): - related_objects = models.ManyToManyField("self", db_constraint=False, symmetrical=False) - obj_ref = models.ForeignKey('ObjectReference', models.CASCADE, null=True) + related_objects = models.ManyToManyField( + "self", db_constraint=False, symmetrical=False + ) + obj_ref = models.ForeignKey("ObjectReference", models.CASCADE, null=True) def __str__(self): return str(self.id) @@ -111,12 +119,12 @@ class ObjectReference(models.Model): class ObjectSelfReference(models.Model): key = models.CharField(max_length=3, unique=True) - obj = models.ForeignKey('ObjectSelfReference', models.SET_NULL, null=True) + obj = models.ForeignKey("ObjectSelfReference", models.SET_NULL, null=True) class CircularA(models.Model): key = models.CharField(max_length=3, unique=True) - obj = models.ForeignKey('CircularB', models.SET_NULL, null=True) + obj = models.ForeignKey("CircularB", models.SET_NULL, null=True) def natural_key(self): return (self.key,) @@ -124,7 +132,7 @@ class CircularA(models.Model): class CircularB(models.Model): key = models.CharField(max_length=3, unique=True) - obj = models.ForeignKey('CircularA', models.SET_NULL, null=True) + obj = models.ForeignKey("CircularA", models.SET_NULL, null=True) def natural_key(self): return (self.key,) @@ -139,12 +147,12 @@ class Author(models.Model): class Book(models.Model): - author = models.ForeignKey(Author, models.CASCADE, to_field='name') + author = models.ForeignKey(Author, models.CASCADE, to_field="name") class SQLKeywordsModel(models.Model): - id = models.AutoField(primary_key=True, db_column='select') - reporter = models.ForeignKey(Reporter, models.CASCADE, db_column='where') + id = models.AutoField(primary_key=True, db_column="select") + reporter = models.ForeignKey(Reporter, models.CASCADE, db_column="where") class Meta: - db_table = 'order' + db_table = "order" diff --git a/tests/backends/mysql/test_creation.py b/tests/backends/mysql/test_creation.py index 0d3480adea..151d00ff3f 100644 --- a/tests/backends/mysql/test_creation.py +++ b/tests/backends/mysql/test_creation.py @@ -9,34 +9,39 @@ from django.db.backends.mysql.creation import DatabaseCreation from django.test import SimpleTestCase -@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL tests') +@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests") class DatabaseCreationTests(SimpleTestCase): - def _execute_raise_database_exists(self, cursor, parameters, keepdb=False): - raise DatabaseError(1007, "Can't create database '%s'; database exists" % parameters['dbname']) + raise DatabaseError( + 1007, "Can't create database '%s'; database exists" % parameters["dbname"] + ) def _execute_raise_access_denied(self, cursor, parameters, keepdb=False): raise DatabaseError(1044, "Access denied for user") def patch_test_db_creation(self, execute_create_test_db): - return mock.patch.object(BaseDatabaseCreation, '_execute_create_test_db', execute_create_test_db) + return mock.patch.object( + BaseDatabaseCreation, "_execute_create_test_db", execute_create_test_db + ) - @mock.patch('sys.stdout', new_callable=StringIO) - @mock.patch('sys.stderr', new_callable=StringIO) + @mock.patch("sys.stdout", new_callable=StringIO) + @mock.patch("sys.stderr", new_callable=StringIO) def test_create_test_db_database_exists(self, *mocked_objects): # Simulate test database creation raising "database exists" creation = DatabaseCreation(connection) with self.patch_test_db_creation(self._execute_raise_database_exists): - with mock.patch('builtins.input', return_value='no'): + with mock.patch("builtins.input", return_value="no"): with self.assertRaises(SystemExit): # SystemExit is raised if the user answers "no" to the # prompt asking if it's okay to delete the test database. - creation._create_test_db(verbosity=0, autoclobber=False, keepdb=False) + creation._create_test_db( + verbosity=0, autoclobber=False, keepdb=False + ) # "Database exists" shouldn't appear when keepdb is on creation._create_test_db(verbosity=0, autoclobber=False, keepdb=True) - @mock.patch('sys.stdout', new_callable=StringIO) - @mock.patch('sys.stderr', new_callable=StringIO) + @mock.patch("sys.stdout", new_callable=StringIO) + @mock.patch("sys.stderr", new_callable=StringIO) def test_create_test_db_unexpected_error(self, *mocked_objects): # Simulate test database creation raising unexpected error creation = DatabaseCreation(connection) @@ -47,8 +52,8 @@ class DatabaseCreationTests(SimpleTestCase): def test_clone_test_db_database_exists(self): creation = DatabaseCreation(connection) with self.patch_test_db_creation(self._execute_raise_database_exists): - with mock.patch.object(DatabaseCreation, '_clone_db') as _clone_db: - creation._clone_test_db('suffix', verbosity=0, keepdb=True) + with mock.patch.object(DatabaseCreation, "_clone_db") as _clone_db: + creation._clone_test_db("suffix", verbosity=0, keepdb=True) _clone_db.assert_not_called() def test_clone_test_db_options_ordering(self): @@ -56,30 +61,32 @@ class DatabaseCreationTests(SimpleTestCase): try: saved_settings = connection.settings_dict connection.settings_dict = { - 'NAME': 'source_db', - 'USER': '', - 'PASSWORD': '', - 'PORT': '', - 'HOST': '', - 'ENGINE': 'django.db.backends.mysql', - 'OPTIONS': { - 'read_default_file': 'my.cnf', + "NAME": "source_db", + "USER": "", + "PASSWORD": "", + "PORT": "", + "HOST": "", + "ENGINE": "django.db.backends.mysql", + "OPTIONS": { + "read_default_file": "my.cnf", }, } - with mock.patch.object(subprocess, 'Popen') as mocked_popen: - creation._clone_db('source_db', 'target_db') - mocked_popen.assert_has_calls([ - mock.call( - [ - 'mysqldump', - '--defaults-file=my.cnf', - '--routines', - '--events', - 'source_db', - ], - stdout=subprocess.PIPE, - env=None, - ), - ]) + with mock.patch.object(subprocess, "Popen") as mocked_popen: + creation._clone_db("source_db", "target_db") + mocked_popen.assert_has_calls( + [ + mock.call( + [ + "mysqldump", + "--defaults-file=my.cnf", + "--routines", + "--events", + "source_db", + ], + stdout=subprocess.PIPE, + env=None, + ), + ] + ) finally: connection.settings_dict = saved_settings diff --git a/tests/backends/mysql/test_features.py b/tests/backends/mysql/test_features.py index 5d27890a5d..ec5bd442fb 100644 --- a/tests/backends/mysql/test_features.py +++ b/tests/backends/mysql/test_features.py @@ -5,17 +5,20 @@ from django.db.backends.mysql.features import DatabaseFeatures from django.test import TestCase -@skipUnless(connection.vendor == 'mysql', 'MySQL tests') +@skipUnless(connection.vendor == "mysql", "MySQL tests") class TestFeatures(TestCase): - def test_supports_transactions(self): """ All storage engines except MyISAM support transactions. """ - with mock.patch('django.db.connection.features._mysql_storage_engine', 'InnoDB'): + with mock.patch( + "django.db.connection.features._mysql_storage_engine", "InnoDB" + ): self.assertTrue(connection.features.supports_transactions) del connection.features.supports_transactions - with mock.patch('django.db.connection.features._mysql_storage_engine', 'MyISAM'): + with mock.patch( + "django.db.connection.features._mysql_storage_engine", "MyISAM" + ): self.assertFalse(connection.features.supports_transactions) del connection.features.supports_transactions @@ -35,6 +38,6 @@ class TestFeatures(TestCase): def test_allows_auto_pk_0(self): with mock.MagicMock() as _connection: - _connection.sql_mode = {'NO_AUTO_VALUE_ON_ZERO'} + _connection.sql_mode = {"NO_AUTO_VALUE_ON_ZERO"} database_features = DatabaseFeatures(_connection) self.assertIs(database_features.allows_auto_pk_0, True) diff --git a/tests/backends/mysql/test_introspection.py b/tests/backends/mysql/test_introspection.py index 4f13622eda..c1247de232 100644 --- a/tests/backends/mysql/test_introspection.py +++ b/tests/backends/mysql/test_introspection.py @@ -4,24 +4,24 @@ from django.db import connection, connections from django.test import TestCase -@skipUnless(connection.vendor == 'mysql', 'MySQL tests') +@skipUnless(connection.vendor == "mysql", "MySQL tests") class ParsingTests(TestCase): def test_parse_constraint_columns(self): _parse_constraint_columns = connection.introspection._parse_constraint_columns tests = ( - ('`height` >= 0', ['height'], ['height']), - ('`cost` BETWEEN 1 AND 10', ['cost'], ['cost']), - ('`ref1` > `ref2`', ['id', 'ref1', 'ref2'], ['ref1', 'ref2']), + ("`height` >= 0", ["height"], ["height"]), + ("`cost` BETWEEN 1 AND 10", ["cost"], ["cost"]), + ("`ref1` > `ref2`", ["id", "ref1", "ref2"], ["ref1", "ref2"]), ( - '`start` IS NULL OR `end` IS NULL OR `start` < `end`', - ['id', 'start', 'end'], - ['start', 'end'], + "`start` IS NULL OR `end` IS NULL OR `start` < `end`", + ["id", "start", "end"], + ["start", "end"], ), - ('JSON_VALID(`json_field`)', ['json_field'], ['json_field']), - ('CHAR_LENGTH(`name`) > 2', ['name'], ['name']), - ("lower(`ref1`) != 'test'", ['id', 'owe', 'ref1'], ['ref1']), - ("lower(`ref1`) != 'test'", ['id', 'lower', 'ref1'], ['ref1']), - ("`name` LIKE 'test%'", ['name'], ['name']), + ("JSON_VALID(`json_field`)", ["json_field"], ["json_field"]), + ("CHAR_LENGTH(`name`) > 2", ["name"], ["name"]), + ("lower(`ref1`) != 'test'", ["id", "owe", "ref1"], ["ref1"]), + ("lower(`ref1`) != 'test'", ["id", "lower", "ref1"], ["ref1"]), + ("`name` LIKE 'test%'", ["name"], ["name"]), ) for check_clause, table_columns, expected_columns in tests: with self.subTest(check_clause): @@ -29,28 +29,32 @@ class ParsingTests(TestCase): self.assertEqual(list(check_columns), expected_columns) -@skipUnless(connection.vendor == 'mysql', 'MySQL tests') +@skipUnless(connection.vendor == "mysql", "MySQL tests") class StorageEngineTests(TestCase): - databases = {'default', 'other'} + databases = {"default", "other"} def test_get_storage_engine(self): - table_name = 'test_storage_engine' - create_sql = 'CREATE TABLE %s (id INTEGER) ENGINE = %%s' % table_name - drop_sql = 'DROP TABLE %s' % table_name - default_connection = connections['default'] - other_connection = connections['other'] + table_name = "test_storage_engine" + create_sql = "CREATE TABLE %s (id INTEGER) ENGINE = %%s" % table_name + drop_sql = "DROP TABLE %s" % table_name + default_connection = connections["default"] + other_connection = connections["other"] try: with default_connection.cursor() as cursor: - cursor.execute(create_sql % 'InnoDB') + cursor.execute(create_sql % "InnoDB") self.assertEqual( - default_connection.introspection.get_storage_engine(cursor, table_name), - 'InnoDB', + default_connection.introspection.get_storage_engine( + cursor, table_name + ), + "InnoDB", ) with other_connection.cursor() as cursor: - cursor.execute(create_sql % 'MyISAM') + cursor.execute(create_sql % "MyISAM") self.assertEqual( - other_connection.introspection.get_storage_engine(cursor, table_name), - 'MyISAM', + other_connection.introspection.get_storage_engine( + cursor, table_name + ), + "MyISAM", ) finally: with default_connection.cursor() as cursor: diff --git a/tests/backends/mysql/test_operations.py b/tests/backends/mysql/test_operations.py index a98e8963b7..bd6170f299 100644 --- a/tests/backends/mysql/test_operations.py +++ b/tests/backends/mysql/test_operations.py @@ -7,7 +7,7 @@ from django.test import SimpleTestCase from ..models import Person, Tag -@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL tests.') +@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests.") class MySQLOperationsTests(SimpleTestCase): def test_sql_flush(self): # allow_cascade doesn't change statements on MySQL. @@ -20,10 +20,10 @@ class MySQLOperationsTests(SimpleTestCase): allow_cascade=allow_cascade, ), [ - 'SET FOREIGN_KEY_CHECKS = 0;', - 'DELETE FROM `backends_person`;', - 'DELETE FROM `backends_tag`;', - 'SET FOREIGN_KEY_CHECKS = 1;', + "SET FOREIGN_KEY_CHECKS = 0;", + "DELETE FROM `backends_person`;", + "DELETE FROM `backends_tag`;", + "SET FOREIGN_KEY_CHECKS = 1;", ], ) @@ -39,9 +39,9 @@ class MySQLOperationsTests(SimpleTestCase): allow_cascade=allow_cascade, ), [ - 'SET FOREIGN_KEY_CHECKS = 0;', - 'TRUNCATE `backends_person`;', - 'TRUNCATE `backends_tag`;', - 'SET FOREIGN_KEY_CHECKS = 1;', + "SET FOREIGN_KEY_CHECKS = 0;", + "TRUNCATE `backends_person`;", + "TRUNCATE `backends_tag`;", + "SET FOREIGN_KEY_CHECKS = 1;", ], ) diff --git a/tests/backends/mysql/test_schema.py b/tests/backends/mysql/test_schema.py index 44f4a07b18..2fb7fea9c5 100644 --- a/tests/backends/mysql/test_schema.py +++ b/tests/backends/mysql/test_schema.py @@ -4,18 +4,19 @@ from django.db import connection from django.test import TestCase -@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL tests') +@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests") class SchemaEditorTests(TestCase): def test_quote_value(self): import MySQLdb + editor = connection.schema_editor() tested_values = [ - ('string', "'string'"), - ('¿Tú hablas inglés?', "'¿Tú hablas inglés?'"), - (b'bytes', b"'bytes'"), - (42, '42'), - (1.754, '1.754e0' if MySQLdb.version_info >= (1, 3, 14) else '1.754'), - (False, b'0' if MySQLdb.version_info >= (1, 4, 0) else '0'), + ("string", "'string'"), + ("¿Tú hablas inglés?", "'¿Tú hablas inglés?'"), + (b"bytes", b"'bytes'"), + (42, "42"), + (1.754, "1.754e0" if MySQLdb.version_info >= (1, 3, 14) else "1.754"), + (False, b"0" if MySQLdb.version_info >= (1, 4, 0) else "0"), ] for value, expected in tested_values: with self.subTest(value=value): diff --git a/tests/backends/mysql/tests.py b/tests/backends/mysql/tests.py index 02fc312abc..6ea289e151 100644 --- a/tests/backends/mysql/tests.py +++ b/tests/backends/mysql/tests.py @@ -14,20 +14,21 @@ def get_connection(): @override_settings(DEBUG=True) -@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL tests') +@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests") class IsolationLevelTests(TestCase): - read_committed = 'read committed' - repeatable_read = 'repeatable read' + read_committed = "read committed" + repeatable_read = "repeatable read" isolation_values = { - level: level.upper() - for level in (read_committed, repeatable_read) + level: level.upper() for level in (read_committed, repeatable_read) } @classmethod def setUpClass(cls): super().setUpClass() - configured_isolation_level = connection.isolation_level or cls.isolation_values[cls.repeatable_read] + configured_isolation_level = ( + connection.isolation_level or cls.isolation_values[cls.repeatable_read] + ) cls.configured_isolation_level = configured_isolation_level.upper() cls.other_isolation_level = ( cls.read_committed @@ -38,50 +39,58 @@ class IsolationLevelTests(TestCase): @staticmethod def get_isolation_level(connection): with connection.cursor() as cursor: - cursor.execute("SHOW VARIABLES WHERE variable_name IN ('transaction_isolation', 'tx_isolation')") - return cursor.fetchone()[1].replace('-', ' ') + cursor.execute( + "SHOW VARIABLES WHERE variable_name IN ('transaction_isolation', 'tx_isolation')" + ) + return cursor.fetchone()[1].replace("-", " ") def test_auto_is_null_auto_config(self): - query = 'set sql_auto_is_null = 0' + query = "set sql_auto_is_null = 0" connection.init_connection_state() - last_query = connection.queries[-1]['sql'].lower() + last_query = connection.queries[-1]["sql"].lower() if connection.features.is_sql_auto_is_null_enabled: self.assertIn(query, last_query) else: self.assertNotIn(query, last_query) def test_connect_isolation_level(self): - self.assertEqual(self.get_isolation_level(connection), self.configured_isolation_level) + self.assertEqual( + self.get_isolation_level(connection), self.configured_isolation_level + ) def test_setting_isolation_level(self): with get_connection() as new_connection: - new_connection.settings_dict['OPTIONS']['isolation_level'] = self.other_isolation_level + new_connection.settings_dict["OPTIONS"][ + "isolation_level" + ] = self.other_isolation_level self.assertEqual( self.get_isolation_level(new_connection), - self.isolation_values[self.other_isolation_level] + self.isolation_values[self.other_isolation_level], ) def test_uppercase_isolation_level(self): # Upper case values are also accepted in 'isolation_level'. with get_connection() as new_connection: - new_connection.settings_dict['OPTIONS']['isolation_level'] = self.other_isolation_level.upper() + new_connection.settings_dict["OPTIONS"][ + "isolation_level" + ] = self.other_isolation_level.upper() self.assertEqual( self.get_isolation_level(new_connection), - self.isolation_values[self.other_isolation_level] + self.isolation_values[self.other_isolation_level], ) def test_default_isolation_level(self): # If not specified in settings, the default is read committed. with get_connection() as new_connection: - new_connection.settings_dict['OPTIONS'].pop('isolation_level', None) + new_connection.settings_dict["OPTIONS"].pop("isolation_level", None) self.assertEqual( self.get_isolation_level(new_connection), - self.isolation_values[self.read_committed] + self.isolation_values[self.read_committed], ) def test_isolation_level_validation(self): new_connection = connection.copy() - new_connection.settings_dict['OPTIONS']['isolation_level'] = 'xxx' + new_connection.settings_dict["OPTIONS"]["isolation_level"] = "xxx" msg = ( "Invalid transaction isolation level 'xxx' specified.\n" "Use one of 'read committed', 'read uncommitted', " diff --git a/tests/backends/oracle/test_creation.py b/tests/backends/oracle/test_creation.py index 930d2520e2..6f6b9fd233 100644 --- a/tests/backends/oracle/test_creation.py +++ b/tests/backends/oracle/test_creation.py @@ -7,16 +7,19 @@ from django.db.backends.oracle.creation import DatabaseCreation from django.test import TestCase -@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') -@mock.patch.object(DatabaseCreation, '_maindb_connection', return_value=connection) -@mock.patch('sys.stdout', new_callable=StringIO) -@mock.patch('sys.stderr', new_callable=StringIO) +@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") +@mock.patch.object(DatabaseCreation, "_maindb_connection", return_value=connection) +@mock.patch("sys.stdout", new_callable=StringIO) +@mock.patch("sys.stderr", new_callable=StringIO) class DatabaseCreationTests(TestCase): - - def _execute_raise_user_already_exists(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False): + def _execute_raise_user_already_exists( + self, cursor, statements, parameters, verbosity, allow_quiet_fail=False + ): # Raise "user already exists" only in test user creation - if statements and statements[0].startswith('CREATE USER'): - raise DatabaseError("ORA-01920: user name 'string' conflicts with another user or role name") + if statements and statements[0].startswith("CREATE USER"): + raise DatabaseError( + "ORA-01920: user name 'string' conflicts with another user or role name" + ) def _execute_raise_tablespace_already_exists( self, cursor, statements, parameters, verbosity, allow_quiet_fail=False @@ -30,17 +33,21 @@ class DatabaseCreationTests(TestCase): def _test_database_passwd(self): # Mocked to avoid test user password changed - return connection.settings_dict['SAVED_PASSWORD'] + return connection.settings_dict["SAVED_PASSWORD"] def patch_execute_statements(self, execute_statements): - return mock.patch.object(DatabaseCreation, '_execute_statements', execute_statements) + return mock.patch.object( + DatabaseCreation, "_execute_statements", execute_statements + ) - @mock.patch.object(DatabaseCreation, '_test_user_create', return_value=False) + @mock.patch.object(DatabaseCreation, "_test_user_create", return_value=False) def test_create_test_db(self, *mocked_objects): creation = DatabaseCreation(connection) # Simulate test database creation raising "tablespace already exists" - with self.patch_execute_statements(self._execute_raise_tablespace_already_exists): - with mock.patch('builtins.input', return_value='no'): + with self.patch_execute_statements( + self._execute_raise_tablespace_already_exists + ): + with mock.patch("builtins.input", return_value="no"): with self.assertRaises(SystemExit): # SystemExit is raised if the user answers "no" to the # prompt asking if it's okay to delete the test tablespace. @@ -54,13 +61,15 @@ class DatabaseCreationTests(TestCase): with self.assertRaises(SystemExit): creation._create_test_db(verbosity=0, keepdb=True) - @mock.patch.object(DatabaseCreation, '_test_database_create', return_value=False) + @mock.patch.object(DatabaseCreation, "_test_database_create", return_value=False) def test_create_test_user(self, *mocked_objects): creation = DatabaseCreation(connection) - with mock.patch.object(DatabaseCreation, '_test_database_passwd', self._test_database_passwd): + with mock.patch.object( + DatabaseCreation, "_test_database_passwd", self._test_database_passwd + ): # Simulate test user creation raising "user already exists" with self.patch_execute_statements(self._execute_raise_user_already_exists): - with mock.patch('builtins.input', return_value='no'): + with mock.patch("builtins.input", return_value="no"): with self.assertRaises(SystemExit): # SystemExit is raised if the user answers "no" to the # prompt asking if it's okay to delete the test user. @@ -68,27 +77,35 @@ class DatabaseCreationTests(TestCase): # "User already exists" error is ignored when keepdb is on creation._create_test_db(verbosity=0, keepdb=True) # Simulate test user creation raising unexpected error - with self.patch_execute_statements(self._execute_raise_insufficient_privileges): + with self.patch_execute_statements( + self._execute_raise_insufficient_privileges + ): with self.assertRaises(SystemExit): creation._create_test_db(verbosity=0, keepdb=False) with self.assertRaises(SystemExit): creation._create_test_db(verbosity=0, keepdb=True) def test_oracle_managed_files(self, *mocked_objects): - def _execute_capture_statements(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False): + def _execute_capture_statements( + self, cursor, statements, parameters, verbosity, allow_quiet_fail=False + ): self.tblspace_sqls = statements creation = DatabaseCreation(connection) # Simulate test database creation with Oracle Managed File (OMF) # tablespaces. - with mock.patch.object(DatabaseCreation, '_test_database_oracle_managed_files', return_value=True): + with mock.patch.object( + DatabaseCreation, "_test_database_oracle_managed_files", return_value=True + ): with self.patch_execute_statements(_execute_capture_statements): with connection.cursor() as cursor: - creation._execute_test_db_creation(cursor, creation._get_test_db_params(), verbosity=0) + creation._execute_test_db_creation( + cursor, creation._get_test_db_params(), verbosity=0 + ) tblspace_sql, tblspace_tmp_sql = creation.tblspace_sqls # Datafile names shouldn't appear. - self.assertIn('DATAFILE SIZE', tblspace_sql) - self.assertIn('TEMPFILE SIZE', tblspace_tmp_sql) + self.assertIn("DATAFILE SIZE", tblspace_sql) + self.assertIn("TEMPFILE SIZE", tblspace_tmp_sql) # REUSE cannot be used with OMF. - self.assertNotIn('REUSE', tblspace_sql) - self.assertNotIn('REUSE', tblspace_tmp_sql) + self.assertNotIn("REUSE", tblspace_sql) + self.assertNotIn("REUSE", tblspace_tmp_sql) diff --git a/tests/backends/oracle/test_introspection.py b/tests/backends/oracle/test_introspection.py index 6664dda3d3..c6a143eb2d 100644 --- a/tests/backends/oracle/test_introspection.py +++ b/tests/backends/oracle/test_introspection.py @@ -6,24 +6,30 @@ from django.test import TransactionTestCase from ..models import Square -@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') +@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") class DatabaseSequenceTests(TransactionTestCase): available_apps = [] def test_get_sequences(self): with connection.cursor() as cursor: - seqs = connection.introspection.get_sequences(cursor, Square._meta.db_table, Square._meta.local_fields) + seqs = connection.introspection.get_sequences( + cursor, Square._meta.db_table, Square._meta.local_fields + ) self.assertEqual(len(seqs), 1) - self.assertIsNotNone(seqs[0]['name']) - self.assertEqual(seqs[0]['table'], Square._meta.db_table) - self.assertEqual(seqs[0]['column'], 'id') + self.assertIsNotNone(seqs[0]["name"]) + self.assertEqual(seqs[0]["table"], Square._meta.db_table) + self.assertEqual(seqs[0]["column"], "id") def test_get_sequences_manually_created_index(self): with connection.cursor() as cursor: with connection.schema_editor() as editor: - editor._drop_identity(Square._meta.db_table, 'id') - seqs = connection.introspection.get_sequences(cursor, Square._meta.db_table, Square._meta.local_fields) - self.assertEqual(seqs, [{'table': Square._meta.db_table, 'column': 'id'}]) + editor._drop_identity(Square._meta.db_table, "id") + seqs = connection.introspection.get_sequences( + cursor, Square._meta.db_table, Square._meta.local_fields + ) + self.assertEqual( + seqs, [{"table": Square._meta.db_table, "column": "id"}] + ) # Recreate model, because adding identity is impossible. editor.delete_model(Square) editor.create_model(Square) diff --git a/tests/backends/oracle/test_operations.py b/tests/backends/oracle/test_operations.py index 54fe73139d..523bdcda8a 100644 --- a/tests/backends/oracle/test_operations.py +++ b/tests/backends/oracle/test_operations.py @@ -7,13 +7,15 @@ from django.test import TransactionTestCase from ..models import Person, Tag -@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') +@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") class OperationsTests(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def test_sequence_name_truncation(self): - seq_name = connection.ops._get_no_autofield_sequence_name('schema_authorwithevenlongee869') - self.assertEqual(seq_name, 'SCHEMA_AUTHORWITHEVENLOB0B8_SQ') + seq_name = connection.ops._get_no_autofield_sequence_name( + "schema_authorwithevenlongee869" + ) + self.assertEqual(seq_name, "SCHEMA_AUTHORWITHEVENLOB0B8_SQ") def test_bulk_batch_size(self): # Oracle restricts the number of parameters in a query. @@ -21,11 +23,11 @@ class OperationsTests(TransactionTestCase): self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects)) # Each field is a parameter for each object. self.assertEqual( - connection.ops.bulk_batch_size(['id'], objects), + connection.ops.bulk_batch_size(["id"], objects), connection.features.max_query_params, ) self.assertEqual( - connection.ops.bulk_batch_size(['id', 'other'], objects), + connection.ops.bulk_batch_size(["id", "other"], objects), connection.features.max_query_params // 2, ) @@ -105,8 +107,8 @@ class OperationsTests(TransactionTestCase): ) # Sequences. self.assertEqual(len(statements[4:]), 2) - self.assertIn('BACKENDS_PERSON_SQ', statements[4]) - self.assertIn('BACKENDS_TAG_SQ', statements[5]) + self.assertIn("BACKENDS_PERSON_SQ", statements[4]) + self.assertIn("BACKENDS_TAG_SQ", statements[5]) def test_sql_flush_sequences_allow_cascade(self): statements = connection.ops.sql_flush( @@ -136,6 +138,6 @@ class OperationsTests(TransactionTestCase): ) # Sequences. self.assertEqual(len(statements[5:]), 3) - self.assertIn('BACKENDS_PERSON_SQ', statements[5]) - self.assertIn('BACKENDS_VERYLONGMODELN7BE2_SQ', statements[6]) - self.assertIn('BACKENDS_TAG_SQ', statements[7]) + self.assertIn("BACKENDS_PERSON_SQ", statements[5]) + self.assertIn("BACKENDS_VERYLONGMODELN7BE2_SQ", statements[6]) + self.assertIn("BACKENDS_TAG_SQ", statements[7]) diff --git a/tests/backends/oracle/tests.py b/tests/backends/oracle/tests.py index 6d31d33d57..a518db2c64 100644 --- a/tests/backends/oracle/tests.py +++ b/tests/backends/oracle/tests.py @@ -4,14 +4,11 @@ from django.db import DatabaseError, connection from django.db.models import BooleanField from django.test import TransactionTestCase -from ..models import ( - Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ, -) +from ..models import Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ -@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') +@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") class Tests(unittest.TestCase): - def test_quote_name(self): """'%' chars are escaped for query execution.""" name = '"SOME%NAME"' @@ -21,21 +18,24 @@ class Tests(unittest.TestCase): def test_quote_name_db_table(self): model = VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ db_table = model._meta.db_table.upper() - self.assertEqual(f'"{db_table}"', connection.ops.quote_name( - 'backends_verylongmodelnamezzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz', - )) + self.assertEqual( + f'"{db_table}"', + connection.ops.quote_name( + "backends_verylongmodelnamezzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", + ), + ) def test_dbms_session(self): """A stored procedure can be called through a cursor wrapper.""" with connection.cursor() as cursor: - cursor.callproc('DBMS_SESSION.SET_IDENTIFIER', ['_django_testing!']) + cursor.callproc("DBMS_SESSION.SET_IDENTIFIER", ["_django_testing!"]) def test_cursor_var(self): """Cursor variables can be passed as query parameters.""" with connection.cursor() as cursor: var = cursor.var(str) cursor.execute("BEGIN %s := 'X'; END; ", [var]) - self.assertEqual(var.getvalue(), 'X') + self.assertEqual(var.getvalue(), "X") def test_order_of_nls_parameters(self): """ @@ -53,41 +53,46 @@ class Tests(unittest.TestCase): """Boolean fields have check constraints on their values.""" for field in (BooleanField(), BooleanField(null=True)): with self.subTest(field=field): - field.set_attributes_from_name('is_nice') + field.set_attributes_from_name("is_nice") self.assertIn('"IS_NICE" IN (0,1)', field.db_check(connection)) -@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') +@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests") class TransactionalTests(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def test_hidden_no_data_found_exception(self): # "ORA-1403: no data found" exception is hidden by Oracle OCI library # when an INSERT statement is used with a RETURNING clause (see #28859). with connection.cursor() as cursor: # Create trigger that raises "ORA-1403: no data found". - cursor.execute(""" + cursor.execute( + """ CREATE OR REPLACE TRIGGER "TRG_NO_DATA_FOUND" AFTER INSERT ON "BACKENDS_SQUARE" FOR EACH ROW BEGIN RAISE NO_DATA_FOUND; END; - """) + """ + ) try: - with self.assertRaisesMessage(DatabaseError, ( - 'The database did not return a new row id. Probably "ORA-1403: ' - 'no data found" was raised internally but was hidden by the ' - 'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).' - )): + with self.assertRaisesMessage( + DatabaseError, + ( + 'The database did not return a new row id. Probably "ORA-1403: ' + 'no data found" was raised internally but was hidden by the ' + "Oracle OCI library (see https://code.djangoproject.com/ticket/28859)." + ), + ): Square.objects.create(root=2, square=4) finally: with connection.cursor() as cursor: cursor.execute('DROP TRIGGER "TRG_NO_DATA_FOUND"') def test_password_with_at_sign(self): - old_password = connection.settings_dict['PASSWORD'] - connection.settings_dict['PASSWORD'] = 'p@ssword' + old_password = connection.settings_dict["PASSWORD"] + connection.settings_dict["PASSWORD"] = "p@ssword" try: self.assertIn( '/"p@ssword"@', @@ -97,6 +102,6 @@ class TransactionalTests(TransactionTestCase): connection.cursor() # Database exception: "ORA-01017: invalid username/password" is # expected. - self.assertIn('ORA-01017', context.exception.args[0].message) + self.assertIn("ORA-01017", context.exception.args[0].message) finally: - connection.settings_dict['PASSWORD'] = old_password + connection.settings_dict["PASSWORD"] = old_password diff --git a/tests/backends/postgresql/test_creation.py b/tests/backends/postgresql/test_creation.py index eb68c6e471..319029334d 100644 --- a/tests/backends/postgresql/test_creation.py +++ b/tests/backends/postgresql/test_creation.py @@ -18,12 +18,11 @@ else: from django.db.backends.postgresql.creation import DatabaseCreation -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL tests') +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests") class DatabaseCreationTests(SimpleTestCase): - @contextmanager def changed_test_settings(self, **kwargs): - settings = connection.settings_dict['TEST'] + settings = connection.settings_dict["TEST"] saved_values = {} for name in kwargs: if name in settings: @@ -47,66 +46,80 @@ class DatabaseCreationTests(SimpleTestCase): self.assertEqual(suffix, expected) def test_sql_table_creation_suffix_with_none_settings(self): - settings = {'CHARSET': None, 'TEMPLATE': None} + settings = {"CHARSET": None, "TEMPLATE": None} self.check_sql_table_creation_suffix(settings, "") def test_sql_table_creation_suffix_with_encoding(self): - settings = {'CHARSET': 'UTF8'} + settings = {"CHARSET": "UTF8"} self.check_sql_table_creation_suffix(settings, "WITH ENCODING 'UTF8'") def test_sql_table_creation_suffix_with_template(self): - settings = {'TEMPLATE': 'template0'} + settings = {"TEMPLATE": "template0"} self.check_sql_table_creation_suffix(settings, 'WITH TEMPLATE "template0"') def test_sql_table_creation_suffix_with_encoding_and_template(self): - settings = {'CHARSET': 'UTF8', 'TEMPLATE': 'template0'} - self.check_sql_table_creation_suffix(settings, '''WITH ENCODING 'UTF8' TEMPLATE "template0"''') + settings = {"CHARSET": "UTF8", "TEMPLATE": "template0"} + self.check_sql_table_creation_suffix( + settings, '''WITH ENCODING 'UTF8' TEMPLATE "template0"''' + ) def test_sql_table_creation_raises_with_collation(self): - settings = {'COLLATION': 'test'} + settings = {"COLLATION": "test"} msg = ( - 'PostgreSQL does not support collation setting at database ' - 'creation time.' + "PostgreSQL does not support collation setting at database " + "creation time." ) with self.assertRaisesMessage(ImproperlyConfigured, msg): self.check_sql_table_creation_suffix(settings, None) def _execute_raise_database_already_exists(self, cursor, parameters, keepdb=False): - error = DatabaseError('database %s already exists' % parameters['dbname']) + error = DatabaseError("database %s already exists" % parameters["dbname"]) error.pgcode = errorcodes.DUPLICATE_DATABASE raise DatabaseError() from error def _execute_raise_permission_denied(self, cursor, parameters, keepdb=False): - error = DatabaseError('permission denied to create database') + error = DatabaseError("permission denied to create database") error.pgcode = errorcodes.INSUFFICIENT_PRIVILEGE raise DatabaseError() from error def patch_test_db_creation(self, execute_create_test_db): - return mock.patch.object(BaseDatabaseCreation, '_execute_create_test_db', execute_create_test_db) + return mock.patch.object( + BaseDatabaseCreation, "_execute_create_test_db", execute_create_test_db + ) - @mock.patch('sys.stdout', new_callable=StringIO) - @mock.patch('sys.stderr', new_callable=StringIO) + @mock.patch("sys.stdout", new_callable=StringIO) + @mock.patch("sys.stderr", new_callable=StringIO) def test_create_test_db(self, *mocked_objects): creation = DatabaseCreation(connection) # Simulate test database creation raising "database already exists" with self.patch_test_db_creation(self._execute_raise_database_already_exists): - with mock.patch('builtins.input', return_value='no'): + with mock.patch("builtins.input", return_value="no"): with self.assertRaises(SystemExit): # SystemExit is raised if the user answers "no" to the # prompt asking if it's okay to delete the test database. - creation._create_test_db(verbosity=0, autoclobber=False, keepdb=False) + creation._create_test_db( + verbosity=0, autoclobber=False, keepdb=False + ) # "Database already exists" error is ignored when keepdb is on creation._create_test_db(verbosity=0, autoclobber=False, keepdb=True) # Simulate test database creation raising unexpected error with self.patch_test_db_creation(self._execute_raise_permission_denied): - with mock.patch.object(DatabaseCreation, '_database_exists', return_value=False): + with mock.patch.object( + DatabaseCreation, "_database_exists", return_value=False + ): with self.assertRaises(SystemExit): - creation._create_test_db(verbosity=0, autoclobber=False, keepdb=False) + creation._create_test_db( + verbosity=0, autoclobber=False, keepdb=False + ) with self.assertRaises(SystemExit): - creation._create_test_db(verbosity=0, autoclobber=False, keepdb=True) + creation._create_test_db( + verbosity=0, autoclobber=False, keepdb=True + ) # Simulate test database creation raising "insufficient privileges". # An error shouldn't appear when keepdb is on and the database already # exists. with self.patch_test_db_creation(self._execute_raise_permission_denied): - with mock.patch.object(DatabaseCreation, '_database_exists', return_value=True): + with mock.patch.object( + DatabaseCreation, "_database_exists", return_value=True + ): creation._create_test_db(verbosity=0, autoclobber=False, keepdb=True) diff --git a/tests/backends/postgresql/test_introspection.py b/tests/backends/postgresql/test_introspection.py index 4dcadbd733..dc95d6ad23 100644 --- a/tests/backends/postgresql/test_introspection.py +++ b/tests/backends/postgresql/test_introspection.py @@ -6,18 +6,24 @@ from django.test import TestCase from ..models import Person -@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") +@unittest.skipUnless(connection.vendor == "postgresql", "Test only for PostgreSQL") class DatabaseSequenceTests(TestCase): def test_get_sequences(self): with connection.cursor() as cursor: seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) self.assertEqual( seqs, - [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] + [ + { + "table": Person._meta.db_table, + "column": "id", + "name": "backends_person_id_seq", + } + ], ) - cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') + cursor.execute("ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq") seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) self.assertEqual( seqs, - [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] + [{"table": Person._meta.db_table, "column": "id", "name": "pers_seq"}], ) diff --git a/tests/backends/postgresql/test_operations.py b/tests/backends/postgresql/test_operations.py index 821bb29cee..c2f2417923 100644 --- a/tests/backends/postgresql/test_operations.py +++ b/tests/backends/postgresql/test_operations.py @@ -7,7 +7,7 @@ from django.test import SimpleTestCase from ..models import Person, Tag -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL tests.') +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.") class PostgreSQLOperationsTests(SimpleTestCase): def test_sql_flush(self): self.assertEqual( diff --git a/tests/backends/postgresql/test_server_side_cursors.py b/tests/backends/postgresql/test_server_side_cursors.py index 0cc3423a9b..705e798c23 100644 --- a/tests/backends/postgresql/test_server_side_cursors.py +++ b/tests/backends/postgresql/test_server_side_cursors.py @@ -9,19 +9,23 @@ from django.test import TestCase from ..models import Person -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL tests') +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests") class ServerSideCursorsPostgres(TestCase): - cursor_fields = 'name, statement, is_holdable, is_binary, is_scrollable, creation_time' - PostgresCursor = namedtuple('PostgresCursor', cursor_fields) + cursor_fields = ( + "name, statement, is_holdable, is_binary, is_scrollable, creation_time" + ) + PostgresCursor = namedtuple("PostgresCursor", cursor_fields) @classmethod def setUpTestData(cls): - Person.objects.create(first_name='a', last_name='a') - Person.objects.create(first_name='b', last_name='b') + Person.objects.create(first_name="a", last_name="a") + Person.objects.create(first_name="b", last_name="b") def inspect_cursors(self): with connection.cursor() as cursor: - cursor.execute('SELECT {fields} FROM pg_cursors;'.format(fields=self.cursor_fields)) + cursor.execute( + "SELECT {fields} FROM pg_cursors;".format(fields=self.cursor_fields) + ) cursors = cursor.fetchall() return [self.PostgresCursor._make(cursor) for cursor in cursors] @@ -30,7 +34,9 @@ class ServerSideCursorsPostgres(TestCase): for setting in kwargs: original_value = connection.settings_dict.get(setting) if setting in connection.settings_dict: - self.addCleanup(operator.setitem, connection.settings_dict, setting, original_value) + self.addCleanup( + operator.setitem, connection.settings_dict, setting, original_value + ) else: self.addCleanup(operator.delitem, connection.settings_dict, setting) @@ -42,7 +48,7 @@ class ServerSideCursorsPostgres(TestCase): cursors = self.inspect_cursors() self.assertEqual(len(cursors), num_expected) for cursor in cursors: - self.assertIn('_django_curs_', cursor.name) + self.assertIn("_django_curs_", cursor.name) self.assertFalse(cursor.is_scrollable) self.assertFalse(cursor.is_holdable) self.assertFalse(cursor.is_binary) @@ -54,17 +60,23 @@ class ServerSideCursorsPostgres(TestCase): self.assertUsesCursor(Person.objects.iterator()) def test_values(self): - self.assertUsesCursor(Person.objects.values('first_name').iterator()) + self.assertUsesCursor(Person.objects.values("first_name").iterator()) def test_values_list(self): - self.assertUsesCursor(Person.objects.values_list('first_name').iterator()) + self.assertUsesCursor(Person.objects.values_list("first_name").iterator()) def test_values_list_flat(self): - self.assertUsesCursor(Person.objects.values_list('first_name', flat=True).iterator()) + self.assertUsesCursor( + Person.objects.values_list("first_name", flat=True).iterator() + ) def test_values_list_fields_not_equal_to_names(self): - expr = models.Count('id') - self.assertUsesCursor(Person.objects.annotate(id__count=expr).values_list(expr, 'id__count').iterator()) + expr = models.Count("id") + self.assertUsesCursor( + Person.objects.annotate(id__count=expr) + .values_list(expr, "id__count") + .iterator() + ) def test_server_side_cursor_many_cursors(self): persons = Person.objects.iterator() diff --git a/tests/backends/postgresql/tests.py b/tests/backends/postgresql/tests.py index 1905147f6f..af08f6f286 100644 --- a/tests/backends/postgresql/tests.py +++ b/tests/backends/postgresql/tests.py @@ -9,9 +9,9 @@ from django.db.backends.base.base import BaseDatabaseWrapper from django.test import TestCase, override_settings -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL tests') +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests") class Tests(TestCase): - databases = {'default', 'other'} + databases = {"default", "other"} def test_nodb_cursor(self): """ @@ -21,14 +21,14 @@ class Tests(TestCase): orig_connect = BaseDatabaseWrapper.connect def mocked_connect(self): - if self.settings_dict['NAME'] is None: + if self.settings_dict["NAME"] is None: raise DatabaseError() return orig_connect(self) with connection._nodb_cursor() as cursor: self.assertIs(cursor.closed, False) self.assertIsNotNone(cursor.db.connection) - self.assertIsNone(cursor.db.settings_dict['NAME']) + self.assertIsNone(cursor.db.settings_dict["NAME"]) self.assertIs(cursor.closed, True) self.assertIsNone(cursor.db.connection) @@ -41,24 +41,29 @@ class Tests(TestCase): "database and will use the first PostgreSQL database instead." ) with self.assertWarnsMessage(RuntimeWarning, msg): - with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.connect', - side_effect=mocked_connect, autospec=True): + with mock.patch( + "django.db.backends.base.base.BaseDatabaseWrapper.connect", + side_effect=mocked_connect, + autospec=True, + ): with mock.patch.object( connection, - 'settings_dict', - {**connection.settings_dict, 'NAME': 'postgres'}, + "settings_dict", + {**connection.settings_dict, "NAME": "postgres"}, ): with connection._nodb_cursor() as cursor: self.assertIs(cursor.closed, False) self.assertIsNotNone(cursor.db.connection) self.assertIs(cursor.closed, True) self.assertIsNone(cursor.db.connection) - self.assertIsNotNone(cursor.db.settings_dict['NAME']) - self.assertEqual(cursor.db.settings_dict['NAME'], connections['other'].settings_dict['NAME']) + self.assertIsNotNone(cursor.db.settings_dict["NAME"]) + self.assertEqual( + cursor.db.settings_dict["NAME"], connections["other"].settings_dict["NAME"] + ) # Cursor is yielded only for the first PostgreSQL database. with self.assertWarnsMessage(RuntimeWarning, msg): with mock.patch( - 'django.db.backends.base.base.BaseDatabaseWrapper.connect', + "django.db.backends.base.base.BaseDatabaseWrapper.connect", side_effect=mocked_connect, autospec=True, ): @@ -71,13 +76,14 @@ class Tests(TestCase): _nodb_cursor() re-raises authentication failure to the 'postgres' db when other connection to the PostgreSQL database isn't available. """ + def mocked_connect(self): raise DatabaseError() def mocked_all(self): test_connection = copy.copy(connections[DEFAULT_DB_ALIAS]) test_connection.settings_dict = copy.deepcopy(connection.settings_dict) - test_connection.settings_dict['NAME'] = 'postgres' + test_connection.settings_dict["NAME"] = "postgres" return [test_connection] msg = ( @@ -89,12 +95,12 @@ class Tests(TestCase): ) with self.assertWarnsMessage(RuntimeWarning, msg): mocker_connections_all = mock.patch( - 'django.utils.connection.BaseConnectionHandler.all', + "django.utils.connection.BaseConnectionHandler.all", side_effect=mocked_all, autospec=True, ) mocker_connect = mock.patch( - 'django.db.backends.base.base.BaseDatabaseWrapper.connect', + "django.db.backends.base.base.BaseDatabaseWrapper.connect", side_effect=mocked_connect, autospec=True, ) @@ -104,27 +110,29 @@ class Tests(TestCase): pass def test_nodb_cursor_reraise_exceptions(self): - with self.assertRaisesMessage(DatabaseError, 'exception'): + with self.assertRaisesMessage(DatabaseError, "exception"): with connection._nodb_cursor(): - raise DatabaseError('exception') + raise DatabaseError("exception") def test_database_name_too_long(self): from django.db.backends.postgresql.base import DatabaseWrapper + settings = connection.settings_dict.copy() max_name_length = connection.ops.max_name_length() - settings['NAME'] = 'a' + (max_name_length * 'a') + settings["NAME"] = "a" + (max_name_length * "a") msg = ( "The database name '%s' (%d characters) is longer than " "PostgreSQL's limit of %s characters. Supply a shorter NAME in " "settings.DATABASES." - ) % (settings['NAME'], max_name_length + 1, max_name_length) + ) % (settings["NAME"], max_name_length + 1, max_name_length) with self.assertRaisesMessage(ImproperlyConfigured, msg): DatabaseWrapper(settings).get_connection_params() def test_database_name_empty(self): from django.db.backends.postgresql.base import DatabaseWrapper + settings = connection.settings_dict.copy() - settings['NAME'] = '' + settings["NAME"] = "" msg = ( "settings.DATABASES is improperly configured. Please supply the " "NAME or OPTIONS['service'] value." @@ -134,22 +142,24 @@ class Tests(TestCase): def test_service_name(self): from django.db.backends.postgresql.base import DatabaseWrapper + settings = connection.settings_dict.copy() - settings['OPTIONS'] = {'service': 'my_service'} - settings['NAME'] = '' + settings["OPTIONS"] = {"service": "my_service"} + settings["NAME"] = "" params = DatabaseWrapper(settings).get_connection_params() - self.assertEqual(params['service'], 'my_service') - self.assertNotIn('database', params) + self.assertEqual(params["service"], "my_service") + self.assertNotIn("database", params) def test_service_name_default_db(self): # None is used to connect to the default 'postgres' db. from django.db.backends.postgresql.base import DatabaseWrapper + settings = connection.settings_dict.copy() - settings['NAME'] = None - settings['OPTIONS'] = {'service': 'django_test'} + settings["NAME"] = None + settings["OPTIONS"] = {"service": "django_test"} params = DatabaseWrapper(settings).get_connection_params() - self.assertEqual(params['database'], 'postgres') - self.assertNotIn('service', params) + self.assertEqual(params["database"], "postgres") + self.assertNotIn("service", params) def test_connect_and_rollback(self): """ @@ -165,7 +175,7 @@ class Tests(TestCase): cursor.execute("RESET TIMEZONE") cursor.execute("SHOW TIMEZONE") db_default_tz = cursor.fetchone()[0] - new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC' + new_tz = "Europe/Paris" if db_default_tz == "UTC" else "UTC" new_connection.close() # Invalidate timezone name cache, because the setting_changed @@ -193,7 +203,7 @@ class Tests(TestCase): after setting the time zone when AUTOCOMMIT is False (#21452). """ new_connection = connection.copy() - new_connection.settings_dict['AUTOCOMMIT'] = False + new_connection.settings_dict["AUTOCOMMIT"] = False try: # Open a database connection. @@ -207,9 +217,7 @@ class Tests(TestCase): The transaction level can be configured with DATABASES ['OPTIONS']['isolation_level']. """ - from psycopg2.extensions import ( - ISOLATION_LEVEL_SERIALIZABLE as serializable, - ) + from psycopg2.extensions import ISOLATION_LEVEL_SERIALIZABLE as serializable # Since this is a django.test.TestCase, a transaction is in progress # and the isolation level isn't reported as 0. This test assumes that @@ -218,7 +226,7 @@ class Tests(TestCase): self.assertIsNone(connection.connection.isolation_level) new_connection = connection.copy() - new_connection.settings_dict['OPTIONS']['isolation_level'] = serializable + new_connection.settings_dict["OPTIONS"]["isolation_level"] = serializable try: # Start a transaction so the isolation level isn't reported as 0. new_connection.set_autocommit(False) @@ -230,7 +238,7 @@ class Tests(TestCase): def test_connect_no_is_usable_checks(self): new_connection = connection.copy() try: - with mock.patch.object(new_connection, 'is_usable') as is_usable: + with mock.patch.object(new_connection, "is_usable") as is_usable: new_connection.connect() is_usable.assert_not_called() finally: @@ -238,49 +246,60 @@ class Tests(TestCase): def _select(self, val): with connection.cursor() as cursor: - cursor.execute('SELECT %s', (val,)) + cursor.execute("SELECT %s", (val,)) return cursor.fetchone()[0] def test_select_ascii_array(self): - a = ['awef'] + a = ["awef"] b = self._select(a) self.assertEqual(a[0], b[0]) def test_select_unicode_array(self): - a = ['ᄲawef'] + a = ["ᄲawef"] b = self._select(a) self.assertEqual(a[0], b[0]) def test_lookup_cast(self): from django.db.backends.postgresql.operations import DatabaseOperations + do = DatabaseOperations(connection=None) lookups = ( - 'iexact', 'contains', 'icontains', 'startswith', 'istartswith', - 'endswith', 'iendswith', 'regex', 'iregex', + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "regex", + "iregex", ) for lookup in lookups: with self.subTest(lookup=lookup): - self.assertIn('::text', do.lookup_cast(lookup)) + self.assertIn("::text", do.lookup_cast(lookup)) for lookup in lookups: - for field_type in ('CICharField', 'CIEmailField', 'CITextField'): + for field_type in ("CICharField", "CIEmailField", "CITextField"): with self.subTest(lookup=lookup, field_type=field_type): - self.assertIn('::citext', do.lookup_cast(lookup, internal_type=field_type)) + self.assertIn( + "::citext", do.lookup_cast(lookup, internal_type=field_type) + ) def test_correct_extraction_psycopg2_version(self): from django.db.backends.postgresql.base import psycopg2_version - with mock.patch('psycopg2.__version__', '4.2.1 (dt dec pq3 ext lo64)'): + + with mock.patch("psycopg2.__version__", "4.2.1 (dt dec pq3 ext lo64)"): self.assertEqual(psycopg2_version(), (4, 2, 1)) - with mock.patch('psycopg2.__version__', '4.2b0.dev1 (dt dec pq3 ext lo64)'): + with mock.patch("psycopg2.__version__", "4.2b0.dev1 (dt dec pq3 ext lo64)"): self.assertEqual(psycopg2_version(), (4, 2)) @override_settings(DEBUG=True) def test_copy_cursors(self): out = StringIO() - copy_expert_sql = 'COPY django_session TO STDOUT (FORMAT CSV, HEADER)' + copy_expert_sql = "COPY django_session TO STDOUT (FORMAT CSV, HEADER)" with connection.cursor() as cursor: cursor.copy_expert(copy_expert_sql, out) - cursor.copy_to(out, 'django_session') + cursor.copy_to(out, "django_session") self.assertEqual( - [q['sql'] for q in connection.queries], - [copy_expert_sql, 'COPY django_session TO STDOUT'], + [q["sql"] for q in connection.queries], + [copy_expert_sql, "COPY django_session TO STDOUT"], ) diff --git a/tests/backends/sqlite/test_creation.py b/tests/backends/sqlite/test_creation.py index 6ec4262f73..ab1640c04e 100644 --- a/tests/backends/sqlite/test_creation.py +++ b/tests/backends/sqlite/test_creation.py @@ -5,15 +5,17 @@ from django.db import DEFAULT_DB_ALIAS, connection, connections from django.test import SimpleTestCase -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") class TestDbSignatureTests(SimpleTestCase): def test_custom_test_name(self): test_connection = copy.copy(connections[DEFAULT_DB_ALIAS]) - test_connection.settings_dict = copy.deepcopy(connections[DEFAULT_DB_ALIAS].settings_dict) - test_connection.settings_dict['NAME'] = None - test_connection.settings_dict['TEST']['NAME'] = 'custom.sqlite.db' + test_connection.settings_dict = copy.deepcopy( + connections[DEFAULT_DB_ALIAS].settings_dict + ) + test_connection.settings_dict["NAME"] = None + test_connection.settings_dict["TEST"]["NAME"] = "custom.sqlite.db" signature = test_connection.creation_class(test_connection).test_db_signature() - self.assertEqual(signature, (None, 'custom.sqlite.db')) + self.assertEqual(signature, (None, "custom.sqlite.db")) def test_get_test_db_clone_settings_name(self): test_connection = copy.copy(connections[DEFAULT_DB_ALIAS]) @@ -21,13 +23,13 @@ class TestDbSignatureTests(SimpleTestCase): connections[DEFAULT_DB_ALIAS].settings_dict, ) tests = [ - ('test.sqlite3', 'test_1.sqlite3'), - ('test', 'test_1'), + ("test.sqlite3", "test_1.sqlite3"), + ("test", "test_1"), ] for test_db_name, expected_clone_name in tests: with self.subTest(test_db_name=test_db_name): - test_connection.settings_dict['NAME'] = test_db_name - test_connection.settings_dict['TEST']['NAME'] = test_db_name + test_connection.settings_dict["NAME"] = test_db_name + test_connection.settings_dict["TEST"]["NAME"] = test_db_name creation_class = test_connection.creation_class(test_connection) - clone_settings_dict = creation_class.get_test_db_clone_settings('1') - self.assertEqual(clone_settings_dict['NAME'], expected_clone_name) + clone_settings_dict = creation_class.get_test_db_clone_settings("1") + self.assertEqual(clone_settings_dict["NAME"], expected_clone_name) diff --git a/tests/backends/sqlite/test_features.py b/tests/backends/sqlite/test_features.py index 9b74794408..50ccbbd3cc 100644 --- a/tests/backends/sqlite/test_features.py +++ b/tests/backends/sqlite/test_features.py @@ -4,14 +4,14 @@ from django.db import OperationalError, connection from django.test import TestCase -@skipUnless(connection.vendor == 'sqlite', 'SQLite tests.') +@skipUnless(connection.vendor == "sqlite", "SQLite tests.") class FeaturesTests(TestCase): def test_supports_json_field_operational_error(self): - if hasattr(connection.features, 'supports_json_field'): + if hasattr(connection.features, "supports_json_field"): del connection.features.supports_json_field - msg = 'unable to open database file' + msg = "unable to open database file" with mock.patch( - 'django.db.backends.base.base.BaseDatabaseWrapper.cursor', + "django.db.backends.base.base.BaseDatabaseWrapper.cursor", side_effect=OperationalError(msg), ): with self.assertRaisesMessage(OperationalError, msg): diff --git a/tests/backends/sqlite/test_functions.py b/tests/backends/sqlite/test_functions.py index 0659717799..1102d5873e 100644 --- a/tests/backends/sqlite/test_functions.py +++ b/tests/backends/sqlite/test_functions.py @@ -1,5 +1,7 @@ from django.db.backends.sqlite3._functions import ( - _sqlite_date_trunc, _sqlite_datetime_trunc, _sqlite_time_trunc, + _sqlite_date_trunc, + _sqlite_datetime_trunc, + _sqlite_time_trunc, ) from django.test import SimpleTestCase @@ -8,14 +10,14 @@ class FunctionTests(SimpleTestCase): def test_sqlite_date_trunc(self): msg = "Unsupported lookup type: 'unknown-lookup'" with self.assertRaisesMessage(ValueError, msg): - _sqlite_date_trunc('unknown-lookup', '2005-08-11', None, None) + _sqlite_date_trunc("unknown-lookup", "2005-08-11", None, None) def test_sqlite_datetime_trunc(self): msg = "Unsupported lookup type: 'unknown-lookup'" with self.assertRaisesMessage(ValueError, msg): - _sqlite_datetime_trunc('unknown-lookup', '2005-08-11 1:00:00', None, None) + _sqlite_datetime_trunc("unknown-lookup", "2005-08-11 1:00:00", None, None) def test_sqlite_time_trunc(self): msg = "Unsupported lookup type: 'unknown-lookup'" with self.assertRaisesMessage(ValueError, msg): - _sqlite_time_trunc('unknown-lookup', '2005-08-11 1:00:00', None, None) + _sqlite_time_trunc("unknown-lookup", "2005-08-11 1:00:00", None, None) diff --git a/tests/backends/sqlite/test_introspection.py b/tests/backends/sqlite/test_introspection.py index 9331b5bb1a..2997ac9595 100644 --- a/tests/backends/sqlite/test_introspection.py +++ b/tests/backends/sqlite/test_introspection.py @@ -6,7 +6,7 @@ from django.db import connection from django.test import TestCase -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") class IntrospectionTests(TestCase): def test_get_primary_key_column(self): """ @@ -14,19 +14,26 @@ class IntrospectionTests(TestCase): quotation. """ testable_column_strings = ( - ('id', 'id'), ('[id]', 'id'), ('`id`', 'id'), ('"id"', 'id'), - ('[id col]', 'id col'), ('`id col`', 'id col'), ('"id col"', 'id col') + ("id", "id"), + ("[id]", "id"), + ("`id`", "id"), + ('"id"', "id"), + ("[id col]", "id col"), + ("`id col`", "id col"), + ('"id col"', "id col"), ) with connection.cursor() as cursor: for column, expected_string in testable_column_strings: - sql = 'CREATE TABLE test_primary (%s int PRIMARY KEY NOT NULL)' % column + sql = "CREATE TABLE test_primary (%s int PRIMARY KEY NOT NULL)" % column with self.subTest(column=column): try: cursor.execute(sql) - field = connection.introspection.get_primary_key_column(cursor, 'test_primary') + field = connection.introspection.get_primary_key_column( + cursor, "test_primary" + ) self.assertEqual(field, expected_string) finally: - cursor.execute('DROP TABLE test_primary') + cursor.execute("DROP TABLE test_primary") def test_get_primary_key_column_pk_constraint(self): sql = """ @@ -41,38 +48,43 @@ class IntrospectionTests(TestCase): cursor.execute(sql) field = connection.introspection.get_primary_key_column( cursor, - 'test_primary', + "test_primary", ) - self.assertEqual(field, 'id') + self.assertEqual(field, "id") finally: - cursor.execute('DROP TABLE test_primary') + cursor.execute("DROP TABLE test_primary") -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") class ParsingTests(TestCase): def parse_definition(self, sql, columns): """Parse a column or constraint definition.""" statement = sqlparse.parse(sql)[0] tokens = (token for token in statement.flatten() if not token.is_whitespace) with connection.cursor(): - return connection.introspection._parse_column_or_constraint_definition(tokens, set(columns)) + return connection.introspection._parse_column_or_constraint_definition( + tokens, set(columns) + ) def assertConstraint(self, constraint_details, cols, unique=False, check=False): - self.assertEqual(constraint_details, { - 'unique': unique, - 'columns': cols, - 'primary_key': False, - 'foreign_key': None, - 'check': check, - 'index': False, - }) + self.assertEqual( + constraint_details, + { + "unique": unique, + "columns": cols, + "primary_key": False, + "foreign_key": None, + "check": check, + "index": False, + }, + ) def test_unique_column(self): tests = ( - ('"ref" integer UNIQUE,', ['ref']), - ('ref integer UNIQUE,', ['ref']), - ('"customname" integer UNIQUE,', ['customname']), - ('customname integer UNIQUE,', ['customname']), + ('"ref" integer UNIQUE,', ["ref"]), + ("ref integer UNIQUE,", ["ref"]), + ('"customname" integer UNIQUE,', ["customname"]), + ("customname integer UNIQUE,", ["customname"]), ) for sql, columns in tests: with self.subTest(sql=sql): @@ -83,10 +95,18 @@ class ParsingTests(TestCase): def test_unique_constraint(self): tests = ( - ('CONSTRAINT "ref" UNIQUE ("ref"),', 'ref', ['ref']), - ('CONSTRAINT ref UNIQUE (ref),', 'ref', ['ref']), - ('CONSTRAINT "customname1" UNIQUE ("customname2"),', 'customname1', ['customname2']), - ('CONSTRAINT customname1 UNIQUE (customname2),', 'customname1', ['customname2']), + ('CONSTRAINT "ref" UNIQUE ("ref"),', "ref", ["ref"]), + ("CONSTRAINT ref UNIQUE (ref),", "ref", ["ref"]), + ( + 'CONSTRAINT "customname1" UNIQUE ("customname2"),', + "customname1", + ["customname2"], + ), + ( + "CONSTRAINT customname1 UNIQUE (customname2),", + "customname1", + ["customname2"], + ), ) for sql, constraint_name, columns in tests: with self.subTest(sql=sql): @@ -97,8 +117,12 @@ class ParsingTests(TestCase): def test_unique_constraint_multicolumn(self): tests = ( - ('CONSTRAINT "ref" UNIQUE ("ref", "customname"),', 'ref', ['ref', 'customname']), - ('CONSTRAINT ref UNIQUE (ref, customname),', 'ref', ['ref', 'customname']), + ( + 'CONSTRAINT "ref" UNIQUE ("ref", "customname"),', + "ref", + ["ref", "customname"], + ), + ("CONSTRAINT ref UNIQUE (ref, customname),", "ref", ["ref", "customname"]), ) for sql, constraint_name, columns in tests: with self.subTest(sql=sql): @@ -109,10 +133,16 @@ class ParsingTests(TestCase): def test_check_column(self): tests = ( - ('"ref" varchar(255) CHECK ("ref" != \'test\'),', ['ref']), - ('ref varchar(255) CHECK (ref != \'test\'),', ['ref']), - ('"customname1" varchar(255) CHECK ("customname2" != \'test\'),', ['customname2']), - ('customname1 varchar(255) CHECK (customname2 != \'test\'),', ['customname2']), + ('"ref" varchar(255) CHECK ("ref" != \'test\'),', ["ref"]), + ("ref varchar(255) CHECK (ref != 'test'),", ["ref"]), + ( + '"customname1" varchar(255) CHECK ("customname2" != \'test\'),', + ["customname2"], + ), + ( + "customname1 varchar(255) CHECK (customname2 != 'test'),", + ["customname2"], + ), ) for sql, columns in tests: with self.subTest(sql=sql): @@ -123,10 +153,18 @@ class ParsingTests(TestCase): def test_check_constraint(self): tests = ( - ('CONSTRAINT "ref" CHECK ("ref" != \'test\'),', 'ref', ['ref']), - ('CONSTRAINT ref CHECK (ref != \'test\'),', 'ref', ['ref']), - ('CONSTRAINT "customname1" CHECK ("customname2" != \'test\'),', 'customname1', ['customname2']), - ('CONSTRAINT customname1 CHECK (customname2 != \'test\'),', 'customname1', ['customname2']), + ('CONSTRAINT "ref" CHECK ("ref" != \'test\'),', "ref", ["ref"]), + ("CONSTRAINT ref CHECK (ref != 'test'),", "ref", ["ref"]), + ( + 'CONSTRAINT "customname1" CHECK ("customname2" != \'test\'),', + "customname1", + ["customname2"], + ), + ( + "CONSTRAINT customname1 CHECK (customname2 != 'test'),", + "customname1", + ["customname2"], + ), ) for sql, constraint_name, columns in tests: with self.subTest(sql=sql): @@ -137,9 +175,12 @@ class ParsingTests(TestCase): def test_check_column_with_operators_and_functions(self): tests = ( - ('"ref" integer CHECK ("ref" BETWEEN 1 AND 10),', ['ref']), - ('"ref" varchar(255) CHECK ("ref" LIKE \'test%\'),', ['ref']), - ('"ref" varchar(255) CHECK (LENGTH(ref) > "max_length"),', ['ref', 'max_length']), + ('"ref" integer CHECK ("ref" BETWEEN 1 AND 10),', ["ref"]), + ('"ref" varchar(255) CHECK ("ref" LIKE \'test%\'),', ["ref"]), + ( + '"ref" varchar(255) CHECK (LENGTH(ref) > "max_length"),', + ["ref", "max_length"], + ), ) for sql, columns in tests: with self.subTest(sql=sql): @@ -150,8 +191,8 @@ class ParsingTests(TestCase): def test_check_and_unique_column(self): tests = ( - ('"ref" varchar(255) CHECK ("ref" != \'test\') UNIQUE,', ['ref']), - ('ref varchar(255) UNIQUE CHECK (ref != \'test\'),', ['ref']), + ('"ref" varchar(255) CHECK ("ref" != \'test\') UNIQUE,', ["ref"]), + ("ref varchar(255) UNIQUE CHECK (ref != 'test'),", ["ref"]), ) for sql, columns in tests: with self.subTest(sql=sql): diff --git a/tests/backends/sqlite/test_operations.py b/tests/backends/sqlite/test_operations.py index 863a978580..3ff055248d 100644 --- a/tests/backends/sqlite/test_operations.py +++ b/tests/backends/sqlite/test_operations.py @@ -7,7 +7,7 @@ from django.test import TestCase from ..models import Person, Tag -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests.') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests.") class SQLiteOperationsTests(TestCase): def test_sql_flush(self): self.assertEqual( @@ -34,7 +34,7 @@ class SQLiteOperationsTests(TestCase): 'DELETE FROM "backends_person";', 'DELETE FROM "backends_tag";', 'DELETE FROM "backends_verylongmodelnamezzzzzzzzzzzzzzzzzzzzzz' - 'zzzzzzzzzzzzzzzzzzzz_m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzz' + "zzzzzzzzzzzzzzzzzzzz_m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzz" 'zzzzzzzzzzzzzzzzzzzzzzz";', ], ) @@ -50,7 +50,7 @@ class SQLiteOperationsTests(TestCase): 'DELETE FROM "backends_person";', 'DELETE FROM "backends_tag";', 'UPDATE "sqlite_sequence" SET "seq" = 0 WHERE "name" IN ' - '(\'backends_person\', \'backends_tag\');', + "('backends_person', 'backends_tag');", ], ) @@ -68,13 +68,16 @@ class SQLiteOperationsTests(TestCase): 'DELETE FROM "backends_person";', 'DELETE FROM "backends_tag";', 'DELETE FROM "backends_verylongmodelnamezzzzzzzzzzzzzzzzzzzzzz' - 'zzzzzzzzzzzzzzzzzzzz_m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzz' + "zzzzzzzzzzzzzzzzzzzz_m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzz" 'zzzzzzzzzzzzzzzzzzzzzzz";', ], ) - self.assertIs(statements[-1].startswith( - 'UPDATE "sqlite_sequence" SET "seq" = 0 WHERE "name" IN (' - ), True) + self.assertIs( + statements[-1].startswith( + 'UPDATE "sqlite_sequence" SET "seq" = 0 WHERE "name" IN (' + ), + True, + ) self.assertIn("'backends_person'", statements[-1]) self.assertIn("'backends_tag'", statements[-1]) self.assertIn( diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py index 07477a5c71..e167e09dcf 100644 --- a/tests/backends/sqlite/tests.py +++ b/tests/backends/sqlite/tests.py @@ -12,7 +12,10 @@ from django.db import NotSupportedError, connection, transaction from django.db.models import Aggregate, Avg, CharField, StdDev, Sum, Variance from django.db.utils import ConnectionHandler from django.test import ( - TestCase, TransactionTestCase, override_settings, skipIfDBFeature, + TestCase, + TransactionTestCase, + override_settings, + skipIfDBFeature, ) from django.test.utils import isolate_apps @@ -25,35 +28,43 @@ except ImproperlyConfigured: pass -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") class Tests(TestCase): longMessage = True def test_check_sqlite_version(self): - msg = 'SQLite 3.9.0 or later is required (found 3.8.11.1).' - with mock.patch.object(dbapi2, 'sqlite_version_info', (3, 8, 11, 1)), \ - mock.patch.object(dbapi2, 'sqlite_version', '3.8.11.1'), \ - self.assertRaisesMessage(ImproperlyConfigured, msg): + msg = "SQLite 3.9.0 or later is required (found 3.8.11.1)." + with mock.patch.object( + dbapi2, "sqlite_version_info", (3, 8, 11, 1) + ), mock.patch.object( + dbapi2, "sqlite_version", "3.8.11.1" + ), self.assertRaisesMessage( + ImproperlyConfigured, msg + ): check_sqlite_version() def test_aggregation(self): """Raise NotSupportedError when aggregating on date/time fields.""" for aggregate in (Sum, Avg, Variance, StdDev): with self.assertRaises(NotSupportedError): - Item.objects.all().aggregate(aggregate('time')) + Item.objects.all().aggregate(aggregate("time")) with self.assertRaises(NotSupportedError): - Item.objects.all().aggregate(aggregate('date')) + Item.objects.all().aggregate(aggregate("date")) with self.assertRaises(NotSupportedError): - Item.objects.all().aggregate(aggregate('last_modified')) + Item.objects.all().aggregate(aggregate("last_modified")) with self.assertRaises(NotSupportedError): Item.objects.all().aggregate( - **{'complex': aggregate('last_modified') + aggregate('last_modified')} + **{ + "complex": aggregate("last_modified") + + aggregate("last_modified") + } ) def test_distinct_aggregation(self): class DistinctAggregate(Aggregate): allow_distinct = True - aggregate = DistinctAggregate('first', 'second', distinct=True) + + aggregate = DistinctAggregate("first", "second", distinct=True) msg = ( "SQLite doesn't support DISTINCT on aggregate functions accepting " "multiple arguments." @@ -67,32 +78,36 @@ class Tests(TestCase): class DistinctAggregate(Aggregate): allow_distinct = True - aggregate = DistinctAggregate('first', 'second', distinct=False) + aggregate = DistinctAggregate("first", "second", distinct=False) connection.ops.check_expression_support(aggregate) def test_memory_db_test_name(self): """A named in-memory db should be allowed where supported.""" from django.db.backends.sqlite3.base import DatabaseWrapper + settings_dict = { - 'TEST': { - 'NAME': 'file:memorydb_test?mode=memory&cache=shared', + "TEST": { + "NAME": "file:memorydb_test?mode=memory&cache=shared", } } creation = DatabaseWrapper(settings_dict).creation - self.assertEqual(creation._get_test_db_name(), creation.connection.settings_dict['TEST']['NAME']) + self.assertEqual( + creation._get_test_db_name(), + creation.connection.settings_dict["TEST"]["NAME"], + ) def test_regexp_function(self): tests = ( - ('test', r'[0-9]+', False), - ('test', r'[a-z]+', True), - ('test', None, None), - (None, r'[a-z]+', None), + ("test", r"[0-9]+", False), + ("test", r"[a-z]+", True), + ("test", None, None), + (None, r"[a-z]+", None), (None, None, None), ) for string, pattern, expected in tests: with self.subTest((string, pattern)): with connection.cursor() as cursor: - cursor.execute('SELECT %s REGEXP %s', [string, pattern]) + cursor.execute("SELECT %s REGEXP %s", [string, pattern]) value = cursor.fetchone()[0] value = bool(value) if value in {0, 1} else value self.assertIs(value, expected) @@ -100,22 +115,22 @@ class Tests(TestCase): def test_pathlib_name(self): with tempfile.TemporaryDirectory() as tmp: settings_dict = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': Path(tmp) / 'test.db', + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": Path(tmp) / "test.db", }, } connections = ConnectionHandler(settings_dict) - connections['default'].ensure_connection() - connections['default'].close() - self.assertTrue(os.path.isfile(os.path.join(tmp, 'test.db'))) + connections["default"].ensure_connection() + connections["default"].close() + self.assertTrue(os.path.isfile(os.path.join(tmp, "test.db"))) -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') -@isolate_apps('backends') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") +@isolate_apps("backends") class SchemaTests(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def test_autoincrement(self): """ @@ -128,9 +143,9 @@ class SchemaTests(TransactionTestCase): match = re.search('"id" ([^,]+),', statements[0]) self.assertIsNotNone(match) self.assertEqual( - 'integer NOT NULL PRIMARY KEY AUTOINCREMENT', + "integer NOT NULL PRIMARY KEY AUTOINCREMENT", match[1], - 'Wrong SQL used to create an auto-increment column on SQLite' + "Wrong SQL used to create an auto-increment column on SQLite", ) def test_disable_constraint_checking_failure_disallowed(self): @@ -139,11 +154,11 @@ class SchemaTests(TransactionTestCase): foreign key constraint checks are not disabled beforehand. """ msg = ( - 'SQLite schema editor cannot be used while foreign key ' - 'constraint checks are enabled. Make sure to disable them ' - 'before entering a transaction.atomic() context because ' - 'SQLite does not support disabling them in the middle of ' - 'a multi-statement transaction.' + "SQLite schema editor cannot be used while foreign key " + "constraint checks are enabled. Make sure to disable them " + "before entering a transaction.atomic() context because " + "SQLite does not support disabling them in the middle of " + "a multi-statement transaction." ) with self.assertRaisesMessage(NotSupportedError, msg): with transaction.atomic(), connection.schema_editor(atomic=True): @@ -154,23 +169,25 @@ class SchemaTests(TransactionTestCase): SQLite schema editor is usable within an outer transaction as long as foreign key constraints checks are disabled beforehand. """ + def constraint_checks_enabled(): with connection.cursor() as cursor: - return bool(cursor.execute('PRAGMA foreign_keys').fetchone()[0]) + return bool(cursor.execute("PRAGMA foreign_keys").fetchone()[0]) + with connection.constraint_checks_disabled(), transaction.atomic(): with connection.schema_editor(atomic=True): self.assertFalse(constraint_checks_enabled()) self.assertFalse(constraint_checks_enabled()) self.assertTrue(constraint_checks_enabled()) - @skipIfDBFeature('supports_atomic_references_rename') + @skipIfDBFeature("supports_atomic_references_rename") def test_field_rename_inside_atomic_block(self): """ NotImplementedError is raised when a model field rename is attempted inside an atomic block. """ new_field = CharField(max_length=255, unique=True) - new_field.set_attributes_from_name('renamed') + new_field.set_attributes_from_name("renamed") msg = ( "Renaming the 'backends_author'.'name' column while in a " "transaction is not supported on SQLite < 3.26 because it would " @@ -179,9 +196,9 @@ class SchemaTests(TransactionTestCase): ) with self.assertRaisesMessage(NotSupportedError, msg): with connection.schema_editor(atomic=True) as editor: - editor.alter_field(Author, Author._meta.get_field('name'), new_field) + editor.alter_field(Author, Author._meta.get_field("name"), new_field) - @skipIfDBFeature('supports_atomic_references_rename') + @skipIfDBFeature("supports_atomic_references_rename") def test_table_rename_inside_atomic_block(self): """ NotImplementedError is raised when a table rename is attempted inside @@ -197,16 +214,15 @@ class SchemaTests(TransactionTestCase): editor.alter_db_table(Author, "backends_author", "renamed_table") -@unittest.skipUnless(connection.vendor == 'sqlite', 'Test only for SQLite') +@unittest.skipUnless(connection.vendor == "sqlite", "Test only for SQLite") @override_settings(DEBUG=True) class LastExecutedQueryTest(TestCase): - def test_no_interpolation(self): # This shouldn't raise an exception (#17158) query = "SELECT strftime('%Y', 'now');" with connection.cursor() as cursor: cursor.execute(query) - self.assertEqual(connection.queries[-1]['sql'], query) + self.assertEqual(connection.queries[-1]["sql"], query) def test_parameter_quoting(self): # The implementation of last_executed_queries isn't optimal. It's @@ -217,7 +233,7 @@ class LastExecutedQueryTest(TestCase): cursor.execute(query, params) # Note that the single quote is repeated substituted = "SELECT '\"''\\'" - self.assertEqual(connection.queries[-1]['sql'], substituted) + self.assertEqual(connection.queries[-1]["sql"], substituted) def test_large_number_of_parameters(self): # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be @@ -230,12 +246,13 @@ class LastExecutedQueryTest(TestCase): cursor.db.ops.last_executed_query(cursor.cursor, sql, params) -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") class EscapingChecks(TestCase): """ All tests in this test case are also run with settings.DEBUG=True in EscapingChecksDebug test case, to also test CursorDebugWrapper. """ + def test_parameter_escaping(self): # '%s' escaping support for sqlite3 (#13648). with connection.cursor() as cursor: @@ -245,19 +262,20 @@ class EscapingChecks(TestCase): self.assertTrue(int(response)) -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") @override_settings(DEBUG=True) class EscapingChecksDebug(EscapingChecks): pass -@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests") class ThreadSharing(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def test_database_sharing_in_threads(self): def create_object(): Object.objects.create() + create_object() thread = threading.Thread(target=create_object) thread.start() diff --git a/tests/backends/test_ddl_references.py b/tests/backends/test_ddl_references.py index ab3ba6ccd2..86984ed3e8 100644 --- a/tests/backends/test_ddl_references.py +++ b/tests/backends/test_ddl_references.py @@ -1,6 +1,11 @@ from django.db import connection from django.db.backends.ddl_references import ( - Columns, Expressions, ForeignKeyName, IndexName, Statement, Table, + Columns, + Expressions, + ForeignKeyName, + IndexName, + Statement, + Table, ) from django.db.models import ExpressionList, F from django.db.models.functions import Upper @@ -13,119 +18,150 @@ from .models import Person class TableTests(SimpleTestCase): def setUp(self): - self.reference = Table('table', lambda table: table.upper()) + self.reference = Table("table", lambda table: table.upper()) def test_references_table(self): - self.assertIs(self.reference.references_table('table'), True) - self.assertIs(self.reference.references_table('other'), False) + self.assertIs(self.reference.references_table("table"), True) + self.assertIs(self.reference.references_table("other"), False) def test_rename_table_references(self): - self.reference.rename_table_references('other', 'table') - self.assertIs(self.reference.references_table('table'), True) - self.assertIs(self.reference.references_table('other'), False) - self.reference.rename_table_references('table', 'other') - self.assertIs(self.reference.references_table('table'), False) - self.assertIs(self.reference.references_table('other'), True) + self.reference.rename_table_references("other", "table") + self.assertIs(self.reference.references_table("table"), True) + self.assertIs(self.reference.references_table("other"), False) + self.reference.rename_table_references("table", "other") + self.assertIs(self.reference.references_table("table"), False) + self.assertIs(self.reference.references_table("other"), True) def test_repr(self): self.assertEqual(repr(self.reference), "<Table 'TABLE'>") def test_str(self): - self.assertEqual(str(self.reference), 'TABLE') + self.assertEqual(str(self.reference), "TABLE") class ColumnsTests(TableTests): def setUp(self): self.reference = Columns( - 'table', ['first_column', 'second_column'], lambda column: column.upper() + "table", ["first_column", "second_column"], lambda column: column.upper() ) def test_references_column(self): - self.assertIs(self.reference.references_column('other', 'first_column'), False) - self.assertIs(self.reference.references_column('table', 'third_column'), False) - self.assertIs(self.reference.references_column('table', 'first_column'), True) + self.assertIs(self.reference.references_column("other", "first_column"), False) + self.assertIs(self.reference.references_column("table", "third_column"), False) + self.assertIs(self.reference.references_column("table", "first_column"), True) def test_rename_column_references(self): - self.reference.rename_column_references('other', 'first_column', 'third_column') - self.assertIs(self.reference.references_column('table', 'first_column'), True) - self.assertIs(self.reference.references_column('table', 'third_column'), False) - self.assertIs(self.reference.references_column('other', 'third_column'), False) - self.reference.rename_column_references('table', 'third_column', 'first_column') - self.assertIs(self.reference.references_column('table', 'first_column'), True) - self.assertIs(self.reference.references_column('table', 'third_column'), False) - self.reference.rename_column_references('table', 'first_column', 'third_column') - self.assertIs(self.reference.references_column('table', 'first_column'), False) - self.assertIs(self.reference.references_column('table', 'third_column'), True) + self.reference.rename_column_references("other", "first_column", "third_column") + self.assertIs(self.reference.references_column("table", "first_column"), True) + self.assertIs(self.reference.references_column("table", "third_column"), False) + self.assertIs(self.reference.references_column("other", "third_column"), False) + self.reference.rename_column_references("table", "third_column", "first_column") + self.assertIs(self.reference.references_column("table", "first_column"), True) + self.assertIs(self.reference.references_column("table", "third_column"), False) + self.reference.rename_column_references("table", "first_column", "third_column") + self.assertIs(self.reference.references_column("table", "first_column"), False) + self.assertIs(self.reference.references_column("table", "third_column"), True) def test_repr(self): - self.assertEqual(repr(self.reference), "<Columns 'FIRST_COLUMN, SECOND_COLUMN'>") + self.assertEqual( + repr(self.reference), "<Columns 'FIRST_COLUMN, SECOND_COLUMN'>" + ) def test_str(self): - self.assertEqual(str(self.reference), 'FIRST_COLUMN, SECOND_COLUMN') + self.assertEqual(str(self.reference), "FIRST_COLUMN, SECOND_COLUMN") class IndexNameTests(ColumnsTests): def setUp(self): def create_index_name(table_name, column_names, suffix): - return ', '.join("%s_%s_%s" % (table_name, column_name, suffix) for column_name in column_names) + return ", ".join( + "%s_%s_%s" % (table_name, column_name, suffix) + for column_name in column_names + ) + self.reference = IndexName( - 'table', ['first_column', 'second_column'], 'suffix', create_index_name + "table", ["first_column", "second_column"], "suffix", create_index_name ) def test_repr(self): - self.assertEqual(repr(self.reference), "<IndexName 'table_first_column_suffix, table_second_column_suffix'>") + self.assertEqual( + repr(self.reference), + "<IndexName 'table_first_column_suffix, table_second_column_suffix'>", + ) def test_str(self): - self.assertEqual(str(self.reference), 'table_first_column_suffix, table_second_column_suffix') + self.assertEqual( + str(self.reference), "table_first_column_suffix, table_second_column_suffix" + ) class ForeignKeyNameTests(IndexNameTests): def setUp(self): def create_foreign_key_name(table_name, column_names, suffix): - return ', '.join("%s_%s_%s" % (table_name, column_name, suffix) for column_name in column_names) + return ", ".join( + "%s_%s_%s" % (table_name, column_name, suffix) + for column_name in column_names + ) + self.reference = ForeignKeyName( - 'table', ['first_column', 'second_column'], - 'to_table', ['to_first_column', 'to_second_column'], - '%(to_table)s_%(to_column)s_fk', + "table", + ["first_column", "second_column"], + "to_table", + ["to_first_column", "to_second_column"], + "%(to_table)s_%(to_column)s_fk", create_foreign_key_name, ) def test_references_table(self): super().test_references_table() - self.assertIs(self.reference.references_table('to_table'), True) + self.assertIs(self.reference.references_table("to_table"), True) def test_references_column(self): super().test_references_column() - self.assertIs(self.reference.references_column('to_table', 'second_column'), False) - self.assertIs(self.reference.references_column('to_table', 'to_second_column'), True) + self.assertIs( + self.reference.references_column("to_table", "second_column"), False + ) + self.assertIs( + self.reference.references_column("to_table", "to_second_column"), True + ) def test_rename_table_references(self): super().test_rename_table_references() - self.reference.rename_table_references('to_table', 'other_to_table') - self.assertIs(self.reference.references_table('other_to_table'), True) - self.assertIs(self.reference.references_table('to_table'), False) + self.reference.rename_table_references("to_table", "other_to_table") + self.assertIs(self.reference.references_table("other_to_table"), True) + self.assertIs(self.reference.references_table("to_table"), False) def test_rename_column_references(self): super().test_rename_column_references() - self.reference.rename_column_references('to_table', 'second_column', 'third_column') - self.assertIs(self.reference.references_column('table', 'second_column'), True) - self.assertIs(self.reference.references_column('to_table', 'to_second_column'), True) - self.reference.rename_column_references('to_table', 'to_first_column', 'to_third_column') - self.assertIs(self.reference.references_column('to_table', 'to_first_column'), False) - self.assertIs(self.reference.references_column('to_table', 'to_third_column'), True) + self.reference.rename_column_references( + "to_table", "second_column", "third_column" + ) + self.assertIs(self.reference.references_column("table", "second_column"), True) + self.assertIs( + self.reference.references_column("to_table", "to_second_column"), True + ) + self.reference.rename_column_references( + "to_table", "to_first_column", "to_third_column" + ) + self.assertIs( + self.reference.references_column("to_table", "to_first_column"), False + ) + self.assertIs( + self.reference.references_column("to_table", "to_third_column"), True + ) def test_repr(self): self.assertEqual( repr(self.reference), "<ForeignKeyName 'table_first_column_to_table_to_first_column_fk, " - "table_second_column_to_table_to_first_column_fk'>" + "table_second_column_to_table_to_first_column_fk'>", ) def test_str(self): self.assertEqual( str(self.reference), - 'table_first_column_to_table_to_first_column_fk, ' - 'table_second_column_to_table_to_first_column_fk' + "table_first_column_to_table_to_first_column_fk, " + "table_second_column_to_table_to_first_column_fk", ) @@ -158,36 +194,48 @@ class MockReference: class StatementTests(SimpleTestCase): def test_references_table(self): - statement = Statement('', reference=MockReference('', {'table'}, {}), non_reference='') - self.assertIs(statement.references_table('table'), True) - self.assertIs(statement.references_table('other'), False) + statement = Statement( + "", reference=MockReference("", {"table"}, {}), non_reference="" + ) + self.assertIs(statement.references_table("table"), True) + self.assertIs(statement.references_table("other"), False) def test_references_column(self): - statement = Statement('', reference=MockReference('', {}, {('table', 'column')}), non_reference='') - self.assertIs(statement.references_column('table', 'column'), True) - self.assertIs(statement.references_column('other', 'column'), False) + statement = Statement( + "", reference=MockReference("", {}, {("table", "column")}), non_reference="" + ) + self.assertIs(statement.references_column("table", "column"), True) + self.assertIs(statement.references_column("other", "column"), False) def test_rename_table_references(self): - reference = MockReference('', {'table'}, {}) - statement = Statement('', reference=reference, non_reference='') - statement.rename_table_references('table', 'other') - self.assertEqual(reference.referenced_tables, {'other'}) + reference = MockReference("", {"table"}, {}) + statement = Statement("", reference=reference, non_reference="") + statement.rename_table_references("table", "other") + self.assertEqual(reference.referenced_tables, {"other"}) def test_rename_column_references(self): - reference = MockReference('', {}, {('table', 'column')}) - statement = Statement('', reference=reference, non_reference='') - statement.rename_column_references('table', 'column', 'other') - self.assertEqual(reference.referenced_columns, {('table', 'other')}) + reference = MockReference("", {}, {("table", "column")}) + statement = Statement("", reference=reference, non_reference="") + statement.rename_column_references("table", "column", "other") + self.assertEqual(reference.referenced_columns, {("table", "other")}) def test_repr(self): - reference = MockReference('reference', {}, {}) - statement = Statement("%(reference)s - %(non_reference)s", reference=reference, non_reference='non_reference') + reference = MockReference("reference", {}, {}) + statement = Statement( + "%(reference)s - %(non_reference)s", + reference=reference, + non_reference="non_reference", + ) self.assertEqual(repr(statement), "<Statement 'reference - non_reference'>") def test_str(self): - reference = MockReference('reference', {}, {}) - statement = Statement("%(reference)s - %(non_reference)s", reference=reference, non_reference='non_reference') - self.assertEqual(str(statement), 'reference - non_reference') + reference = MockReference("reference", {}, {}) + statement = Statement( + "%(reference)s - %(non_reference)s", + reference=reference, + non_reference="non_reference", + ) + self.assertEqual(str(statement), "reference - non_reference") class ExpressionsTests(TransactionTestCase): @@ -199,9 +247,9 @@ class ExpressionsTests(TransactionTestCase): self.expressions = Expressions( table=Person._meta.db_table, expressions=ExpressionList( - IndexExpression(F('first_name')), - IndexExpression(F('last_name').desc()), - IndexExpression(Upper('last_name')), + IndexExpression(F("first_name")), + IndexExpression(F("last_name").desc()), + IndexExpression(Upper("last_name")), ).resolve_expression(compiler.query), compiler=compiler, quote_value=self.editor.quote_value, @@ -209,23 +257,24 @@ class ExpressionsTests(TransactionTestCase): def test_references_table(self): self.assertIs(self.expressions.references_table(Person._meta.db_table), True) - self.assertIs(self.expressions.references_table('other'), False) + self.assertIs(self.expressions.references_table("other"), False) def test_references_column(self): table = Person._meta.db_table - self.assertIs(self.expressions.references_column(table, 'first_name'), True) - self.assertIs(self.expressions.references_column(table, 'last_name'), True) - self.assertIs(self.expressions.references_column(table, 'other'), False) + self.assertIs(self.expressions.references_column(table, "first_name"), True) + self.assertIs(self.expressions.references_column(table, "last_name"), True) + self.assertIs(self.expressions.references_column(table, "other"), False) def test_rename_table_references(self): table = Person._meta.db_table - self.expressions.rename_table_references(table, 'other') + self.expressions.rename_table_references(table, "other") self.assertIs(self.expressions.references_table(table), False) - self.assertIs(self.expressions.references_table('other'), True) + self.assertIs(self.expressions.references_table("other"), True) self.assertIn( - '%s.%s' % ( - self.editor.quote_name('other'), - self.editor.quote_name('first_name'), + "%s.%s" + % ( + self.editor.quote_name("other"), + self.editor.quote_name("first_name"), ), str(self.expressions), ) @@ -236,39 +285,39 @@ class ExpressionsTests(TransactionTestCase): expressions = Expressions( table=table, expressions=ExpressionList( - IndexExpression(Upper('last_name')), - IndexExpression(F('first_name')), + IndexExpression(Upper("last_name")), + IndexExpression(F("first_name")), ).resolve_expression(compiler.query), compiler=compiler, quote_value=self.editor.quote_value, ) - expressions.rename_table_references(table, 'other') + expressions.rename_table_references(table, "other") self.assertIs(expressions.references_table(table), False) - self.assertIs(expressions.references_table('other'), True) - expected_str = '(UPPER(%s)), %s' % ( - self.editor.quote_name('last_name'), - self.editor.quote_name('first_name'), + self.assertIs(expressions.references_table("other"), True) + expected_str = "(UPPER(%s)), %s" % ( + self.editor.quote_name("last_name"), + self.editor.quote_name("first_name"), ) self.assertEqual(str(expressions), expected_str) def test_rename_column_references(self): table = Person._meta.db_table - self.expressions.rename_column_references(table, 'first_name', 'other') - self.assertIs(self.expressions.references_column(table, 'other'), True) - self.assertIs(self.expressions.references_column(table, 'first_name'), False) + self.expressions.rename_column_references(table, "first_name", "other") + self.assertIs(self.expressions.references_column(table, "other"), True) + self.assertIs(self.expressions.references_column(table, "first_name"), False) self.assertIn( - '%s.%s' % (self.editor.quote_name(table), self.editor.quote_name('other')), + "%s.%s" % (self.editor.quote_name(table), self.editor.quote_name("other")), str(self.expressions), ) def test_str(self): table_name = self.editor.quote_name(Person._meta.db_table) - expected_str = '%s.%s, %s.%s DESC, (UPPER(%s.%s))' % ( + expected_str = "%s.%s, %s.%s DESC, (UPPER(%s.%s))" % ( table_name, - self.editor.quote_name('first_name'), + self.editor.quote_name("first_name"), table_name, - self.editor.quote_name('last_name'), + self.editor.quote_name("last_name"), table_name, - self.editor.quote_name('last_name'), + self.editor.quote_name("last_name"), ) self.assertEqual(str(self.expressions), expected_str) diff --git a/tests/backends/test_utils.py b/tests/backends/test_utils.py index 54819829fd..1b830eaced 100644 --- a/tests/backends/test_utils.py +++ b/tests/backends/test_utils.py @@ -3,72 +3,87 @@ from decimal import Decimal, Rounded from django.db import NotSupportedError, connection from django.db.backends.utils import ( - format_number, split_identifier, split_tzname_delta, truncate_name, + format_number, + split_identifier, + split_tzname_delta, + truncate_name, ) from django.test import ( - SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, + SimpleTestCase, + TransactionTestCase, + skipIfDBFeature, + skipUnlessDBFeature, ) class TestUtils(SimpleTestCase): - def test_truncate_name(self): - self.assertEqual(truncate_name('some_table', 10), 'some_table') - self.assertEqual(truncate_name('some_long_table', 10), 'some_la38a') - self.assertEqual(truncate_name('some_long_table', 10, 3), 'some_loa38') - self.assertEqual(truncate_name('some_long_table'), 'some_long_table') + self.assertEqual(truncate_name("some_table", 10), "some_table") + self.assertEqual(truncate_name("some_long_table", 10), "some_la38a") + self.assertEqual(truncate_name("some_long_table", 10, 3), "some_loa38") + self.assertEqual(truncate_name("some_long_table"), "some_long_table") # "user"."table" syntax - self.assertEqual(truncate_name('username"."some_table', 10), 'username"."some_table') - self.assertEqual(truncate_name('username"."some_long_table', 10), 'username"."some_la38a') - self.assertEqual(truncate_name('username"."some_long_table', 10, 3), 'username"."some_loa38') + self.assertEqual( + truncate_name('username"."some_table', 10), 'username"."some_table' + ) + self.assertEqual( + truncate_name('username"."some_long_table', 10), 'username"."some_la38a' + ) + self.assertEqual( + truncate_name('username"."some_long_table', 10, 3), 'username"."some_loa38' + ) def test_split_identifier(self): - self.assertEqual(split_identifier('some_table'), ('', 'some_table')) - self.assertEqual(split_identifier('"some_table"'), ('', 'some_table')) - self.assertEqual(split_identifier('namespace"."some_table'), ('namespace', 'some_table')) - self.assertEqual(split_identifier('"namespace"."some_table"'), ('namespace', 'some_table')) + self.assertEqual(split_identifier("some_table"), ("", "some_table")) + self.assertEqual(split_identifier('"some_table"'), ("", "some_table")) + self.assertEqual( + split_identifier('namespace"."some_table'), ("namespace", "some_table") + ) + self.assertEqual( + split_identifier('"namespace"."some_table"'), ("namespace", "some_table") + ) def test_format_number(self): def equal(value, max_d, places, result): self.assertEqual(format_number(Decimal(value), max_d, places), result) - equal('0', 12, 3, '0.000') - equal('0', 12, 8, '0.00000000') - equal('1', 12, 9, '1.000000000') - equal('0.00000000', 12, 8, '0.00000000') - equal('0.000000004', 12, 8, '0.00000000') - equal('0.000000008', 12, 8, '0.00000001') - equal('0.000000000000000000999', 10, 8, '0.00000000') - equal('0.1234567890', 12, 10, '0.1234567890') - equal('0.1234567890', 12, 9, '0.123456789') - equal('0.1234567890', 12, 8, '0.12345679') - equal('0.1234567890', 12, 5, '0.12346') - equal('0.1234567890', 12, 3, '0.123') - equal('0.1234567890', 12, 1, '0.1') - equal('0.1234567890', 12, 0, '0') - equal('0.1234567890', None, 0, '0') - equal('1234567890.1234567890', None, 0, '1234567890') - equal('1234567890.1234567890', None, 2, '1234567890.12') - equal('0.1234', 5, None, '0.1234') - equal('123.12', 5, None, '123.12') + equal("0", 12, 3, "0.000") + equal("0", 12, 8, "0.00000000") + equal("1", 12, 9, "1.000000000") + equal("0.00000000", 12, 8, "0.00000000") + equal("0.000000004", 12, 8, "0.00000000") + equal("0.000000008", 12, 8, "0.00000001") + equal("0.000000000000000000999", 10, 8, "0.00000000") + equal("0.1234567890", 12, 10, "0.1234567890") + equal("0.1234567890", 12, 9, "0.123456789") + equal("0.1234567890", 12, 8, "0.12345679") + equal("0.1234567890", 12, 5, "0.12346") + equal("0.1234567890", 12, 3, "0.123") + equal("0.1234567890", 12, 1, "0.1") + equal("0.1234567890", 12, 0, "0") + equal("0.1234567890", None, 0, "0") + equal("1234567890.1234567890", None, 0, "1234567890") + equal("1234567890.1234567890", None, 2, "1234567890.12") + equal("0.1234", 5, None, "0.1234") + equal("123.12", 5, None, "123.12") with self.assertRaises(Rounded): - equal('0.1234567890', 5, None, '0.12346') + equal("0.1234567890", 5, None, "0.12346") with self.assertRaises(Rounded): - equal('1234567890.1234', 5, None, '1234600000') + equal("1234567890.1234", 5, None, "1234600000") def test_split_tzname_delta(self): tests = [ - ('Asia/Ust+Nera', ('Asia/Ust+Nera', None, None)), - ('Asia/Ust-Nera', ('Asia/Ust-Nera', None, None)), - ('Asia/Ust+Nera-02:00', ('Asia/Ust+Nera', '-', '02:00')), - ('Asia/Ust-Nera+05:00', ('Asia/Ust-Nera', '+', '05:00')), - ('America/Coral_Harbour-01:00', ('America/Coral_Harbour', '-', '01:00')), - ('America/Coral_Harbour+02:30', ('America/Coral_Harbour', '+', '02:30')), - ('UTC+15:00', ('UTC', '+', '15:00')), - ('UTC-04:43', ('UTC', '-', '04:43')), - ('UTC', ('UTC', None, None)), - ('UTC+1', ('UTC+1', None, None)), + ("Asia/Ust+Nera", ("Asia/Ust+Nera", None, None)), + ("Asia/Ust-Nera", ("Asia/Ust-Nera", None, None)), + ("Asia/Ust+Nera-02:00", ("Asia/Ust+Nera", "-", "02:00")), + ("Asia/Ust-Nera+05:00", ("Asia/Ust-Nera", "+", "05:00")), + ("America/Coral_Harbour-01:00", ("America/Coral_Harbour", "-", "01:00")), + ("America/Coral_Harbour+02:30", ("America/Coral_Harbour", "+", "02:30")), + ("UTC+15:00", ("UTC", "+", "15:00")), + ("UTC-04:43", ("UTC", "-", "04:43")), + ("UTC", ("UTC", None, None)), + ("UTC+1", ("UTC+1", None, None)), ] for tzname, expected in tests: with self.subTest(tzname=tzname): @@ -84,25 +99,38 @@ class CursorWrapperTests(TransactionTestCase): # Use a new cursor because in MySQL a procedure can't be used in the # same cursor in which it was created. with connection.cursor() as cursor: - cursor.callproc('test_procedure', params, kparams) + cursor.callproc("test_procedure", params, kparams) with connection.schema_editor() as editor: - editor.remove_procedure('test_procedure', param_types) + editor.remove_procedure("test_procedure", param_types) - @skipUnlessDBFeature('create_test_procedure_without_params_sql') + @skipUnlessDBFeature("create_test_procedure_without_params_sql") def test_callproc_without_params(self): - self._test_procedure(connection.features.create_test_procedure_without_params_sql, [], []) + self._test_procedure( + connection.features.create_test_procedure_without_params_sql, [], [] + ) - @skipUnlessDBFeature('create_test_procedure_with_int_param_sql') + @skipUnlessDBFeature("create_test_procedure_with_int_param_sql") def test_callproc_with_int_params(self): - self._test_procedure(connection.features.create_test_procedure_with_int_param_sql, [1], ['INTEGER']) - - @skipUnlessDBFeature('create_test_procedure_with_int_param_sql', 'supports_callproc_kwargs') + self._test_procedure( + connection.features.create_test_procedure_with_int_param_sql, + [1], + ["INTEGER"], + ) + + @skipUnlessDBFeature( + "create_test_procedure_with_int_param_sql", "supports_callproc_kwargs" + ) def test_callproc_kparams(self): - self._test_procedure(connection.features.create_test_procedure_with_int_param_sql, [], ['INTEGER'], {'P_I': 1}) - - @skipIfDBFeature('supports_callproc_kwargs') + self._test_procedure( + connection.features.create_test_procedure_with_int_param_sql, + [], + ["INTEGER"], + {"P_I": 1}, + ) + + @skipIfDBFeature("supports_callproc_kwargs") def test_unsupported_callproc_kparams_raises_error(self): - msg = 'Keyword parameters for callproc are not supported on this database backend.' + msg = "Keyword parameters for callproc are not supported on this database backend." with self.assertRaisesMessage(NotSupportedError, msg): with connection.cursor() as cursor: - cursor.callproc('test_procedure', [], {'P_I': 1}) + cursor.callproc("test_procedure", [], {"P_I": 1}) diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 2291a76c75..c6c79b4bf7 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -7,27 +7,43 @@ from unittest import mock from django.core.management.color import no_style from django.db import ( - DEFAULT_DB_ALIAS, DatabaseError, IntegrityError, connection, connections, - reset_queries, transaction, + DEFAULT_DB_ALIAS, + DatabaseError, + IntegrityError, + connection, + connections, + reset_queries, + transaction, ) from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.signals import connection_created from django.db.backends.utils import CursorWrapper from django.db.models.sql.constants import CURSOR from django.test import ( - TestCase, TransactionTestCase, override_settings, skipIfDBFeature, + TestCase, + TransactionTestCase, + override_settings, + skipIfDBFeature, skipUnlessDBFeature, ) from .models import ( - Article, Object, ObjectReference, Person, Post, RawData, Reporter, - ReporterProxy, SchoolClass, SQLKeywordsModel, Square, + Article, + Object, + ObjectReference, + Person, + Post, + RawData, + Reporter, + ReporterProxy, + SchoolClass, + SQLKeywordsModel, + Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ, ) class DateQuotingTest(TestCase): - def test_django_date_trunc(self): """ Test the custom ``django_date_trunc method``, in particular against @@ -35,7 +51,7 @@ class DateQuotingTest(TestCase): """ updated = datetime.datetime(2010, 2, 20) SchoolClass.objects.create(year=2009, last_updated=updated) - years = SchoolClass.objects.dates('last_updated', 'year') + years = SchoolClass.objects.dates("last_updated", "year") self.assertEqual(list(years), [datetime.date(2010, 1, 1)]) def test_django_date_extract(self): @@ -51,26 +67,27 @@ class DateQuotingTest(TestCase): @override_settings(DEBUG=True) class LastExecutedQueryTest(TestCase): - def test_last_executed_query_without_previous_query(self): """ last_executed_query should not raise an exception even if no previous query has been run. """ with connection.cursor() as cursor: - connection.ops.last_executed_query(cursor, '', ()) + connection.ops.last_executed_query(cursor, "", ()) def test_debug_sql(self): list(Reporter.objects.filter(first_name="test")) - sql = connection.queries[-1]['sql'].lower() + sql = connection.queries[-1]["sql"].lower() self.assertIn("select", sql) self.assertIn(Reporter._meta.db_table, sql) def test_query_encoding(self): """last_executed_query() returns a string.""" - data = RawData.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1}) + data = RawData.objects.filter(raw_data=b"\x00\x46 \xFE").extra( + select={"föö": 1} + ) sql, params = data.query.sql_with_params() - with data.query.get_compiler('default').execute_sql(CURSOR) as cursor: + with data.query.get_compiler("default").execute_sql(CURSOR) as cursor: last_sql = cursor.db.ops.last_executed_query(cursor, sql, params) self.assertIsInstance(last_sql, str) @@ -92,16 +109,16 @@ class LastExecutedQueryTest(TestCase): str(qs.query), ) - @skipUnlessDBFeature('supports_paramstyle_pyformat') + @skipUnlessDBFeature("supports_paramstyle_pyformat") def test_last_executed_query_dict(self): square_opts = Square._meta - sql = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % ( + sql = "INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)" % ( connection.introspection.identifier_converter(square_opts.db_table), - connection.ops.quote_name(square_opts.get_field('root').column), - connection.ops.quote_name(square_opts.get_field('square').column), + connection.ops.quote_name(square_opts.get_field("root").column), + connection.ops.quote_name(square_opts.get_field("square").column), ) with connection.cursor() as cursor: - params = {'root': 2, 'square': 4} + params = {"root": 2, "square": 4} cursor.execute(sql, params) self.assertEqual( cursor.db.ops.last_executed_query(cursor, sql, params), @@ -110,15 +127,14 @@ class LastExecutedQueryTest(TestCase): class ParameterHandlingTest(TestCase): - def test_bad_parameter_count(self): "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" with connection.cursor() as cursor: - query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( - connection.introspection.identifier_converter('backends_square'), - connection.ops.quote_name('root'), - connection.ops.quote_name('square') - )) + query = "INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % ( + connection.introspection.identifier_converter("backends_square"), + connection.ops.quote_name("root"), + connection.ops.quote_name("square"), + ) with self.assertRaises(Exception): cursor.executemany(query, [(1, 2, 3)]) with self.assertRaises(Exception): @@ -132,7 +148,8 @@ class LongNameTest(TransactionTestCase): the correct sequence name in last_insert_id and other places, so check it is. Refs #8901. """ - available_apps = ['backends'] + + available_apps = ["backends"] def test_sequence_name_length_limits_create(self): """Test creation of model with long name and long pk name doesn't error. Ref #8901""" @@ -143,8 +160,10 @@ class LongNameTest(TransactionTestCase): An m2m save of a model with a long name and a long m2m field name doesn't error (#8901). """ - obj = VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ.objects.create() - rel_obj = Person.objects.create(first_name='Django', last_name='Reinhardt') + obj = ( + VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ.objects.create() + ) + rel_obj = Person.objects.create(first_name="Django", last_name="Reinhardt") obj.m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz.add(rel_obj) def test_sequence_name_length_limits_flush(self): @@ -157,7 +176,9 @@ class LongNameTest(TransactionTestCase): # Some convenience aliases VLM = VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ - VLM_m2m = VLM.m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz.through + VLM_m2m = ( + VLM.m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz.through + ) tables = [ VLM._meta.db_table, VLM_m2m._meta.db_table, @@ -167,21 +188,22 @@ class LongNameTest(TransactionTestCase): class SequenceResetTest(TestCase): - def test_generic_relation(self): "Sequence names are correct when resetting generic relations (Ref #13941)" # Create an object with a manually specified PK - Post.objects.create(id=10, name='1st post', text='hello world') + Post.objects.create(id=10, name="1st post", text="hello world") # Reset the sequences for the database - commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post]) + commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql( + no_style(), [Post] + ) with connection.cursor() as cursor: for sql in commands: cursor.execute(sql) # If we create a new object now, it should have a PK greater # than the PK we specified manually. - obj = Post.objects.create(name='New post', text='goodbye world') + obj = Post.objects.create(name="New post", text="goodbye world") self.assertGreater(obj.pk, 10) @@ -193,7 +215,7 @@ class ConnectionCreatedSignalTest(TransactionTestCase): # Unfortunately with sqlite3 the in-memory test database cannot be closed, # and so it cannot be re-opened during testing. - @skipUnlessDBFeature('test_db_allows_multiple_connections') + @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_signal(self): data = {} @@ -224,12 +246,12 @@ class EscapingChecks(TestCase): def test_paramless_no_escaping(self): with connection.cursor() as cursor: cursor.execute("SELECT '%s'" + self.bare_select_suffix) - self.assertEqual(cursor.fetchall()[0][0], '%s') + self.assertEqual(cursor.fetchall()[0][0], "%s") def test_parameter_escaping(self): with connection.cursor() as cursor: - cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) - self.assertEqual(cursor.fetchall()[0], ('%', '%d')) + cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ("%d",)) + self.assertEqual(cursor.fetchall()[0], ("%", "%d")) @override_settings(DEBUG=True) @@ -239,20 +261,24 @@ class EscapingChecksDebug(EscapingChecks): class BackendTestCase(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def create_squares_with_executemany(self, args): - self.create_squares(args, 'format', True) + self.create_squares(args, "format", True) def create_squares(self, args, paramstyle, multiple): opts = Square._meta tbl = connection.introspection.identifier_converter(opts.db_table) - f1 = connection.ops.quote_name(opts.get_field('root').column) - f2 = connection.ops.quote_name(opts.get_field('square').column) - if paramstyle == 'format': - query = 'INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (tbl, f1, f2) - elif paramstyle == 'pyformat': - query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2) + f1 = connection.ops.quote_name(opts.get_field("root").column) + f2 = connection.ops.quote_name(opts.get_field("square").column) + if paramstyle == "format": + query = "INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % (tbl, f1, f2) + elif paramstyle == "pyformat": + query = "INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)" % ( + tbl, + f1, + f2, + ) else: raise ValueError("unsupported paramstyle in test") with connection.cursor() as cursor: @@ -263,12 +289,12 @@ class BackendTestCase(TransactionTestCase): def test_cursor_executemany(self): # Test cursor.executemany #4896 - args = [(i, i ** 2) for i in range(-5, 6)] + args = [(i, i**2) for i in range(-5, 6)] self.create_squares_with_executemany(args) self.assertEqual(Square.objects.count(), 11) for i in range(-5, 6): square = Square.objects.get(root=i) - self.assertEqual(square.square, i ** 2) + self.assertEqual(square.square, i**2) def test_cursor_executemany_with_empty_params_list(self): # Test executemany with params=[] does nothing #4765 @@ -278,43 +304,43 @@ class BackendTestCase(TransactionTestCase): def test_cursor_executemany_with_iterator(self): # Test executemany accepts iterators #10320 - args = ((i, i ** 2) for i in range(-3, 2)) + args = ((i, i**2) for i in range(-3, 2)) self.create_squares_with_executemany(args) self.assertEqual(Square.objects.count(), 5) - args = ((i, i ** 2) for i in range(3, 7)) + args = ((i, i**2) for i in range(3, 7)) with override_settings(DEBUG=True): # same test for DebugCursorWrapper self.create_squares_with_executemany(args) self.assertEqual(Square.objects.count(), 9) - @skipUnlessDBFeature('supports_paramstyle_pyformat') + @skipUnlessDBFeature("supports_paramstyle_pyformat") def test_cursor_execute_with_pyformat(self): # Support pyformat style passing of parameters #10070 - args = {'root': 3, 'square': 9} - self.create_squares(args, 'pyformat', multiple=False) + args = {"root": 3, "square": 9} + self.create_squares(args, "pyformat", multiple=False) self.assertEqual(Square.objects.count(), 1) - @skipUnlessDBFeature('supports_paramstyle_pyformat') + @skipUnlessDBFeature("supports_paramstyle_pyformat") def test_cursor_executemany_with_pyformat(self): # Support pyformat style passing of parameters #10070 - args = [{'root': i, 'square': i ** 2} for i in range(-5, 6)] - self.create_squares(args, 'pyformat', multiple=True) + args = [{"root": i, "square": i**2} for i in range(-5, 6)] + self.create_squares(args, "pyformat", multiple=True) self.assertEqual(Square.objects.count(), 11) for i in range(-5, 6): square = Square.objects.get(root=i) - self.assertEqual(square.square, i ** 2) + self.assertEqual(square.square, i**2) - @skipUnlessDBFeature('supports_paramstyle_pyformat') + @skipUnlessDBFeature("supports_paramstyle_pyformat") def test_cursor_executemany_with_pyformat_iterator(self): - args = ({'root': i, 'square': i ** 2} for i in range(-3, 2)) - self.create_squares(args, 'pyformat', multiple=True) + args = ({"root": i, "square": i**2} for i in range(-3, 2)) + self.create_squares(args, "pyformat", multiple=True) self.assertEqual(Square.objects.count(), 5) - args = ({'root': i, 'square': i ** 2} for i in range(3, 7)) + args = ({"root": i, "square": i**2} for i in range(3, 7)) with override_settings(DEBUG=True): # same test for DebugCursorWrapper - self.create_squares(args, 'pyformat', multiple=True) + self.create_squares(args, "pyformat", multiple=True) self.assertEqual(Square.objects.count(), 9) def test_unicode_fetches(self): @@ -326,23 +352,28 @@ class BackendTestCase(TransactionTestCase): Person(first_name="Peter", last_name="Parker").save() Person(first_name="Clark", last_name="Kent").save() opts2 = Person._meta - f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') + f3, f4 = opts2.get_field("first_name"), opts2.get_field("last_name") with connection.cursor() as cursor: cursor.execute( - 'SELECT %s, %s FROM %s ORDER BY %s' % ( + "SELECT %s, %s FROM %s ORDER BY %s" + % ( qn(f3.column), qn(f4.column), connection.introspection.identifier_converter(opts2.db_table), qn(f3.column), ) ) - self.assertEqual(cursor.fetchone(), ('Clark', 'Kent')) - self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')]) - self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')]) + self.assertEqual(cursor.fetchone(), ("Clark", "Kent")) + self.assertEqual( + list(cursor.fetchmany(2)), [("Jane", "Doe"), ("John", "Doe")] + ) + self.assertEqual( + list(cursor.fetchall()), [("Mary", "Agnelline"), ("Peter", "Parker")] + ) def test_unicode_password(self): - old_password = connection.settings_dict['PASSWORD'] - connection.settings_dict['PASSWORD'] = "françois" + old_password = connection.settings_dict["PASSWORD"] + connection.settings_dict["PASSWORD"] = "françois" try: with connection.cursor(): pass @@ -350,14 +381,14 @@ class BackendTestCase(TransactionTestCase): # As password is probably wrong, a database exception is expected pass except Exception as e: - self.fail('Unexpected error raised with Unicode password: %s' % e) + self.fail("Unexpected error raised with Unicode password: %s" % e) finally: - connection.settings_dict['PASSWORD'] = old_password + connection.settings_dict["PASSWORD"] = old_password def test_database_operations_helper_class(self): # Ticket #13630 - self.assertTrue(hasattr(connection, 'ops')) - self.assertTrue(hasattr(connection.ops, 'connection')) + self.assertTrue(hasattr(connection, "ops")) + self.assertTrue(hasattr(connection.ops, "connection")) self.assertEqual(connection, connection.ops.connection) def test_database_operations_init(self): @@ -373,8 +404,8 @@ class BackendTestCase(TransactionTestCase): self.assertIn(connection.features.can_introspect_foreign_keys, (True, False)) def test_duplicate_table_error(self): - """ Creating an existing table returns a DatabaseError """ - query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table + """Creating an existing table returns a DatabaseError""" + query = "CREATE TABLE %s (id INTEGER);" % Article._meta.db_table with connection.cursor() as cursor: with self.assertRaises(DatabaseError): cursor.execute(query) @@ -392,8 +423,10 @@ class BackendTestCase(TransactionTestCase): # cursor should be closed, so no queries should be possible. cursor.execute("SELECT 1" + connection.features.bare_select_suffix) - @unittest.skipUnless(connection.vendor == 'postgresql', - "Psycopg2 specific cursor.closed attribute needed") + @unittest.skipUnless( + connection.vendor == "postgresql", + "Psycopg2 specific cursor.closed attribute needed", + ) def test_cursor_contextmanager_closing(self): # There isn't a generic way to test that cursors are closed, but # psycopg2 offers us a way to check that by closed attribute. @@ -403,7 +436,7 @@ class BackendTestCase(TransactionTestCase): self.assertTrue(cursor.closed) # Unfortunately with sqlite3 the in-memory test database cannot be closed. - @skipUnlessDBFeature('test_db_allows_multiple_connections') + @skipUnlessDBFeature("test_db_allows_multiple_connections") def test_is_usable_after_database_disconnects(self): """ is_usable() doesn't crash when the database disconnects (#21553). @@ -429,34 +462,34 @@ class BackendTestCase(TransactionTestCase): """ Test the documented API of connection.queries. """ - sql = 'SELECT 1' + connection.features.bare_select_suffix + sql = "SELECT 1" + connection.features.bare_select_suffix with connection.cursor() as cursor: reset_queries() cursor.execute(sql) self.assertEqual(1, len(connection.queries)) self.assertIsInstance(connection.queries, list) self.assertIsInstance(connection.queries[0], dict) - self.assertEqual(list(connection.queries[0]), ['sql', 'time']) - self.assertEqual(connection.queries[0]['sql'], sql) + self.assertEqual(list(connection.queries[0]), ["sql", "time"]) + self.assertEqual(connection.queries[0]["sql"], sql) reset_queries() self.assertEqual(0, len(connection.queries)) - sql = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( - connection.introspection.identifier_converter('backends_square'), - connection.ops.quote_name('root'), - connection.ops.quote_name('square'), - )) + sql = "INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % ( + connection.introspection.identifier_converter("backends_square"), + connection.ops.quote_name("root"), + connection.ops.quote_name("square"), + ) with connection.cursor() as cursor: cursor.executemany(sql, [(1, 1), (2, 4)]) self.assertEqual(1, len(connection.queries)) self.assertIsInstance(connection.queries, list) self.assertIsInstance(connection.queries[0], dict) - self.assertEqual(list(connection.queries[0]), ['sql', 'time']) - self.assertEqual(connection.queries[0]['sql'], '2 times: %s' % sql) + self.assertEqual(list(connection.queries[0]), ["sql", "time"]) + self.assertEqual(connection.queries[0]["sql"], "2 times: %s" % sql) # Unfortunately with sqlite3 the in-memory test database cannot be closed. - @skipUnlessDBFeature('test_db_allows_multiple_connections') + @skipUnlessDBFeature("test_db_allows_multiple_connections") @override_settings(DEBUG=True) def test_queries_limit(self): """ @@ -492,22 +525,22 @@ class BackendTestCase(TransactionTestCase): BaseDatabaseWrapper.queries_limit = old_queries_limit new_connection.close() - @mock.patch('django.db.backends.utils.logger') + @mock.patch("django.db.backends.utils.logger") @override_settings(DEBUG=True) def test_queries_logger(self, mocked_logger): - sql = 'SELECT 1' + connection.features.bare_select_suffix + sql = "SELECT 1" + connection.features.bare_select_suffix with connection.cursor() as cursor: cursor.execute(sql) params, kwargs = mocked_logger.debug.call_args - self.assertIn('; alias=%s', params[0]) + self.assertIn("; alias=%s", params[0]) self.assertEqual(params[2], sql) self.assertIsNone(params[3]) self.assertEqual(params[4], connection.alias) self.assertEqual( - list(kwargs['extra']), - ['duration', 'sql', 'params', 'alias'], + list(kwargs["extra"]), + ["duration", "sql", "params", "alias"], ) - self.assertEqual(tuple(kwargs['extra'].values()), params[1:]) + self.assertEqual(tuple(kwargs["extra"].values()), params[1:]) def test_timezone_none_use_tz_false(self): connection.ensure_connection() @@ -519,18 +552,22 @@ class BackendTestCase(TransactionTestCase): # between MySQL+InnoDB and MySQL+MYISAM (something we currently can't do). class FkConstraintsTests(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def setUp(self): # Create a Reporter. - self.r = Reporter.objects.create(first_name='John', last_name='Smith') + self.r = Reporter.objects.create(first_name="John", last_name="Smith") def test_integrity_checks_on_creation(self): """ Try to create a model instance that violates a FK constraint. If it fails it should fail with IntegrityError. """ - a1 = Article(headline="This is a test", pub_date=datetime.datetime(2005, 7, 27), reporter_id=30) + a1 = Article( + headline="This is a test", + pub_date=datetime.datetime(2005, 7, 27), + reporter_id=30, + ) try: a1.save() except IntegrityError: @@ -540,7 +577,8 @@ class FkConstraintsTests(TransactionTestCase): # Now that we know this backend supports integrity checks we make sure # constraints are also enforced for proxy Refs #17519 a2 = Article( - headline='This is another test', reporter=self.r, + headline="This is another test", + reporter=self.r, pub_date=datetime.datetime(2012, 8, 3), reporter_proxy_id=30, ) @@ -553,7 +591,11 @@ class FkConstraintsTests(TransactionTestCase): If it fails it should fail with IntegrityError. """ # Create an Article. - Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r) + Article.objects.create( + headline="Test article", + pub_date=datetime.datetime(2010, 9, 4), + reporter=self.r, + ) # Retrieve it from the DB a1 = Article.objects.get(headline="Test article") a1.reporter_id = 30 @@ -568,12 +610,13 @@ class FkConstraintsTests(TransactionTestCase): # Create another article r_proxy = ReporterProxy.objects.get(pk=self.r.pk) Article.objects.create( - headline='Another article', + headline="Another article", pub_date=datetime.datetime(1988, 5, 15), - reporter=self.r, reporter_proxy=r_proxy, + reporter=self.r, + reporter_proxy=r_proxy, ) # Retrieve the second article from the DB - a2 = Article.objects.get(headline='Another article') + a2 = Article.objects.get(headline="Another article") a2.reporter_proxy_id = 30 with self.assertRaises(IntegrityError): a2.save() @@ -651,13 +694,13 @@ class FkConstraintsTests(TransactionTestCase): with connection.constraint_checks_disabled(): obj.save() with self.assertRaises(IntegrityError): - connection.check_constraints(table_names=['order']) + connection.check_constraints(table_names=["order"]) transaction.set_rollback(True) class ThreadTests(TransactionTestCase): - available_apps = ['backends'] + available_apps = ["backends"] def test_default_connection_thread_local(self): """ @@ -675,6 +718,7 @@ class ThreadTests(TransactionTestCase): # Passing django.db.connection between threads doesn't work while # connections[DEFAULT_DB_ALIAS] does. from django.db import connections + connection = connections[DEFAULT_DB_ALIAS] # Allow thread sharing so the connection can be closed by the # main thread. @@ -682,13 +726,16 @@ class ThreadTests(TransactionTestCase): with connection.cursor(): pass connections_dict[id(connection)] = connection + try: for x in range(2): t = threading.Thread(target=runner) t.start() t.join() # Each created connection got different inner connection. - self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) + self.assertEqual( + len({conn.connection for conn in connections_dict.values()}), 3 + ) finally: # Finish by closing the connections opened by the other threads # (the connection opened in the main thread will automatically be @@ -710,11 +757,13 @@ class ThreadTests(TransactionTestCase): def runner(): from django.db import connections + for conn in connections.all(): # Allow thread sharing so the connection can be closed by the # main thread. conn.inc_thread_sharing() connections_dict[id(conn)] = conn + try: num_new_threads = 2 for x in range(num_new_threads): @@ -743,12 +792,14 @@ class ThreadTests(TransactionTestCase): def do_thread(): def runner(main_thread_connection): from django.db import connections - connections['default'] = main_thread_connection + + connections["default"] = main_thread_connection try: Person.objects.get(first_name="John", last_name="Doe") except Exception as e: exceptions.append(e) - t = threading.Thread(target=runner, args=[connections['default']]) + + t = threading.Thread(target=runner, args=[connections["default"]]) t.start() t.join() @@ -757,17 +808,17 @@ class ThreadTests(TransactionTestCase): do_thread() # Forbidden! self.assertIsInstance(exceptions[0], DatabaseError) - connections['default'].close() + connections["default"].close() # After calling inc_thread_sharing() on the connection. - connections['default'].inc_thread_sharing() + connections["default"].inc_thread_sharing() try: exceptions = [] do_thread() # All good self.assertEqual(exceptions, []) finally: - connections['default'].dec_thread_sharing() + connections["default"].dec_thread_sharing() def test_closing_non_shared_connections(self): """ @@ -783,9 +834,11 @@ class ThreadTests(TransactionTestCase): other_thread_connection.close() except DatabaseError as e: exceptions.add(e) - t2 = threading.Thread(target=runner2, args=[connections['default']]) + + t2 = threading.Thread(target=runner2, args=[connections["default"]]) t2.start() t2.join() + t1 = threading.Thread(target=runner1) t1.start() t1.join() @@ -801,14 +854,16 @@ class ThreadTests(TransactionTestCase): other_thread_connection.close() except DatabaseError as e: exceptions.add(e) + # Enable thread sharing - connections['default'].inc_thread_sharing() + connections["default"].inc_thread_sharing() try: - t2 = threading.Thread(target=runner2, args=[connections['default']]) + t2 = threading.Thread(target=runner2, args=[connections["default"]]) t2.start() t2.join() finally: - connections['default'].dec_thread_sharing() + connections["default"].dec_thread_sharing() + t1 = threading.Thread(target=runner1) t1.start() t1.join() @@ -825,7 +880,7 @@ class ThreadTests(TransactionTestCase): self.assertIs(connection.allow_thread_sharing, True) connection.dec_thread_sharing() self.assertIs(connection.allow_thread_sharing, False) - msg = 'Cannot decrement the thread sharing count below zero.' + msg = "Cannot decrement the thread sharing count below zero." with self.assertRaisesMessage(RuntimeError, msg): connection.dec_thread_sharing() @@ -836,14 +891,14 @@ class MySQLPKZeroTests(TestCase): does not allow zero for autoincrement primary key if the NO_AUTO_VALUE_ON_ZERO SQL mode is not enabled. """ - @skipIfDBFeature('allows_auto_pk_0') + + @skipIfDBFeature("allows_auto_pk_0") def test_zero_as_autoval(self): with self.assertRaises(ValueError): Square.objects.create(id=0, root=0, square=1) class DBConstraintTestCase(TestCase): - def test_can_reference_existent(self): obj = Object.objects.create() ref = ObjectReference.objects.create(obj=obj) @@ -867,7 +922,9 @@ class DBConstraintTestCase(TestCase): self.assertEqual(Object.objects.count(), 2) self.assertEqual(obj.related_objects.count(), 1) - intermediary_model = Object._meta.get_field("related_objects").remote_field.through + intermediary_model = Object._meta.get_field( + "related_objects" + ).remote_field.through intermediary_model.objects.create(from_object_id=obj.id, to_object_id=12345) self.assertEqual(obj.related_objects.count(), 1) self.assertEqual(intermediary_model.objects.count(), 2) |