summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-05-25 08:47:29 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-25 10:07:31 -0400
commita5d481eaa5bff958692fc3b0024f0b9b1c4f56c6 (patch)
treee1a53b2d9ee12ac5d49c14eb85483c0a118745c7
parentd7b131d2dfc4c519b23d9ed29364036ef88b1863 (diff)
downloadsqlalchemy-a5d481eaa5bff958692fc3b0024f0b9b1c4f56c6.tar.gz
apply bindparam escape name to processors dictionary
Fixed SQL compiler issue where the "bind processing" function for a bound parameter would not be correctly applied to a bound value if the bound parameter's name were "escaped". Concretely, this applies, among other cases, to Oracle when a :class:`.Column` has a name that itself requires quoting, such that the quoting-required name is then used for the bound parameters generated within DML statements, and the datatype in use requires bind processing, such as the :class:`.Enum` datatype. Fixes: #8053 Change-Id: I39d060a87e240b4ebcfccaa9c535e971b7255d99
-rw-r--r--doc/build/changelog/unreleased_14/8053.rst11
-rw-r--r--lib/sqlalchemy/sql/compiler.py10
-rw-r--r--test/dialect/oracle/test_dialect.py19
-rw-r--r--test/ext/mypy/plain_files/sql_operations.py2
-rw-r--r--test/sql/test_compiler.py42
5 files changed, 82 insertions, 2 deletions
diff --git a/doc/build/changelog/unreleased_14/8053.rst b/doc/build/changelog/unreleased_14/8053.rst
new file mode 100644
index 000000000..316b63859
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/8053.rst
@@ -0,0 +1,11 @@
+.. change::
+ :tags: bug, oracle
+ :tickets: 8053
+
+ Fixed SQL compiler issue where the "bind processing" function for a bound
+ parameter would not be correctly applied to a bound value if the bound
+ parameter's name were "escaped". Concretely, this applies, among other
+ cases, to Oracle when a :class:`.Column` has a name that itself requires
+ quoting, such that the quoting-required name is then used for the bound
+ parameters generated within DML statements, and the datatype in use
+ requires bind processing, such as the :class:`.Enum` datatype.
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 0eae31a1a..63ed45a96 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1143,10 +1143,18 @@ class SQLCompiler(Compiled):
str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
]:
+ _escaped_bind_names = self.escaped_bind_names
+ has_escaped_names = bool(_escaped_bind_names)
+
# mypy is not able to see the two value types as the above Union,
# it just sees "object". don't know how to resolve
return dict(
- (key, value) # type: ignore
+ (
+ _escaped_bind_names.get(key, key)
+ if has_escaped_names
+ else key,
+ value,
+ ) # type: ignore
for key, value in (
(
self.bind_names[bindparam],
diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py
index 26a29b73e..8d74c1f48 100644
--- a/test/dialect/oracle/test_dialect.py
+++ b/test/dialect/oracle/test_dialect.py
@@ -7,6 +7,7 @@ from unittest.mock import Mock
from sqlalchemy import bindparam
from sqlalchemy import Computed
from sqlalchemy import create_engine
+from sqlalchemy import Enum
from sqlalchemy import exc
from sqlalchemy import Float
from sqlalchemy import func
@@ -33,6 +34,7 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import pep435_enum
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.suite import test_select
@@ -527,6 +529,23 @@ class QuotedBindRoundTripTest(fixtures.TestBase):
4,
)
+ def test_param_w_processors(self, metadata, connection):
+ """test #8053"""
+
+ SomeEnum = pep435_enum("SomeEnum")
+ one = SomeEnum("one", 1)
+ SomeEnum("two", 2)
+
+ t = Table(
+ "t",
+ metadata,
+ Column("_id", Integer, primary_key=True),
+ Column("_data", Enum(SomeEnum)),
+ )
+ t.create(connection)
+ connection.execute(t.insert(), {"_id": 1, "_data": one})
+ eq_(connection.scalar(select(t.c._data)), one)
+
def test_numeric_bind_in_crud(self, metadata, connection):
t = Table("asfd", metadata, Column("100K", Integer))
t.create(connection)
diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py
index b4d0bd006..0ed0df661 100644
--- a/test/ext/mypy/plain_files/sql_operations.py
+++ b/test/ext/mypy/plain_files/sql_operations.py
@@ -56,7 +56,7 @@ if typing.TYPE_CHECKING:
# EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\]
reveal_type(expr2)
- # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\]
+ # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, .*\.Decimal\]\]
reveal_type(expr3)
# EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\]
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 4e8e2ac13..4e40ae0a2 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -25,6 +25,7 @@ from sqlalchemy import Column
from sqlalchemy import Date
from sqlalchemy import desc
from sqlalchemy import distinct
+from sqlalchemy import Enum
from sqlalchemy import exc
from sqlalchemy import except_
from sqlalchemy import exists
@@ -96,6 +97,7 @@ from sqlalchemy.testing import is_
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing import ne_
+from sqlalchemy.testing.schema import pep435_enum
table1 = table(
"mytable",
@@ -3745,6 +3747,46 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
s,
)
+ def test_bind_param_escaping(self):
+ """general bind param escape unit tests added as a result of
+ #8053
+ #
+ #"""
+
+ SomeEnum = pep435_enum("SomeEnum")
+ one = SomeEnum("one", 1)
+ SomeEnum("two", 2)
+
+ t = Table(
+ "t",
+ MetaData(),
+ Column("_id", Integer, primary_key=True),
+ Column("_data", Enum(SomeEnum)),
+ )
+
+ class MyCompiler(compiler.SQLCompiler):
+ def bindparam_string(self, name, **kw):
+ kw["escaped_from"] = name
+ return super(MyCompiler, self).bindparam_string(
+ '"%s"' % name, **kw
+ )
+
+ dialect = default.DefaultDialect()
+ dialect.statement_compiler = MyCompiler
+
+ self.assert_compile(
+ t.insert(),
+ 'INSERT INTO t (_id, _data) VALUES (:"_id", :"_data")',
+ dialect=dialect,
+ )
+
+ compiled = t.insert().compile(
+ dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
+ )
+ params = compiled.construct_params({"_id": 1, "_data": one})
+ eq_(params, {'"_id"': 1, '"_data"': one})
+ eq_(compiled._bind_processors, {'"_data"': mock.ANY})
+
def test_expanding_non_expanding_conflict(self):
"""test #8018"""