diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-21 10:22:43 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-01-03 13:22:29 -0500 |
| commit | fd3c063dd68b289814af724689165418de5e4408 (patch) | |
| tree | 3a13c1cd3bd58b8b5b88bc3294e491aca63ecf0b /lib/sqlalchemy/testing/suite | |
| parent | dd41a5e61a30a2d05ee09f583fdfde1f1c204807 (diff) | |
| download | sqlalchemy-fd3c063dd68b289814af724689165418de5e4408.tar.gz | |
remove metadata.bind use from test suite
importantly this means we can remove bound metadata from
the fixtures that are used by Alembic's test suite.
hopefully this is the last one that has to happen to allow
Alembic to be fully 1.4/2.0.
Start moving from @testing.provide_metadata to a pytest
metadata fixture. This does not seem to have any negative
effects even though TablesTest uses a "self.metadata" attribute.
Change-Id: Iae6ab95938a7e92b6d42086aec534af27b5577d3
Diffstat (limited to 'lib/sqlalchemy/testing/suite')
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 771 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 222 |
2 files changed, 489 insertions, 504 deletions
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index bef3abb59..6c3c1005a 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -207,10 +207,10 @@ class QuotedNameArgumentTest(fixtures.TablesTest): ] for name in names: query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - testing.db.dialect.identifier_preparer.quote( + config.db.dialect.identifier_preparer.quote( "view %s" % name ), - testing.db.dialect.identifier_preparer.quote(name), + config.db.dialect.identifier_preparer.quote(name), ) event.listen(metadata, "after_create", DDL(query)) @@ -219,7 +219,7 @@ class QuotedNameArgumentTest(fixtures.TablesTest): "before_drop", DDL( "DROP VIEW %s" - % testing.db.dialect.identifier_preparer.quote( + % config.db.dialect.identifier_preparer.quote( "view %s" % name ) ), @@ -233,52 +233,52 @@ class QuotedNameArgumentTest(fixtures.TablesTest): @quote_fixtures def test_get_table_options(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) insp.get_table_options(name) @quote_fixtures @testing.requires.view_column_reflection def test_get_view_definition(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_view_definition("view %s" % name) @quote_fixtures def test_get_columns(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_columns(name) @quote_fixtures def test_get_pk_constraint(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_pk_constraint(name) @quote_fixtures def test_get_foreign_keys(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_foreign_keys(name) @quote_fixtures def test_get_indexes(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_indexes(name) @quote_fixtures @testing.requires.unique_constraint_reflection def test_get_unique_constraints(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_unique_constraints(name) @quote_fixtures @testing.requires.comment_reflection def test_get_table_comment(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_table_comment(name) @quote_fixtures @testing.requires.check_constraint_reflection def test_get_check_constraints(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_check_constraints(name) @@ -508,7 +508,7 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_get_schema_names(self): - insp = inspect(testing.db) + insp = inspect(self.bind) self.assert_(testing.config.test_schema in insp.get_schema_names()) @@ -520,13 +520,28 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_get_default_schema_name(self): - insp = inspect(testing.db) - eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) - - @testing.provide_metadata - def _test_get_table_names( - self, schema=None, table_type="table", order_by=None + insp = inspect(self.bind) + eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + + @testing.combinations( + (None, True, False, False), + (None, True, False, True, testing.requires.schemas), + ("foreign_key", True, False, False), + (None, False, True, False), + (None, False, True, True, testing.requires.schemas), + (None, True, True, False), + (None, True, True, True, testing.requires.schemas), + argnames="order_by,include_plain,include_views,use_schema", + ) + def test_get_table_names( + self, connection, order_by, include_plain, include_views, use_schema ): + + if use_schema: + schema = config.test_schema + else: + schema = None + _ignore_tables = [ "comment_test", "noncol_idx_test_pk", @@ -535,16 +550,16 @@ class ComponentReflectionTest(fixtures.TablesTest): "remote_table", "remote_table_2", ] - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(connection) - if table_type == "view": + if include_views: table_names = insp.get_view_names(schema) table_names.sort() answer = ["email_addresses_v", "users_v"] eq_(sorted(table_names), answer) - else: + + if include_plain: if order_by: tables = [ rec[0] @@ -576,15 +591,6 @@ class ComponentReflectionTest(fixtures.TablesTest): temp_table_names = insp.get_temp_view_names() eq_(sorted(temp_table_names), ["user_tmp_v"]) - @testing.requires.table_reflection - def test_get_table_names(self): - self._test_get_table_names() - - @testing.requires.table_reflection - @testing.requires.foreign_key_constraint_reflection - def test_get_table_names_fks(self): - self._test_get_table_names(order_by="foreign_key") - @testing.requires.comment_reflection def test_get_comments(self): self._test_get_comments() @@ -595,7 +601,7 @@ class ComponentReflectionTest(fixtures.TablesTest): self._test_get_comments(testing.config.test_schema) def _test_get_comments(self, schema=None): - insp = inspect(testing.db) + insp = inspect(self.bind) eq_( insp.get_table_comment("comment_test", schema=schema), @@ -621,35 +627,27 @@ class ComponentReflectionTest(fixtures.TablesTest): ], ) - @testing.requires.table_reflection - @testing.requires.schemas - def test_get_table_names_with_schema(self): - self._test_get_table_names(testing.config.test_schema) - - @testing.requires.view_column_reflection - def test_get_view_names(self): - self._test_get_table_names(table_type="view") - - @testing.requires.view_column_reflection - @testing.requires.schemas - def test_get_view_names_with_schema(self): - self._test_get_table_names( - testing.config.test_schema, table_type="view" - ) - - @testing.requires.table_reflection - @testing.requires.view_column_reflection - def test_get_tables_and_views(self): - self._test_get_table_names() - self._test_get_table_names(table_type="view") + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False), + (False, True, testing.requires.schemas), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None - def _test_get_columns(self, schema=None, table_type="table"): - meta = MetaData(testing.db) users, addresses = (self.tables.users, self.tables.email_addresses) - table_names = ["users", "email_addresses"] - if table_type == "view": + if use_views: table_names = ["users_v", "email_addresses_v"] - insp = inspect(meta.bind) + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) @@ -699,67 +697,13 @@ class ComponentReflectionTest(fixtures.TablesTest): if not col.primary_key: assert cols[i]["default"] is None - @testing.requires.table_reflection - def test_get_columns(self): - self._test_get_columns() - - @testing.provide_metadata - def _type_round_trip(self, *types): - t = Table( - "t", - self.metadata, - *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] - ) - t.create() - - return [ - c["type"] for c in inspect(self.metadata.bind).get_columns("t") - ] - - @testing.requires.table_reflection - def test_numeric_reflection(self): - for typ in self._type_round_trip(sql_types.Numeric(18, 5)): - assert isinstance(typ, sql_types.Numeric) - eq_(typ.precision, 18) - eq_(typ.scale, 5) - - @testing.requires.table_reflection - def test_varchar_reflection(self): - typ = self._type_round_trip(sql_types.String(52))[0] - assert isinstance(typ, sql_types.String) - eq_(typ.length, 52) - - @testing.requires.table_reflection - @testing.provide_metadata - def test_nullable_reflection(self): - t = Table( - "t", - self.metadata, - Column("a", Integer, nullable=True), - Column("b", Integer, nullable=False), - ) - t.create() - eq_( - dict( - (col["name"], col["nullable"]) - for col in inspect(self.metadata.bind).get_columns("t") - ), - {"a": True, "b": False}, - ) - - @testing.requires.table_reflection - @testing.requires.schemas - def test_get_columns_with_schema(self): - self._test_get_columns(schema=testing.config.test_schema) - @testing.requires.temp_table_reflection def test_get_temp_table_columns(self): table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident + config, self.bind, "user_tmp_%s" % config.ident ) - meta = MetaData(self.bind) user_tmp = self.tables[table_name] - insp = inspect(meta.bind) + insp = inspect(self.bind) cols = insp.get_columns(table_name) self.assert_(len(cols) > 0, len(cols)) @@ -774,22 +718,18 @@ class ComponentReflectionTest(fixtures.TablesTest): cols = insp.get_columns("user_tmp_v") eq_([col["name"] for col in cols], ["id", "name", "foo"]) - @testing.requires.view_column_reflection - def test_get_view_columns(self): - self._test_get_columns(table_type="view") - - @testing.requires.view_column_reflection - @testing.requires.schemas - def test_get_view_columns_with_schema(self): - self._test_get_columns( - schema=testing.config.test_schema, table_type="view" - ) + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self, connection, use_schema): + if use_schema: + schema = testing.config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_pk_constraint(self, schema=None): - meta = self.metadata users, addresses = self.tables.users, self.tables.email_addresses - insp = inspect(meta.bind) + insp = inspect(connection) users_cons = insp.get_pk_constraint(users.name, schema=schema) users_pkeys = users_cons["constrained_columns"] @@ -802,21 +742,18 @@ class ComponentReflectionTest(fixtures.TablesTest): with testing.requires.reflects_pk_names.fail_if(): eq_(addr_cons["name"], "email_ad_pk") - @testing.requires.primary_key_constraint_reflection - def test_get_pk_constraint(self): - self._test_get_pk_constraint() - - @testing.requires.table_reflection - @testing.requires.primary_key_constraint_reflection - @testing.requires.schemas - def test_get_pk_constraint_with_schema(self): - self._test_get_pk_constraint(schema=testing.config.test_schema) + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_foreign_keys(self, schema=None): - meta = self.metadata users, addresses = (self.tables.users, self.tables.email_addresses) - insp = inspect(meta.bind) + insp = inspect(connection) expected_schema = schema # users @@ -845,25 +782,16 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(fkey1["referred_columns"], ["user_id"]) eq_(fkey1["constrained_columns"], ["remote_user_id"]) - @testing.requires.foreign_key_constraint_reflection - def test_get_foreign_keys(self): - self._test_get_foreign_keys() - - @testing.requires.foreign_key_constraint_reflection - @testing.requires.schemas - def test_get_foreign_keys_with_schema(self): - self._test_get_foreign_keys(schema=testing.config.test_schema) - @testing.requires.cross_schema_fk_reflection @testing.requires.schemas def test_get_inter_schema_foreign_keys(self): local_table, remote_table, remote_table_2 = self.tables( - "%s.local_table" % testing.db.dialect.default_schema_name, + "%s.local_table" % self.bind.dialect.default_schema_name, "%s.remote_table" % testing.config.test_schema, "%s.remote_table_2" % testing.config.test_schema, ) - insp = inspect(config.db) + insp = inspect(self.bind) local_fkeys = insp.get_foreign_keys(local_table.name) eq_(len(local_fkeys), 1) @@ -883,85 +811,12 @@ class ComponentReflectionTest(fixtures.TablesTest): assert fkey2["referred_schema"] in ( None, - testing.db.dialect.default_schema_name, + self.bind.dialect.default_schema_name, ) eq_(fkey2["referred_table"], local_table.name) eq_(fkey2["referred_columns"], ["id"]) eq_(fkey2["constrained_columns"], ["local_id"]) - @testing.requires.foreign_key_constraint_option_reflection_ondelete - def test_get_foreign_key_options_ondelete(self): - self._test_get_foreign_key_options(ondelete="CASCADE") - - @testing.requires.foreign_key_constraint_option_reflection_onupdate - def test_get_foreign_key_options_onupdate(self): - self._test_get_foreign_key_options(onupdate="SET NULL") - - @testing.requires.foreign_key_constraint_option_reflection_onupdate - def test_get_foreign_key_options_onupdate_noaction(self): - self._test_get_foreign_key_options(onupdate="NO ACTION", expected={}) - - @testing.requires.fk_constraint_option_reflection_ondelete_noaction - def test_get_foreign_key_options_ondelete_noaction(self): - self._test_get_foreign_key_options(ondelete="NO ACTION", expected={}) - - @testing.requires.fk_constraint_option_reflection_onupdate_restrict - def test_get_foreign_key_options_onupdate_restrict(self): - self._test_get_foreign_key_options(onupdate="RESTRICT") - - @testing.requires.fk_constraint_option_reflection_ondelete_restrict - def test_get_foreign_key_options_ondelete_restrict(self): - self._test_get_foreign_key_options(ondelete="RESTRICT") - - @testing.provide_metadata - def _test_get_foreign_key_options(self, expected=None, **options): - meta = self.metadata - - if expected is None: - expected = options - - Table( - "x", - meta, - Column("id", Integer, primary_key=True), - test_needs_fk=True, - ) - - Table( - "table", - meta, - Column("id", Integer, primary_key=True), - Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), - Column("test", String(10)), - test_needs_fk=True, - ) - - Table( - "user", - meta, - Column("id", Integer, primary_key=True), - Column("name", String(50), nullable=False), - Column("tid", Integer), - sa.ForeignKeyConstraint( - ["tid"], ["table.id"], name="myfk", **options - ), - test_needs_fk=True, - ) - - meta.create_all() - - insp = inspect(meta.bind) - - # test 'options' is always present for a backend - # that can reflect these, since alembic looks for this - opts = insp.get_foreign_keys("table")[0]["options"] - - eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) - - opts = insp.get_foreign_keys("user")[0]["options"] - eq_(opts, expected) - # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) - def _assert_insp_indexes(self, indexes, expected_indexes): index_names = [d["name"] for d in indexes] for e_index in expected_indexes: @@ -970,13 +825,19 @@ class ComponentReflectionTest(fixtures.TablesTest): for key in e_index: eq_(e_index[key], index[key]) - @testing.provide_metadata - def _test_get_indexes(self, schema=None): - meta = self.metadata + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_indexes(self, connection, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. - insp = inspect(meta.bind) + insp = inspect(self.bind) indexes = insp.get_indexes("users", schema=schema) expected_indexes = [ { @@ -992,19 +853,15 @@ class ComponentReflectionTest(fixtures.TablesTest): ] self._assert_insp_indexes(indexes, expected_indexes) + @testing.combinations( + ("noncol_idx_test_nopk", "noncol_idx_nopk"), + ("noncol_idx_test_pk", "noncol_idx_pk"), + argnames="tname,ixname", + ) @testing.requires.index_reflection - def test_get_indexes(self): - self._test_get_indexes() - - @testing.requires.index_reflection - @testing.requires.schemas - def test_get_indexes_with_schema(self): - self._test_get_indexes(schema=testing.config.test_schema) - - @testing.provide_metadata - def _test_get_noncol_index(self, tname, ixname): - meta = self.metadata - insp = inspect(meta.bind) + @testing.requires.indexes_with_ascdesc + def test_get_noncol_index(self, connection, tname, ixname): + insp = inspect(connection) indexes = insp.get_indexes(tname) # reflecting an index that has "x DESC" in it as the column. @@ -1013,85 +870,11 @@ class ComponentReflectionTest(fixtures.TablesTest): expected_indexes = [{"unique": False, "name": ixname}] self._assert_insp_indexes(indexes, expected_indexes) - t = Table(tname, meta, autoload_with=meta.bind) + t = Table(tname, MetaData(), autoload_with=connection) eq_(len(t.indexes), 1) is_(list(t.indexes)[0].table, t) eq_(list(t.indexes)[0].name, ixname) - @testing.requires.index_reflection - @testing.requires.indexes_with_ascdesc - def test_get_noncol_index_no_pk(self): - self._test_get_noncol_index("noncol_idx_test_nopk", "noncol_idx_nopk") - - @testing.requires.index_reflection - @testing.requires.indexes_with_ascdesc - def test_get_noncol_index_pk(self): - self._test_get_noncol_index("noncol_idx_test_pk", "noncol_idx_pk") - - @testing.requires.indexes_with_expressions - @testing.provide_metadata - def test_reflect_expression_based_indexes(self): - t = Table( - "t", - self.metadata, - Column("x", String(30)), - Column("y", String(30)), - ) - - Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) - - Index("t_idx_2", t.c.x) - - self.metadata.create_all(testing.db) - - insp = inspect(testing.db) - - expected = [ - {"name": "t_idx_2", "column_names": ["x"], "unique": False} - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - - with expect_warnings( - "Skipped unsupported reflection of expression-based index t_idx" - ): - eq_( - insp.get_indexes("t"), - expected, - ) - - @testing.requires.index_reflects_included_columns - @testing.provide_metadata - def test_reflect_covering_index(self): - t = Table( - "t", - self.metadata, - Column("x", String(30)), - Column("y", String(30)), - ) - idx = Index("t_idx", t.c.x) - idx.dialect_options[testing.db.name]["include"] = ["y"] - - self.metadata.create_all(testing.db) - - insp = inspect(testing.db) - - eq_( - insp.get_indexes("t"), - [ - { - "name": "t_idx", - "column_names": ["x"], - "include_columns": ["y"], - "unique": False, - } - ], - ) - - @testing.requires.unique_constraint_reflection - def test_get_unique_constraints(self): - self._test_get_unique_constraints() - @testing.requires.temp_table_reflection @testing.requires.unique_constraint_reflection def test_get_temp_table_unique_constraints(self): @@ -1130,19 +913,22 @@ class ComponentReflectionTest(fixtures.TablesTest): expected, ) + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) @testing.requires.unique_constraint_reflection - @testing.requires.schemas - def test_get_unique_constraints_with_schema(self): - self._test_get_unique_constraints(schema=testing.config.test_schema) - - @testing.provide_metadata - def _test_get_unique_constraints(self, schema=None): + def test_get_unique_constraints(self, metadata, connection, use_schema): # SQLite dialect needs to parse the names of the constraints # separately from what it gets from PRAGMA index_list(), and # then matches them up. so same set of column_names in two # constraints will confuse it. Perhaps we should no longer # bother with index_list() here since we have the whole # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None uniques = sorted( [ {"name": "unique_a", "column_names": ["a"]}, @@ -1154,10 +940,9 @@ class ComponentReflectionTest(fixtures.TablesTest): ], key=operator.itemgetter("name"), ) - orig_meta = self.metadata table = Table( "testtbl", - orig_meta, + metadata, Column("a", sa.String(20)), Column("b", sa.String(30)), Column("c", sa.Integer), @@ -1170,9 +955,9 @@ class ComponentReflectionTest(fixtures.TablesTest): table.append_constraint( sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) ) - orig_meta.create_all() + table.create(connection) - inspector = inspect(orig_meta.bind) + inspector = inspect(connection) reflected = sorted( inspector.get_unique_constraints("testtbl", schema=schema), key=operator.itemgetter("name"), @@ -1192,7 +977,7 @@ class ComponentReflectionTest(fixtures.TablesTest): reflected = Table( "testtbl", reflected_metadata, - autoload_with=orig_meta.bind, + autoload_with=connection, schema=schema, ) @@ -1214,30 +999,90 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(names_that_duplicate_index, idx_names) eq_(uq_names, set()) - @testing.requires.check_constraint_reflection - def test_get_check_constraints(self): - self._test_get_check_constraints() + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + view_name1 = "users_v" + view_name2 = "email_addresses_v" + insp = inspect(connection) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + + # why is this here if it's PG specific ? + @testing.combinations( + ("users", False), + ("users", True, testing.requires.schemas), + argnames="table_name,use_schema", + ) + @testing.only_on("postgresql", "PG specific feature") + def test_get_table_oid(self, connection, table_name, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + insp = inspect(connection) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, int)) + + @testing.requires.table_reflection + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + insp = inspect(self.bind) + + for tname, cname in [ + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), + ]: + cols = insp.get_columns(tname) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) + + +class ComponentReflectionTestExtra(fixtures.TestBase): + __backend__ = True + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) @testing.requires.check_constraint_reflection - @testing.requires.schemas - def test_get_check_constraints_schema(self): - self._test_get_check_constraints(schema=testing.config.test_schema) + def test_get_check_constraints(self, metadata, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_check_constraints(self, schema=None): - orig_meta = self.metadata Table( "sa_cc", - orig_meta, + metadata, Column("a", Integer()), sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"), schema=schema, ) - orig_meta.create_all() + metadata.create_all(connection) - inspector = inspect(orig_meta.bind) + inspector = inspect(connection) reflected = sorted( inspector.get_check_constraints("sa_cc", schema=schema), key=operator.itemgetter("name"), @@ -1263,67 +1108,200 @@ class ComponentReflectionTest(fixtures.TablesTest): ], ) - @testing.provide_metadata - def _test_get_view_definition(self, schema=None): - meta = self.metadata - view_name1 = "users_v" - view_name2 = "email_addresses_v" - insp = inspect(meta.bind) - v1 = insp.get_view_definition(view_name1, schema=schema) - self.assert_(v1) - v2 = insp.get_view_definition(view_name2, schema=schema) - self.assert_(v2) + @testing.requires.indexes_with_expressions + def test_reflect_expression_based_indexes(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) - @testing.requires.view_reflection - def test_get_view_definition(self): - self._test_get_view_definition() + Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) - @testing.requires.view_reflection - @testing.requires.schemas - def test_get_view_definition_with_schema(self): - self._test_get_view_definition(schema=testing.config.test_schema) + Index("t_idx_2", t.c.x) - @testing.only_on("postgresql", "PG specific feature") - @testing.provide_metadata - def _test_get_table_oid(self, table_name, schema=None): - meta = self.metadata - insp = inspect(meta.bind) - oid = insp.get_table_oid(table_name, schema) - self.assert_(isinstance(oid, int)) + metadata.create_all(connection) - def test_get_table_oid(self): - self._test_get_table_oid("users") + insp = inspect(connection) - @testing.requires.schemas - def test_get_table_oid_with_schema(self): - self._test_get_table_oid("users", schema=testing.config.test_schema) + expected = [ + {"name": "t_idx_2", "column_names": ["x"], "unique": False} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + + with expect_warnings( + "Skipped unsupported reflection of expression-based index t_idx" + ): + eq_( + insp.get_indexes("t"), + expected, + ) + + @testing.requires.index_reflects_included_columns + def test_reflect_covering_index(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + idx = Index("t_idx", t.c.x) + idx.dialect_options[connection.engine.name]["include"] = ["y"] + + metadata.create_all(connection) + + insp = inspect(connection) + + eq_( + insp.get_indexes("t"), + [ + { + "name": "t_idx", + "column_names": ["x"], + "include_columns": ["y"], + "unique": False, + } + ], + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", + metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) + t.create(connection) + + return [c["type"] for c in inspect(connection).get_columns("t")] @testing.requires.table_reflection - @testing.provide_metadata - def test_autoincrement_col(self): - """test that 'autoincrement' is reflected according to sqla's policy. + def test_numeric_reflection(self, connection, metadata): + for typ in self._type_round_trip( + connection, metadata, sql_types.Numeric(18, 5) + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) - Don't mark this test as unsupported for any backend ! + @testing.requires.table_reflection + def test_varchar_reflection(self, connection, metadata): + typ = self._type_round_trip( + connection, metadata, sql_types.String(52) + )[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) - (technically it fails with MySQL InnoDB since "id" comes before "id2") + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + eq_( + dict( + (col["name"], col["nullable"]) + for col in inspect(connection).get_columns("t") + ), + {"a": True, "b": False}, + ) - A backend is better off not returning "autoincrement" at all, - instead of potentially returning "False" for an auto-incrementing - primary key column. + @testing.combinations( + ( + None, + "CASCADE", + None, + testing.requires.foreign_key_constraint_option_reflection_ondelete, + ), + ( + None, + None, + "SET NULL", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + None, + "NO ACTION", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + "NO ACTION", + None, + testing.requires.fk_constraint_option_reflection_ondelete_noaction, + ), + ( + None, + None, + "RESTRICT", + testing.requires.fk_constraint_option_reflection_onupdate_restrict, + ), + ( + None, + "RESTRICT", + None, + testing.requires.fk_constraint_option_reflection_ondelete_restrict, + ), + argnames="expected,ondelete,onupdate", + ) + def test_get_foreign_key_options( + self, connection, metadata, expected, ondelete, onupdate + ): + options = {} + if ondelete: + options["ondelete"] = ondelete + if onupdate: + options["onupdate"] = onupdate - """ + if expected is None: + expected = options - meta = self.metadata - insp = inspect(meta.bind) + Table( + "x", + metadata, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) - for tname, cname in [ - ("users", "user_id"), - ("email_addresses", "address_id"), - ("dingalings", "dingaling_id"), - ]: - cols = insp.get_columns(tname) - id_ = {c["name"]: c for c in cols}[cname] - assert id_.get("autoincrement", True) + Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + + # test 'options' is always present for a backend + # that can reflect these, since alembic looks for this + opts = insp.get_foreign_keys("table")[0]["options"] + + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(opts, expected) + # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) class NormalizedNameTest(fixtures.TablesTest): @@ -1348,21 +1326,21 @@ class NormalizedNameTest(fixtures.TablesTest): m2 = MetaData() t2_ref = Table( - quoted_name("t2", quote=True), m2, autoload_with=testing.db + quoted_name("t2", quote=True), m2, autoload_with=config.db ) t1_ref = m2.tables["t1"] assert t2_ref.c.t1id.references(t1_ref.c.id) m3 = MetaData() m3.reflect( - testing.db, only=lambda name, m: name.lower() in ("t1", "t2") + config.db, only=lambda name, m: name.lower() in ("t1", "t2") ) assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) def test_get_table_names(self): tablenames = [ t - for t in inspect(testing.db).get_table_names() + for t in inspect(config.db).get_table_names() if t.lower() in ("t1", "t2") ] @@ -1637,20 +1615,16 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): ) @testing.requires.primary_key_constraint_reflection - @testing.provide_metadata def test_pk_column_order(self): # test for issue #5661 - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(self.bind) primary_key = insp.get_pk_constraint(self.tables.tb1.name) eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) @testing.requires.foreign_key_constraint_reflection - @testing.provide_metadata def test_fk_column_order(self): # test for issue #5661 - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(self.bind) foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) eq_(len(foreign_keys), 1) fkey1 = foreign_keys[0] @@ -1660,6 +1634,7 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): __all__ = ( "ComponentReflectionTest", + "ComponentReflectionTestExtra", "QuotedNameArgumentTest", "HasTableTest", "HasIndexTest", diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 43777239c..3a5e02c32 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -47,18 +47,19 @@ from ...util import u class _LiteralRoundTripFixture(object): supports_whereclause = True - @testing.provide_metadata - def _literal_round_trip(self, type_, input_, output, filter_=None): + @testing.fixture + def literal_round_trip(self, metadata, connection): """test literal rendering """ # for literal, we test the literal render in an INSERT # into a typed column. we can then SELECT it back as its # official type; ideally we'd be able to use CAST here # but MySQL in particular can't CAST fully - t = Table("t", self.metadata, Column("x", type_)) - t.create() - with testing.db.begin() as conn: + def run(type_, input_, output, filter_=None): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + for value in input_: ins = ( t.insert() @@ -68,7 +69,7 @@ class _LiteralRoundTripFixture(object): compile_kwargs=dict(literal_binds=True), ) ) - conn.execute(ins) + connection.execute(ins) if self.supports_whereclause: stmt = t.select().where(t.c.x == literal(value)) @@ -79,12 +80,14 @@ class _LiteralRoundTripFixture(object): dialect=testing.db.dialect, compile_kwargs=dict(literal_binds=True), ) - for row in conn.execute(stmt): + for row in connection.execute(stmt): value = row[0] if filter_ is not None: value = filter_(value) assert value in output + return run + class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): __requires__ = ("unicode_data",) @@ -149,11 +152,11 @@ class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): row = connection.execute(select(unicode_table.c.unicode_data)).first() eq_(row, (u(""),)) - def test_literal(self): - self._literal_round_trip(self.datatype, [self.data], [self.data]) + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) @@ -227,25 +230,25 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): row = connection.execute(select(text_table.c.text_data)).first() eq_(row, (None,)) - def test_literal(self): - self._literal_round_trip(Text, ["some text"], ["some text"]) + def test_literal(self, literal_round_trip): + literal_round_trip(Text, ["some text"], ["some text"]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) - def test_literal_quoting(self): + def test_literal_quoting(self, literal_round_trip): data = """some 'text' hey "hi there" that's text""" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) - def test_literal_backslashes(self): + def test_literal_backslashes(self, literal_round_trip): data = r"backslash one \ backslash two \\ end" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) - def test_literal_percentsigns(self): + def test_literal_percentsigns(self, literal_round_trip): data = r"percent % signs %% percent" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): @@ -259,23 +262,23 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): foo.create(config.db) foo.drop(config.db) - def test_literal(self): + def test_literal(self, literal_round_trip): # note that in Python 3, this invokes the Unicode # datatype for the literal part because all strings are unicode - self._literal_round_trip(String(40), ["some text"], ["some text"]) + literal_round_trip(String(40), ["some text"], ["some text"]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) - def test_literal_quoting(self): + def test_literal_quoting(self, literal_round_trip): data = """some 'text' hey "hi there" that's text""" - self._literal_round_trip(String(40), [data], [data]) + literal_round_trip(String(40), [data], [data]) - def test_literal_backslashes(self): + def test_literal_backslashes(self, literal_round_trip): data = r"backslash one \ backslash two \\ end" - self._literal_round_trip(String(40), [data], [data]) + literal_round_trip(String(40), [data], [data]) class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): @@ -331,9 +334,9 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): eq_(row, (None,)) @testing.requires.datetime_literals - def test_literal(self): + def test_literal(self, literal_round_trip): compare = self.compare or self.data - self._literal_round_trip(self.datatype, [self.data], [compare]) + literal_round_trip(self.datatype, [self.data], [compare]) @testing.requires.standalone_null_binds_whereclause def test_null_bound_comparison(self): @@ -430,36 +433,41 @@ class DateHistoricTest(_DateFixture, fixtures.TablesTest): class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True - def test_literal(self): - self._literal_round_trip(Integer, [5], [5]) + def test_literal(self, literal_round_trip): + literal_round_trip(Integer, [5], [5]) - def test_huge_int(self, connection): - self._round_trip(BigInteger, 1376537018368127, connection) + def test_huge_int(self, integer_round_trip): + integer_round_trip(BigInteger, 1376537018368127) - @testing.provide_metadata - def _round_trip(self, datatype, data, connection): - metadata = self.metadata - int_table = Table( - "integer_table", - metadata, - Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ), - Column("integer_data", datatype), - ) + @testing.fixture + def integer_round_trip(self, metadata, connection): + def run(datatype, data): + int_table = Table( + "integer_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("integer_data", datatype), + ) - metadata.create_all(config.db) + metadata.create_all(config.db) - connection.execute(int_table.insert(), {"integer_data": data}) + connection.execute(int_table.insert(), {"integer_data": data}) - row = connection.execute(select(int_table.c.integer_data)).first() + row = connection.execute(select(int_table.c.integer_data)).first() - eq_(row, (data,)) + eq_(row, (data,)) - if util.py3k: - assert isinstance(row[0], int) - else: - assert isinstance(row[0], (long, int)) # noqa + if util.py3k: + assert isinstance(row[0], int) + else: + assert isinstance(row[0], (long, int)) # noqa + + return run class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): @@ -481,12 +489,10 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): return StringAsInt() - @testing.provide_metadata - def test_special_type(self, connection, string_as_int): + def test_special_type(self, metadata, connection, string_as_int): type_ = string_as_int - metadata = self.metadata t = Table("t", metadata, Column("x", type_)) t.create(connection) @@ -504,42 +510,46 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True - @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - @testing.provide_metadata - def _do_test(self, type_, input_, output, filter_=None, check_scale=False): - metadata = self.metadata - t = Table("t", metadata, Column("x", type_)) - t.create() - with config.db.begin() as conn: - conn.execute(t.insert(), [{"x": x} for x in input_]) - - result = {row[0] for row in conn.execute(t.select())} - output = set(output) - if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) - eq_(result, output) - if check_scale: - eq_([str(x) for x in result], [str(x) for x in output]) + @testing.fixture + def do_numeric_test(self, metadata): + @testing.emits_warning( + r".*does \*not\* support Decimal objects natively" + ) + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(testing.db) + with config.db.begin() as conn: + conn.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in conn.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + return run @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric(self): - self._literal_round_trip( + def test_render_literal_numeric(self, literal_round_trip): + literal_round_trip( Numeric(precision=8, scale=4), [15.7563, decimal.Decimal("15.7563")], [decimal.Decimal("15.7563")], ) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric_asfloat(self): - self._literal_round_trip( + def test_render_literal_numeric_asfloat(self, literal_round_trip): + literal_round_trip( Numeric(precision=8, scale=4, asdecimal=False), [15.7563, decimal.Decimal("15.7563")], [15.7563], ) - def test_render_literal_float(self): - self._literal_round_trip( + def test_render_literal_float(self, literal_round_trip): + literal_round_trip( Float(4), [15.7563, decimal.Decimal("15.7563")], [15.7563], @@ -547,49 +557,49 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): ) @testing.requires.precision_generic_float_type - def test_float_custom_scale(self): - self._do_test( + def test_float_custom_scale(self, do_numeric_test): + do_numeric_test( Float(None, decimal_return_scale=7, asdecimal=True), [15.7563827, decimal.Decimal("15.7563827")], [decimal.Decimal("15.7563827")], check_scale=True, ) - def test_numeric_as_decimal(self): - self._do_test( + def test_numeric_as_decimal(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4), [15.7563, decimal.Decimal("15.7563")], [decimal.Decimal("15.7563")], ) - def test_numeric_as_float(self): - self._do_test( + def test_numeric_as_float(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4, asdecimal=False), [15.7563, decimal.Decimal("15.7563")], [15.7563], ) @testing.requires.fetch_null_from_numeric - def test_numeric_null_as_decimal(self): - self._do_test(Numeric(precision=8, scale=4), [None], [None]) + def test_numeric_null_as_decimal(self, do_numeric_test): + do_numeric_test(Numeric(precision=8, scale=4), [None], [None]) @testing.requires.fetch_null_from_numeric - def test_numeric_null_as_float(self): - self._do_test( + def test_numeric_null_as_float(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4, asdecimal=False), [None], [None] ) @testing.requires.floats_to_four_decimals - def test_float_as_decimal(self): - self._do_test( + def test_float_as_decimal(self, do_numeric_test): + do_numeric_test( Float(precision=8, asdecimal=True), [15.7563, decimal.Decimal("15.7563"), None], [decimal.Decimal("15.7563"), None], filter_=lambda n: n is not None and round(n, 4) or None, ) - def test_float_as_float(self): - self._do_test( + def test_float_as_float(self, do_numeric_test): + do_numeric_test( Float(precision=8), [15.7563, decimal.Decimal("15.7563")], [15.7563], @@ -621,7 +631,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): eq_(val, expr) @testing.requires.precision_numerics_general - def test_precision_decimal(self): + def test_precision_decimal(self, do_numeric_test): numbers = set( [ decimal.Decimal("54.234246451650"), @@ -630,10 +640,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): ] ) - self._do_test(Numeric(precision=18, scale=12), numbers, numbers) + do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) @testing.requires.precision_numerics_enotation_large - def test_enotation_decimal(self): + def test_enotation_decimal(self, do_numeric_test): """test exceedingly small decimals. Decimal reports values with E notation when the exponent @@ -657,10 +667,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("696E-12"), ] ) - self._do_test(Numeric(precision=18, scale=14), numbers, numbers) + do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large - def test_enotation_decimal_large(self): + def test_enotation_decimal_large(self, do_numeric_test): """test exceedingly large decimals.""" numbers = set( @@ -671,10 +681,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("00000000000000.1E+12"), ] ) - self._do_test(Numeric(precision=25, scale=2), numbers, numbers) + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits - def test_many_significant_digits(self): + def test_many_significant_digits(self, do_numeric_test): numbers = set( [ decimal.Decimal("31943874831932418390.01"), @@ -682,12 +692,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("87673.594069654243"), ] ) - self._do_test(Numeric(precision=38, scale=12), numbers, numbers) + do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits - def test_numeric_no_decimal(self): + def test_numeric_no_decimal(self, do_numeric_test): numbers = set([decimal.Decimal("1.000")]) - self._do_test( + do_numeric_test( Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -705,8 +715,8 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): Column("unconstrained_value", Boolean(create_constraint=False)), ) - def test_render_literal_bool(self): - self._literal_round_trip(Boolean(), [True, False], [True, False]) + def test_render_literal_bool(self, literal_round_trip): + literal_round_trip(Boolean(), [True, False], [True, False]) def test_round_trip(self, connection): boolean_table = self.tables.boolean_table |
