diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-25 08:47:29 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-25 10:07:31 -0400 |
| commit | a5d481eaa5bff958692fc3b0024f0b9b1c4f56c6 (patch) | |
| tree | e1a53b2d9ee12ac5d49c14eb85483c0a118745c7 | |
| parent | d7b131d2dfc4c519b23d9ed29364036ef88b1863 (diff) | |
| download | sqlalchemy-a5d481eaa5bff958692fc3b0024f0b9b1c4f56c6.tar.gz | |
apply bindparam escape name to processors dictionary
Fixed SQL compiler issue where the "bind processing" function for a bound
parameter would not be correctly applied to a bound value if the bound
parameter's name were "escaped". Concretely, this applies, among other
cases, to Oracle when a :class:`.Column` has a name that itself requires
quoting, such that the quoting-required name is then used for the bound
parameters generated within DML statements, and the datatype in use
requires bind processing, such as the :class:`.Enum` datatype.
Fixes: #8053
Change-Id: I39d060a87e240b4ebcfccaa9c535e971b7255d99
| -rw-r--r-- | doc/build/changelog/unreleased_14/8053.rst | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 10 | ||||
| -rw-r--r-- | test/dialect/oracle/test_dialect.py | 19 | ||||
| -rw-r--r-- | test/ext/mypy/plain_files/sql_operations.py | 2 | ||||
| -rw-r--r-- | test/sql/test_compiler.py | 42 |
5 files changed, 82 insertions, 2 deletions
diff --git a/doc/build/changelog/unreleased_14/8053.rst b/doc/build/changelog/unreleased_14/8053.rst new file mode 100644 index 000000000..316b63859 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8053.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, oracle + :tickets: 8053 + + Fixed SQL compiler issue where the "bind processing" function for a bound + parameter would not be correctly applied to a bound value if the bound + parameter's name were "escaped". Concretely, this applies, among other + cases, to Oracle when a :class:`.Column` has a name that itself requires + quoting, such that the quoting-required name is then used for the bound + parameters generated within DML statements, and the datatype in use + requires bind processing, such as the :class:`.Enum` datatype. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 0eae31a1a..63ed45a96 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1143,10 +1143,18 @@ class SQLCompiler(Compiled): str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] ]: + _escaped_bind_names = self.escaped_bind_names + has_escaped_names = bool(_escaped_bind_names) + # mypy is not able to see the two value types as the above Union, # it just sees "object". don't know how to resolve return dict( - (key, value) # type: ignore + ( + _escaped_bind_names.get(key, key) + if has_escaped_names + else key, + value, + ) # type: ignore for key, value in ( ( self.bind_names[bindparam], diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 26a29b73e..8d74c1f48 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -7,6 +7,7 @@ from unittest.mock import Mock from sqlalchemy import bindparam from sqlalchemy import Computed from sqlalchemy import create_engine +from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import func @@ -33,6 +34,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.schema import Table from sqlalchemy.testing.suite import test_select @@ -527,6 +529,23 @@ class QuotedBindRoundTripTest(fixtures.TestBase): 4, ) + def test_param_w_processors(self, metadata, connection): + """test #8053""" + + SomeEnum = pep435_enum("SomeEnum") + one = SomeEnum("one", 1) + SomeEnum("two", 2) + + t = Table( + "t", + metadata, + Column("_id", Integer, primary_key=True), + Column("_data", Enum(SomeEnum)), + ) + t.create(connection) + connection.execute(t.insert(), {"_id": 1, "_data": one}) + eq_(connection.scalar(select(t.c._data)), one) + def test_numeric_bind_in_crud(self, metadata, connection): t = Table("asfd", metadata, Column("100K", Integer)) t.create(connection) diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py index b4d0bd006..0ed0df661 100644 --- a/test/ext/mypy/plain_files/sql_operations.py +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -56,7 +56,7 @@ if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\] reveal_type(expr2) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, .*\.Decimal\]\] reveal_type(expr3) # EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\] diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 4e8e2ac13..4e40ae0a2 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -25,6 +25,7 @@ from sqlalchemy import Column from sqlalchemy import Date from sqlalchemy import desc from sqlalchemy import distinct +from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import except_ from sqlalchemy import exists @@ -96,6 +97,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ +from sqlalchemy.testing.schema import pep435_enum table1 = table( "mytable", @@ -3745,6 +3747,46 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): s, ) + def test_bind_param_escaping(self): + """general bind param escape unit tests added as a result of + #8053 + # + #""" + + SomeEnum = pep435_enum("SomeEnum") + one = SomeEnum("one", 1) + SomeEnum("two", 2) + + t = Table( + "t", + MetaData(), + Column("_id", Integer, primary_key=True), + Column("_data", Enum(SomeEnum)), + ) + + class MyCompiler(compiler.SQLCompiler): + def bindparam_string(self, name, **kw): + kw["escaped_from"] = name + return super(MyCompiler, self).bindparam_string( + '"%s"' % name, **kw + ) + + dialect = default.DefaultDialect() + dialect.statement_compiler = MyCompiler + + self.assert_compile( + t.insert(), + 'INSERT INTO t (_id, _data) VALUES (:"_id", :"_data")', + dialect=dialect, + ) + + compiled = t.insert().compile( + dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data")) + ) + params = compiled.construct_params({"_id": 1, "_data": one}) + eq_(params, {'"_id"': 1, '"_data"': one}) + eq_(compiled._bind_processors, {'"_data"': mock.ANY}) + def test_expanding_non_expanding_conflict(self): """test #8018""" |
