diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-09-13 11:00:46 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-09-13 11:18:19 -0400 |
| commit | 14b634d7065446d146456eed006c4969a7972b1a (patch) | |
| tree | aa94ea605196a585de4b43f195df9db143133a0f | |
| parent | 479dbc99e7fc5a60f846992c0cca8542047a8933 (diff) | |
| download | sqlalchemy-14b634d7065446d146456eed006c4969a7972b1a.tar.gz | |
Add type awareness to evaluator
Fixed regression where using ORM update() with synchronize_session='fetch'
would fail due to the use of evaluators that are now used to determine the
in-Python value for expressions in the the SET clause when refreshing
objects; if the evaluators make use of math operators against non-numeric
values such as PostgreSQL JSONB, the non-evaluable condition would fail to
be detected correctly. The evaluator now limits the use of math mutation
operators to numeric types only, with the exception of "+" that continues
to work for strings as well. SQLAlchemy 2.0 may alter this further by
fetching the SET values completely rather than using evaluation.
Fixes: #8507
Change-Id: Icf7120ccbf4266499df6bb3e05159c9f50971d69
| -rw-r--r-- | doc/build/changelog/unreleased_14/8507.rst | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 68 | ||||
| -rw-r--r-- | test/orm/test_evaluator.py | 64 |
3 files changed, 126 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_14/8507.rst b/doc/build/changelog/unreleased_14/8507.rst new file mode 100644 index 000000000..07944da75 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8507.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 8507 + + Fixed regression where using ORM update() with synchronize_session='fetch' + would fail due to the use of evaluators that are now used to determine the + in-Python value for expressions in the the SET clause when refreshing + objects; if the evaluators make use of math operators against non-numeric + values such as PostgreSQL JSONB, the non-evaluable condition would fail to + be detected correctly. The evaluator now limits the use of math mutation + operators to numeric types only, with the exception of "+" that continues + to work for strings as well. SQLAlchemy 2.0 may alter this further by + fetching the SET values completely rather than using evaluation. diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 72936d1ab..b3129afdd 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -16,6 +16,8 @@ from .. import inspect from .. import util from ..sql import and_ from ..sql import operators +from ..sql.sqltypes import Integer +from ..sql.sqltypes import Numeric class UnevaluatableError(exc.InvalidRequestError): @@ -120,7 +122,7 @@ class EvaluatorCompiler: dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op" meth = getattr(self, dispatch, None) if meth: - return meth(clause.operator, eval_left, eval_right) + return meth(clause.operator, eval_left, eval_right, clause) else: raise UnevaluatableError( f"Cannot evaluate {type(clause).__name__} with " @@ -165,9 +167,13 @@ class EvaluatorCompiler: return evaluate - def visit_custom_op_binary_op(self, operator, eval_left, eval_right): + def visit_custom_op_binary_op( + self, operator, eval_left, eval_right, clause + ): if operator.python_impl: - return self._straight_evaluate(operator, eval_left, eval_right) + return self._straight_evaluate( + operator, eval_left, eval_right, clause + ) else: raise UnevaluatableError( f"Custom operator {operator.opstring!r} can't be evaluated " @@ -175,19 +181,19 @@ class EvaluatorCompiler: "`.python_impl`." ) - def visit_is_binary_op(self, operator, eval_left, eval_right): + def visit_is_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): return eval_left(obj) == eval_right(obj) return evaluate - def visit_is_not_binary_op(self, operator, eval_left, eval_right): + def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): return eval_left(obj) != eval_right(obj) return evaluate - def _straight_evaluate(self, operator, eval_left, eval_right): + def _straight_evaluate(self, operator, eval_left, eval_right, clause): def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) @@ -197,11 +203,25 @@ class EvaluatorCompiler: return evaluate - visit_add_binary_op = _straight_evaluate - visit_mul_binary_op = _straight_evaluate - visit_sub_binary_op = _straight_evaluate - visit_mod_binary_op = _straight_evaluate - visit_truediv_binary_op = _straight_evaluate + def _straight_evaluate_numeric_only( + self, operator, eval_left, eval_right, clause + ): + if clause.left.type._type_affinity not in ( + Numeric, + Integer, + ) or clause.right.type._type_affinity not in (Numeric, Integer): + raise UnevaluatableError( + f'Cannot evaluate math operator "{operator.__name__}" for ' + f"datatypes {clause.left.type}, {clause.right.type}" + ) + + return self._straight_evaluate(operator, eval_left, eval_right, clause) + + visit_add_binary_op = _straight_evaluate_numeric_only + visit_mul_binary_op = _straight_evaluate_numeric_only + visit_sub_binary_op = _straight_evaluate_numeric_only + visit_mod_binary_op = _straight_evaluate_numeric_only + visit_truediv_binary_op = _straight_evaluate_numeric_only visit_lt_binary_op = _straight_evaluate visit_le_binary_op = _straight_evaluate visit_ne_binary_op = _straight_evaluate @@ -209,33 +229,43 @@ class EvaluatorCompiler: visit_ge_binary_op = _straight_evaluate visit_eq_binary_op = _straight_evaluate - def visit_in_op_binary_op(self, operator, eval_left, eval_right): + def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause): return self._straight_evaluate( lambda a, b: a in b if a is not _NO_OBJECT else None, eval_left, eval_right, + clause, ) - def visit_not_in_op_binary_op(self, operator, eval_left, eval_right): + def visit_not_in_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( lambda a, b: a not in b if a is not _NO_OBJECT else None, eval_left, eval_right, + clause, ) - def visit_concat_op_binary_op(self, operator, eval_left, eval_right): + def visit_concat_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( - lambda a, b: a + b, eval_left, eval_right + lambda a, b: a + b, eval_left, eval_right, clause ) - def visit_startswith_op_binary_op(self, operator, eval_left, eval_right): + def visit_startswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( - lambda a, b: a.startswith(b), eval_left, eval_right + lambda a, b: a.startswith(b), eval_left, eval_right, clause ) - def visit_endswith_op_binary_op(self, operator, eval_left, eval_right): + def visit_endswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( - lambda a, b: a.endswith(b), eval_left, eval_right + lambda a, b: a.endswith(b), eval_left, eval_right, clause ) def visit_unary(self, clause): diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py index 104e47ae8..ff40cd201 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/test_evaluator.py @@ -5,6 +5,7 @@ from sqlalchemy import bindparam from sqlalchemy import ForeignKey from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import JSON from sqlalchemy import not_ from sqlalchemy import or_ from sqlalchemy import String @@ -16,6 +17,7 @@ from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import relationship from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -53,6 +55,7 @@ class EvaluateTest(fixtures.MappedTest): Column("id", Integer, primary_key=True), Column("name", String(64)), Column("othername", String(64)), + Column("json", JSON), ) @classmethod @@ -343,6 +346,67 @@ class EvaluateTest(fixtures.MappedTest): ], ) + @testing.combinations( + (lambda User: User.id + 5, "id", 10, 15, None), + ( + # note this one uses concat_op, not operator.add + lambda User: User.name + " name", + "name", + "some value", + "some value name", + None, + ), + ( + lambda User: User.id + "name", + "id", + 10, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"add\" for " + r"datatypes INTEGER, VARCHAR", + ), + ( + lambda User: User.json + 12, + "json", + {"foo": "bar"}, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"add\" for " + r"datatypes JSON, INTEGER", + ), + ( + lambda User: User.json - 12, + "json", + {"foo": "bar"}, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"sub\" for " + r"datatypes JSON, INTEGER", + ), + ( + lambda User: User.json - "foo", + "json", + {"foo": "bar"}, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"sub\" for " + r"datatypes JSON, VARCHAR", + ), + ) + def test_math_op_type_exclusions( + self, expr, attrname, initial_value, expected, message + ): + """test #8507""" + + User = self.classes.User + + expr = testing.resolve_lambda(expr, User=User) + + if expected is evaluator.UnevaluatableError: + with expect_raises_message(evaluator.UnevaluatableError, message): + compiler.process(expr) + else: + obj = User(**{attrname: initial_value}) + + new_value = compiler.process(expr)(obj) + eq_(new_value, expected) + class M2OEvaluateTest(fixtures.DeclarativeMappedTest): @classmethod |
