summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/contrib/gis/db/backends/mysql/introspection.py6
-rw-r--r--django/contrib/gis/db/backends/oracle/introspection.py6
-rw-r--r--django/contrib/gis/db/backends/postgis/introspection.py12
-rw-r--r--django/contrib/gis/db/backends/spatialite/introspection.py6
-rw-r--r--django/db/backends/base/base.py5
-rw-r--r--django/db/backends/base/introspection.py29
-rw-r--r--django/db/backends/mysql/base.py57
-rw-r--r--django/db/backends/oracle/creation.py147
-rw-r--r--django/db/backends/sqlite3/base.py58
-rw-r--r--django/db/migrations/executor.py3
-rw-r--r--django/test/testcases.py6
-rw-r--r--docs/topics/db/multi-db.txt3
-rw-r--r--docs/topics/db/sql.txt4
-rw-r--r--tests/backends/postgresql/test_introspection.py24
-rw-r--r--tests/backends/postgresql/tests.py14
-rw-r--r--tests/backends/sqlite/tests.py16
-rw-r--r--tests/backends/tests.py87
-rw-r--r--tests/db_utils/tests.py8
18 files changed, 234 insertions, 257 deletions
diff --git a/django/contrib/gis/db/backends/mysql/introspection.py b/django/contrib/gis/db/backends/mysql/introspection.py
index 364cdeecfe..ba125ec4af 100644
--- a/django/contrib/gis/db/backends/mysql/introspection.py
+++ b/django/contrib/gis/db/backends/mysql/introspection.py
@@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection):
data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col):
- cursor = self.connection.cursor()
- try:
+ with self.connection.cursor() as cursor:
# In order to get the specific geometry type of the field,
# we introspect on the table definition using `DESCRIBE`.
cursor.execute('DESCRIBE %s' %
@@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection):
field_type = OGRGeomType(typ).django
field_params = {}
break
- finally:
- cursor.close()
-
return field_type, field_params
def supports_spatial_index(self, cursor, table_name):
diff --git a/django/contrib/gis/db/backends/oracle/introspection.py b/django/contrib/gis/db/backends/oracle/introspection.py
index 446dc78216..7f4f886a34 100644
--- a/django/contrib/gis/db/backends/oracle/introspection.py
+++ b/django/contrib/gis/db/backends/oracle/introspection.py
@@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection):
data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col):
- cursor = self.connection.cursor()
- try:
+ with self.connection.cursor() as cursor:
# Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information.
try:
cursor.execute(
@@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection):
dim = dim.size()
if dim != 2:
field_params['dim'] = dim
- finally:
- cursor.close()
-
return field_type, field_params
diff --git a/django/contrib/gis/db/backends/postgis/introspection.py b/django/contrib/gis/db/backends/postgis/introspection.py
index 3a90ebf5c5..97fa7480e6 100644
--- a/django/contrib/gis/db/backends/postgis/introspection.py
+++ b/django/contrib/gis/db/backends/postgis/introspection.py
@@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection):
# to query the PostgreSQL pg_type table corresponding to the
# PostGIS custom data types.
oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s'
- cursor = self.connection.cursor()
- try:
+ with self.connection.cursor() as cursor:
for field_type in field_types:
cursor.execute(oid_sql, (field_type[0],))
for result in cursor.fetchall():
postgis_types[result[0]] = field_type[1]
- finally:
- cursor.close()
-
return postgis_types
def get_field_type(self, data_type, description):
@@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection):
PointField or a PolygonField). Thus, this routine queries the PostGIS
metadata tables to determine the geometry type.
"""
- cursor = self.connection.cursor()
- try:
+ with self.connection.cursor() as cursor:
try:
# First seeing if this geometry column is in the `geometry_columns`
cursor.execute('SELECT "coord_dimension", "srid", "type" '
@@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection):
field_params['srid'] = srid
if dim != 2:
field_params['dim'] = dim
- finally:
- cursor.close()
-
return field_type, field_params
diff --git a/django/contrib/gis/db/backends/spatialite/introspection.py b/django/contrib/gis/db/backends/spatialite/introspection.py
index 98fda427d0..5cd5613b69 100644
--- a/django/contrib/gis/db/backends/spatialite/introspection.py
+++ b/django/contrib/gis/db/backends/spatialite/introspection.py
@@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
data_types_reverse = GeoFlexibleFieldLookupDict()
def get_geometry_type(self, table_name, geo_col):
- cursor = self.connection.cursor()
- try:
+ with self.connection.cursor() as cursor:
# Querying the `geometry_columns` table to get additional metadata.
cursor.execute('SELECT coord_dimension, srid, geometry_type '
'FROM geometry_columns '
@@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
field_params['srid'] = srid
if (isinstance(dim, str) and 'Z' in dim) or dim == 3:
field_params['dim'] = 3
- finally:
- cursor.close()
-
return field_type, field_params
def get_constraints(self, cursor, table_name):
diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py
index 468eb16e14..53c3c30063 100644
--- a/django/db/backends/base/base.py
+++ b/django/db/backends/base/base.py
@@ -573,11 +573,10 @@ class BaseDatabaseWrapper:
Provide a cursor: with self.temporary_connection() as cursor: ...
"""
must_close = self.connection is None
- cursor = self.cursor()
try:
- yield cursor
+ with self.cursor() as cursor:
+ yield cursor
finally:
- cursor.close()
if must_close:
self.close()
diff --git a/django/db/backends/base/introspection.py b/django/db/backends/base/introspection.py
index 154ae22bf1..8fe8966a21 100644
--- a/django/db/backends/base/introspection.py
+++ b/django/db/backends/base/introspection.py
@@ -116,21 +116,20 @@ class BaseDatabaseIntrospection:
from django.db import router
sequence_list = []
- cursor = self.connection.cursor()
-
- for app_config in apps.get_app_configs():
- for model in router.get_migratable_models(app_config, self.connection.alias):
- if not model._meta.managed:
- continue
- if model._meta.swapped:
- continue
- sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
- for f in model._meta.local_many_to_many:
- # If this is an m2m using an intermediate table,
- # we don't need to reset the sequence.
- if f.remote_field.through is None:
- sequence = self.get_sequences(cursor, f.m2m_db_table())
- sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
+ with self.connection.cursor() as cursor:
+ for app_config in apps.get_app_configs():
+ for model in router.get_migratable_models(app_config, self.connection.alias):
+ if not model._meta.managed:
+ continue
+ if model._meta.swapped:
+ continue
+ sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
+ for f in model._meta.local_many_to_many:
+ # If this is an m2m using an intermediate table,
+ # we don't need to reset the sequence.
+ if f.remote_field.through is None:
+ sequence = self.get_sequences(cursor, f.m2m_db_table())
+ sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()):
diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
index abf8f55736..6b82afc45d 100644
--- a/django/db/backends/mysql/base.py
+++ b/django/db/backends/mysql/base.py
@@ -294,36 +294,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
"""
- cursor = self.cursor()
- if table_names is None:
- table_names = self.introspection.table_names(cursor)
- for table_name in table_names:
- primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
- if not primary_key_column_name:
- continue
- key_columns = self.introspection.get_key_columns(cursor, table_name)
- for column_name, referenced_table_name, referenced_column_name in key_columns:
- cursor.execute(
- """
- SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
- LEFT JOIN `%s` as REFERRED
- ON (REFERRING.`%s` = REFERRED.`%s`)
- WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
- """ % (
- primary_key_column_name, column_name, table_name,
- referenced_table_name, column_name, referenced_column_name,
- column_name, referenced_column_name,
- )
- )
- for bad_row in cursor.fetchall():
- raise utils.IntegrityError(
- "The row in table '%s' with primary key '%s' has an invalid "
- "foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
- % (
- table_name, bad_row[0], table_name, column_name,
- bad_row[1], referenced_table_name, referenced_column_name,
+ with self.cursor() as cursor:
+ if table_names is None:
+ table_names = self.introspection.table_names(cursor)
+ for table_name in table_names:
+ primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
+ if not primary_key_column_name:
+ continue
+ key_columns = self.introspection.get_key_columns(cursor, table_name)
+ for column_name, referenced_table_name, referenced_column_name in key_columns:
+ cursor.execute(
+ """
+ SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
+ LEFT JOIN `%s` as REFERRED
+ ON (REFERRING.`%s` = REFERRED.`%s`)
+ WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
+ """ % (
+ primary_key_column_name, column_name, table_name,
+ referenced_table_name, column_name, referenced_column_name,
+ column_name, referenced_column_name,
)
)
+ for bad_row in cursor.fetchall():
+ raise utils.IntegrityError(
+ "The row in table '%s' with primary key '%s' has an invalid "
+ "foreign key: %s.%s contains a value '%s' that does not "
+ "have a corresponding value in %s.%s."
+ % (
+ table_name, bad_row[0], table_name, column_name,
+ bad_row[1], referenced_table_name, referenced_column_name,
+ )
+ )
def is_usable(self):
try:
diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py
index fe053c54f7..aa57e6a6ab 100644
--- a/django/db/backends/oracle/creation.py
+++ b/django/db/backends/oracle/creation.py
@@ -30,75 +30,72 @@ class DatabaseCreation(BaseDatabaseCreation):
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
parameters = self._get_test_db_params()
- cursor = self._maindb_connection.cursor()
- if self._test_database_create():
- try:
- self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
- except Exception as e:
- if 'ORA-01543' not in str(e):
- # All errors except "tablespace already exists" cancel tests
- sys.stderr.write("Got an error creating the test database: %s\n" % e)
- sys.exit(2)
- if not autoclobber:
- confirm = input(
- "It appears the test database, %s, already exists. "
- "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
- if autoclobber or confirm == 'yes':
- if verbosity >= 1:
- print("Destroying old test database for alias '%s'..." % self.connection.alias)
- try:
- self._execute_test_db_destruction(cursor, parameters, verbosity)
- except DatabaseError as e:
- if 'ORA-29857' in str(e):
- self._handle_objects_preventing_db_destruction(cursor, parameters,
- verbosity, autoclobber)
- else:
- # Ran into a database error that isn't about leftover objects in the tablespace
+ with self._maindb_connection.cursor() as cursor:
+ if self._test_database_create():
+ try:
+ self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
+ except Exception as e:
+ if 'ORA-01543' not in str(e):
+ # All errors except "tablespace already exists" cancel tests
+ sys.stderr.write("Got an error creating the test database: %s\n" % e)
+ sys.exit(2)
+ if not autoclobber:
+ confirm = input(
+ "It appears the test database, %s, already exists. "
+ "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
+ if autoclobber or confirm == 'yes':
+ if verbosity >= 1:
+ print("Destroying old test database for alias '%s'..." % self.connection.alias)
+ try:
+ self._execute_test_db_destruction(cursor, parameters, verbosity)
+ except DatabaseError as e:
+ if 'ORA-29857' in str(e):
+ self._handle_objects_preventing_db_destruction(cursor, parameters,
+ verbosity, autoclobber)
+ else:
+ # Ran into a database error that isn't about leftover objects in the tablespace
+ sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
+ sys.exit(2)
+ except Exception as e:
sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
sys.exit(2)
- except Exception as e:
- sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
- sys.exit(2)
- try:
- self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
- except Exception as e:
- sys.stderr.write("Got an error recreating the test database: %s\n" % e)
- sys.exit(2)
- else:
- print("Tests cancelled.")
- sys.exit(1)
+ try:
+ self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
+ except Exception as e:
+ sys.stderr.write("Got an error recreating the test database: %s\n" % e)
+ sys.exit(2)
+ else:
+ print("Tests cancelled.")
+ sys.exit(1)
- if self._test_user_create():
- if verbosity >= 1:
- print("Creating test user...")
- try:
- self._create_test_user(cursor, parameters, verbosity, keepdb)
- except Exception as e:
- if 'ORA-01920' not in str(e):
- # All errors except "user already exists" cancel tests
- sys.stderr.write("Got an error creating the test user: %s\n" % e)
- sys.exit(2)
- if not autoclobber:
- confirm = input(
- "It appears the test user, %s, already exists. Type "
- "'yes' to delete it, or 'no' to cancel: " % parameters['user'])
- if autoclobber or confirm == 'yes':
- try:
- if verbosity >= 1:
- print("Destroying old test user...")
- self._destroy_test_user(cursor, parameters, verbosity)
- if verbosity >= 1:
- print("Creating test user...")
- self._create_test_user(cursor, parameters, verbosity, keepdb)
- except Exception as e:
- sys.stderr.write("Got an error recreating the test user: %s\n" % e)
+ if self._test_user_create():
+ if verbosity >= 1:
+ print("Creating test user...")
+ try:
+ self._create_test_user(cursor, parameters, verbosity, keepdb)
+ except Exception as e:
+ if 'ORA-01920' not in str(e):
+ # All errors except "user already exists" cancel tests
+ sys.stderr.write("Got an error creating the test user: %s\n" % e)
sys.exit(2)
- else:
- print("Tests cancelled.")
- sys.exit(1)
-
- # Cursor must be closed before closing connection.
- cursor.close()
+ if not autoclobber:
+ confirm = input(
+ "It appears the test user, %s, already exists. Type "
+ "'yes' to delete it, or 'no' to cancel: " % parameters['user'])
+ if autoclobber or confirm == 'yes':
+ try:
+ if verbosity >= 1:
+ print("Destroying old test user...")
+ self._destroy_test_user(cursor, parameters, verbosity)
+ if verbosity >= 1:
+ print("Creating test user...")
+ self._create_test_user(cursor, parameters, verbosity, keepdb)
+ except Exception as e:
+ sys.stderr.write("Got an error recreating the test user: %s\n" % e)
+ sys.exit(2)
+ else:
+ print("Tests cancelled.")
+ sys.exit(1)
self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters)
return self.connection.settings_dict['NAME']
@@ -175,17 +172,15 @@ class DatabaseCreation(BaseDatabaseCreation):
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close()
parameters = self._get_test_db_params()
- cursor = self._maindb_connection.cursor()
- if self._test_user_create():
- if verbosity >= 1:
- print('Destroying test user...')
- self._destroy_test_user(cursor, parameters, verbosity)
- if self._test_database_create():
- if verbosity >= 1:
- print('Destroying test database tables...')
- self._execute_test_db_destruction(cursor, parameters, verbosity)
- # Cursor must be closed before closing connection.
- cursor.close()
+ with self._maindb_connection.cursor() as cursor:
+ if self._test_user_create():
+ if verbosity >= 1:
+ print('Destroying test user...')
+ self._destroy_test_user(cursor, parameters, verbosity)
+ if self._test_database_create():
+ if verbosity >= 1:
+ print('Destroying test database tables...')
+ self._execute_test_db_destruction(cursor, parameters, verbosity)
self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index 40b205cf30..13f9a575e2 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -237,37 +237,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
"""
- cursor = self.cursor()
- if table_names is None:
- table_names = self.introspection.table_names(cursor)
- for table_name in table_names:
- primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
- if not primary_key_column_name:
- continue
- key_columns = self.introspection.get_key_columns(cursor, table_name)
- for column_name, referenced_table_name, referenced_column_name in key_columns:
- cursor.execute(
- """
- SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
- LEFT JOIN `%s` as REFERRED
- ON (REFERRING.`%s` = REFERRED.`%s`)
- WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
- """
- % (
- primary_key_column_name, column_name, table_name,
- referenced_table_name, column_name, referenced_column_name,
- column_name, referenced_column_name,
- )
- )
- for bad_row in cursor.fetchall():
- raise utils.IntegrityError(
- "The row in table '%s' with primary key '%s' has an "
- "invalid foreign key: %s.%s contains a value '%s' that "
- "does not have a corresponding value in %s.%s." % (
- table_name, bad_row[0], table_name, column_name,
- bad_row[1], referenced_table_name, referenced_column_name,
+ with self.cursor() as cursor:
+ if table_names is None:
+ table_names = self.introspection.table_names(cursor)
+ for table_name in table_names:
+ primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
+ if not primary_key_column_name:
+ continue
+ key_columns = self.introspection.get_key_columns(cursor, table_name)
+ for column_name, referenced_table_name, referenced_column_name in key_columns:
+ cursor.execute(
+ """
+ SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
+ LEFT JOIN `%s` as REFERRED
+ ON (REFERRING.`%s` = REFERRED.`%s`)
+ WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
+ """
+ % (
+ primary_key_column_name, column_name, table_name,
+ referenced_table_name, column_name, referenced_column_name,
+ column_name, referenced_column_name,
)
)
+ for bad_row in cursor.fetchall():
+ raise utils.IntegrityError(
+ "The row in table '%s' with primary key '%s' has an "
+ "invalid foreign key: %s.%s contains a value '%s' that "
+ "does not have a corresponding value in %s.%s." % (
+ table_name, bad_row[0], table_name, column_name,
+ bad_row[1], referenced_table_name, referenced_column_name,
+ )
+ )
def is_usable(self):
return True
diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py
index ea7bc70db3..ebaf75634b 100644
--- a/django/db/migrations/executor.py
+++ b/django/db/migrations/executor.py
@@ -322,7 +322,8 @@ class MigrationExecutor:
apps = after_state.apps
found_create_model_migration = False
found_add_field_migration = False
- existing_table_names = self.connection.introspection.table_names(self.connection.cursor())
+ with self.connection.cursor() as cursor:
+ existing_table_names = self.connection.introspection.table_names(cursor)
# Make sure all create model and add field operations are done
for operation in migration.operations:
if isinstance(operation, migrations.CreateModel):
diff --git a/django/test/testcases.py b/django/test/testcases.py
index afab58f2dc..4c32c87da5 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -852,9 +852,9 @@ class TransactionTestCase(SimpleTestCase):
no_style(), conn.introspection.sequence_list())
if sql_list:
with transaction.atomic(using=db_name):
- cursor = conn.cursor()
- for sql in sql_list:
- cursor.execute(sql)
+ with conn.cursor() as cursor:
+ for sql in sql_list:
+ cursor.execute(sql)
def _fixture_setup(self):
for db_name in self._databases_names(include_mirrors=False):
diff --git a/docs/topics/db/multi-db.txt b/docs/topics/db/multi-db.txt
index 78f3fe23d9..394d25cfd9 100644
--- a/docs/topics/db/multi-db.txt
+++ b/docs/topics/db/multi-db.txt
@@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its
alias::
from django.db import connections
- cursor = connections['my_db_alias'].cursor()
+ with connections['my_db_alias'].cursor() as cursor:
+ ...
Limitations of multiple databases
=================================
diff --git a/docs/topics/db/sql.txt b/docs/topics/db/sql.txt
index dc2bcc50b6..96d79a999c 100644
--- a/docs/topics/db/sql.txt
+++ b/docs/topics/db/sql.txt
@@ -279,8 +279,8 @@ object that allows you to retrieve a specific connection using its
alias::
from django.db import connections
- cursor = connections['my_db_alias'].cursor()
- # Your code here...
+ with connections['my_db_alias'].cursor() as cursor:
+ # Your code here...
By default, the Python DB API will return results without their field names,
which means you end up with a ``list`` of values, rather than a ``dict``. At a
diff --git a/tests/backends/postgresql/test_introspection.py b/tests/backends/postgresql/test_introspection.py
index cfa801a77f..4dcadbd733 100644
--- a/tests/backends/postgresql/test_introspection.py
+++ b/tests/backends/postgresql/test_introspection.py
@@ -9,15 +9,15 @@ from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class DatabaseSequenceTests(TestCase):
def test_get_sequences(self):
- cursor = connection.cursor()
- seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
- self.assertEqual(
- seqs,
- [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
- )
- cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
- seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
- self.assertEqual(
- seqs,
- [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
- )
+ with connection.cursor() as cursor:
+ seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
+ self.assertEqual(
+ seqs,
+ [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
+ )
+ cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
+ seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
+ self.assertEqual(
+ seqs,
+ [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
+ )
diff --git a/tests/backends/postgresql/tests.py b/tests/backends/postgresql/tests.py
index 140fbbc444..caa8b87fb9 100644
--- a/tests/backends/postgresql/tests.py
+++ b/tests/backends/postgresql/tests.py
@@ -44,10 +44,10 @@ class Tests(TestCase):
# Ensure the database default time zone is different than
# the time zone in new_connection.settings_dict. We can
# get the default time zone by reset & show.
- cursor = new_connection.cursor()
- cursor.execute("RESET TIMEZONE")
- cursor.execute("SHOW TIMEZONE")
- db_default_tz = cursor.fetchone()[0]
+ with new_connection.cursor() as cursor:
+ cursor.execute("RESET TIMEZONE")
+ cursor.execute("SHOW TIMEZONE")
+ db_default_tz = cursor.fetchone()[0]
new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
new_connection.close()
@@ -59,12 +59,12 @@ class Tests(TestCase):
# time zone, run a query and rollback.
with self.settings(TIME_ZONE=new_tz):
new_connection.set_autocommit(False)
- cursor = new_connection.cursor()
new_connection.rollback()
# Now let's see if the rollback rolled back the SET TIME ZONE.
- cursor.execute("SHOW TIMEZONE")
- tz = cursor.fetchone()[0]
+ with new_connection.cursor() as cursor:
+ cursor.execute("SHOW TIMEZONE")
+ tz = cursor.fetchone()[0]
self.assertEqual(new_tz, tz)
finally:
diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py
index 0c07f95e6f..0fbb139278 100644
--- a/tests/backends/sqlite/tests.py
+++ b/tests/backends/sqlite/tests.py
@@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase):
# If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
# greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
# can hit the SQLITE_MAX_COLUMN limit (#26063).
- cursor = connection.cursor()
- sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
- params = list(range(2001))
- # This should not raise an exception.
- cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
+ with connection.cursor() as cursor:
+ sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
+ params = list(range(2001))
+ # This should not raise an exception.
+ cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
@@ -97,9 +97,9 @@ class EscapingChecks(TestCase):
"""
def test_parameter_escaping(self):
# '%s' escaping support for sqlite3 (#13648).
- cursor = connection.cursor()
- cursor.execute("select strftime('%s', date('now'))")
- response = cursor.fetchall()[0][0]
+ with connection.cursor() as cursor:
+ cursor.execute("select strftime('%s', date('now'))")
+ response = cursor.fetchall()[0][0]
# response should be an non-zero integer
self.assertTrue(int(response))
diff --git a/tests/backends/tests.py b/tests/backends/tests.py
index 6d38625a98..32d76577a1 100644
--- a/tests/backends/tests.py
+++ b/tests/backends/tests.py
@@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase):
last_executed_query should not raise an exception even if no previous
query has been run.
"""
- cursor = connection.cursor()
- connection.ops.last_executed_query(cursor, '', ())
+ with connection.cursor() as cursor:
+ connection.ops.last_executed_query(cursor, '', ())
def test_debug_sql(self):
list(Reporter.objects.filter(first_name="test"))
@@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
- cursor = connection.cursor()
- query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
- connection.introspection.table_name_converter('backends_square'),
- connection.ops.quote_name('root'),
- connection.ops.quote_name('square')
- ))
- with self.assertRaises(Exception):
- cursor.executemany(query, [(1, 2, 3)])
- with self.assertRaises(Exception):
- cursor.executemany(query, [(1,)])
+ with connection.cursor() as cursor:
+ query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
+ connection.introspection.table_name_converter('backends_square'),
+ connection.ops.quote_name('root'),
+ connection.ops.quote_name('square')
+ ))
+ with self.assertRaises(Exception):
+ cursor.executemany(query, [(1, 2, 3)])
+ with self.assertRaises(Exception):
+ cursor.executemany(query, [(1,)])
class LongNameTest(TransactionTestCase):
@@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase):
'table': VLM._meta.db_table
},
]
- cursor = connection.cursor()
- for statement in connection.ops.sql_flush(no_style(), tables, sequences):
- cursor.execute(statement)
+ sql_list = connection.ops.sql_flush(no_style(), tables, sequences)
+ with connection.cursor() as cursor:
+ for statement in sql_list:
+ cursor.execute(statement)
class SequenceResetTest(TestCase):
@@ -146,10 +147,10 @@ class SequenceResetTest(TestCase):
Post.objects.create(id=10, name='1st post', text='hello world')
# Reset the sequences for the database
- cursor = connection.cursor()
commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post])
- for sql in commands:
- cursor.execute(sql)
+ with connection.cursor() as cursor:
+ for sql in commands:
+ cursor.execute(sql)
# If we create a new object now, it should have a PK greater
# than the PK we specified manually.
@@ -192,14 +193,14 @@ class EscapingChecks(TestCase):
bare_select_suffix = connection.features.bare_select_suffix
def test_paramless_no_escaping(self):
- cursor = connection.cursor()
- cursor.execute("SELECT '%s'" + self.bare_select_suffix)
- self.assertEqual(cursor.fetchall()[0][0], '%s')
+ with connection.cursor() as cursor:
+ cursor.execute("SELECT '%s'" + self.bare_select_suffix)
+ self.assertEqual(cursor.fetchall()[0][0], '%s')
def test_parameter_escaping(self):
- cursor = connection.cursor()
- cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
- self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
+ with connection.cursor() as cursor:
+ cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
+ self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
@override_settings(DEBUG=True)
@@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase):
self.create_squares(args, 'format', True)
def create_squares(self, args, paramstyle, multiple):
- cursor = connection.cursor()
opts = Square._meta
tbl = connection.introspection.table_name_converter(opts.db_table)
f1 = connection.ops.quote_name(opts.get_field('root').column)
@@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase):
query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
else:
raise ValueError("unsupported paramstyle in test")
- if multiple:
- cursor.executemany(query, args)
- else:
- cursor.execute(query, args)
+ with connection.cursor() as cursor:
+ if multiple:
+ cursor.executemany(query, args)
+ else:
+ cursor.execute(query, args)
def test_cursor_executemany(self):
# Test cursor.executemany #4896
@@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase):
Person(first_name="Clark", last_name="Kent").save()
opts2 = Person._meta
f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
- cursor = connection.cursor()
- cursor.execute(
- 'SELECT %s, %s FROM %s ORDER BY %s' % (
- qn(f3.column),
- qn(f4.column),
- connection.introspection.table_name_converter(opts2.db_table),
- qn(f3.column),
+ with connection.cursor() as cursor:
+ cursor.execute(
+ 'SELECT %s, %s FROM %s ORDER BY %s' % (
+ qn(f3.column),
+ qn(f4.column),
+ connection.introspection.table_name_converter(opts2.db_table),
+ qn(f3.column),
+ )
)
- )
- self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
- self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
- self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
+ self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
+ self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
+ self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
def test_unicode_password(self):
old_password = connection.settings_dict['PASSWORD']
@@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase):
def test_duplicate_table_error(self):
""" Creating an existing table returns a DatabaseError """
- cursor = connection.cursor()
query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table
- with self.assertRaises(DatabaseError):
- cursor.execute(query)
+ with connection.cursor() as cursor:
+ with self.assertRaises(DatabaseError):
+ cursor.execute(query)
def test_cursor_contextmanager(self):
"""
diff --git a/tests/db_utils/tests.py b/tests/db_utils/tests.py
index 2a45342df5..4e35e6bb8b 100644
--- a/tests/db_utils/tests.py
+++ b/tests/db_utils/tests.py
@@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase):
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test')
def test_reraising_backend_specific_database_exception(self):
- cursor = connection.cursor()
- msg = 'table "X" does not exist'
- with self.assertRaisesMessage(ProgrammingError, msg) as cm:
- cursor.execute('DROP TABLE "X"')
+ with connection.cursor() as cursor:
+ msg = 'table "X" does not exist'
+ with self.assertRaisesMessage(ProgrammingError, msg) as cm:
+ cursor.execute('DROP TABLE "X"')
self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__))
self.assertIsNotNone(cm.exception.__cause__)
self.assertIsNotNone(cm.exception.__cause__.pgcode)