summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorJon Dufresne <jon.dufresne@gmail.com>2017-11-28 05:12:28 -0800
committerTim Graham <timograham@gmail.com>2017-11-28 11:28:09 -0500
commit7a6fbf36b1fdb8978ea0842075ccce83bcd63789 (patch)
treea48e72be7e31066d92de43a85458b0bfb7e2cfcd /tests
parent3308085838f520db49f606b72345a301c1cf2a3e (diff)
downloaddjango-7a6fbf36b1fdb8978ea0842075ccce83bcd63789.tar.gz
Fixed #28853 -- Updated connection.cursor() uses to use a context manager.
Diffstat (limited to 'tests')
-rw-r--r--tests/backends/postgresql/test_introspection.py24
-rw-r--r--tests/backends/postgresql/tests.py14
-rw-r--r--tests/backends/sqlite/tests.py16
-rw-r--r--tests/backends/tests.py87
-rw-r--r--tests/db_utils/tests.py8
5 files changed, 75 insertions, 74 deletions
diff --git a/tests/backends/postgresql/test_introspection.py b/tests/backends/postgresql/test_introspection.py
index cfa801a77f..4dcadbd733 100644
--- a/tests/backends/postgresql/test_introspection.py
+++ b/tests/backends/postgresql/test_introspection.py
@@ -9,15 +9,15 @@ from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class DatabaseSequenceTests(TestCase):
def test_get_sequences(self):
- cursor = connection.cursor()
- seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
- self.assertEqual(
- seqs,
- [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
- )
- cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
- seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
- self.assertEqual(
- seqs,
- [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
- )
+ with connection.cursor() as cursor:
+ seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
+ self.assertEqual(
+ seqs,
+ [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
+ )
+ cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
+ seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
+ self.assertEqual(
+ seqs,
+ [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
+ )
diff --git a/tests/backends/postgresql/tests.py b/tests/backends/postgresql/tests.py
index 140fbbc444..caa8b87fb9 100644
--- a/tests/backends/postgresql/tests.py
+++ b/tests/backends/postgresql/tests.py
@@ -44,10 +44,10 @@ class Tests(TestCase):
# Ensure the database default time zone is different than
# the time zone in new_connection.settings_dict. We can
# get the default time zone by reset & show.
- cursor = new_connection.cursor()
- cursor.execute("RESET TIMEZONE")
- cursor.execute("SHOW TIMEZONE")
- db_default_tz = cursor.fetchone()[0]
+ with new_connection.cursor() as cursor:
+ cursor.execute("RESET TIMEZONE")
+ cursor.execute("SHOW TIMEZONE")
+ db_default_tz = cursor.fetchone()[0]
new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
new_connection.close()
@@ -59,12 +59,12 @@ class Tests(TestCase):
# time zone, run a query and rollback.
with self.settings(TIME_ZONE=new_tz):
new_connection.set_autocommit(False)
- cursor = new_connection.cursor()
new_connection.rollback()
# Now let's see if the rollback rolled back the SET TIME ZONE.
- cursor.execute("SHOW TIMEZONE")
- tz = cursor.fetchone()[0]
+ with new_connection.cursor() as cursor:
+ cursor.execute("SHOW TIMEZONE")
+ tz = cursor.fetchone()[0]
self.assertEqual(new_tz, tz)
finally:
diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py
index 0c07f95e6f..0fbb139278 100644
--- a/tests/backends/sqlite/tests.py
+++ b/tests/backends/sqlite/tests.py
@@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase):
# If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
# greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
# can hit the SQLITE_MAX_COLUMN limit (#26063).
- cursor = connection.cursor()
- sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
- params = list(range(2001))
- # This should not raise an exception.
- cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
+ with connection.cursor() as cursor:
+ sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
+ params = list(range(2001))
+ # This should not raise an exception.
+ cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
@@ -97,9 +97,9 @@ class EscapingChecks(TestCase):
"""
def test_parameter_escaping(self):
# '%s' escaping support for sqlite3 (#13648).
- cursor = connection.cursor()
- cursor.execute("select strftime('%s', date('now'))")
- response = cursor.fetchall()[0][0]
+ with connection.cursor() as cursor:
+ cursor.execute("select strftime('%s', date('now'))")
+ response = cursor.fetchall()[0][0]
# response should be an non-zero integer
self.assertTrue(int(response))
diff --git a/tests/backends/tests.py b/tests/backends/tests.py
index 6d38625a98..32d76577a1 100644
--- a/tests/backends/tests.py
+++ b/tests/backends/tests.py
@@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase):
last_executed_query should not raise an exception even if no previous
query has been run.
"""
- cursor = connection.cursor()
- connection.ops.last_executed_query(cursor, '', ())
+ with connection.cursor() as cursor:
+ connection.ops.last_executed_query(cursor, '', ())
def test_debug_sql(self):
list(Reporter.objects.filter(first_name="test"))
@@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
- cursor = connection.cursor()
- query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
- connection.introspection.table_name_converter('backends_square'),
- connection.ops.quote_name('root'),
- connection.ops.quote_name('square')
- ))
- with self.assertRaises(Exception):
- cursor.executemany(query, [(1, 2, 3)])
- with self.assertRaises(Exception):
- cursor.executemany(query, [(1,)])
+ with connection.cursor() as cursor:
+ query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
+ connection.introspection.table_name_converter('backends_square'),
+ connection.ops.quote_name('root'),
+ connection.ops.quote_name('square')
+ ))
+ with self.assertRaises(Exception):
+ cursor.executemany(query, [(1, 2, 3)])
+ with self.assertRaises(Exception):
+ cursor.executemany(query, [(1,)])
class LongNameTest(TransactionTestCase):
@@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase):
'table': VLM._meta.db_table
},
]
- cursor = connection.cursor()
- for statement in connection.ops.sql_flush(no_style(), tables, sequences):
- cursor.execute(statement)
+ sql_list = connection.ops.sql_flush(no_style(), tables, sequences)
+ with connection.cursor() as cursor:
+ for statement in sql_list:
+ cursor.execute(statement)
class SequenceResetTest(TestCase):
@@ -146,10 +147,10 @@ class SequenceResetTest(TestCase):
Post.objects.create(id=10, name='1st post', text='hello world')
# Reset the sequences for the database
- cursor = connection.cursor()
commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post])
- for sql in commands:
- cursor.execute(sql)
+ with connection.cursor() as cursor:
+ for sql in commands:
+ cursor.execute(sql)
# If we create a new object now, it should have a PK greater
# than the PK we specified manually.
@@ -192,14 +193,14 @@ class EscapingChecks(TestCase):
bare_select_suffix = connection.features.bare_select_suffix
def test_paramless_no_escaping(self):
- cursor = connection.cursor()
- cursor.execute("SELECT '%s'" + self.bare_select_suffix)
- self.assertEqual(cursor.fetchall()[0][0], '%s')
+ with connection.cursor() as cursor:
+ cursor.execute("SELECT '%s'" + self.bare_select_suffix)
+ self.assertEqual(cursor.fetchall()[0][0], '%s')
def test_parameter_escaping(self):
- cursor = connection.cursor()
- cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
- self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
+ with connection.cursor() as cursor:
+ cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
+ self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
@override_settings(DEBUG=True)
@@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase):
self.create_squares(args, 'format', True)
def create_squares(self, args, paramstyle, multiple):
- cursor = connection.cursor()
opts = Square._meta
tbl = connection.introspection.table_name_converter(opts.db_table)
f1 = connection.ops.quote_name(opts.get_field('root').column)
@@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase):
query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
else:
raise ValueError("unsupported paramstyle in test")
- if multiple:
- cursor.executemany(query, args)
- else:
- cursor.execute(query, args)
+ with connection.cursor() as cursor:
+ if multiple:
+ cursor.executemany(query, args)
+ else:
+ cursor.execute(query, args)
def test_cursor_executemany(self):
# Test cursor.executemany #4896
@@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase):
Person(first_name="Clark", last_name="Kent").save()
opts2 = Person._meta
f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
- cursor = connection.cursor()
- cursor.execute(
- 'SELECT %s, %s FROM %s ORDER BY %s' % (
- qn(f3.column),
- qn(f4.column),
- connection.introspection.table_name_converter(opts2.db_table),
- qn(f3.column),
+ with connection.cursor() as cursor:
+ cursor.execute(
+ 'SELECT %s, %s FROM %s ORDER BY %s' % (
+ qn(f3.column),
+ qn(f4.column),
+ connection.introspection.table_name_converter(opts2.db_table),
+ qn(f3.column),
+ )
)
- )
- self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
- self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
- self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
+ self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
+ self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
+ self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
def test_unicode_password(self):
old_password = connection.settings_dict['PASSWORD']
@@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase):
def test_duplicate_table_error(self):
""" Creating an existing table returns a DatabaseError """
- cursor = connection.cursor()
query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table
- with self.assertRaises(DatabaseError):
- cursor.execute(query)
+ with connection.cursor() as cursor:
+ with self.assertRaises(DatabaseError):
+ cursor.execute(query)
def test_cursor_contextmanager(self):
"""
diff --git a/tests/db_utils/tests.py b/tests/db_utils/tests.py
index 2a45342df5..4e35e6bb8b 100644
--- a/tests/db_utils/tests.py
+++ b/tests/db_utils/tests.py
@@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase):
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test')
def test_reraising_backend_specific_database_exception(self):
- cursor = connection.cursor()
- msg = 'table "X" does not exist'
- with self.assertRaisesMessage(ProgrammingError, msg) as cm:
- cursor.execute('DROP TABLE "X"')
+ with connection.cursor() as cursor:
+ msg = 'table "X" does not exist'
+ with self.assertRaisesMessage(ProgrammingError, msg) as cm:
+ cursor.execute('DROP TABLE "X"')
self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__))
self.assertIsNotNone(cm.exception.__cause__)
self.assertIsNotNone(cm.exception.__cause__.pgcode)