diff options
| author | Gord Thompson <gord@gordthompson.com> | 2021-08-20 09:50:10 -0600 |
|---|---|---|
| committer | Gord Thompson <gord@gordthompson.com> | 2021-08-23 11:01:10 -0600 |
| commit | 207ec35c2e175f5fcf68e886d5e61a0678c2d6fc (patch) | |
| tree | 290f9a750802bacf4b38308389abf22ebfe2ae36 /lib/sqlalchemy/dialects | |
| parent | 5219a0e25ba733faf8f35775681bf3d57016267e (diff) | |
| download | sqlalchemy-207ec35c2e175f5fcf68e886d5e61a0678c2d6fc.tar.gz | |
Fix has_table() to exclude other sessions' local temp tables
Fixes: #6910
Change-Id: I9986566e1195d42ad7e9a01f0f84ef2074576257
Diffstat (limited to 'lib/sqlalchemy/dialects')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index c11166735..8607edeca 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -758,6 +758,7 @@ from ... import Identity from ... import schema as sa_schema from ... import Sequence from ... import sql +from ... import text from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor @@ -2823,8 +2824,16 @@ class MSDialect(default.DefaultDialect): ) ) - result = connection.execute(s.limit(1)) - return result.scalar() is not None + table_name = connection.execute(s.limit(1)).scalar() + if table_name: + # #6910: verify it's not a temp table from another session + obj_id = connection.execute( + text("SELECT object_id(:table_name)"), + {"table_name": "tempdb.dbo.[{}]".format(table_name)}, + ).scalar() + return bool(obj_id) + else: + return False else: tables = ischema.tables @@ -3009,7 +3018,12 @@ class MSDialect(default.DefaultDialect): return view_def def _temp_table_name_like_pattern(self, tablename): - return tablename + (("___%") if not tablename.startswith("##") else "") + # LIKE uses '%' to match zero or more characters and '_' to match any + # single character. We want to match literal underscores, so T-SQL + # requires that we enclose them in square brackets. + return tablename + ( + ("[_][_][_]%") if not tablename.startswith("##") else "" + ) def _get_internal_temp_table_name(self, connection, tablename): # it's likely that schema is always "dbo", but since we can |
