summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-02-08 11:58:15 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-02-08 13:16:02 -0500
commitb348e82dcf5eca1fc8496c941dc1ac2ffd60eed9 (patch)
treef95e6a111f14b7517a3cb0b0e7b42b9a6487217a /lib/sqlalchemy
parent146efafcb436660e7891d3b34d05cd794c45268d (diff)
downloadsqlalchemy-b348e82dcf5eca1fc8496c941dc1ac2ffd60eed9.tar.gz
Add identifier_preparer per-execution context for schema translates
Fixed bug where the "schema_translate_map" feature failed to be taken into account for the use case of direct execution of :class:`_schema.DefaultGenerator` objects such as sequences, which included the case where they were "pre-executed" in order to generate primary key values when implicit_returning was disabled. Fixes: #5929 Change-Id: I3fed1d0af28be5ce9c9bb572524dcc8411633f60 (cherry picked from commit 2385ebb19366efeb35415298166ac18668864c51)
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py2
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py6
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py5
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py6
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2
-rw-r--r--lib/sqlalchemy/engine/default.py11
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py68
-rw-r--r--lib/sqlalchemy/util/compat.py3
8 files changed, 94 insertions, 9 deletions
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
index 28fefa5b7..9138a81a9 100644
--- a/lib/sqlalchemy/dialects/firebird/base.py
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -612,7 +612,7 @@ class FBExecutionContext(default.DefaultExecutionContext):
return self._execute_scalar(
"SELECT gen_id(%s, 1) FROM rdb$database"
- % self.dialect.identifier_preparer.format_sequence(seq),
+ % self.identifier_preparer.format_sequence(seq),
type_,
)
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 22f329730..debfa55b1 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -1495,7 +1495,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
self.cursor,
self._opt_encode(
"SET IDENTITY_INSERT %s ON"
- % self.dialect.identifier_preparer.format_table(tbl)
+ % self.identifier_preparer.format_table(tbl)
),
(),
self,
@@ -1531,7 +1531,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
self.cursor,
self._opt_encode(
"SET IDENTITY_INSERT %s OFF"
- % self.dialect.identifier_preparer.format_table(
+ % self.identifier_preparer.format_table(
self.compiled.statement.table
)
),
@@ -1548,7 +1548,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
self.cursor.execute(
self._opt_encode(
"SET IDENTITY_INSERT %s OFF"
- % self.dialect.identifier_preparer.format_table(
+ % self.identifier_preparer.format_table(
self.compiled.statement.table
)
)
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index c41d6acf7..47e4dff94 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -901,9 +901,13 @@ from ...types import BLOB
from ...types import BOOLEAN
from ...types import DATE
from ...types import VARBINARY
+from ...util import compat
from ...util import topological
+if compat.TYPE_CHECKING:
+ from typing import Any
+
RESERVED_WORDS = set(
[
"accessible",
@@ -3053,6 +3057,7 @@ class MySQLDialect(default.DefaultDialect):
return parser.parse(sql, charset)
def _detect_charset(self, connection):
+ # type: (Any) -> str
raise NotImplementedError()
def _detect_casing(self, connection):
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index c476554bd..c62116572 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -958,9 +958,7 @@ class OracleCompiler(compiler.SQLCompiler):
return self.process(vc.column, **kw) + "(+)"
def visit_sequence(self, seq, **kw):
- return (
- self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
- )
+ return self.preparer.format_sequence(seq) + ".nextval"
def get_render_as_alias_suffix(self, alias_name_text):
"""Oracle doesn't like ``FROM table AS alias``"""
@@ -1281,7 +1279,7 @@ class OracleExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
return self._execute_scalar(
"SELECT "
- + self.dialect.identifier_preparer.format_sequence(seq)
+ + self.identifier_preparer.format_sequence(seq)
+ ".nextval FROM DUAL",
type_,
)
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 39e11aa61..7dec6d818 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2487,7 +2487,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return self._execute_scalar(
(
"select nextval('%s')"
- % self.dialect.identifier_preparer.format_sequence(seq)
+ % self.identifier_preparer.format_sequence(seq)
),
type_,
)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 1c0a87b4a..59eac7e0d 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1057,6 +1057,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self
@util.memoized_property
+ def identifier_preparer(self):
+ if self.compiled:
+ return self.compiled.preparer
+ elif "schema_translate_map" in self.execution_options:
+ return self.dialect.identifier_preparer._with_schema_translate(
+ self.execution_options["schema_translate_map"]
+ )
+ else:
+ return self.dialect.identifier_preparer
+
+ @util.memoized_property
def engine(self):
return self.root_connection.engine
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index 22ae7d43c..6c80f9487 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -39,6 +39,34 @@ class SequenceTest(fixtures.TablesTest):
Column("data", String(50)),
)
+ Table(
+ "seq_no_returning",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ )
+
+ if testing.requires.schemas.enabled:
+ Table(
+ "seq_no_returning_sch",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema=config.test_schema),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ schema=config.test_schema,
+ )
+
def test_insert_roundtrip(self):
config.db.execute(self.tables.seq_pk.insert(), data="some data")
self._assert_round_trip(self.tables.seq_pk, config.db)
@@ -62,6 +90,46 @@ class SequenceTest(fixtures.TablesTest):
row = conn.execute(table.select()).first()
eq_(row, (1, "some data"))
+ def test_insert_roundtrip_no_implicit_returning(self, connection):
+ connection.execute(
+ self.tables.seq_no_returning.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.seq_no_returning, connection)
+
+ @testing.combinations((True,), (False,), argnames="implicit_returning")
+ @testing.requires.schemas
+ def test_insert_roundtrip_translate(self, connection, implicit_returning):
+
+ seq_no_returning = Table(
+ "seq_no_returning_sch",
+ MetaData(),
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema="alt_schema"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=implicit_returning,
+ schema="alt_schema",
+ )
+
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+ connection.execute(seq_no_returning.insert(), dict(data="some data"))
+ self._assert_round_trip(seq_no_returning, connection)
+
+ @testing.requires.schemas
+ def test_nextval_direct_schema_translate(self, connection):
+ seq = Sequence("noret_sch_id_seq", schema="alt_schema")
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+
+ r = connection.execute(seq)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
__requires__ = ("sequences",)
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index aed39366d..a1d55376d 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -186,6 +186,8 @@ if py3k:
# as the __traceback__ object creates a cycle
del exception, replace_context, from_, with_traceback
+ from typing import TYPE_CHECKING
+
def u(s):
return s
@@ -299,6 +301,7 @@ else:
" raise exception\n"
)
+ TYPE_CHECKING = False
if py35: