summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorGord Thompson <gord@gordthompson.com>2020-08-12 14:46:59 -0600
committerGord Thompson <gord@gordthompson.com>2020-09-01 08:05:51 -0600
commit516131c40da9c8cd304061850e2d98e309966dd5 (patch)
tree148c91095e3021de6881fc0d328980538d3fcea0 /lib/sqlalchemy
parent301c3f3579ace1ef1c28067904b57dd789620eae (diff)
downloadsqlalchemy-516131c40da9c8cd304061850e2d98e309966dd5.tar.gz
Improve reflection for mssql temporary tables
Fixes: #5506 Change-Id: I718474d76e3c630a1b71e07eaa20cefb104d11de
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py37
-rw-r--r--lib/sqlalchemy/dialects/mssql/information_schema.py22
-rw-r--r--lib/sqlalchemy/dialects/mssql/provision.py12
-rw-r--r--lib/sqlalchemy/testing/provision.py15
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py14
5 files changed, 91 insertions, 9 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index f38c537fd..ed17fb863 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2913,11 +2913,46 @@ class MSDialect(default.DefaultDialect):
view_def = rp.scalar()
return view_def
+ def _get_internal_temp_table_name(self, connection, tablename):
+ result = connection.execute(
+ sql.text(
+ "select table_name "
+ "from tempdb.information_schema.tables "
+ "where table_name like :p1"
+ ),
+ {
+ "p1": tablename
+ + (("___%") if not tablename.startswith("##") else "")
+ },
+ ).fetchall()
+ if len(result) > 1:
+ raise exc.UnreflectableTableError(
+ "Found more than one temporary table named '%s' in tempdb "
+ "at this time. Cannot reliably resolve that name to its "
+ "internal table name." % tablename
+ )
+ elif len(result) == 0:
+ raise exc.NoSuchTableError(
+ "Unable to find a temporary table named '%s' in tempdb."
+ % tablename
+ )
+ else:
+ return result[0][0]
+
@reflection.cache
@_db_plus_owner
def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
+ is_temp_table = tablename.startswith("#")
+ if is_temp_table:
+ tablename = self._get_internal_temp_table_name(
+ connection, tablename
+ )
# Get base columns
- columns = ischema.columns
+ columns = (
+ ischema.mssql_temp_table_columns
+ if is_temp_table
+ else ischema.columns
+ )
computed_cols = ischema.computed_columns
if owner:
whereclause = sql.and_(
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
index 6cdde8386..f80110b7d 100644
--- a/lib/sqlalchemy/dialects/mssql/information_schema.py
+++ b/lib/sqlalchemy/dialects/mssql/information_schema.py
@@ -5,9 +5,6 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-# TODO: should be using the sys. catalog with SQL Server, not information
-# schema
-
from ... import cast
from ... import Column
from ... import MetaData
@@ -93,6 +90,25 @@ columns = Table(
schema="INFORMATION_SCHEMA",
)
+mssql_temp_table_columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="tempdb.INFORMATION_SCHEMA",
+)
+
constraints = Table(
"TABLE_CONSTRAINTS",
ischema,
diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py
index a5131eae6..269eb164f 100644
--- a/lib/sqlalchemy/dialects/mssql/provision.py
+++ b/lib/sqlalchemy/dialects/mssql/provision.py
@@ -2,8 +2,10 @@ from ... import create_engine
from ... import exc
from ...testing.provision import create_db
from ...testing.provision import drop_db
+from ...testing.provision import get_temp_table_name
from ...testing.provision import log
from ...testing.provision import run_reap_dbs
+from ...testing.provision import temp_table_keyword_args
@create_db.for_db("mssql")
@@ -72,3 +74,13 @@ def _reap_mssql_dbs(url, idents):
log.info(
"Dropped %d out of %d stale databases detected", dropped, total
)
+
+
+@temp_table_keyword_args.for_db("mssql")
+def _mssql_temp_table_keyword_args(cfg, eng):
+ return {}
+
+
+@get_temp_table_name.for_db("mssql")
+def _mssql_get_temp_table_name(cfg, eng, base_name):
+ return "#" + base_name
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index 0edaae490..8bdad357c 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -296,3 +296,18 @@ def temp_table_keyword_args(cfg, eng):
raise NotImplementedError(
"no temp table keyword args routine for cfg: %s" % eng.url
)
+
+
+@register.init
+def get_temp_table_name(cfg, eng, base_name):
+ """Specify table name for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ name to use when creating a temporary table for testing,
+ e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+
+ Default to just the base name since that's what most dialects will
+ use. The mssql dialect's implementation will need a "#" prepended.
+ """
+ return base_name
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 151be757a..94ec22c1e 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -8,6 +8,7 @@ from .. import eq_
from .. import expect_warnings
from .. import fixtures
from .. import is_
+from ..provision import get_temp_table_name
from ..provision import temp_table_keyword_args
from ..schema import Column
from ..schema import Table
@@ -442,8 +443,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@classmethod
def define_temp_tables(cls, metadata):
kw = temp_table_keyword_args(config, config.db)
+ table_name = get_temp_table_name(config, config.db, "user_tmp")
user_tmp = Table(
- "user_tmp",
+ table_name,
metadata,
Column("id", sa.INT, primary_key=True),
Column("name", sa.VARCHAR(50)),
@@ -736,10 +738,11 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.temp_table_reflection
def test_get_temp_table_columns(self):
+ table_name = get_temp_table_name(config, config.db, "user_tmp")
meta = MetaData(self.bind)
- user_tmp = self.tables.user_tmp
+ user_tmp = self.tables[table_name]
insp = inspect(meta.bind)
- cols = insp.get_columns("user_tmp")
+ cols = insp.get_columns(table_name)
self.assert_(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
@@ -1051,10 +1054,11 @@ class ComponentReflectionTest(fixtures.TablesTest):
refl.pop("duplicates_index", None)
eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}])
- @testing.requires.temp_table_reflection
+ @testing.requires.temp_table_reflect_indexes
def test_get_temp_table_indexes(self):
insp = inspect(self.bind)
- indexes = insp.get_indexes("user_tmp")
+ table_name = get_temp_table_name(config, config.db, "user_tmp")
+ indexes = insp.get_indexes(table_name)
for ind in indexes:
ind.pop("dialect_options", None)
eq_(