summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-09-13 11:00:46 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-09-13 11:18:19 -0400
commit14b634d7065446d146456eed006c4969a7972b1a (patch)
treeaa94ea605196a585de4b43f195df9db143133a0f
parent479dbc99e7fc5a60f846992c0cca8542047a8933 (diff)
downloadsqlalchemy-14b634d7065446d146456eed006c4969a7972b1a.tar.gz
Add type awareness to evaluator
Fixed regression where using ORM update() with synchronize_session='fetch' would fail due to the use of evaluators that are now used to determine the in-Python value for expressions in the the SET clause when refreshing objects; if the evaluators make use of math operators against non-numeric values such as PostgreSQL JSONB, the non-evaluable condition would fail to be detected correctly. The evaluator now limits the use of math mutation operators to numeric types only, with the exception of "+" that continues to work for strings as well. SQLAlchemy 2.0 may alter this further by fetching the SET values completely rather than using evaluation. Fixes: #8507 Change-Id: Icf7120ccbf4266499df6bb3e05159c9f50971d69
-rw-r--r--doc/build/changelog/unreleased_14/8507.rst13
-rw-r--r--lib/sqlalchemy/orm/evaluator.py68
-rw-r--r--test/orm/test_evaluator.py64
3 files changed, 126 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_14/8507.rst b/doc/build/changelog/unreleased_14/8507.rst
new file mode 100644
index 000000000..07944da75
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/8507.rst
@@ -0,0 +1,13 @@
+.. change::
+ :tags: bug, orm, regression
+ :tickets: 8507
+
+ Fixed regression where using ORM update() with synchronize_session='fetch'
+ would fail due to the use of evaluators that are now used to determine the
+ in-Python value for expressions in the the SET clause when refreshing
+ objects; if the evaluators make use of math operators against non-numeric
+ values such as PostgreSQL JSONB, the non-evaluable condition would fail to
+ be detected correctly. The evaluator now limits the use of math mutation
+ operators to numeric types only, with the exception of "+" that continues
+ to work for strings as well. SQLAlchemy 2.0 may alter this further by
+ fetching the SET values completely rather than using evaluation.
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
index 72936d1ab..b3129afdd 100644
--- a/lib/sqlalchemy/orm/evaluator.py
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -16,6 +16,8 @@ from .. import inspect
from .. import util
from ..sql import and_
from ..sql import operators
+from ..sql.sqltypes import Integer
+from ..sql.sqltypes import Numeric
class UnevaluatableError(exc.InvalidRequestError):
@@ -120,7 +122,7 @@ class EvaluatorCompiler:
dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
meth = getattr(self, dispatch, None)
if meth:
- return meth(clause.operator, eval_left, eval_right)
+ return meth(clause.operator, eval_left, eval_right, clause)
else:
raise UnevaluatableError(
f"Cannot evaluate {type(clause).__name__} with "
@@ -165,9 +167,13 @@ class EvaluatorCompiler:
return evaluate
- def visit_custom_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_custom_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
if operator.python_impl:
- return self._straight_evaluate(operator, eval_left, eval_right)
+ return self._straight_evaluate(
+ operator, eval_left, eval_right, clause
+ )
else:
raise UnevaluatableError(
f"Custom operator {operator.opstring!r} can't be evaluated "
@@ -175,19 +181,19 @@ class EvaluatorCompiler:
"`.python_impl`."
)
- def visit_is_binary_op(self, operator, eval_left, eval_right):
+ def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
return eval_left(obj) == eval_right(obj)
return evaluate
- def visit_is_not_binary_op(self, operator, eval_left, eval_right):
+ def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
return evaluate
- def _straight_evaluate(self, operator, eval_left, eval_right):
+ def _straight_evaluate(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
@@ -197,11 +203,25 @@ class EvaluatorCompiler:
return evaluate
- visit_add_binary_op = _straight_evaluate
- visit_mul_binary_op = _straight_evaluate
- visit_sub_binary_op = _straight_evaluate
- visit_mod_binary_op = _straight_evaluate
- visit_truediv_binary_op = _straight_evaluate
+ def _straight_evaluate_numeric_only(
+ self, operator, eval_left, eval_right, clause
+ ):
+ if clause.left.type._type_affinity not in (
+ Numeric,
+ Integer,
+ ) or clause.right.type._type_affinity not in (Numeric, Integer):
+ raise UnevaluatableError(
+ f'Cannot evaluate math operator "{operator.__name__}" for '
+ f"datatypes {clause.left.type}, {clause.right.type}"
+ )
+
+ return self._straight_evaluate(operator, eval_left, eval_right, clause)
+
+ visit_add_binary_op = _straight_evaluate_numeric_only
+ visit_mul_binary_op = _straight_evaluate_numeric_only
+ visit_sub_binary_op = _straight_evaluate_numeric_only
+ visit_mod_binary_op = _straight_evaluate_numeric_only
+ visit_truediv_binary_op = _straight_evaluate_numeric_only
visit_lt_binary_op = _straight_evaluate
visit_le_binary_op = _straight_evaluate
visit_ne_binary_op = _straight_evaluate
@@ -209,33 +229,43 @@ class EvaluatorCompiler:
visit_ge_binary_op = _straight_evaluate
visit_eq_binary_op = _straight_evaluate
- def visit_in_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
return self._straight_evaluate(
lambda a, b: a in b if a is not _NO_OBJECT else None,
eval_left,
eval_right,
+ clause,
)
- def visit_not_in_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_not_in_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
lambda a, b: a not in b if a is not _NO_OBJECT else None,
eval_left,
eval_right,
+ clause,
)
- def visit_concat_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_concat_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
- lambda a, b: a + b, eval_left, eval_right
+ lambda a, b: a + b, eval_left, eval_right, clause
)
- def visit_startswith_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_startswith_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
- lambda a, b: a.startswith(b), eval_left, eval_right
+ lambda a, b: a.startswith(b), eval_left, eval_right, clause
)
- def visit_endswith_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_endswith_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
- lambda a, b: a.endswith(b), eval_left, eval_right
+ lambda a, b: a.endswith(b), eval_left, eval_right, clause
)
def visit_unary(self, clause):
diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py
index 104e47ae8..ff40cd201 100644
--- a/test/orm/test_evaluator.py
+++ b/test/orm/test_evaluator.py
@@ -5,6 +5,7 @@ from sqlalchemy import bindparam
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
+from sqlalchemy import JSON
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import String
@@ -16,6 +17,7 @@ from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import relationship
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
@@ -53,6 +55,7 @@ class EvaluateTest(fixtures.MappedTest):
Column("id", Integer, primary_key=True),
Column("name", String(64)),
Column("othername", String(64)),
+ Column("json", JSON),
)
@classmethod
@@ -343,6 +346,67 @@ class EvaluateTest(fixtures.MappedTest):
],
)
+ @testing.combinations(
+ (lambda User: User.id + 5, "id", 10, 15, None),
+ (
+ # note this one uses concat_op, not operator.add
+ lambda User: User.name + " name",
+ "name",
+ "some value",
+ "some value name",
+ None,
+ ),
+ (
+ lambda User: User.id + "name",
+ "id",
+ 10,
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"add\" for "
+ r"datatypes INTEGER, VARCHAR",
+ ),
+ (
+ lambda User: User.json + 12,
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"add\" for "
+ r"datatypes JSON, INTEGER",
+ ),
+ (
+ lambda User: User.json - 12,
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"sub\" for "
+ r"datatypes JSON, INTEGER",
+ ),
+ (
+ lambda User: User.json - "foo",
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"sub\" for "
+ r"datatypes JSON, VARCHAR",
+ ),
+ )
+ def test_math_op_type_exclusions(
+ self, expr, attrname, initial_value, expected, message
+ ):
+ """test #8507"""
+
+ User = self.classes.User
+
+ expr = testing.resolve_lambda(expr, User=User)
+
+ if expected is evaluator.UnevaluatableError:
+ with expect_raises_message(evaluator.UnevaluatableError, message):
+ compiler.process(expr)
+ else:
+ obj = User(**{attrname: initial_value})
+
+ new_value = compiler.process(expr)(obj)
+ eq_(new_value, expected)
+
class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
@classmethod