diff options
Diffstat (limited to 'lib/sqlalchemy/sql/elements.py')
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 138 |
1 files changed, 94 insertions, 44 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8e1b623a7..60c816ee6 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -381,6 +381,7 @@ class ClauseElement( try: traverse_internals = self._traverse_internals except AttributeError: + # user-defined classes may not have a _traverse_internals return for attrname, obj, meth in _copy_internals.run_generated_dispatch( @@ -410,6 +411,7 @@ class ClauseElement( try: traverse_internals = self._traverse_internals except AttributeError: + # user-defined classes may not have a _traverse_internals return [] return itertools.chain.from_iterable( @@ -516,10 +518,62 @@ class ClauseElement( dialect = bind.dialect elif self.bind: dialect = self.bind.dialect - bind = self.bind else: dialect = default.StrCompileDialect() - return self._compiler(dialect, bind=bind, **kw) + + return self._compiler(dialect, **kw) + + def _compile_w_cache( + self, + dialect, + compiled_cache=None, + column_keys=None, + inline=False, + schema_translate_map=None, + **kw + ): + if compiled_cache is not None: + elem_cache_key = self._generate_cache_key() + else: + elem_cache_key = None + + cache_hit = False + + if elem_cache_key: + cache_key, extracted_params = elem_cache_key + key = ( + dialect, + cache_key, + tuple(column_keys), + bool(schema_translate_map), + inline, + ) + compiled_sql = compiled_cache.get(key) + + if compiled_sql is None: + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + inline=inline, + schema_translate_map=schema_translate_map, + **kw + ) + compiled_cache[key] = compiled_sql + else: + cache_hit = True + else: + extracted_params = None + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + inline=inline, + schema_translate_map=schema_translate_map, + **kw + ) + + return compiled_sql, extracted_params, cache_hit def _compiler(self, dialect, **kw): """Return a compiler appropriate for this ClauseElement, given a @@ -1035,6 +1089,10 @@ class BindParameter(roles.InElementRole, ColumnElement): _is_bind_parameter = True _key_is_anon = False + # bindparam implements its own _gen_cache_key() method however + # we check subclasses for this flag, else no cache key is generated + inherit_cache = True + def __init__( self, key, @@ -1396,6 +1454,13 @@ class BindParameter(roles.InElementRole, ColumnElement): return c def _gen_cache_key(self, anon_map, bindparams): + _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False) + + if not _gen_cache_ok: + if anon_map is not None: + anon_map[NO_CACHE] = True + return None + idself = id(self) if idself in anon_map: return (anon_map[idself], self.__class__) @@ -2082,6 +2147,7 @@ class ClauseList( roles.InElementRole, roles.OrderByRole, roles.ColumnsClauseRole, + roles.DMLColumnRole, ClauseElement, ): """Describe a list of clauses, separated by an operator. @@ -2174,6 +2240,7 @@ class ClauseList( class BooleanClauseList(ClauseList, ColumnElement): __visit_name__ = "clauselist" + inherit_cache = True _tuple_values = False @@ -3428,6 +3495,8 @@ class CollectionAggregate(UnaryExpression): class AsBoolean(WrapsColumnExpression, UnaryExpression): + inherit_cache = True + def __init__(self, element, operator, negate): self.element = element self.type = type_api.BOOLEANTYPE @@ -3474,6 +3543,7 @@ class BinaryExpression(ColumnElement): ("operator", InternalTraversal.dp_operator), ("negate", InternalTraversal.dp_operator), ("modifiers", InternalTraversal.dp_plain_dict), + ("type", InternalTraversal.dp_type,), # affects JSON CAST operators ] _is_implicitly_boolean = True @@ -3482,41 +3552,6 @@ class BinaryExpression(ColumnElement): """ - def _gen_cache_key(self, anon_map, bindparams): - # inlined for performance - - idself = id(self) - - if idself in anon_map: - return (anon_map[idself], self.__class__) - else: - # inline of - # id_ = anon_map[idself] - anon_map[idself] = id_ = str(anon_map.index) - anon_map.index += 1 - - if self._cache_key_traversal is NO_CACHE: - anon_map[NO_CACHE] = True - return None - - result = (id_, self.__class__) - - return result + ( - ("left", self.left._gen_cache_key(anon_map, bindparams)), - ("right", self.right._gen_cache_key(anon_map, bindparams)), - ("operator", self.operator), - ("negate", self.negate), - ( - "modifiers", - tuple( - (key, self.modifiers[key]) - for key in sorted(self.modifiers) - ) - if self.modifiers - else None, - ), - ) - def __init__( self, left, right, operator, type_=None, negate=None, modifiers=None ): @@ -3587,15 +3622,30 @@ class Slice(ColumnElement): __visit_name__ = "slice" _traverse_internals = [ - ("start", InternalTraversal.dp_plain_obj), - ("stop", InternalTraversal.dp_plain_obj), - ("step", InternalTraversal.dp_plain_obj), + ("start", InternalTraversal.dp_clauseelement), + ("stop", InternalTraversal.dp_clauseelement), + ("step", InternalTraversal.dp_clauseelement), ] - def __init__(self, start, stop, step): - self.start = start - self.stop = stop - self.step = step + def __init__(self, start, stop, step, _name=None): + self.start = coercions.expect( + roles.ExpressionElementRole, + start, + name=_name, + type_=type_api.INTEGERTYPE, + ) + self.stop = coercions.expect( + roles.ExpressionElementRole, + stop, + name=_name, + type_=type_api.INTEGERTYPE, + ) + self.step = coercions.expect( + roles.ExpressionElementRole, + step, + name=_name, + type_=type_api.INTEGERTYPE, + ) self.type = type_api.NULLTYPE def self_group(self, against=None): |
