diff options
| author | Gord Thompson <gord@gordthompson.com> | 2020-08-12 14:46:59 -0600 |
|---|---|---|
| committer | Gord Thompson <gord@gordthompson.com> | 2020-09-01 08:05:51 -0600 |
| commit | 516131c40da9c8cd304061850e2d98e309966dd5 (patch) | |
| tree | 148c91095e3021de6881fc0d328980538d3fcea0 /lib/sqlalchemy | |
| parent | 301c3f3579ace1ef1c28067904b57dd789620eae (diff) | |
| download | sqlalchemy-516131c40da9c8cd304061850e2d98e309966dd5.tar.gz | |
Improve reflection for mssql temporary tables
Fixes: #5506
Change-Id: I718474d76e3c630a1b71e07eaa20cefb104d11de
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/information_schema.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/provision.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/provision.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 14 |
5 files changed, 91 insertions, 9 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index f38c537fd..ed17fb863 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2913,11 +2913,46 @@ class MSDialect(default.DefaultDialect): view_def = rp.scalar() return view_def + def _get_internal_temp_table_name(self, connection, tablename): + result = connection.execute( + sql.text( + "select table_name " + "from tempdb.information_schema.tables " + "where table_name like :p1" + ), + { + "p1": tablename + + (("___%") if not tablename.startswith("##") else "") + }, + ).fetchall() + if len(result) > 1: + raise exc.UnreflectableTableError( + "Found more than one temporary table named '%s' in tempdb " + "at this time. Cannot reliably resolve that name to its " + "internal table name." % tablename + ) + elif len(result) == 0: + raise exc.NoSuchTableError( + "Unable to find a temporary table named '%s' in tempdb." + % tablename + ) + else: + return result[0][0] + @reflection.cache @_db_plus_owner def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + is_temp_table = tablename.startswith("#") + if is_temp_table: + tablename = self._get_internal_temp_table_name( + connection, tablename + ) # Get base columns - columns = ischema.columns + columns = ( + ischema.mssql_temp_table_columns + if is_temp_table + else ischema.columns + ) computed_cols = ischema.computed_columns if owner: whereclause = sql.and_( diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 6cdde8386..f80110b7d 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -5,9 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -# TODO: should be using the sys. catalog with SQL Server, not information -# schema - from ... import cast from ... import Column from ... import MetaData @@ -93,6 +90,25 @@ columns = Table( schema="INFORMATION_SCHEMA", ) +mssql_temp_table_columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="tempdb.INFORMATION_SCHEMA", +) + constraints = Table( "TABLE_CONSTRAINTS", ischema, diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index a5131eae6..269eb164f 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -2,8 +2,10 @@ from ... import create_engine from ... import exc from ...testing.provision import create_db from ...testing.provision import drop_db +from ...testing.provision import get_temp_table_name from ...testing.provision import log from ...testing.provision import run_reap_dbs +from ...testing.provision import temp_table_keyword_args @create_db.for_db("mssql") @@ -72,3 +74,13 @@ def _reap_mssql_dbs(url, idents): log.info( "Dropped %d out of %d stale databases detected", dropped, total ) + + +@temp_table_keyword_args.for_db("mssql") +def _mssql_temp_table_keyword_args(cfg, eng): + return {} + + +@get_temp_table_name.for_db("mssql") +def _mssql_get_temp_table_name(cfg, eng, base_name): + return "#" + base_name diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 0edaae490..8bdad357c 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -296,3 +296,18 @@ def temp_table_keyword_args(cfg, eng): raise NotImplementedError( "no temp table keyword args routine for cfg: %s" % eng.url ) + + +@register.init +def get_temp_table_name(cfg, eng, base_name): + """Specify table name for creating a temporary Table. + + Dialect-specific implementations of this method will return the + name to use when creating a temporary table for testing, + e.g., in the define_temp_tables method of the + ComponentReflectionTest class in suite/test_reflection.py + + Default to just the base name since that's what most dialects will + use. The mssql dialect's implementation will need a "#" prepended. + """ + return base_name diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 151be757a..94ec22c1e 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -8,6 +8,7 @@ from .. import eq_ from .. import expect_warnings from .. import fixtures from .. import is_ +from ..provision import get_temp_table_name from ..provision import temp_table_keyword_args from ..schema import Column from ..schema import Table @@ -442,8 +443,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @classmethod def define_temp_tables(cls, metadata): kw = temp_table_keyword_args(config, config.db) + table_name = get_temp_table_name(config, config.db, "user_tmp") user_tmp = Table( - "user_tmp", + table_name, metadata, Column("id", sa.INT, primary_key=True), Column("name", sa.VARCHAR(50)), @@ -736,10 +738,11 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection def test_get_temp_table_columns(self): + table_name = get_temp_table_name(config, config.db, "user_tmp") meta = MetaData(self.bind) - user_tmp = self.tables.user_tmp + user_tmp = self.tables[table_name] insp = inspect(meta.bind) - cols = insp.get_columns("user_tmp") + cols = insp.get_columns(table_name) self.assert_(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): @@ -1051,10 +1054,11 @@ class ComponentReflectionTest(fixtures.TablesTest): refl.pop("duplicates_index", None) eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}]) - @testing.requires.temp_table_reflection + @testing.requires.temp_table_reflect_indexes def test_get_temp_table_indexes(self): insp = inspect(self.bind) - indexes = insp.get_indexes("user_tmp") + table_name = get_temp_table_name(config, config.db, "user_tmp") + indexes = insp.get_indexes(table_name) for ind in indexes: ind.pop("dialect_options", None) eq_( |
