summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects
diff options
context:
space:
mode:
authorGord Thompson <gord@gordthompson.com>2020-08-12 14:46:59 -0600
committerGord Thompson <gord@gordthompson.com>2020-09-01 08:05:51 -0600
commit516131c40da9c8cd304061850e2d98e309966dd5 (patch)
tree148c91095e3021de6881fc0d328980538d3fcea0 /lib/sqlalchemy/dialects
parent301c3f3579ace1ef1c28067904b57dd789620eae (diff)
downloadsqlalchemy-516131c40da9c8cd304061850e2d98e309966dd5.tar.gz
Improve reflection for mssql temporary tables
Fixes: #5506 Change-Id: I718474d76e3c630a1b71e07eaa20cefb104d11de
Diffstat (limited to 'lib/sqlalchemy/dialects')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py37
-rw-r--r--lib/sqlalchemy/dialects/mssql/information_schema.py22
-rw-r--r--lib/sqlalchemy/dialects/mssql/provision.py12
3 files changed, 67 insertions, 4 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index f38c537fd..ed17fb863 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2913,11 +2913,46 @@ class MSDialect(default.DefaultDialect):
view_def = rp.scalar()
return view_def
+ def _get_internal_temp_table_name(self, connection, tablename):
+ result = connection.execute(
+ sql.text(
+ "select table_name "
+ "from tempdb.information_schema.tables "
+ "where table_name like :p1"
+ ),
+ {
+ "p1": tablename
+ + (("___%") if not tablename.startswith("##") else "")
+ },
+ ).fetchall()
+ if len(result) > 1:
+ raise exc.UnreflectableTableError(
+ "Found more than one temporary table named '%s' in tempdb "
+ "at this time. Cannot reliably resolve that name to its "
+ "internal table name." % tablename
+ )
+ elif len(result) == 0:
+ raise exc.NoSuchTableError(
+ "Unable to find a temporary table named '%s' in tempdb."
+ % tablename
+ )
+ else:
+ return result[0][0]
+
@reflection.cache
@_db_plus_owner
def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
+ is_temp_table = tablename.startswith("#")
+ if is_temp_table:
+ tablename = self._get_internal_temp_table_name(
+ connection, tablename
+ )
# Get base columns
- columns = ischema.columns
+ columns = (
+ ischema.mssql_temp_table_columns
+ if is_temp_table
+ else ischema.columns
+ )
computed_cols = ischema.computed_columns
if owner:
whereclause = sql.and_(
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
index 6cdde8386..f80110b7d 100644
--- a/lib/sqlalchemy/dialects/mssql/information_schema.py
+++ b/lib/sqlalchemy/dialects/mssql/information_schema.py
@@ -5,9 +5,6 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-# TODO: should be using the sys. catalog with SQL Server, not information
-# schema
-
from ... import cast
from ... import Column
from ... import MetaData
@@ -93,6 +90,25 @@ columns = Table(
schema="INFORMATION_SCHEMA",
)
+mssql_temp_table_columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="tempdb.INFORMATION_SCHEMA",
+)
+
constraints = Table(
"TABLE_CONSTRAINTS",
ischema,
diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py
index a5131eae6..269eb164f 100644
--- a/lib/sqlalchemy/dialects/mssql/provision.py
+++ b/lib/sqlalchemy/dialects/mssql/provision.py
@@ -2,8 +2,10 @@ from ... import create_engine
from ... import exc
from ...testing.provision import create_db
from ...testing.provision import drop_db
+from ...testing.provision import get_temp_table_name
from ...testing.provision import log
from ...testing.provision import run_reap_dbs
+from ...testing.provision import temp_table_keyword_args
@create_db.for_db("mssql")
@@ -72,3 +74,13 @@ def _reap_mssql_dbs(url, idents):
log.info(
"Dropped %d out of %d stale databases detected", dropped, total
)
+
+
+@temp_table_keyword_args.for_db("mssql")
+def _mssql_temp_table_keyword_args(cfg, eng):
+ return {}
+
+
+@get_temp_table_name.for_db("mssql")
+def _mssql_get_temp_table_name(cfg, eng, base_name):
+ return "#" + base_name