diff options
| author | Gord Thompson <gord@gordthompson.com> | 2020-08-12 14:46:59 -0600 |
|---|---|---|
| committer | Gord Thompson <gord@gordthompson.com> | 2020-09-01 08:05:51 -0600 |
| commit | 516131c40da9c8cd304061850e2d98e309966dd5 (patch) | |
| tree | 148c91095e3021de6881fc0d328980538d3fcea0 /lib/sqlalchemy/dialects | |
| parent | 301c3f3579ace1ef1c28067904b57dd789620eae (diff) | |
| download | sqlalchemy-516131c40da9c8cd304061850e2d98e309966dd5.tar.gz | |
Improve reflection for mssql temporary tables
Fixes: #5506
Change-Id: I718474d76e3c630a1b71e07eaa20cefb104d11de
Diffstat (limited to 'lib/sqlalchemy/dialects')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/information_schema.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/provision.py | 12 |
3 files changed, 67 insertions, 4 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index f38c537fd..ed17fb863 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2913,11 +2913,46 @@ class MSDialect(default.DefaultDialect): view_def = rp.scalar() return view_def + def _get_internal_temp_table_name(self, connection, tablename): + result = connection.execute( + sql.text( + "select table_name " + "from tempdb.information_schema.tables " + "where table_name like :p1" + ), + { + "p1": tablename + + (("___%") if not tablename.startswith("##") else "") + }, + ).fetchall() + if len(result) > 1: + raise exc.UnreflectableTableError( + "Found more than one temporary table named '%s' in tempdb " + "at this time. Cannot reliably resolve that name to its " + "internal table name." % tablename + ) + elif len(result) == 0: + raise exc.NoSuchTableError( + "Unable to find a temporary table named '%s' in tempdb." + % tablename + ) + else: + return result[0][0] + @reflection.cache @_db_plus_owner def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + is_temp_table = tablename.startswith("#") + if is_temp_table: + tablename = self._get_internal_temp_table_name( + connection, tablename + ) # Get base columns - columns = ischema.columns + columns = ( + ischema.mssql_temp_table_columns + if is_temp_table + else ischema.columns + ) computed_cols = ischema.computed_columns if owner: whereclause = sql.and_( diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 6cdde8386..f80110b7d 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -5,9 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -# TODO: should be using the sys. catalog with SQL Server, not information -# schema - from ... import cast from ... import Column from ... import MetaData @@ -93,6 +90,25 @@ columns = Table( schema="INFORMATION_SCHEMA", ) +mssql_temp_table_columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="tempdb.INFORMATION_SCHEMA", +) + constraints = Table( "TABLE_CONSTRAINTS", ischema, diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index a5131eae6..269eb164f 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -2,8 +2,10 @@ from ... import create_engine from ... import exc from ...testing.provision import create_db from ...testing.provision import drop_db +from ...testing.provision import get_temp_table_name from ...testing.provision import log from ...testing.provision import run_reap_dbs +from ...testing.provision import temp_table_keyword_args @create_db.for_db("mssql") @@ -72,3 +74,13 @@ def _reap_mssql_dbs(url, idents): log.info( "Dropped %d out of %d stale databases detected", dropped, total ) + + +@temp_table_keyword_args.for_db("mssql") +def _mssql_temp_table_keyword_args(cfg, eng): + return {} + + +@get_temp_table_name.for_db("mssql") +def _mssql_get_temp_table_name(cfg, eng, base_name): + return "#" + base_name |
