summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects
diff options
context:
space:
mode:
authorGord Thompson <gord@gordthompson.com>2021-08-20 09:50:10 -0600
committerGord Thompson <gord@gordthompson.com>2021-08-23 11:01:10 -0600
commit207ec35c2e175f5fcf68e886d5e61a0678c2d6fc (patch)
tree290f9a750802bacf4b38308389abf22ebfe2ae36 /lib/sqlalchemy/dialects
parent5219a0e25ba733faf8f35775681bf3d57016267e (diff)
downloadsqlalchemy-207ec35c2e175f5fcf68e886d5e61a0678c2d6fc.tar.gz
Fix has_table() to exclude other sessions' local temp tables
Fixes: #6910 Change-Id: I9986566e1195d42ad7e9a01f0f84ef2074576257
Diffstat (limited to 'lib/sqlalchemy/dialects')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py20
1 files changed, 17 insertions, 3 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index c11166735..8607edeca 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -758,6 +758,7 @@ from ... import Identity
from ... import schema as sa_schema
from ... import Sequence
from ... import sql
+from ... import text
from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
@@ -2823,8 +2824,16 @@ class MSDialect(default.DefaultDialect):
)
)
- result = connection.execute(s.limit(1))
- return result.scalar() is not None
+ table_name = connection.execute(s.limit(1)).scalar()
+ if table_name:
+ # #6910: verify it's not a temp table from another session
+ obj_id = connection.execute(
+ text("SELECT object_id(:table_name)"),
+ {"table_name": "tempdb.dbo.[{}]".format(table_name)},
+ ).scalar()
+ return bool(obj_id)
+ else:
+ return False
else:
tables = ischema.tables
@@ -3009,7 +3018,12 @@ class MSDialect(default.DefaultDialect):
return view_def
def _temp_table_name_like_pattern(self, tablename):
- return tablename + (("___%") if not tablename.startswith("##") else "")
+ # LIKE uses '%' to match zero or more characters and '_' to match any
+ # single character. We want to match literal underscores, so T-SQL
+ # requires that we enclose them in square brackets.
+ return tablename + (
+ ("[_][_][_]%") if not tablename.startswith("##") else ""
+ )
def _get_internal_temp_table_name(self, connection, tablename):
# it's likely that schema is always "dbo", but since we can