diff options
Diffstat (limited to 'lib/sqlalchemy/testing/suite')
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 771 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 222 |
2 files changed, 489 insertions, 504 deletions
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index bef3abb59..6c3c1005a 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -207,10 +207,10 @@ class QuotedNameArgumentTest(fixtures.TablesTest): ] for name in names: query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - testing.db.dialect.identifier_preparer.quote( + config.db.dialect.identifier_preparer.quote( "view %s" % name ), - testing.db.dialect.identifier_preparer.quote(name), + config.db.dialect.identifier_preparer.quote(name), ) event.listen(metadata, "after_create", DDL(query)) @@ -219,7 +219,7 @@ class QuotedNameArgumentTest(fixtures.TablesTest): "before_drop", DDL( "DROP VIEW %s" - % testing.db.dialect.identifier_preparer.quote( + % config.db.dialect.identifier_preparer.quote( "view %s" % name ) ), @@ -233,52 +233,52 @@ class QuotedNameArgumentTest(fixtures.TablesTest): @quote_fixtures def test_get_table_options(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) insp.get_table_options(name) @quote_fixtures @testing.requires.view_column_reflection def test_get_view_definition(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_view_definition("view %s" % name) @quote_fixtures def test_get_columns(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_columns(name) @quote_fixtures def test_get_pk_constraint(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_pk_constraint(name) @quote_fixtures def test_get_foreign_keys(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_foreign_keys(name) @quote_fixtures def test_get_indexes(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_indexes(name) @quote_fixtures @testing.requires.unique_constraint_reflection def test_get_unique_constraints(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_unique_constraints(name) @quote_fixtures @testing.requires.comment_reflection def test_get_table_comment(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_table_comment(name) @quote_fixtures @testing.requires.check_constraint_reflection def test_get_check_constraints(self, name): - insp = inspect(testing.db) + insp = inspect(config.db) assert insp.get_check_constraints(name) @@ -508,7 +508,7 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_get_schema_names(self): - insp = inspect(testing.db) + insp = inspect(self.bind) self.assert_(testing.config.test_schema in insp.get_schema_names()) @@ -520,13 +520,28 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_get_default_schema_name(self): - insp = inspect(testing.db) - eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) - - @testing.provide_metadata - def _test_get_table_names( - self, schema=None, table_type="table", order_by=None + insp = inspect(self.bind) + eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + + @testing.combinations( + (None, True, False, False), + (None, True, False, True, testing.requires.schemas), + ("foreign_key", True, False, False), + (None, False, True, False), + (None, False, True, True, testing.requires.schemas), + (None, True, True, False), + (None, True, True, True, testing.requires.schemas), + argnames="order_by,include_plain,include_views,use_schema", + ) + def test_get_table_names( + self, connection, order_by, include_plain, include_views, use_schema ): + + if use_schema: + schema = config.test_schema + else: + schema = None + _ignore_tables = [ "comment_test", "noncol_idx_test_pk", @@ -535,16 +550,16 @@ class ComponentReflectionTest(fixtures.TablesTest): "remote_table", "remote_table_2", ] - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(connection) - if table_type == "view": + if include_views: table_names = insp.get_view_names(schema) table_names.sort() answer = ["email_addresses_v", "users_v"] eq_(sorted(table_names), answer) - else: + + if include_plain: if order_by: tables = [ rec[0] @@ -576,15 +591,6 @@ class ComponentReflectionTest(fixtures.TablesTest): temp_table_names = insp.get_temp_view_names() eq_(sorted(temp_table_names), ["user_tmp_v"]) - @testing.requires.table_reflection - def test_get_table_names(self): - self._test_get_table_names() - - @testing.requires.table_reflection - @testing.requires.foreign_key_constraint_reflection - def test_get_table_names_fks(self): - self._test_get_table_names(order_by="foreign_key") - @testing.requires.comment_reflection def test_get_comments(self): self._test_get_comments() @@ -595,7 +601,7 @@ class ComponentReflectionTest(fixtures.TablesTest): self._test_get_comments(testing.config.test_schema) def _test_get_comments(self, schema=None): - insp = inspect(testing.db) + insp = inspect(self.bind) eq_( insp.get_table_comment("comment_test", schema=schema), @@ -621,35 +627,27 @@ class ComponentReflectionTest(fixtures.TablesTest): ], ) - @testing.requires.table_reflection - @testing.requires.schemas - def test_get_table_names_with_schema(self): - self._test_get_table_names(testing.config.test_schema) - - @testing.requires.view_column_reflection - def test_get_view_names(self): - self._test_get_table_names(table_type="view") - - @testing.requires.view_column_reflection - @testing.requires.schemas - def test_get_view_names_with_schema(self): - self._test_get_table_names( - testing.config.test_schema, table_type="view" - ) - - @testing.requires.table_reflection - @testing.requires.view_column_reflection - def test_get_tables_and_views(self): - self._test_get_table_names() - self._test_get_table_names(table_type="view") + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False), + (False, True, testing.requires.schemas), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None - def _test_get_columns(self, schema=None, table_type="table"): - meta = MetaData(testing.db) users, addresses = (self.tables.users, self.tables.email_addresses) - table_names = ["users", "email_addresses"] - if table_type == "view": + if use_views: table_names = ["users_v", "email_addresses_v"] - insp = inspect(meta.bind) + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) @@ -699,67 +697,13 @@ class ComponentReflectionTest(fixtures.TablesTest): if not col.primary_key: assert cols[i]["default"] is None - @testing.requires.table_reflection - def test_get_columns(self): - self._test_get_columns() - - @testing.provide_metadata - def _type_round_trip(self, *types): - t = Table( - "t", - self.metadata, - *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] - ) - t.create() - - return [ - c["type"] for c in inspect(self.metadata.bind).get_columns("t") - ] - - @testing.requires.table_reflection - def test_numeric_reflection(self): - for typ in self._type_round_trip(sql_types.Numeric(18, 5)): - assert isinstance(typ, sql_types.Numeric) - eq_(typ.precision, 18) - eq_(typ.scale, 5) - - @testing.requires.table_reflection - def test_varchar_reflection(self): - typ = self._type_round_trip(sql_types.String(52))[0] - assert isinstance(typ, sql_types.String) - eq_(typ.length, 52) - - @testing.requires.table_reflection - @testing.provide_metadata - def test_nullable_reflection(self): - t = Table( - "t", - self.metadata, - Column("a", Integer, nullable=True), - Column("b", Integer, nullable=False), - ) - t.create() - eq_( - dict( - (col["name"], col["nullable"]) - for col in inspect(self.metadata.bind).get_columns("t") - ), - {"a": True, "b": False}, - ) - - @testing.requires.table_reflection - @testing.requires.schemas - def test_get_columns_with_schema(self): - self._test_get_columns(schema=testing.config.test_schema) - @testing.requires.temp_table_reflection def test_get_temp_table_columns(self): table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident + config, self.bind, "user_tmp_%s" % config.ident ) - meta = MetaData(self.bind) user_tmp = self.tables[table_name] - insp = inspect(meta.bind) + insp = inspect(self.bind) cols = insp.get_columns(table_name) self.assert_(len(cols) > 0, len(cols)) @@ -774,22 +718,18 @@ class ComponentReflectionTest(fixtures.TablesTest): cols = insp.get_columns("user_tmp_v") eq_([col["name"] for col in cols], ["id", "name", "foo"]) - @testing.requires.view_column_reflection - def test_get_view_columns(self): - self._test_get_columns(table_type="view") - - @testing.requires.view_column_reflection - @testing.requires.schemas - def test_get_view_columns_with_schema(self): - self._test_get_columns( - schema=testing.config.test_schema, table_type="view" - ) + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self, connection, use_schema): + if use_schema: + schema = testing.config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_pk_constraint(self, schema=None): - meta = self.metadata users, addresses = self.tables.users, self.tables.email_addresses - insp = inspect(meta.bind) + insp = inspect(connection) users_cons = insp.get_pk_constraint(users.name, schema=schema) users_pkeys = users_cons["constrained_columns"] @@ -802,21 +742,18 @@ class ComponentReflectionTest(fixtures.TablesTest): with testing.requires.reflects_pk_names.fail_if(): eq_(addr_cons["name"], "email_ad_pk") - @testing.requires.primary_key_constraint_reflection - def test_get_pk_constraint(self): - self._test_get_pk_constraint() - - @testing.requires.table_reflection - @testing.requires.primary_key_constraint_reflection - @testing.requires.schemas - def test_get_pk_constraint_with_schema(self): - self._test_get_pk_constraint(schema=testing.config.test_schema) + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_foreign_keys(self, schema=None): - meta = self.metadata users, addresses = (self.tables.users, self.tables.email_addresses) - insp = inspect(meta.bind) + insp = inspect(connection) expected_schema = schema # users @@ -845,25 +782,16 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(fkey1["referred_columns"], ["user_id"]) eq_(fkey1["constrained_columns"], ["remote_user_id"]) - @testing.requires.foreign_key_constraint_reflection - def test_get_foreign_keys(self): - self._test_get_foreign_keys() - - @testing.requires.foreign_key_constraint_reflection - @testing.requires.schemas - def test_get_foreign_keys_with_schema(self): - self._test_get_foreign_keys(schema=testing.config.test_schema) - @testing.requires.cross_schema_fk_reflection @testing.requires.schemas def test_get_inter_schema_foreign_keys(self): local_table, remote_table, remote_table_2 = self.tables( - "%s.local_table" % testing.db.dialect.default_schema_name, + "%s.local_table" % self.bind.dialect.default_schema_name, "%s.remote_table" % testing.config.test_schema, "%s.remote_table_2" % testing.config.test_schema, ) - insp = inspect(config.db) + insp = inspect(self.bind) local_fkeys = insp.get_foreign_keys(local_table.name) eq_(len(local_fkeys), 1) @@ -883,85 +811,12 @@ class ComponentReflectionTest(fixtures.TablesTest): assert fkey2["referred_schema"] in ( None, - testing.db.dialect.default_schema_name, + self.bind.dialect.default_schema_name, ) eq_(fkey2["referred_table"], local_table.name) eq_(fkey2["referred_columns"], ["id"]) eq_(fkey2["constrained_columns"], ["local_id"]) - @testing.requires.foreign_key_constraint_option_reflection_ondelete - def test_get_foreign_key_options_ondelete(self): - self._test_get_foreign_key_options(ondelete="CASCADE") - - @testing.requires.foreign_key_constraint_option_reflection_onupdate - def test_get_foreign_key_options_onupdate(self): - self._test_get_foreign_key_options(onupdate="SET NULL") - - @testing.requires.foreign_key_constraint_option_reflection_onupdate - def test_get_foreign_key_options_onupdate_noaction(self): - self._test_get_foreign_key_options(onupdate="NO ACTION", expected={}) - - @testing.requires.fk_constraint_option_reflection_ondelete_noaction - def test_get_foreign_key_options_ondelete_noaction(self): - self._test_get_foreign_key_options(ondelete="NO ACTION", expected={}) - - @testing.requires.fk_constraint_option_reflection_onupdate_restrict - def test_get_foreign_key_options_onupdate_restrict(self): - self._test_get_foreign_key_options(onupdate="RESTRICT") - - @testing.requires.fk_constraint_option_reflection_ondelete_restrict - def test_get_foreign_key_options_ondelete_restrict(self): - self._test_get_foreign_key_options(ondelete="RESTRICT") - - @testing.provide_metadata - def _test_get_foreign_key_options(self, expected=None, **options): - meta = self.metadata - - if expected is None: - expected = options - - Table( - "x", - meta, - Column("id", Integer, primary_key=True), - test_needs_fk=True, - ) - - Table( - "table", - meta, - Column("id", Integer, primary_key=True), - Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), - Column("test", String(10)), - test_needs_fk=True, - ) - - Table( - "user", - meta, - Column("id", Integer, primary_key=True), - Column("name", String(50), nullable=False), - Column("tid", Integer), - sa.ForeignKeyConstraint( - ["tid"], ["table.id"], name="myfk", **options - ), - test_needs_fk=True, - ) - - meta.create_all() - - insp = inspect(meta.bind) - - # test 'options' is always present for a backend - # that can reflect these, since alembic looks for this - opts = insp.get_foreign_keys("table")[0]["options"] - - eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) - - opts = insp.get_foreign_keys("user")[0]["options"] - eq_(opts, expected) - # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) - def _assert_insp_indexes(self, indexes, expected_indexes): index_names = [d["name"] for d in indexes] for e_index in expected_indexes: @@ -970,13 +825,19 @@ class ComponentReflectionTest(fixtures.TablesTest): for key in e_index: eq_(e_index[key], index[key]) - @testing.provide_metadata - def _test_get_indexes(self, schema=None): - meta = self.metadata + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_indexes(self, connection, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. - insp = inspect(meta.bind) + insp = inspect(self.bind) indexes = insp.get_indexes("users", schema=schema) expected_indexes = [ { @@ -992,19 +853,15 @@ class ComponentReflectionTest(fixtures.TablesTest): ] self._assert_insp_indexes(indexes, expected_indexes) + @testing.combinations( + ("noncol_idx_test_nopk", "noncol_idx_nopk"), + ("noncol_idx_test_pk", "noncol_idx_pk"), + argnames="tname,ixname", + ) @testing.requires.index_reflection - def test_get_indexes(self): - self._test_get_indexes() - - @testing.requires.index_reflection - @testing.requires.schemas - def test_get_indexes_with_schema(self): - self._test_get_indexes(schema=testing.config.test_schema) - - @testing.provide_metadata - def _test_get_noncol_index(self, tname, ixname): - meta = self.metadata - insp = inspect(meta.bind) + @testing.requires.indexes_with_ascdesc + def test_get_noncol_index(self, connection, tname, ixname): + insp = inspect(connection) indexes = insp.get_indexes(tname) # reflecting an index that has "x DESC" in it as the column. @@ -1013,85 +870,11 @@ class ComponentReflectionTest(fixtures.TablesTest): expected_indexes = [{"unique": False, "name": ixname}] self._assert_insp_indexes(indexes, expected_indexes) - t = Table(tname, meta, autoload_with=meta.bind) + t = Table(tname, MetaData(), autoload_with=connection) eq_(len(t.indexes), 1) is_(list(t.indexes)[0].table, t) eq_(list(t.indexes)[0].name, ixname) - @testing.requires.index_reflection - @testing.requires.indexes_with_ascdesc - def test_get_noncol_index_no_pk(self): - self._test_get_noncol_index("noncol_idx_test_nopk", "noncol_idx_nopk") - - @testing.requires.index_reflection - @testing.requires.indexes_with_ascdesc - def test_get_noncol_index_pk(self): - self._test_get_noncol_index("noncol_idx_test_pk", "noncol_idx_pk") - - @testing.requires.indexes_with_expressions - @testing.provide_metadata - def test_reflect_expression_based_indexes(self): - t = Table( - "t", - self.metadata, - Column("x", String(30)), - Column("y", String(30)), - ) - - Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) - - Index("t_idx_2", t.c.x) - - self.metadata.create_all(testing.db) - - insp = inspect(testing.db) - - expected = [ - {"name": "t_idx_2", "column_names": ["x"], "unique": False} - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - - with expect_warnings( - "Skipped unsupported reflection of expression-based index t_idx" - ): - eq_( - insp.get_indexes("t"), - expected, - ) - - @testing.requires.index_reflects_included_columns - @testing.provide_metadata - def test_reflect_covering_index(self): - t = Table( - "t", - self.metadata, - Column("x", String(30)), - Column("y", String(30)), - ) - idx = Index("t_idx", t.c.x) - idx.dialect_options[testing.db.name]["include"] = ["y"] - - self.metadata.create_all(testing.db) - - insp = inspect(testing.db) - - eq_( - insp.get_indexes("t"), - [ - { - "name": "t_idx", - "column_names": ["x"], - "include_columns": ["y"], - "unique": False, - } - ], - ) - - @testing.requires.unique_constraint_reflection - def test_get_unique_constraints(self): - self._test_get_unique_constraints() - @testing.requires.temp_table_reflection @testing.requires.unique_constraint_reflection def test_get_temp_table_unique_constraints(self): @@ -1130,19 +913,22 @@ class ComponentReflectionTest(fixtures.TablesTest): expected, ) + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) @testing.requires.unique_constraint_reflection - @testing.requires.schemas - def test_get_unique_constraints_with_schema(self): - self._test_get_unique_constraints(schema=testing.config.test_schema) - - @testing.provide_metadata - def _test_get_unique_constraints(self, schema=None): + def test_get_unique_constraints(self, metadata, connection, use_schema): # SQLite dialect needs to parse the names of the constraints # separately from what it gets from PRAGMA index_list(), and # then matches them up. so same set of column_names in two # constraints will confuse it. Perhaps we should no longer # bother with index_list() here since we have the whole # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None uniques = sorted( [ {"name": "unique_a", "column_names": ["a"]}, @@ -1154,10 +940,9 @@ class ComponentReflectionTest(fixtures.TablesTest): ], key=operator.itemgetter("name"), ) - orig_meta = self.metadata table = Table( "testtbl", - orig_meta, + metadata, Column("a", sa.String(20)), Column("b", sa.String(30)), Column("c", sa.Integer), @@ -1170,9 +955,9 @@ class ComponentReflectionTest(fixtures.TablesTest): table.append_constraint( sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) ) - orig_meta.create_all() + table.create(connection) - inspector = inspect(orig_meta.bind) + inspector = inspect(connection) reflected = sorted( inspector.get_unique_constraints("testtbl", schema=schema), key=operator.itemgetter("name"), @@ -1192,7 +977,7 @@ class ComponentReflectionTest(fixtures.TablesTest): reflected = Table( "testtbl", reflected_metadata, - autoload_with=orig_meta.bind, + autoload_with=connection, schema=schema, ) @@ -1214,30 +999,90 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(names_that_duplicate_index, idx_names) eq_(uq_names, set()) - @testing.requires.check_constraint_reflection - def test_get_check_constraints(self): - self._test_get_check_constraints() + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + view_name1 = "users_v" + view_name2 = "email_addresses_v" + insp = inspect(connection) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + + # why is this here if it's PG specific ? + @testing.combinations( + ("users", False), + ("users", True, testing.requires.schemas), + argnames="table_name,use_schema", + ) + @testing.only_on("postgresql", "PG specific feature") + def test_get_table_oid(self, connection, table_name, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + insp = inspect(connection) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, int)) + + @testing.requires.table_reflection + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + insp = inspect(self.bind) + + for tname, cname in [ + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), + ]: + cols = insp.get_columns(tname) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) + + +class ComponentReflectionTestExtra(fixtures.TestBase): + __backend__ = True + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) @testing.requires.check_constraint_reflection - @testing.requires.schemas - def test_get_check_constraints_schema(self): - self._test_get_check_constraints(schema=testing.config.test_schema) + def test_get_check_constraints(self, metadata, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None - @testing.provide_metadata - def _test_get_check_constraints(self, schema=None): - orig_meta = self.metadata Table( "sa_cc", - orig_meta, + metadata, Column("a", Integer()), sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"), schema=schema, ) - orig_meta.create_all() + metadata.create_all(connection) - inspector = inspect(orig_meta.bind) + inspector = inspect(connection) reflected = sorted( inspector.get_check_constraints("sa_cc", schema=schema), key=operator.itemgetter("name"), @@ -1263,67 +1108,200 @@ class ComponentReflectionTest(fixtures.TablesTest): ], ) - @testing.provide_metadata - def _test_get_view_definition(self, schema=None): - meta = self.metadata - view_name1 = "users_v" - view_name2 = "email_addresses_v" - insp = inspect(meta.bind) - v1 = insp.get_view_definition(view_name1, schema=schema) - self.assert_(v1) - v2 = insp.get_view_definition(view_name2, schema=schema) - self.assert_(v2) + @testing.requires.indexes_with_expressions + def test_reflect_expression_based_indexes(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) - @testing.requires.view_reflection - def test_get_view_definition(self): - self._test_get_view_definition() + Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) - @testing.requires.view_reflection - @testing.requires.schemas - def test_get_view_definition_with_schema(self): - self._test_get_view_definition(schema=testing.config.test_schema) + Index("t_idx_2", t.c.x) - @testing.only_on("postgresql", "PG specific feature") - @testing.provide_metadata - def _test_get_table_oid(self, table_name, schema=None): - meta = self.metadata - insp = inspect(meta.bind) - oid = insp.get_table_oid(table_name, schema) - self.assert_(isinstance(oid, int)) + metadata.create_all(connection) - def test_get_table_oid(self): - self._test_get_table_oid("users") + insp = inspect(connection) - @testing.requires.schemas - def test_get_table_oid_with_schema(self): - self._test_get_table_oid("users", schema=testing.config.test_schema) + expected = [ + {"name": "t_idx_2", "column_names": ["x"], "unique": False} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + + with expect_warnings( + "Skipped unsupported reflection of expression-based index t_idx" + ): + eq_( + insp.get_indexes("t"), + expected, + ) + + @testing.requires.index_reflects_included_columns + def test_reflect_covering_index(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + idx = Index("t_idx", t.c.x) + idx.dialect_options[connection.engine.name]["include"] = ["y"] + + metadata.create_all(connection) + + insp = inspect(connection) + + eq_( + insp.get_indexes("t"), + [ + { + "name": "t_idx", + "column_names": ["x"], + "include_columns": ["y"], + "unique": False, + } + ], + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", + metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) + t.create(connection) + + return [c["type"] for c in inspect(connection).get_columns("t")] @testing.requires.table_reflection - @testing.provide_metadata - def test_autoincrement_col(self): - """test that 'autoincrement' is reflected according to sqla's policy. + def test_numeric_reflection(self, connection, metadata): + for typ in self._type_round_trip( + connection, metadata, sql_types.Numeric(18, 5) + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) - Don't mark this test as unsupported for any backend ! + @testing.requires.table_reflection + def test_varchar_reflection(self, connection, metadata): + typ = self._type_round_trip( + connection, metadata, sql_types.String(52) + )[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) - (technically it fails with MySQL InnoDB since "id" comes before "id2") + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + eq_( + dict( + (col["name"], col["nullable"]) + for col in inspect(connection).get_columns("t") + ), + {"a": True, "b": False}, + ) - A backend is better off not returning "autoincrement" at all, - instead of potentially returning "False" for an auto-incrementing - primary key column. + @testing.combinations( + ( + None, + "CASCADE", + None, + testing.requires.foreign_key_constraint_option_reflection_ondelete, + ), + ( + None, + None, + "SET NULL", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + None, + "NO ACTION", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + "NO ACTION", + None, + testing.requires.fk_constraint_option_reflection_ondelete_noaction, + ), + ( + None, + None, + "RESTRICT", + testing.requires.fk_constraint_option_reflection_onupdate_restrict, + ), + ( + None, + "RESTRICT", + None, + testing.requires.fk_constraint_option_reflection_ondelete_restrict, + ), + argnames="expected,ondelete,onupdate", + ) + def test_get_foreign_key_options( + self, connection, metadata, expected, ondelete, onupdate + ): + options = {} + if ondelete: + options["ondelete"] = ondelete + if onupdate: + options["onupdate"] = onupdate - """ + if expected is None: + expected = options - meta = self.metadata - insp = inspect(meta.bind) + Table( + "x", + metadata, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) - for tname, cname in [ - ("users", "user_id"), - ("email_addresses", "address_id"), - ("dingalings", "dingaling_id"), - ]: - cols = insp.get_columns(tname) - id_ = {c["name"]: c for c in cols}[cname] - assert id_.get("autoincrement", True) + Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + + # test 'options' is always present for a backend + # that can reflect these, since alembic looks for this + opts = insp.get_foreign_keys("table")[0]["options"] + + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(opts, expected) + # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) class NormalizedNameTest(fixtures.TablesTest): @@ -1348,21 +1326,21 @@ class NormalizedNameTest(fixtures.TablesTest): m2 = MetaData() t2_ref = Table( - quoted_name("t2", quote=True), m2, autoload_with=testing.db + quoted_name("t2", quote=True), m2, autoload_with=config.db ) t1_ref = m2.tables["t1"] assert t2_ref.c.t1id.references(t1_ref.c.id) m3 = MetaData() m3.reflect( - testing.db, only=lambda name, m: name.lower() in ("t1", "t2") + config.db, only=lambda name, m: name.lower() in ("t1", "t2") ) assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) def test_get_table_names(self): tablenames = [ t - for t in inspect(testing.db).get_table_names() + for t in inspect(config.db).get_table_names() if t.lower() in ("t1", "t2") ] @@ -1637,20 +1615,16 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): ) @testing.requires.primary_key_constraint_reflection - @testing.provide_metadata def test_pk_column_order(self): # test for issue #5661 - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(self.bind) primary_key = insp.get_pk_constraint(self.tables.tb1.name) eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) @testing.requires.foreign_key_constraint_reflection - @testing.provide_metadata def test_fk_column_order(self): # test for issue #5661 - meta = self.metadata - insp = inspect(meta.bind) + insp = inspect(self.bind) foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) eq_(len(foreign_keys), 1) fkey1 = foreign_keys[0] @@ -1660,6 +1634,7 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): __all__ = ( "ComponentReflectionTest", + "ComponentReflectionTestExtra", "QuotedNameArgumentTest", "HasTableTest", "HasIndexTest", diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 43777239c..3a5e02c32 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -47,18 +47,19 @@ from ...util import u class _LiteralRoundTripFixture(object): supports_whereclause = True - @testing.provide_metadata - def _literal_round_trip(self, type_, input_, output, filter_=None): + @testing.fixture + def literal_round_trip(self, metadata, connection): """test literal rendering """ # for literal, we test the literal render in an INSERT # into a typed column. we can then SELECT it back as its # official type; ideally we'd be able to use CAST here # but MySQL in particular can't CAST fully - t = Table("t", self.metadata, Column("x", type_)) - t.create() - with testing.db.begin() as conn: + def run(type_, input_, output, filter_=None): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + for value in input_: ins = ( t.insert() @@ -68,7 +69,7 @@ class _LiteralRoundTripFixture(object): compile_kwargs=dict(literal_binds=True), ) ) - conn.execute(ins) + connection.execute(ins) if self.supports_whereclause: stmt = t.select().where(t.c.x == literal(value)) @@ -79,12 +80,14 @@ class _LiteralRoundTripFixture(object): dialect=testing.db.dialect, compile_kwargs=dict(literal_binds=True), ) - for row in conn.execute(stmt): + for row in connection.execute(stmt): value = row[0] if filter_ is not None: value = filter_(value) assert value in output + return run + class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): __requires__ = ("unicode_data",) @@ -149,11 +152,11 @@ class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): row = connection.execute(select(unicode_table.c.unicode_data)).first() eq_(row, (u(""),)) - def test_literal(self): - self._literal_round_trip(self.datatype, [self.data], [self.data]) + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) @@ -227,25 +230,25 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): row = connection.execute(select(text_table.c.text_data)).first() eq_(row, (None,)) - def test_literal(self): - self._literal_round_trip(Text, ["some text"], ["some text"]) + def test_literal(self, literal_round_trip): + literal_round_trip(Text, ["some text"], ["some text"]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) - def test_literal_quoting(self): + def test_literal_quoting(self, literal_round_trip): data = """some 'text' hey "hi there" that's text""" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) - def test_literal_backslashes(self): + def test_literal_backslashes(self, literal_round_trip): data = r"backslash one \ backslash two \\ end" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) - def test_literal_percentsigns(self): + def test_literal_percentsigns(self, literal_round_trip): data = r"percent % signs %% percent" - self._literal_round_trip(Text, [data], [data]) + literal_round_trip(Text, [data], [data]) class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): @@ -259,23 +262,23 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): foo.create(config.db) foo.drop(config.db) - def test_literal(self): + def test_literal(self, literal_round_trip): # note that in Python 3, this invokes the Unicode # datatype for the literal part because all strings are unicode - self._literal_round_trip(String(40), ["some text"], ["some text"]) + literal_round_trip(String(40), ["some text"], ["some text"]) - def test_literal_non_ascii(self): - self._literal_round_trip( + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] ) - def test_literal_quoting(self): + def test_literal_quoting(self, literal_round_trip): data = """some 'text' hey "hi there" that's text""" - self._literal_round_trip(String(40), [data], [data]) + literal_round_trip(String(40), [data], [data]) - def test_literal_backslashes(self): + def test_literal_backslashes(self, literal_round_trip): data = r"backslash one \ backslash two \\ end" - self._literal_round_trip(String(40), [data], [data]) + literal_round_trip(String(40), [data], [data]) class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): @@ -331,9 +334,9 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): eq_(row, (None,)) @testing.requires.datetime_literals - def test_literal(self): + def test_literal(self, literal_round_trip): compare = self.compare or self.data - self._literal_round_trip(self.datatype, [self.data], [compare]) + literal_round_trip(self.datatype, [self.data], [compare]) @testing.requires.standalone_null_binds_whereclause def test_null_bound_comparison(self): @@ -430,36 +433,41 @@ class DateHistoricTest(_DateFixture, fixtures.TablesTest): class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True - def test_literal(self): - self._literal_round_trip(Integer, [5], [5]) + def test_literal(self, literal_round_trip): + literal_round_trip(Integer, [5], [5]) - def test_huge_int(self, connection): - self._round_trip(BigInteger, 1376537018368127, connection) + def test_huge_int(self, integer_round_trip): + integer_round_trip(BigInteger, 1376537018368127) - @testing.provide_metadata - def _round_trip(self, datatype, data, connection): - metadata = self.metadata - int_table = Table( - "integer_table", - metadata, - Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ), - Column("integer_data", datatype), - ) + @testing.fixture + def integer_round_trip(self, metadata, connection): + def run(datatype, data): + int_table = Table( + "integer_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("integer_data", datatype), + ) - metadata.create_all(config.db) + metadata.create_all(config.db) - connection.execute(int_table.insert(), {"integer_data": data}) + connection.execute(int_table.insert(), {"integer_data": data}) - row = connection.execute(select(int_table.c.integer_data)).first() + row = connection.execute(select(int_table.c.integer_data)).first() - eq_(row, (data,)) + eq_(row, (data,)) - if util.py3k: - assert isinstance(row[0], int) - else: - assert isinstance(row[0], (long, int)) # noqa + if util.py3k: + assert isinstance(row[0], int) + else: + assert isinstance(row[0], (long, int)) # noqa + + return run class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): @@ -481,12 +489,10 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): return StringAsInt() - @testing.provide_metadata - def test_special_type(self, connection, string_as_int): + def test_special_type(self, metadata, connection, string_as_int): type_ = string_as_int - metadata = self.metadata t = Table("t", metadata, Column("x", type_)) t.create(connection) @@ -504,42 +510,46 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True - @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - @testing.provide_metadata - def _do_test(self, type_, input_, output, filter_=None, check_scale=False): - metadata = self.metadata - t = Table("t", metadata, Column("x", type_)) - t.create() - with config.db.begin() as conn: - conn.execute(t.insert(), [{"x": x} for x in input_]) - - result = {row[0] for row in conn.execute(t.select())} - output = set(output) - if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) - eq_(result, output) - if check_scale: - eq_([str(x) for x in result], [str(x) for x in output]) + @testing.fixture + def do_numeric_test(self, metadata): + @testing.emits_warning( + r".*does \*not\* support Decimal objects natively" + ) + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(testing.db) + with config.db.begin() as conn: + conn.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in conn.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + return run @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric(self): - self._literal_round_trip( + def test_render_literal_numeric(self, literal_round_trip): + literal_round_trip( Numeric(precision=8, scale=4), [15.7563, decimal.Decimal("15.7563")], [decimal.Decimal("15.7563")], ) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric_asfloat(self): - self._literal_round_trip( + def test_render_literal_numeric_asfloat(self, literal_round_trip): + literal_round_trip( Numeric(precision=8, scale=4, asdecimal=False), [15.7563, decimal.Decimal("15.7563")], [15.7563], ) - def test_render_literal_float(self): - self._literal_round_trip( + def test_render_literal_float(self, literal_round_trip): + literal_round_trip( Float(4), [15.7563, decimal.Decimal("15.7563")], [15.7563], @@ -547,49 +557,49 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): ) @testing.requires.precision_generic_float_type - def test_float_custom_scale(self): - self._do_test( + def test_float_custom_scale(self, do_numeric_test): + do_numeric_test( Float(None, decimal_return_scale=7, asdecimal=True), [15.7563827, decimal.Decimal("15.7563827")], [decimal.Decimal("15.7563827")], check_scale=True, ) - def test_numeric_as_decimal(self): - self._do_test( + def test_numeric_as_decimal(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4), [15.7563, decimal.Decimal("15.7563")], [decimal.Decimal("15.7563")], ) - def test_numeric_as_float(self): - self._do_test( + def test_numeric_as_float(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4, asdecimal=False), [15.7563, decimal.Decimal("15.7563")], [15.7563], ) @testing.requires.fetch_null_from_numeric - def test_numeric_null_as_decimal(self): - self._do_test(Numeric(precision=8, scale=4), [None], [None]) + def test_numeric_null_as_decimal(self, do_numeric_test): + do_numeric_test(Numeric(precision=8, scale=4), [None], [None]) @testing.requires.fetch_null_from_numeric - def test_numeric_null_as_float(self): - self._do_test( + def test_numeric_null_as_float(self, do_numeric_test): + do_numeric_test( Numeric(precision=8, scale=4, asdecimal=False), [None], [None] ) @testing.requires.floats_to_four_decimals - def test_float_as_decimal(self): - self._do_test( + def test_float_as_decimal(self, do_numeric_test): + do_numeric_test( Float(precision=8, asdecimal=True), [15.7563, decimal.Decimal("15.7563"), None], [decimal.Decimal("15.7563"), None], filter_=lambda n: n is not None and round(n, 4) or None, ) - def test_float_as_float(self): - self._do_test( + def test_float_as_float(self, do_numeric_test): + do_numeric_test( Float(precision=8), [15.7563, decimal.Decimal("15.7563")], [15.7563], @@ -621,7 +631,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): eq_(val, expr) @testing.requires.precision_numerics_general - def test_precision_decimal(self): + def test_precision_decimal(self, do_numeric_test): numbers = set( [ decimal.Decimal("54.234246451650"), @@ -630,10 +640,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): ] ) - self._do_test(Numeric(precision=18, scale=12), numbers, numbers) + do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) @testing.requires.precision_numerics_enotation_large - def test_enotation_decimal(self): + def test_enotation_decimal(self, do_numeric_test): """test exceedingly small decimals. Decimal reports values with E notation when the exponent @@ -657,10 +667,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("696E-12"), ] ) - self._do_test(Numeric(precision=18, scale=14), numbers, numbers) + do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large - def test_enotation_decimal_large(self): + def test_enotation_decimal_large(self, do_numeric_test): """test exceedingly large decimals.""" numbers = set( @@ -671,10 +681,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("00000000000000.1E+12"), ] ) - self._do_test(Numeric(precision=25, scale=2), numbers, numbers) + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits - def test_many_significant_digits(self): + def test_many_significant_digits(self, do_numeric_test): numbers = set( [ decimal.Decimal("31943874831932418390.01"), @@ -682,12 +692,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): decimal.Decimal("87673.594069654243"), ] ) - self._do_test(Numeric(precision=38, scale=12), numbers, numbers) + do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits - def test_numeric_no_decimal(self): + def test_numeric_no_decimal(self, do_numeric_test): numbers = set([decimal.Decimal("1.000")]) - self._do_test( + do_numeric_test( Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -705,8 +715,8 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): Column("unconstrained_value", Boolean(create_constraint=False)), ) - def test_render_literal_bool(self): - self._literal_round_trip(Boolean(), [True, False], [True, False]) + def test_render_literal_bool(self, literal_round_trip): + literal_round_trip(Boolean(), [True, False], [True, False]) def test_round_trip(self, connection): boolean_table = self.tables.boolean_table |
