diff options
| author | Jon Dufresne <jon.dufresne@gmail.com> | 2017-11-28 05:12:28 -0800 |
|---|---|---|
| committer | Tim Graham <timograham@gmail.com> | 2017-11-28 11:28:09 -0500 |
| commit | 7a6fbf36b1fdb8978ea0842075ccce83bcd63789 (patch) | |
| tree | a48e72be7e31066d92de43a85458b0bfb7e2cfcd /tests | |
| parent | 3308085838f520db49f606b72345a301c1cf2a3e (diff) | |
| download | django-7a6fbf36b1fdb8978ea0842075ccce83bcd63789.tar.gz | |
Fixed #28853 -- Updated connection.cursor() uses to use a context manager.
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/backends/postgresql/test_introspection.py | 24 | ||||
| -rw-r--r-- | tests/backends/postgresql/tests.py | 14 | ||||
| -rw-r--r-- | tests/backends/sqlite/tests.py | 16 | ||||
| -rw-r--r-- | tests/backends/tests.py | 87 | ||||
| -rw-r--r-- | tests/db_utils/tests.py | 8 |
5 files changed, 75 insertions, 74 deletions
diff --git a/tests/backends/postgresql/test_introspection.py b/tests/backends/postgresql/test_introspection.py index cfa801a77f..4dcadbd733 100644 --- a/tests/backends/postgresql/test_introspection.py +++ b/tests/backends/postgresql/test_introspection.py @@ -9,15 +9,15 @@ from ..models import Person @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") class DatabaseSequenceTests(TestCase): def test_get_sequences(self): - cursor = connection.cursor() - seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) - self.assertEqual( - seqs, - [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] - ) - cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') - seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) - self.assertEqual( - seqs, - [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] - ) + with connection.cursor() as cursor: + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) + self.assertEqual( + seqs, + [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] + ) + cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) + self.assertEqual( + seqs, + [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] + ) diff --git a/tests/backends/postgresql/tests.py b/tests/backends/postgresql/tests.py index 140fbbc444..caa8b87fb9 100644 --- a/tests/backends/postgresql/tests.py +++ b/tests/backends/postgresql/tests.py @@ -44,10 +44,10 @@ class Tests(TestCase): # Ensure the database default time zone is different than # the time zone in new_connection.settings_dict. We can # get the default time zone by reset & show. - cursor = new_connection.cursor() - cursor.execute("RESET TIMEZONE") - cursor.execute("SHOW TIMEZONE") - db_default_tz = cursor.fetchone()[0] + with new_connection.cursor() as cursor: + cursor.execute("RESET TIMEZONE") + cursor.execute("SHOW TIMEZONE") + db_default_tz = cursor.fetchone()[0] new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC' new_connection.close() @@ -59,12 +59,12 @@ class Tests(TestCase): # time zone, run a query and rollback. with self.settings(TIME_ZONE=new_tz): new_connection.set_autocommit(False) - cursor = new_connection.cursor() new_connection.rollback() # Now let's see if the rollback rolled back the SET TIME ZONE. - cursor.execute("SHOW TIMEZONE") - tz = cursor.fetchone()[0] + with new_connection.cursor() as cursor: + cursor.execute("SHOW TIMEZONE") + tz = cursor.fetchone()[0] self.assertEqual(new_tz, tz) finally: diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py index 0c07f95e6f..0fbb139278 100644 --- a/tests/backends/sqlite/tests.py +++ b/tests/backends/sqlite/tests.py @@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase): # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be # greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query # can hit the SQLITE_MAX_COLUMN limit (#26063). - cursor = connection.cursor() - sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001) - params = list(range(2001)) - # This should not raise an exception. - cursor.db.ops.last_executed_query(cursor.cursor, sql, params) + with connection.cursor() as cursor: + sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001) + params = list(range(2001)) + # This should not raise an exception. + cursor.db.ops.last_executed_query(cursor.cursor, sql, params) @unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') @@ -97,9 +97,9 @@ class EscapingChecks(TestCase): """ def test_parameter_escaping(self): # '%s' escaping support for sqlite3 (#13648). - cursor = connection.cursor() - cursor.execute("select strftime('%s', date('now'))") - response = cursor.fetchall()[0][0] + with connection.cursor() as cursor: + cursor.execute("select strftime('%s', date('now'))") + response = cursor.fetchall()[0][0] # response should be an non-zero integer self.assertTrue(int(response)) diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 6d38625a98..32d76577a1 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase): last_executed_query should not raise an exception even if no previous query has been run. """ - cursor = connection.cursor() - connection.ops.last_executed_query(cursor, '', ()) + with connection.cursor() as cursor: + connection.ops.last_executed_query(cursor, '', ()) def test_debug_sql(self): list(Reporter.objects.filter(first_name="test")) @@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase): def test_bad_parameter_count(self): "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" - cursor = connection.cursor() - query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( - connection.introspection.table_name_converter('backends_square'), - connection.ops.quote_name('root'), - connection.ops.quote_name('square') - )) - with self.assertRaises(Exception): - cursor.executemany(query, [(1, 2, 3)]) - with self.assertRaises(Exception): - cursor.executemany(query, [(1,)]) + with connection.cursor() as cursor: + query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( + connection.introspection.table_name_converter('backends_square'), + connection.ops.quote_name('root'), + connection.ops.quote_name('square') + )) + with self.assertRaises(Exception): + cursor.executemany(query, [(1, 2, 3)]) + with self.assertRaises(Exception): + cursor.executemany(query, [(1,)]) class LongNameTest(TransactionTestCase): @@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase): 'table': VLM._meta.db_table }, ] - cursor = connection.cursor() - for statement in connection.ops.sql_flush(no_style(), tables, sequences): - cursor.execute(statement) + sql_list = connection.ops.sql_flush(no_style(), tables, sequences) + with connection.cursor() as cursor: + for statement in sql_list: + cursor.execute(statement) class SequenceResetTest(TestCase): @@ -146,10 +147,10 @@ class SequenceResetTest(TestCase): Post.objects.create(id=10, name='1st post', text='hello world') # Reset the sequences for the database - cursor = connection.cursor() commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post]) - for sql in commands: - cursor.execute(sql) + with connection.cursor() as cursor: + for sql in commands: + cursor.execute(sql) # If we create a new object now, it should have a PK greater # than the PK we specified manually. @@ -192,14 +193,14 @@ class EscapingChecks(TestCase): bare_select_suffix = connection.features.bare_select_suffix def test_paramless_no_escaping(self): - cursor = connection.cursor() - cursor.execute("SELECT '%s'" + self.bare_select_suffix) - self.assertEqual(cursor.fetchall()[0][0], '%s') + with connection.cursor() as cursor: + cursor.execute("SELECT '%s'" + self.bare_select_suffix) + self.assertEqual(cursor.fetchall()[0][0], '%s') def test_parameter_escaping(self): - cursor = connection.cursor() - cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) - self.assertEqual(cursor.fetchall()[0], ('%', '%d')) + with connection.cursor() as cursor: + cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) + self.assertEqual(cursor.fetchall()[0], ('%', '%d')) @override_settings(DEBUG=True) @@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase): self.create_squares(args, 'format', True) def create_squares(self, args, paramstyle, multiple): - cursor = connection.cursor() opts = Square._meta tbl = connection.introspection.table_name_converter(opts.db_table) f1 = connection.ops.quote_name(opts.get_field('root').column) @@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase): query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2) else: raise ValueError("unsupported paramstyle in test") - if multiple: - cursor.executemany(query, args) - else: - cursor.execute(query, args) + with connection.cursor() as cursor: + if multiple: + cursor.executemany(query, args) + else: + cursor.execute(query, args) def test_cursor_executemany(self): # Test cursor.executemany #4896 @@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase): Person(first_name="Clark", last_name="Kent").save() opts2 = Person._meta f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') - cursor = connection.cursor() - cursor.execute( - 'SELECT %s, %s FROM %s ORDER BY %s' % ( - qn(f3.column), - qn(f4.column), - connection.introspection.table_name_converter(opts2.db_table), - qn(f3.column), + with connection.cursor() as cursor: + cursor.execute( + 'SELECT %s, %s FROM %s ORDER BY %s' % ( + qn(f3.column), + qn(f4.column), + connection.introspection.table_name_converter(opts2.db_table), + qn(f3.column), + ) ) - ) - self.assertEqual(cursor.fetchone(), ('Clark', 'Kent')) - self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')]) - self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')]) + self.assertEqual(cursor.fetchone(), ('Clark', 'Kent')) + self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')]) + self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')]) def test_unicode_password(self): old_password = connection.settings_dict['PASSWORD'] @@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase): def test_duplicate_table_error(self): """ Creating an existing table returns a DatabaseError """ - cursor = connection.cursor() query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table - with self.assertRaises(DatabaseError): - cursor.execute(query) + with connection.cursor() as cursor: + with self.assertRaises(DatabaseError): + cursor.execute(query) def test_cursor_contextmanager(self): """ diff --git a/tests/db_utils/tests.py b/tests/db_utils/tests.py index 2a45342df5..4e35e6bb8b 100644 --- a/tests/db_utils/tests.py +++ b/tests/db_utils/tests.py @@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase): @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test') def test_reraising_backend_specific_database_exception(self): - cursor = connection.cursor() - msg = 'table "X" does not exist' - with self.assertRaisesMessage(ProgrammingError, msg) as cm: - cursor.execute('DROP TABLE "X"') + with connection.cursor() as cursor: + msg = 'table "X" does not exist' + with self.assertRaisesMessage(ProgrammingError, msg) as cm: + cursor.execute('DROP TABLE "X"') self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__)) self.assertIsNotNone(cm.exception.__cause__) self.assertIsNotNone(cm.exception.__cause__.pgcode) |
