summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/elements.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/elements.py')
-rw-r--r--lib/sqlalchemy/sql/elements.py138
1 files changed, 94 insertions, 44 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8e1b623a7..60c816ee6 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -381,6 +381,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return
for attrname, obj, meth in _copy_internals.run_generated_dispatch(
@@ -410,6 +411,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return []
return itertools.chain.from_iterable(
@@ -516,10 +518,62 @@ class ClauseElement(
dialect = bind.dialect
elif self.bind:
dialect = self.bind.dialect
- bind = self.bind
else:
dialect = default.StrCompileDialect()
- return self._compiler(dialect, bind=bind, **kw)
+
+ return self._compiler(dialect, **kw)
+
+ def _compile_w_cache(
+ self,
+ dialect,
+ compiled_cache=None,
+ column_keys=None,
+ inline=False,
+ schema_translate_map=None,
+ **kw
+ ):
+ if compiled_cache is not None:
+ elem_cache_key = self._generate_cache_key()
+ else:
+ elem_cache_key = None
+
+ cache_hit = False
+
+ if elem_cache_key:
+ cache_key, extracted_params = elem_cache_key
+ key = (
+ dialect,
+ cache_key,
+ tuple(column_keys),
+ bool(schema_translate_map),
+ inline,
+ )
+ compiled_sql = compiled_cache.get(key)
+
+ if compiled_sql is None:
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+ compiled_cache[key] = compiled_sql
+ else:
+ cache_hit = True
+ else:
+ extracted_params = None
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+
+ return compiled_sql, extracted_params, cache_hit
def _compiler(self, dialect, **kw):
"""Return a compiler appropriate for this ClauseElement, given a
@@ -1035,6 +1089,10 @@ class BindParameter(roles.InElementRole, ColumnElement):
_is_bind_parameter = True
_key_is_anon = False
+ # bindparam implements its own _gen_cache_key() method however
+ # we check subclasses for this flag, else no cache key is generated
+ inherit_cache = True
+
def __init__(
self,
key,
@@ -1396,6 +1454,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
return c
def _gen_cache_key(self, anon_map, bindparams):
+ _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False)
+
+ if not _gen_cache_ok:
+ if anon_map is not None:
+ anon_map[NO_CACHE] = True
+ return None
+
idself = id(self)
if idself in anon_map:
return (anon_map[idself], self.__class__)
@@ -2082,6 +2147,7 @@ class ClauseList(
roles.InElementRole,
roles.OrderByRole,
roles.ColumnsClauseRole,
+ roles.DMLColumnRole,
ClauseElement,
):
"""Describe a list of clauses, separated by an operator.
@@ -2174,6 +2240,7 @@ class ClauseList(
class BooleanClauseList(ClauseList, ColumnElement):
__visit_name__ = "clauselist"
+ inherit_cache = True
_tuple_values = False
@@ -3428,6 +3495,8 @@ class CollectionAggregate(UnaryExpression):
class AsBoolean(WrapsColumnExpression, UnaryExpression):
+ inherit_cache = True
+
def __init__(self, element, operator, negate):
self.element = element
self.type = type_api.BOOLEANTYPE
@@ -3474,6 +3543,7 @@ class BinaryExpression(ColumnElement):
("operator", InternalTraversal.dp_operator),
("negate", InternalTraversal.dp_operator),
("modifiers", InternalTraversal.dp_plain_dict),
+ ("type", InternalTraversal.dp_type,), # affects JSON CAST operators
]
_is_implicitly_boolean = True
@@ -3482,41 +3552,6 @@ class BinaryExpression(ColumnElement):
"""
- def _gen_cache_key(self, anon_map, bindparams):
- # inlined for performance
-
- idself = id(self)
-
- if idself in anon_map:
- return (anon_map[idself], self.__class__)
- else:
- # inline of
- # id_ = anon_map[idself]
- anon_map[idself] = id_ = str(anon_map.index)
- anon_map.index += 1
-
- if self._cache_key_traversal is NO_CACHE:
- anon_map[NO_CACHE] = True
- return None
-
- result = (id_, self.__class__)
-
- return result + (
- ("left", self.left._gen_cache_key(anon_map, bindparams)),
- ("right", self.right._gen_cache_key(anon_map, bindparams)),
- ("operator", self.operator),
- ("negate", self.negate),
- (
- "modifiers",
- tuple(
- (key, self.modifiers[key])
- for key in sorted(self.modifiers)
- )
- if self.modifiers
- else None,
- ),
- )
-
def __init__(
self, left, right, operator, type_=None, negate=None, modifiers=None
):
@@ -3587,15 +3622,30 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
_traverse_internals = [
- ("start", InternalTraversal.dp_plain_obj),
- ("stop", InternalTraversal.dp_plain_obj),
- ("step", InternalTraversal.dp_plain_obj),
+ ("start", InternalTraversal.dp_clauseelement),
+ ("stop", InternalTraversal.dp_clauseelement),
+ ("step", InternalTraversal.dp_clauseelement),
]
- def __init__(self, start, stop, step):
- self.start = start
- self.stop = stop
- self.step = step
+ def __init__(self, start, stop, step, _name=None):
+ self.start = coercions.expect(
+ roles.ExpressionElementRole,
+ start,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.stop = coercions.expect(
+ roles.ExpressionElementRole,
+ stop,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.step = coercions.expect(
+ roles.ExpressionElementRole,
+ step,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
self.type = type_api.NULLTYPE
def self_group(self, against=None):