diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-11-15 16:58:50 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-11 13:26:05 -0500 |
| commit | ba5cbf9366e9b2c5ed8e27e91815d7a2c3b63e41 (patch) | |
| tree | 038f2263d581d5e49d74731af68febc4bf64eb19 /test/sql | |
| parent | 87d58b6d8188ccff808b3207d5f9398bb9adf9b9 (diff) | |
| download | sqlalchemy-ba5cbf9366e9b2c5ed8e27e91815d7a2c3b63e41.tar.gz | |
correct for "autocommit" deprecation warning
Ensure no autocommit warnings occur internally or
within tests.
Also includes fixes for SQL Server full text tests
which apparently have not been working at all for a long
time, as it used long removed APIs. CI has not had
fulltext running for some years and is now installed.
Change-Id: Id806e1856c9da9f0a9eac88cebc7a94ecc95eb96
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_defaults.py | 14 | ||||
| -rw-r--r-- | test/sql/test_delete.py | 62 | ||||
| -rw-r--r-- | test/sql/test_deprecations.py | 252 | ||||
| -rw-r--r-- | test/sql/test_query.py | 39 | ||||
| -rw-r--r-- | test/sql/test_quote.py | 123 | ||||
| -rw-r--r-- | test/sql/test_resultset.py | 81 | ||||
| -rw-r--r-- | test/sql/test_returning.py | 74 | ||||
| -rw-r--r-- | test/sql/test_sequences.py | 64 | ||||
| -rw-r--r-- | test/sql/test_type_expressions.py | 28 | ||||
| -rw-r--r-- | test/sql/test_types.py | 385 | ||||
| -rw-r--r-- | test/sql/test_update.py | 86 |
11 files changed, 660 insertions, 548 deletions
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 4a6ebd0c8..2a2e70bc3 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -1012,9 +1012,7 @@ class PKIncrementTest(fixtures.TablesTest): Column("str1", String(20)), ) - # TODO: add coverage for increment on a secondary column in a key - @testing.fails_on("firebird", "Data type unknown") - def _test_autoincrement(self, connection): + def test_autoincrement(self, connection): aitable = self.tables.aitable ids = set() @@ -1064,14 +1062,6 @@ class PKIncrementTest(fixtures.TablesTest): ], ) - def test_autoincrement_autocommit(self): - with testing.db.connect() as conn: - self._test_autoincrement(conn) - - def test_autoincrement_transaction(self): - with testing.db.begin() as conn: - self._test_autoincrement(conn) - class EmptyInsertTest(fixtures.TestBase): __backend__ = True @@ -1267,7 +1257,7 @@ class SpecialTypePKTest(fixtures.TestBase): implicit_returning=implicit_returning, ) - with testing.db.connect() as conn: + with testing.db.begin() as conn: t.create(conn) r = conn.execute(t.insert().values(data=5)) diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py index 934022560..6f7b3f8f5 100644 --- a/test/sql/test_delete.py +++ b/test/sql/test_delete.py @@ -308,32 +308,31 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): ) @testing.requires.delete_from - def test_exec_two_table(self): + def test_exec_two_table(self, connection): users, addresses = self.tables.users, self.tables.addresses dingalings = self.tables.dingalings - with testing.db.connect() as conn: - conn.execute(dingalings.delete()) # fk violation otherwise + connection.execute(dingalings.delete()) # fk violation otherwise - conn.execute( - addresses.delete() - .where(users.c.id == addresses.c.user_id) - .where(users.c.name == "ed") - ) + connection.execute( + addresses.delete() + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + ) - expected = [ - (1, 7, "x", "jack@bean.com"), - (5, 9, "x", "fred@fred.com"), - ] - self._assert_table(addresses, expected) + expected = [ + (1, 7, "x", "jack@bean.com"), + (5, 9, "x", "fred@fred.com"), + ] + self._assert_table(connection, addresses, expected) @testing.requires.delete_from - def test_exec_three_table(self): + def test_exec_three_table(self, connection): users = self.tables.users addresses = self.tables.addresses dingalings = self.tables.dingalings - testing.db.execute( + connection.execute( dingalings.delete() .where(users.c.id == addresses.c.user_id) .where(users.c.name == "ed") @@ -341,34 +340,33 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): ) expected = [(2, 5, "ding 2/5")] - self._assert_table(dingalings, expected) + self._assert_table(connection, dingalings, expected) @testing.requires.delete_from - def test_exec_two_table_plus_alias(self): + def test_exec_two_table_plus_alias(self, connection): users, addresses = self.tables.users, self.tables.addresses dingalings = self.tables.dingalings - with testing.db.connect() as conn: - conn.execute(dingalings.delete()) # fk violation otherwise - a1 = addresses.alias() - conn.execute( - addresses.delete() - .where(users.c.id == addresses.c.user_id) - .where(users.c.name == "ed") - .where(a1.c.id == addresses.c.id) - ) + connection.execute(dingalings.delete()) # fk violation otherwise + a1 = addresses.alias() + connection.execute( + addresses.delete() + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + .where(a1.c.id == addresses.c.id) + ) expected = [(1, 7, "x", "jack@bean.com"), (5, 9, "x", "fred@fred.com")] - self._assert_table(addresses, expected) + self._assert_table(connection, addresses, expected) @testing.requires.delete_from - def test_exec_alias_plus_table(self): + def test_exec_alias_plus_table(self, connection): users, addresses = self.tables.users, self.tables.addresses dingalings = self.tables.dingalings d1 = dingalings.alias() - testing.db.execute( + connection.execute( delete(d1) .where(users.c.id == addresses.c.user_id) .where(users.c.name == "ed") @@ -376,8 +374,8 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): ) expected = [(2, 5, "ding 2/5")] - self._assert_table(dingalings, expected) + self._assert_table(connection, dingalings, expected) - def _assert_table(self, table, expected): + def _assert_table(self, connection, table, expected): stmt = table.select().order_by(table.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index c0d2e87e8..e082cf55d 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -23,6 +23,7 @@ from sqlalchemy import MetaData from sqlalchemy import null from sqlalchemy import or_ from sqlalchemy import select +from sqlalchemy import Sequence from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import table @@ -1271,6 +1272,165 @@ class KeyTargetingTest(fixtures.TablesTest): in_(stmt.c.keyed2_b, row) +class PKIncrementTest(fixtures.TablesTest): + run_define_tables = "each" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "aitable", + metadata, + Column( + "id", + Integer, + Sequence("ai_id_seq", optional=True), + primary_key=True, + ), + Column("int1", Integer), + Column("str1", String(20)), + ) + + def _test_autoincrement(self, connection): + aitable = self.tables.aitable + + ids = set() + rs = connection.execute(aitable.insert(), int1=1) + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = connection.execute(aitable.insert(), str1="row 2") + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = connection.execute(aitable.insert(), int1=3, str1="row 3") + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = connection.execute( + aitable.insert().values({"int1": func.length("four")}) + ) + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + eq_( + ids, + set( + range( + testing.db.dialect.default_sequence_base, + testing.db.dialect.default_sequence_base + 4, + ) + ), + ) + + eq_( + list(connection.execute(aitable.select().order_by(aitable.c.id))), + [ + (testing.db.dialect.default_sequence_base, 1, None), + (testing.db.dialect.default_sequence_base + 1, None, "row 2"), + (testing.db.dialect.default_sequence_base + 2, 3, "row 3"), + (testing.db.dialect.default_sequence_base + 3, 4, None), + ], + ) + + def test_autoincrement_autocommit(self): + with testing.db.connect() as conn: + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit, " + ): + self._test_autoincrement(conn) + + +class ConnectionlessCursorResultTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column( + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, + ) + + def test_connectionless_autoclose_rows_exhausted(self): + users = self.tables.users + with testing.db.begin() as conn: + conn.execute(users.insert(), dict(user_id=1, user_name="john")) + + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(text("select * from users")) + connection = result.connection + assert not connection.closed + eq_(result.fetchone(), (1, "john")) + assert not connection.closed + eq_(result.fetchone(), None) + assert connection.closed + + @testing.requires.returning + def test_connectionless_autoclose_crud_rows_exhausted(self): + users = self.tables.users + stmt = ( + users.insert() + .values(user_id=1, user_name="john") + .returning(users.c.user_id) + ) + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(stmt) + connection = result.connection + assert not connection.closed + eq_(result.fetchone(), (1,)) + assert not connection.closed + eq_(result.fetchone(), None) + assert connection.closed + + def test_connectionless_autoclose_no_rows(self): + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(text("select * from users")) + connection = result.connection + assert not connection.closed + eq_(result.fetchone(), None) + assert connection.closed + + @testing.requires.updateable_autoincrement_pks + def test_connectionless_autoclose_no_metadata(self): + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(text("update users set user_id=5")) + connection = result.connection + assert connection.closed + + assert_raises_message( + exc.ResourceClosedError, + "This result object does not return rows.", + result.fetchone, + ) + assert_raises_message( + exc.ResourceClosedError, + "This result object does not return rows.", + result.keys, + ) + + class CursorResultTest(fixtures.TablesTest): __backend__ = True @@ -1436,7 +1596,7 @@ class CursorResultTest(fixtures.TablesTest): def test_pickled_rows(self): users = self.tables.users addresses = self.tables.addresses - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(users.delete()) conn.execute( users.insert(), @@ -2319,3 +2479,93 @@ class LegacyOperatorTest(AssertsCompiledSQL, fixtures.TestBase): _op_modern = getattr(operators.ColumnOperators, _modern) _op_legacy = getattr(operators.ColumnOperators, _legacy) assert _op_modern == _op_legacy + + +class LegacySequenceExecTest(fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + @classmethod + def setup_class(cls): + cls.seq = Sequence("my_sequence") + cls.seq.create(testing.db) + + @classmethod + def teardown_class(cls): + cls.seq.drop(testing.db) + + def _assert_seq_result(self, ret): + """asserts return of next_value is an int""" + + assert isinstance(ret, util.int_types) + assert ret >= testing.db.dialect.default_sequence_base + + def test_implicit_connectionless(self): + with testing.expect_deprecated_20( + r"The MetaData.bind argument is deprecated" + ): + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + + with testing.expect_deprecated_20( + r"The DefaultGenerator.execute\(\) method is considered legacy " + "as of the 1.x", + ): + self._assert_seq_result(s.execute()) + + def test_explicit(self, connection): + s = Sequence("my_sequence") + with testing.expect_deprecated_20( + r"The DefaultGenerator.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.execute(connection)) + + def test_explicit_optional(self): + """test dialect executes a Sequence, returns nextval, whether + or not "optional" is set""" + + s = Sequence("my_sequence", optional=True) + with testing.expect_deprecated_20( + r"The DefaultGenerator.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.execute(testing.db)) + + def test_func_implicit_connectionless_execute(self): + """test func.next_value().execute()/.scalar() works + with connectionless execution.""" + + with testing.expect_deprecated_20( + r"The MetaData.bind argument is deprecated" + ): + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.next_value().execute().scalar()) + + def test_func_explicit(self): + s = Sequence("my_sequence") + with testing.expect_deprecated_20( + r"The Engine.scalar\(\) method is considered legacy" + ): + self._assert_seq_result(testing.db.scalar(s.next_value())) + + def test_func_implicit_connectionless_scalar(self): + """test func.next_value().execute()/.scalar() works. """ + + with testing.expect_deprecated_20( + r"The MetaData.bind argument is deprecated" + ): + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.next_value().scalar()) + + def test_func_embedded_select(self): + """test can use next_value() in select column expr""" + + s = Sequence("my_sequence") + with testing.expect_deprecated_20( + r"The Engine.scalar\(\) method is considered legacy" + ): + self._assert_seq_result(testing.db.scalar(select(s.next_value()))) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 7d05462ab..6d26f7975 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -84,7 +84,7 @@ class QueryTest(fixtures.TestBase): @engines.close_first def teardown(self): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(addresses.delete()) conn.execute(users.delete()) conn.execute(users2.delete()) @@ -878,21 +878,22 @@ class RequiredBindTest(fixtures.TablesTest): ) def _assert_raises(self, stmt, params): - assert_raises_message( - exc.StatementError, - "A value is required for bind parameter 'x'", - testing.db.execute, - stmt, - **params - ) + with testing.db.connect() as conn: + assert_raises_message( + exc.StatementError, + "A value is required for bind parameter 'x'", + conn.execute, + stmt, + **params + ) - assert_raises_message( - exc.StatementError, - "A value is required for bind parameter 'x'", - testing.db.execute, - stmt, - params, - ) + assert_raises_message( + exc.StatementError, + "A value is required for bind parameter 'x'", + conn.execute, + stmt, + params, + ) def test_insert(self): stmt = self.tables.foo.insert().values( @@ -953,7 +954,7 @@ class LimitTest(fixtures.TestBase): ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(users.insert(), user_id=1, user_name="john") conn.execute( addresses.insert(), address_id=1, user_id=1, address="addr1" @@ -1105,7 +1106,7 @@ class CompoundTest(fixtures.TestBase): ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute( t1.insert(), [ @@ -1470,7 +1471,7 @@ class JoinTest(fixtures.TestBase): metadata.drop_all() metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: # t1.10 -> t2.20 -> t3.30 # t1.11 -> t2.21 # t1.12 @@ -1823,7 +1824,7 @@ class OperatorTest(fixtures.TestBase): ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute( flds.insert(), [dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")], diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 1c023e7b1..a78d6c16b 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -25,19 +25,12 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing.util import picklers -class QuoteExecTest(fixtures.TestBase): +class QuoteExecTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - # TODO: figure out which databases/which identifiers allow special - # characters to be used, such as: spaces, quote characters, - # punctuation characters, set up tests for those as well. - - global table1, table2 - metadata = MetaData(testing.db) - - table1 = Table( + def define_tables(cls, metadata): + Table( "WorstCase1", metadata, Column("lowercase", Integer, primary_key=True), @@ -45,7 +38,7 @@ class QuoteExecTest(fixtures.TestBase): Column("MixedCase", Integer), Column("ASC", Integer, key="a123"), ) - table2 = Table( + Table( "WorstCase2", metadata, Column("desc", Integer, primary_key=True, key="d123"), @@ -53,18 +46,6 @@ class QuoteExecTest(fixtures.TestBase): Column("MixedCase", Integer), ) - table1.create() - table2.create() - - def teardown(self): - table1.delete().execute() - table2.delete().execute() - - @classmethod - def teardown_class(cls): - table1.drop() - table2.drop() - def test_reflect(self): meta2 = MetaData() t2 = Table("WorstCase1", meta2, autoload_with=testing.db, quote=True) @@ -88,25 +69,22 @@ class QuoteExecTest(fixtures.TestBase): assert "MixedCase" in t2.c @testing.provide_metadata - def test_has_table_case_sensitive(self): + def test_has_table_case_sensitive(self, connection): preparer = testing.db.dialect.identifier_preparer - with testing.db.connect() as conn: - if conn.dialect.requires_name_normalize: - conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)") - else: - conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)") - conn.exec_driver_sql( - "CREATE TABLE %s (id INTEGER)" - % preparer.quote_identifier("tab2") - ) - conn.exec_driver_sql( - "CREATE TABLE %s (id INTEGER)" - % preparer.quote_identifier("TAB3") - ) - conn.exec_driver_sql( - "CREATE TABLE %s (id INTEGER)" - % preparer.quote_identifier("TAB4") - ) + conn = connection + if conn.dialect.requires_name_normalize: + conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)") + else: + conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)") + conn.exec_driver_sql( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("tab2") + ) + conn.exec_driver_sql( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB3") + ) + conn.exec_driver_sql( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB4") + ) t1 = Table( "tab1", self.metadata, Column("id", Integer, primary_key=True) @@ -127,7 +105,7 @@ class QuoteExecTest(fixtures.TestBase): quote=True, ) - insp = inspect(testing.db) + insp = inspect(connection) assert insp.has_table(t1.name) eq_([c["name"] for c in insp.get_columns(t1.name)], ["id"]) @@ -140,16 +118,24 @@ class QuoteExecTest(fixtures.TestBase): assert insp.has_table(t4.name) eq_([c["name"] for c in insp.get_columns(t4.name)], ["id"]) - def test_basic(self): - table1.insert().execute( - {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + def test_basic(self, connection): + table1, table2 = self.tables("WorstCase1", "WorstCase2") + + connection.execute( + table1.insert(), + [ + {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + ], ) - table2.insert().execute( - {"d123": 1, "u123": 2, "MixedCase": 3}, - {"d123": 2, "u123": 2, "MixedCase": 3}, - {"d123": 4, "u123": 3, "MixedCase": 2}, + connection.execute( + table2.insert(), + [ + {"d123": 1, "u123": 2, "MixedCase": 3}, + {"d123": 2, "u123": 2, "MixedCase": 3}, + {"d123": 4, "u123": 3, "MixedCase": 2}, + ], ) columns = [ @@ -158,23 +144,30 @@ class QuoteExecTest(fixtures.TestBase): table1.c.MixedCase, table1.c.a123, ] - result = select(columns).execute().fetchall() + result = connection.execute(select(columns)).all() assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)] columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase] - result = select(columns).execute().fetchall() + result = connection.execute(select(columns)).all() assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)] - def test_use_labels(self): - table1.insert().execute( - {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, - ) - table2.insert().execute( - {"d123": 1, "u123": 2, "MixedCase": 3}, - {"d123": 2, "u123": 2, "MixedCase": 3}, - {"d123": 4, "u123": 3, "MixedCase": 2}, + def test_use_labels(self, connection): + table1, table2 = self.tables("WorstCase1", "WorstCase2") + connection.execute( + table1.insert(), + [ + {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + ], + ) + connection.execute( + table2.insert(), + [ + {"d123": 1, "u123": 2, "MixedCase": 3}, + {"d123": 2, "u123": 2, "MixedCase": 3}, + {"d123": 4, "u123": 3, "MixedCase": 2}, + ], ) columns = [ @@ -183,11 +176,11 @@ class QuoteExecTest(fixtures.TestBase): table1.c.MixedCase, table1.c.a123, ] - result = select(columns, use_labels=True).execute().fetchall() + result = connection.execute(select(columns).apply_labels()).fetchall() assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)] columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase] - result = select(columns, use_labels=True).execute().fetchall() + result = connection.execute(select(columns).apply_labels()).all() assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)] diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 9ef533be3..db0e0d4c8 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -615,63 +615,6 @@ class CursorResultTest(fixtures.TablesTest): result.fetchone, ) - def test_connectionless_autoclose_rows_exhausted(self): - # TODO: deprecate for 2.0 - users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(user_id=1, user_name="john")) - - result = testing.db.execute(text("select * from users")) - connection = result.connection - assert not connection.closed - eq_(result.fetchone(), (1, "john")) - assert not connection.closed - eq_(result.fetchone(), None) - assert connection.closed - - @testing.requires.returning - def test_connectionless_autoclose_crud_rows_exhausted(self): - # TODO: deprecate for 2.0 - users = self.tables.users - stmt = ( - users.insert() - .values(user_id=1, user_name="john") - .returning(users.c.user_id) - ) - result = testing.db.execute(stmt) - connection = result.connection - assert not connection.closed - eq_(result.fetchone(), (1,)) - assert not connection.closed - eq_(result.fetchone(), None) - assert connection.closed - - def test_connectionless_autoclose_no_rows(self): - # TODO: deprecate for 2.0 - result = testing.db.execute(text("select * from users")) - connection = result.connection - assert not connection.closed - eq_(result.fetchone(), None) - assert connection.closed - - @testing.requires.updateable_autoincrement_pks - def test_connectionless_autoclose_no_metadata(self): - # TODO: deprecate for 2.0 - result = testing.db.execute(text("update users set user_id=5")) - connection = result.connection - assert connection.closed - - assert_raises_message( - exc.ResourceClosedError, - "This result object does not return rows.", - result.fetchone, - ) - assert_raises_message( - exc.ResourceClosedError, - "This result object does not return rows.", - result.keys, - ) - def test_row_case_sensitive(self, connection): row = connection.execute( select( @@ -1285,7 +1228,7 @@ class CursorResultTest(fixtures.TablesTest): with patch.object( engine.dialect.execution_ctx_cls, "rowcount" ) as mock_rowcount: - with engine.connect() as conn: + with engine.begin() as conn: mock_rowcount.__get__ = Mock() conn.execute( t.insert(), {"data": "d1"}, {"data": "d2"}, {"data": "d3"} @@ -1362,20 +1305,14 @@ class CursorResultTest(fixtures.TablesTest): eq_(row[1:0:-1], ("Uno",)) @testing.requires.cextensions - def test_row_c_sequence_check(self): - # TODO: modernize for 2.0 - metadata = MetaData() - metadata.bind = "sqlite://" - users = Table( - "users", - metadata, - Column("id", Integer, primary_key=True), - Column("name", String(40)), - ) - users.create() + @testing.provide_metadata + def test_row_c_sequence_check(self, connection): + users = self.tables.users2 - users.insert().execute(name="Test") - row = users.select().execute().fetchone() + connection.execute(users.insert(), dict(user_id=1, user_name="Test")) + row = connection.execute( + users.select().where(users.c.user_id == 1) + ).fetchone() s = util.StringIO() writer = csv.writer(s) @@ -2340,7 +2277,7 @@ class AlternateCursorResultTest(fixtures.TablesTest): @testing.fixture def row_growth_fixture(self): with self._proxy_fixture(_cursor.BufferedRowCursorFetchStrategy): - with self.engine.connect() as conn: + with self.engine.begin() as conn: conn.execute( self.table.insert(), [{"x": i, "y": "t_%d" % i} for i in range(15, 3000)], diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 065205c45..9f2afd7b7 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -23,9 +23,6 @@ from sqlalchemy.testing.schema import Table from sqlalchemy.types import TypeDecorator -table = GoofyType = seq = None - - class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "postgresql" @@ -92,14 +89,14 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): ) -class ReturningTest(fixtures.TestBase, AssertsExecutionResults): +class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): __requires__ = ("returning",) __backend__ = True - def setup(self): - meta = MetaData(testing.db) - global table, GoofyType + run_create_tables = "each" + @classmethod + def define_tables(cls, metadata): class GoofyType(TypeDecorator): impl = String @@ -113,9 +110,11 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): return None return value + "BAR" - table = Table( + cls.GoofyType = GoofyType + + Table( "tables", - meta, + metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), @@ -123,14 +122,9 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): Column("full", Boolean), Column("goofy", GoofyType(50)), ) - with testing.db.connect() as conn: - table.create(conn, checkfirst=True) - - def teardown(self): - with testing.db.connect() as conn: - table.drop(conn) def test_column_targeting(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.id, table.c.full), {"persons": 1, "full": False}, @@ -155,6 +149,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.fails_on("firebird", "fb can't handle returning x AS y") def test_labeling(self, connection): + table = self.tables.tables result = connection.execute( table.insert() .values(persons=6) @@ -167,6 +162,8 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): "firebird", "fb/kintersbasdb can't handle the bind params" ) def test_anon_expressions(self, connection): + table = self.tables.tables + GoofyType = self.GoofyType result = connection.execute( table.insert() .values(goofy="someOTHERgoofy") @@ -182,6 +179,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(row[0], 30) def test_update_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -201,6 +199,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.full_returning def test_update_full_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -215,6 +214,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.full_returning def test_delete_full_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -226,6 +226,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result.fetchall(), [(1, False), (2, False)]) def test_insert_returning(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.id), {"persons": 1, "full": False} ) @@ -234,6 +235,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.multivalues_inserts def test_multirow_returning(self, connection): + table = self.tables.tables ins = ( table.insert() .returning(table.c.id, table.c.persons) @@ -249,6 +251,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) def test_no_ipk_on_returning(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.id), {"persons": 1, "full": False} ) @@ -274,6 +277,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_([dict(row._mapping) for row in result4], [{"persons": 10}]) def test_delete_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -319,17 +323,16 @@ class CompositeStatementTest(fixtures.TestBase): eq_(result.scalar(), 5) -class SequenceReturningTest(fixtures.TestBase): +class SequenceReturningTest(fixtures.TablesTest): __requires__ = "returning", "sequences" __backend__ = True - def setup(self): - meta = MetaData(testing.db) - global table, seq + @classmethod + def define_tables(cls, metadata): seq = Sequence("tid_seq") - table = Table( + Table( "tables", - meta, + metadata, Column( "id", Integer, @@ -338,38 +341,32 @@ class SequenceReturningTest(fixtures.TestBase): ), Column("data", String(50)), ) - with testing.db.connect() as conn: - table.create(conn, checkfirst=True) - - def teardown(self): - with testing.db.connect() as conn: - table.drop(conn) + cls.sequences.tid_seq = seq def test_insert(self, connection): + table = self.tables.tables r = connection.execute( table.insert().values(data="hi").returning(table.c.id) ) eq_(r.first(), tuple([testing.db.dialect.default_sequence_base])) eq_( - connection.execute(seq), + connection.execute(self.sequences.tid_seq), testing.db.dialect.default_sequence_base + 1, ) -class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults): +class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): """test returning() works with columns that define 'key'.""" __requires__ = ("returning",) __backend__ = True - def setup(self): - meta = MetaData(testing.db) - global table - - table = Table( + @classmethod + def define_tables(cls, metadata): + Table( "tables", - meta, + metadata, Column( "id", Integer, @@ -379,16 +376,11 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults): ), Column("data", String(20)), ) - with testing.db.connect() as conn: - table.create(conn, checkfirst=True) - - def teardown(self): - with testing.db.connect() as conn: - table.drop(conn) @testing.exclude("firebird", "<", (2, 0), "2.0+ feature") @testing.exclude("postgresql", "<", (8, 2), "8.2+ feature") def test_insert(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.foo_id), data="somedata" ) diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index e609a8a91..1809e0cca 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -95,64 +95,6 @@ class SequenceDDLTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) -class LegacySequenceExecTest(fixtures.TestBase): - __requires__ = ("sequences",) - __backend__ = True - - @classmethod - def setup_class(cls): - cls.seq = Sequence("my_sequence") - cls.seq.create(testing.db) - - @classmethod - def teardown_class(cls): - cls.seq.drop(testing.db) - - def _assert_seq_result(self, ret): - """asserts return of next_value is an int""" - - assert isinstance(ret, util.int_types) - assert ret >= testing.db.dialect.default_sequence_base - - def test_implicit_connectionless(self): - s = Sequence("my_sequence", metadata=MetaData(testing.db)) - self._assert_seq_result(s.execute()) - - def test_explicit(self, connection): - s = Sequence("my_sequence") - self._assert_seq_result(s.execute(connection)) - - def test_explicit_optional(self): - """test dialect executes a Sequence, returns nextval, whether - or not "optional" is set""" - - s = Sequence("my_sequence", optional=True) - self._assert_seq_result(s.execute(testing.db)) - - def test_func_implicit_connectionless_execute(self): - """test func.next_value().execute()/.scalar() works - with connectionless execution.""" - - s = Sequence("my_sequence", metadata=MetaData(testing.db)) - self._assert_seq_result(s.next_value().execute().scalar()) - - def test_func_explicit(self): - s = Sequence("my_sequence") - self._assert_seq_result(testing.db.scalar(s.next_value())) - - def test_func_implicit_connectionless_scalar(self): - """test func.next_value().execute()/.scalar() works. """ - - s = Sequence("my_sequence", metadata=MetaData(testing.db)) - self._assert_seq_result(s.next_value().scalar()) - - def test_func_embedded_select(self): - """test can use next_value() in select column expr""" - - s = Sequence("my_sequence") - self._assert_seq_result(testing.db.scalar(select(s.next_value()))) - - class SequenceExecTest(fixtures.TestBase): __requires__ = ("sequences",) __backend__ = True @@ -247,7 +189,7 @@ class SequenceExecTest(fixtures.TestBase): s = Sequence("my_sequence_here", metadata=metadata) e = engines.testing_engine(options={"implicit_returning": False}) - with e.connect() as conn: + with e.begin() as conn: t1.create(conn) s.create(conn) @@ -279,7 +221,7 @@ class SequenceExecTest(fixtures.TestBase): t1.create(testing.db) e = engines.testing_engine(options={"implicit_returning": True}) - with e.connect() as conn: + with e.begin() as conn: r = conn.execute(t1.insert().values(x=s.next_value())) self._assert_seq_result(r.inserted_primary_key[0]) @@ -476,7 +418,7 @@ class TableBoundSequenceTest(fixtures.TablesTest): engine = engines.testing_engine(options={"implicit_returning": False}) - with engine.connect() as conn: + with engine.begin() as conn: result = conn.execute(sometable.insert(), dict(name="somename")) eq_(result.postfetch_cols(), [sometable.c.obj_id]) diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index 09ade319e..719f8e318 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -359,34 +359,34 @@ class RoundTripTestBase(object): [("X1", "Y1"), ("X2", "Y2"), ("X3", "Y3")], ) - def test_targeting_no_labels(self): - testing.db.execute( + def test_targeting_no_labels(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute(select(self.tables.test_table)).first() + row = connection.execute(select(self.tables.test_table)).first() eq_(row._mapping[self.tables.test_table.c.y], "Y1") - def test_targeting_by_string(self): - testing.db.execute( + def test_targeting_by_string(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute(select(self.tables.test_table)).first() + row = connection.execute(select(self.tables.test_table)).first() eq_(row._mapping["y"], "Y1") - def test_targeting_apply_labels(self): - testing.db.execute( + def test_targeting_apply_labels(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute( + row = connection.execute( select(self.tables.test_table).apply_labels() ).first() eq_(row._mapping[self.tables.test_table.c.y], "Y1") - def test_targeting_individual_labels(self): - testing.db.execute( + def test_targeting_individual_labels(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute( + row = connection.execute( select( self.tables.test_table.c.x.label("xbar"), self.tables.test_table.c.y.label("ybar"), @@ -450,9 +450,9 @@ class ReturningTest(fixtures.TablesTest): ) @testing.provide_metadata - def test_insert_returning(self): + def test_insert_returning(self, connection): table = self.tables.test_table - result = testing.db.execute( + result = connection.execute( table.insert().returning(table.c.y), {"x": "xvalue"} ) eq_(result.first(), ("yvalue",)) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index fd1783e09..3f89d438a 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -535,49 +535,48 @@ class _UserDefinedTypeFixture(object): class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): __backend__ = True - def _data_fixture(self): + def _data_fixture(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute( - users.insert(), - dict( - user_id=2, - goofy="jack", - goofy2="jack", - goofy4=util.u("jack"), - goofy7=util.u("jack"), - goofy8=12, - goofy9=12, - ), - ) - conn.execute( - users.insert(), - dict( - user_id=3, - goofy="lala", - goofy2="lala", - goofy4=util.u("lala"), - goofy7=util.u("lala"), - goofy8=15, - goofy9=15, - ), - ) - conn.execute( - users.insert(), - dict( - user_id=4, - goofy="fred", - goofy2="fred", - goofy4=util.u("fred"), - goofy7=util.u("fred"), - goofy8=9, - goofy9=9, - ), - ) + connection.execute( + users.insert(), + dict( + user_id=2, + goofy="jack", + goofy2="jack", + goofy4=util.u("jack"), + goofy7=util.u("jack"), + goofy8=12, + goofy9=12, + ), + ) + connection.execute( + users.insert(), + dict( + user_id=3, + goofy="lala", + goofy2="lala", + goofy4=util.u("lala"), + goofy7=util.u("lala"), + goofy8=15, + goofy9=15, + ), + ) + connection.execute( + users.insert(), + dict( + user_id=4, + goofy="fred", + goofy2="fred", + goofy4=util.u("fred"), + goofy7=util.u("fred"), + goofy8=9, + goofy9=9, + ), + ) def test_processing(self, connection): users = self.tables.users - self._data_fixture() + self._data_fixture(connection) result = connection.execute( users.select().order_by(users.c.user_id) @@ -601,7 +600,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): def test_plain_in(self, connection): users = self.tables.users - self._data_fixture() + self._data_fixture(connection) stmt = ( select(users.c.user_id, users.c.goofy8) @@ -613,7 +612,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): def test_expanding_in(self, connection): users = self.tables.users - self._data_fixture() + self._data_fixture(connection) stmt = ( select(users.c.user_id, users.c.goofy8) @@ -1225,41 +1224,38 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): @testing.only_on("sqlite") @testing.provide_metadata - def test_round_trip(self): + def test_round_trip(self, connection): variant = self.UTypeOne().with_variant(self.UTypeTwo(), "sqlite") t = Table("t", self.metadata, Column("x", variant)) - with testing.db.connect() as conn: - t.create(conn) + t.create(connection) - conn.execute(t.insert(), x="foo") + connection.execute(t.insert(), x="foo") - eq_(conn.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO") + eq_(connection.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO") @testing.only_on("sqlite") @testing.provide_metadata - def test_round_trip_sqlite_datetime(self): + def test_round_trip_sqlite_datetime(self, connection): variant = DateTime().with_variant( dialects.sqlite.DATETIME(truncate_microseconds=True), "sqlite" ) t = Table("t", self.metadata, Column("x", variant)) - with testing.db.connect() as conn: - t.create(conn) + t.create(connection) - conn.execute( - t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839) - ) + connection.execute( + t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839) + ) - eq_( - conn.scalar( - select(t.c.x).where( - t.c.x - == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059) - ) - ), - datetime.datetime(2015, 4, 18, 10, 15, 17), - ) + eq_( + connection.scalar( + select(t.c.x).where( + t.c.x == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059) + ) + ), + datetime.datetime(2015, 4, 18, 10, 15, 17), + ) class UnicodeTest(fixtures.TestBase): @@ -1702,14 +1698,25 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): 2, ) - with testing.db.connect() as conn: - self.metadata.create_all(conn) + self.metadata.create_all(testing.db) + + # not using the connection fixture because we need to rollback and + # start again in the middle + with testing.db.connect() as connection: + # postgresql needs this in order to continue after the exception + trans = connection.begin() assert_raises( (exc.DBAPIError,), - conn.exec_driver_sql, + connection.exec_driver_sql, "insert into my_table " "(data) values('four')", ) - conn.exec_driver_sql("insert into my_table (data) values ('two')") + trans.rollback() + + with connection.begin(): + connection.exec_driver_sql( + "insert into my_table (data) values ('two')" + ) + eq_(connection.execute(select(t.c.data)).scalar(), "two") @testing.requires.enforces_check_constraints @testing.provide_metadata @@ -1747,34 +1754,44 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): 2, ) - with testing.db.connect() as conn: - self.metadata.create_all(conn) + self.metadata.create_all(testing.db) + + # not using the connection fixture because we need to rollback and + # start again in the middle + with testing.db.connect() as connection: + # postgresql needs this in order to continue after the exception + trans = connection.begin() assert_raises( (exc.DBAPIError,), - conn.exec_driver_sql, + connection.exec_driver_sql, "insert into my_table " "(data) values('two')", ) - conn.exec_driver_sql("insert into my_table (data) values ('four')") + trans.rollback() - def test_skip_check_constraint(self): - with testing.db.connect() as conn: - conn.exec_driver_sql( - "insert into non_native_enum_table " - "(id, someotherenum) values(1, 'four')" - ) - eq_( - conn.exec_driver_sql( - "select someotherenum from non_native_enum_table" - ).scalar(), - "four", - ) - assert_raises_message( - LookupError, - "'four' is not among the defined enum values. " - "Enum name: None. Possible values: one, two, three", - conn.scalar, - select(self.tables.non_native_enum_table.c.someotherenum), - ) + with connection.begin(): + connection.exec_driver_sql( + "insert into my_table (data) values ('four')" + ) + eq_(connection.execute(select(t.c.data)).scalar(), "four") + + def test_skip_check_constraint(self, connection): + connection.exec_driver_sql( + "insert into non_native_enum_table " + "(id, someotherenum) values(1, 'four')" + ) + eq_( + connection.exec_driver_sql( + "select someotherenum from non_native_enum_table" + ).scalar(), + "four", + ) + assert_raises_message( + LookupError, + "'four' is not among the defined enum values. " + "Enum name: None. Possible values: one, two, three", + connection.scalar, + select(self.tables.non_native_enum_table.c.someotherenum), + ) def test_non_native_round_trip(self, connection): non_native_enum_table = self.tables["non_native_enum_table"] @@ -2086,15 +2103,15 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): eq_(e.length, 42) -binary_table = MyPickleType = metadata = None +MyPickleType = None -class BinaryTest(fixtures.TestBase, AssertsExecutionResults): +class BinaryTest(fixtures.TablesTest, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): - global binary_table, MyPickleType, metadata + def define_tables(cls, metadata): + global MyPickleType class MyPickleType(types.TypeDecorator): impl = PickleType @@ -2109,8 +2126,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): value.stuff = "this is the right stuff" return value - metadata = MetaData(testing.db) - binary_table = Table( + Table( "binary_table", metadata, Column( @@ -2125,19 +2141,11 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): Column("pickled", PickleType), Column("mypickle", MyPickleType), ) - metadata.create_all() - - @engines.close_first - def teardown(self): - with testing.db.connect() as conn: - conn.execute(binary_table.delete()) - - @classmethod - def teardown_class(cls): - metadata.drop_all() @testing.requires.non_broken_binary def test_round_trip(self, connection): + binary_table = self.tables.binary_table + testobj1 = pickleable.Foo("im foo 1") testobj2 = pickleable.Foo("im foo 2") testobj3 = pickleable.Foo("im foo 3") @@ -2197,6 +2205,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.binary_comparisons def test_comparison(self, connection): """test that type coercion occurs on comparison for binary""" + binary_table = self.tables.binary_table expr = binary_table.c.data == "foo" assert isinstance(expr.right.type, LargeBinary) @@ -2419,17 +2428,17 @@ class ArrayTest(fixtures.TestBase): assert isinstance(arrtable.c.strarr[1:3].type, MyArray) -test_table = meta = MyCustomType = MyTypeDec = None +MyCustomType = MyTypeDec = None class ExpressionTest( - fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL + fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL ): __dialect__ = "default" @classmethod - def setup_class(cls): - global test_table, meta, MyCustomType, MyTypeDec + def define_tables(cls, metadata): + global MyCustomType, MyTypeDec class MyCustomType(types.UserDefinedType): def get_col_spec(self): @@ -2463,10 +2472,9 @@ class ExpressionTest( def process_result_value(self, value, dialect): return value + "BIND_OUT" - meta = MetaData(testing.db) - test_table = Table( + Table( "test", - meta, + metadata, Column("id", Integer, primary_key=True), Column("data", String(30)), Column("atimestamp", Date), @@ -2474,25 +2482,22 @@ class ExpressionTest( Column("bvalue", MyTypeDec(50)), ) - meta.create_all() - - with testing.db.connect() as conn: - conn.execute( - test_table.insert(), - { - "id": 1, - "data": "somedata", - "atimestamp": datetime.date(2007, 10, 15), - "avalue": 25, - "bvalue": "foo", - }, - ) - @classmethod - def teardown_class(cls): - meta.drop_all() + def insert_data(cls, connection): + test_table = cls.tables.test + connection.execute( + test_table.insert(), + { + "id": 1, + "data": "somedata", + "atimestamp": datetime.date(2007, 10, 15), + "avalue": 25, + "bvalue": "foo", + }, + ) def test_control(self, connection): + test_table = self.tables.test assert ( connection.exec_driver_sql("select avalue from test").scalar() == 250 @@ -2513,6 +2518,9 @@ class ExpressionTest( def test_bind_adapt(self, connection): # test an untyped bind gets the left side's type + + test_table = self.tables.test + expr = test_table.c.atimestamp == bindparam("thedate") eq_(expr.right.type._type_affinity, Date) @@ -2565,6 +2573,8 @@ class ExpressionTest( ) def test_grouped_bind_adapt(self): + test_table = self.tables.test + expr = test_table.c.atimestamp == elements.Grouping( bindparam("thedate") ) @@ -2579,6 +2589,8 @@ class ExpressionTest( eq_(expr.right.element.element.type._type_affinity, Date) def test_bind_adapt_update(self): + test_table = self.tables.test + bp = bindparam("somevalue") stmt = test_table.update().values(avalue=bp) compiled = stmt.compile() @@ -2586,13 +2598,17 @@ class ExpressionTest( eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType) def test_bind_adapt_insert(self): + test_table = self.tables.test bp = bindparam("somevalue") + stmt = test_table.insert().values(avalue=bp) compiled = stmt.compile() eq_(bp.type._type_affinity, types.NullType) eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType) def test_bind_adapt_expression(self): + test_table = self.tables.test + bp = bindparam("somevalue") stmt = test_table.c.avalue == bp eq_(bp.type._type_affinity, types.NullType) @@ -2629,6 +2645,8 @@ class ExpressionTest( is_(literal(data).type.__class__, expected) def test_typedec_operator_adapt(self, connection): + test_table = self.tables.test + expr = test_table.c.bvalue + "hi" assert expr.type.__class__ is MyTypeDec @@ -2846,6 +2864,8 @@ class ExpressionTest( eq_(expr.type, types.NULLTYPE) def test_distinct(self, connection): + test_table = self.tables.test + s = select(distinct(test_table.c.avalue)) eq_(connection.execute(s).scalar(), 25) @@ -3004,17 +3024,18 @@ class NumericRawSQLTest(fixtures.TestBase): __backend__ = True - def _fixture(self, metadata, type_, data): + def _fixture(self, connection, metadata, type_, data): t = Table("t", metadata, Column("val", type_)) - metadata.create_all() - with testing.db.connect() as conn: - conn.execute(t.insert(), val=data) + metadata.create_all(connection) + connection.execute(t.insert(), val=data) @testing.fails_on("sqlite", "Doesn't provide Decimal results natively") @testing.provide_metadata def test_decimal_fp(self, connection): metadata = self.metadata - self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45.5")) + self._fixture( + connection, metadata, Numeric(10, 5), decimal.Decimal("45.5") + ) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, decimal.Decimal) eq_(val, decimal.Decimal("45.5")) @@ -3023,7 +3044,9 @@ class NumericRawSQLTest(fixtures.TestBase): @testing.provide_metadata def test_decimal_int(self, connection): metadata = self.metadata - self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45")) + self._fixture( + connection, metadata, Numeric(10, 5), decimal.Decimal("45") + ) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, decimal.Decimal) eq_(val, decimal.Decimal("45")) @@ -3031,7 +3054,7 @@ class NumericRawSQLTest(fixtures.TestBase): @testing.provide_metadata def test_ints(self, connection): metadata = self.metadata - self._fixture(metadata, Integer, 45) + self._fixture(connection, metadata, Integer, 45) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, util.int_types) eq_(val, 45) @@ -3039,7 +3062,7 @@ class NumericRawSQLTest(fixtures.TestBase): @testing.provide_metadata def test_float(self, connection): metadata = self.metadata - self._fixture(metadata, Float, 46.583) + self._fixture(connection, metadata, Float, 46.583) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, float) @@ -3050,19 +3073,14 @@ class NumericRawSQLTest(fixtures.TestBase): eq_(val, 46.583) -interval_table = metadata = None - - -class IntervalTest(fixtures.TestBase, AssertsExecutionResults): +class IntervalTest(fixtures.TablesTest, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): - global interval_table, metadata - metadata = MetaData(testing.db) - interval_table = Table( - "intervaltable", + def define_tables(cls, metadata): + Table( + "intervals", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -3074,16 +3092,6 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults): ), Column("non_native_interval", Interval(native=False)), ) - metadata.create_all() - - @engines.close_first - def teardown(self): - with testing.db.connect() as conn: - conn.execute(interval_table.delete()) - - @classmethod - def teardown_class(cls): - metadata.drop_all() def test_non_native_adapt(self): interval = Interval(native=False) @@ -3092,30 +3100,32 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults): assert adapted.native is False eq_(str(adapted), "DATETIME") - def test_roundtrip(self): + def test_roundtrip(self, connection): + interval_table = self.tables.intervals + small_delta = datetime.timedelta(days=15, seconds=5874) delta = datetime.timedelta(14) - with testing.db.begin() as conn: - conn.execute( - interval_table.insert(), - native_interval=small_delta, - native_interval_args=delta, - non_native_interval=delta, - ) - row = conn.execute(interval_table.select()).first() + connection.execute( + interval_table.insert(), + native_interval=small_delta, + native_interval_args=delta, + non_native_interval=delta, + ) + row = connection.execute(interval_table.select()).first() eq_(row.native_interval, small_delta) eq_(row.native_interval_args, delta) eq_(row.non_native_interval, delta) - def test_null(self): - with testing.db.begin() as conn: - conn.execute( - interval_table.insert(), - id=1, - native_inverval=None, - non_native_interval=None, - ) - row = conn.execute(interval_table.select()).first() + def test_null(self, connection): + interval_table = self.tables.intervals + + connection.execute( + interval_table.insert(), + id=1, + native_inverval=None, + non_native_interval=None, + ) + row = connection.execute(interval_table.select()).first() eq_(row.native_interval, None) eq_(row.native_interval_args, None) eq_(row.non_native_interval, None) @@ -3215,25 +3225,24 @@ class BooleanTest( ) @testing.requires.non_native_boolean_unconstrained - def test_nonnative_processor_coerces_integer_to_boolean(self): + def test_nonnative_processor_coerces_integer_to_boolean(self, connection): boolean_table = self.tables.boolean_table - with testing.db.connect() as conn: - conn.exec_driver_sql( - "insert into boolean_table (id, unconstrained_value) " - "values (1, 5)" - ) + connection.exec_driver_sql( + "insert into boolean_table (id, unconstrained_value) " + "values (1, 5)" + ) - eq_( - conn.exec_driver_sql( - "select unconstrained_value from boolean_table" - ).scalar(), - 5, - ) + eq_( + connection.exec_driver_sql( + "select unconstrained_value from boolean_table" + ).scalar(), + 5, + ) - eq_( - conn.scalar(select(boolean_table.c.unconstrained_value)), - True, - ) + eq_( + connection.scalar(select(boolean_table.c.unconstrained_value)), + True, + ) def test_bind_processor_coercion_native_true(self): proc = Boolean().bind_processor( diff --git a/test/sql/test_update.py b/test/sql/test_update.py index ec96af207..946a01651 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -1263,10 +1263,10 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): __backend__ = True @testing.requires.update_from - def test_exec_two_table(self): + def test_exec_two_table(self, connection): users, addresses = self.tables.users, self.tables.addresses - testing.db.execute( + connection.execute( addresses.update() .values(email_address=users.c.name) .where(users.c.id == addresses.c.user_id) @@ -1280,14 +1280,14 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "ed"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) @testing.requires.update_from - def test_exec_two_table_plus_alias(self): + def test_exec_two_table_plus_alias(self, connection): users, addresses = self.tables.users, self.tables.addresses a1 = addresses.alias() - testing.db.execute( + connection.execute( addresses.update() .values(email_address=users.c.name) .where(users.c.id == a1.c.user_id) @@ -1302,15 +1302,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "ed"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) @testing.requires.update_from - def test_exec_three_table(self): + def test_exec_three_table(self, connection): users = self.tables.users addresses = self.tables.addresses dingalings = self.tables.dingalings - testing.db.execute( + connection.execute( addresses.update() .values(email_address=users.c.name) .where(users.c.id == addresses.c.user_id) @@ -1326,15 +1326,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "ed@lala.com"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) @testing.only_on("mysql", "Multi table update") - def test_exec_multitable(self): + def test_exec_multitable(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.email_address: "updated", users.c.name: "ed2"} - testing.db.execute( + connection.execute( addresses.update() .values(values) .where(users.c.id == addresses.c.user_id) @@ -1348,18 +1348,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "updated"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_exec_join_multitable(self): + def test_exec_join_multitable(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.email_address: "updated", users.c.name: "ed2"} - testing.db.execute( + connection.execute( update(users.join(addresses)) .values(values) .where(users.c.name == "ed") @@ -1372,18 +1372,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "updated"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_exec_multitable_same_name(self): + def test_exec_multitable_same_name(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.name: "ad_ed2", users.c.name: "ed2"} - testing.db.execute( + connection.execute( addresses.update() .values(values) .where(users.c.id == addresses.c.user_id) @@ -1397,18 +1397,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "ad_ed2", "ed@lala.com"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) - def _assert_addresses(self, addresses, expected): + def _assert_addresses(self, connection, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) - def _assert_users(self, users, expected): + def _assert_users(self, connection, users, expected): stmt = users.select().order_by(users.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) class UpdateFromMultiTableUpdateDefaultsTest( @@ -1472,12 +1472,12 @@ class UpdateFromMultiTableUpdateDefaultsTest( ) @testing.only_on("mysql", "Multi table update") - def test_defaults_second_table(self): + def test_defaults_second_table(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.email_address: "updated", users.c.name: "ed2"} - ret = testing.db.execute( + ret = connection.execute( addresses.update() .values(values) .where(users.c.id == addresses.c.user_id) @@ -1491,18 +1491,18 @@ class UpdateFromMultiTableUpdateDefaultsTest( (3, 8, "updated"), (4, 9, "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(8, "ed2", "im the update"), (9, "fred", "value")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_defaults_second_table_same_name(self): + def test_defaults_second_table_same_name(self, connection): users, foobar = self.tables.users, self.tables.foobar values = {foobar.c.data: foobar.c.data + "a", users.c.name: "ed2"} - ret = testing.db.execute( + ret = connection.execute( users.update() .values(values) .where(users.c.id == foobar.c.user_id) @@ -1519,16 +1519,16 @@ class UpdateFromMultiTableUpdateDefaultsTest( (3, 8, "d2a", "im the other update"), (4, 9, "d3", None), ] - self._assert_foobar(foobar, expected) + self._assert_foobar(connection, foobar, expected) expected = [(8, "ed2", "im the update"), (9, "fred", "value")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_no_defaults_second_table(self): + def test_no_defaults_second_table(self, connection): users, addresses = self.tables.users, self.tables.addresses - ret = testing.db.execute( + ret = connection.execute( addresses.update() .values({"email_address": users.c.name}) .where(users.c.id == addresses.c.user_id) @@ -1538,20 +1538,20 @@ class UpdateFromMultiTableUpdateDefaultsTest( eq_(ret.prefetch_cols(), []) expected = [(2, 8, "ed"), (3, 8, "ed"), (4, 9, "fred@fred.com")] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) # users table not actually updated, so no onupdate expected = [(8, "ed", "value"), (9, "fred", "value")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) - def _assert_foobar(self, foobar, expected): + def _assert_foobar(self, connection, foobar, expected): stmt = foobar.select().order_by(foobar.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) - def _assert_addresses(self, addresses, expected): + def _assert_addresses(self, connection, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) - def _assert_users(self, users, expected): + def _assert_users(self, connection, users, expected): stmt = users.select().order_by(users.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) |
