diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 714 |
1 files changed, 492 insertions, 222 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a5f545de9..4448f7c7b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1,5 +1,5 @@ # sql/compiler.py -# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file> +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -23,13 +23,12 @@ To generate user-defined SQL strings, see """ import re -import sys -from .. import schema, engine, util, exc, types -from . import ( - operators, functions, util as sql_util, visitors, expression as sql -) +from . import schema, sqltypes, operators, functions, \ + util as sql_util, visitors, elements, selectable, base +from .. import util, exc import decimal import itertools +import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -115,6 +114,7 @@ OPERATORS = { operators.asc_op: ' ASC', operators.nullsfirst_op: ' NULLS FIRST', operators.nullslast_op: ' NULLS LAST', + } FUNCTIONS = { @@ -150,14 +150,122 @@ EXTRACT_MAP = { } COMPOUND_KEYWORDS = { - sql.CompoundSelect.UNION: 'UNION', - sql.CompoundSelect.UNION_ALL: 'UNION ALL', - sql.CompoundSelect.EXCEPT: 'EXCEPT', - sql.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', - sql.CompoundSelect.INTERSECT: 'INTERSECT', - sql.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' + selectable.CompoundSelect.UNION: 'UNION', + selectable.CompoundSelect.UNION_ALL: 'UNION ALL', + selectable.CompoundSelect.EXCEPT: 'EXCEPT', + selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', + selectable.CompoundSelect.INTERSECT: 'INTERSECT', + selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' } +class Compiled(object): + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + def __init__(self, dialect, statement, bind=None, + compile_kwargs=util.immutabledict()): + """Construct a new ``Compiled`` object. + + :param dialect: ``Dialect`` to compile against. + + :param statement: ``ClauseElement`` to be compiled. + + :param bind: Optional Engine or Connection to compile this + statement against. + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. + + .. versionadded:: 0.8 + + """ + + self.dialect = dialect + self.bind = bind + if statement is not None: + self.statement = statement + self.can_execute = statement.supports_execution + self.string = self.process(self.statement, **compile_kwargs) + + @util.deprecated("0.7", ":class:`.Compiled` objects now compile " + "within the constructor.") + def compile(self): + """Produce the internal string representation of this element. + """ + pass + + def _execute_on_connection(self, connection, multiparams, params): + return connection._execute_compiled(self, multiparams, params) + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self): + """Return the string text of the generated SQL or DDL.""" + + return self.string or '' + + def construct_params(self, params=None): + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whose values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.bind + if e is None: + raise exc.UnboundExecutionError( + "This Compiled object is not bound to any Engine " + "or Connection.") + return e._execute_compiled(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's + scalar value.""" + + return self.execute(*multiparams, **params).scalar() + + +class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_): + return type_._compiler_dispatch(self) + + class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression.Label.""" @@ -178,12 +286,8 @@ class _CompileLabel(visitors.Visitable): def type(self): return self.element.type - @property - def quote(self): - return self.element.quote - -class SQLCompiler(engine.Compiled): +class SQLCompiler(Compiled): """Default implementation of Compiled. Compiles ClauseElements into SQL strings. Uses a similar visit @@ -284,7 +388,7 @@ class SQLCompiler(engine.Compiled): # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} - engine.Compiled.__init__(self, dialect, statement, **kwargs) + Compiled.__init__(self, dialect, statement, **kwargs) if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() @@ -397,7 +501,7 @@ class SQLCompiler(engine.Compiled): render_label_only = render_label_as_label is label if render_label_only or render_label_with_as: - if isinstance(label.name, sql._truncated_label): + if isinstance(label.name, elements._truncated_label): labelname = self._truncated_identifier("colident", label.name) else: labelname = label.name @@ -432,7 +536,7 @@ class SQLCompiler(engine.Compiled): "its 'name' is assigned.") is_literal = column.is_literal - if not is_literal and isinstance(name, sql._truncated_label): + if not is_literal and isinstance(name, elements._truncated_label): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: @@ -446,24 +550,22 @@ class SQLCompiler(engine.Compiled): if is_literal: name = self.escape_literal_column(name) else: - name = self.preparer.quote(name, column.quote) + name = self.preparer.quote(name) table = column.table if table is None or not include_table or not table.named_with_column: return name else: if table.schema: - schema_prefix = self.preparer.quote_schema( - table.schema, - table.quote_schema) + '.' + schema_prefix = self.preparer.quote_schema(table.schema) + '.' else: schema_prefix = '' tablename = table.name - if isinstance(tablename, sql._truncated_label): + if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) return schema_prefix + \ - self.preparer.quote(tablename, table.quote) + \ + self.preparer.quote(tablename) + \ "." + name def escape_literal_column(self, text): @@ -484,20 +586,13 @@ class SQLCompiler(engine.Compiled): def post_process_text(self, text): return text - def visit_textclause(self, textclause, **kwargs): - if textclause.typemap is not None: - for colname, type_ in textclause.typemap.items(): - self.result_map[colname - if self.dialect.case_sensitive - else colname.lower()] = \ - (colname, None, type_) - + def visit_textclause(self, textclause, **kw): def do_bindparam(m): name = m.group(1) - if name in textclause.bindparams: - return self.process(textclause.bindparams[name]) + if name in textclause._bindparams: + return self.process(textclause._bindparams[name], **kw) else: - return self.bindparam_string(name, **kwargs) + return self.bindparam_string(name, **kw) # un-escape any \:params return BIND_PARAMS_ESC.sub(lambda m: m.group(1), @@ -505,14 +600,47 @@ class SQLCompiler(engine.Compiled): self.post_process_text(textclause.text)) ) + def visit_text_as_from(self, taf, iswrapper=False, + compound_index=0, force_result_map=False, + asfrom=False, + parens=True, **kw): + + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + populate_result_map = force_result_map or ( + compound_index == 0 and ( + toplevel or \ + entry['iswrapper'] + ) + ) + + if populate_result_map: + for c in taf.c: + self._add_to_result_map( + c.key, c.key, (c,), c.type + ) + + text = self.process(taf.element, **kw) + if asfrom and parens: + text = "(%s)" % text + return text + + def visit_null(self, expr, **kw): return 'NULL' def visit_true(self, expr, **kw): - return 'true' + if self.dialect.supports_native_boolean: + return 'true' + else: + return "1" def visit_false(self, expr, **kw): - return 'false' + if self.dialect.supports_native_boolean: + return 'false' + else: + return "0" def visit_clauselist(self, clauselist, order_by_select=None, **kw): if order_by_select is not None: @@ -619,6 +747,7 @@ class SQLCompiler(engine.Compiled): def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) + def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=0, **kwargs): toplevel = not self.stack @@ -684,11 +813,23 @@ class SQLCompiler(engine.Compiled): raise exc.CompileError( "Unary expression has no operator or modifier") + def visit_istrue_unary_operator(self, element, operator, **kw): + if self.dialect.supports_native_boolean: + return self.process(element.element, **kw) + else: + return "%s = 1" % self.process(element.element, **kw) + + def visit_isfalse_unary_operator(self, element, operator, **kw): + if self.dialect.supports_native_boolean: + return "NOT %s" % self.process(element.element, **kw) + else: + return "%s = 0" % self.process(element.element, **kw) + def visit_binary(self, binary, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ - isinstance(binary.left, sql.BindParameter) and \ - isinstance(binary.right, sql.BindParameter): + isinstance(binary.left, elements.BindParameter) and \ + isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True operator = binary.operator @@ -728,7 +869,7 @@ class SQLCompiler(engine.Compiled): @util.memoized_property def _like_percent_literal(self): - return sql.literal_column("'%'", type_=types.String()) + return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE) def visit_contains_op_binary(self, binary, operator, **kw): binary = binary._clone() @@ -772,39 +913,49 @@ class SQLCompiler(engine.Compiled): def visit_like_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) + + # TODO: use ternary here, not "and"/ "or" return '%s LIKE %s' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT LIKE %s' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_bindparam(self, bindparam, within_columns_clause=False, literal_binds=False, @@ -820,8 +971,9 @@ class SQLCompiler(engine.Compiled): (within_columns_clause and \ self.ansi_bind_rules): if bindparam.value is None: - raise exc.CompileError("Bind parameter without a " - "renderable value not allowed here.") + raise exc.CompileError("Bind parameter '%s' without a " + "renderable value not allowed here." + % bindparam.key) return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs) @@ -851,13 +1003,10 @@ class SQLCompiler(engine.Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - return self.bindparam_string(name, quote=bindparam.quote, **kwargs) + return self.bindparam_string(name, **kwargs) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.value - processor = bindparam.type._cached_bind_processor(self.dialect) - if processor: - value = processor(value) return self.render_literal_value(value, bindparam.type) def render_literal_value(self, value, type_): @@ -870,15 +1019,10 @@ class SQLCompiler(engine.Compiled): of the DBAPI. """ - if isinstance(value, util.string_types): - value = value.replace("'", "''") - return "'%s'" % value - elif value is None: - return "NULL" - elif isinstance(value, (float, ) + util.int_types): - return repr(value) - elif isinstance(value, decimal.Decimal): - return str(value) + + processor = type_._cached_literal_processor(self.dialect) + if processor: + return processor(value) else: raise NotImplementedError( "Don't know how to literal-quote value %r" % value) @@ -888,7 +1032,7 @@ class SQLCompiler(engine.Compiled): return self.bind_names[bindparam] bind_name = bindparam.key - if isinstance(bind_name, sql._truncated_label): + if isinstance(bind_name, elements._truncated_label): bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation @@ -921,8 +1065,7 @@ class SQLCompiler(engine.Compiled): self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) - def bindparam_string(self, name, quote=None, - positional_names=None, **kw): + def bindparam_string(self, name, positional_names=None, **kw): if self.positional: if positional_names is not None: positional_names.append(name) @@ -937,7 +1080,7 @@ class SQLCompiler(engine.Compiled): if self.positional: kwargs['positional_names'] = self.cte_positional - if isinstance(cte.name, sql._truncated_label): + if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: cte_name = cte.name @@ -947,7 +1090,7 @@ class SQLCompiler(engine.Compiled): # we've generated a same-named CTE that we are enclosed in, # or this is the same CTE. just return the name. if cte in existing_cte._restates or cte is existing_cte: - return cte_name + return self.preparer.format_alias(cte, cte_name) elif existing_cte in cte._restates: # we've generated a same-named CTE that is # enclosed in us - we take precedence, so @@ -961,19 +1104,24 @@ class SQLCompiler(engine.Compiled): self.ctes_by_name[cte_name] = cte - if cte.cte_alias: - if isinstance(cte.cte_alias, sql._truncated_label): - cte_alias = self._truncated_identifier("alias", cte.cte_alias) - else: - cte_alias = cte.cte_alias - if not cte.cte_alias and cte not in self.ctes: + if cte._cte_alias is not None: + orig_cte = cte._cte_alias + if orig_cte not in self.ctes: + self.visit_cte(orig_cte) + cte_alias_name = cte._cte_alias.name + if isinstance(cte_alias_name, elements._truncated_label): + cte_alias_name = self._truncated_identifier("alias", cte_alias_name) + else: + orig_cte = cte + cte_alias_name = None + if not cte_alias_name and cte not in self.ctes: if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) if cte.recursive: - if isinstance(cte.original, sql.Select): + if isinstance(cte.original, selectable.Select): col_source = cte.original - elif isinstance(cte.original, sql.CompoundSelect): + elif isinstance(cte.original, selectable.CompoundSelect): col_source = cte.original.selects[0] else: assert False @@ -989,9 +1137,10 @@ class SQLCompiler(engine.Compiled): self, asfrom=True, **kwargs ) self.ctes[cte] = text + if asfrom: - if cte.cte_alias: - text = self.preparer.format_alias(cte, cte_alias) + if cte_alias_name: + text = self.preparer.format_alias(cte, cte_alias_name) text += " AS " + cte_name else: return self.preparer.format_alias(cte, cte_name) @@ -1001,7 +1150,7 @@ class SQLCompiler(engine.Compiled): iscrud=False, fromhints=None, **kwargs): if asfrom or ashint: - if isinstance(alias.name, sql._truncated_label): + if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) else: alias_name = alias.name @@ -1059,7 +1208,7 @@ class SQLCompiler(engine.Compiled): if not within_columns_clause: result_expr = col_expr - elif isinstance(column, sql.Label): + elif isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( col_expr, @@ -1078,23 +1227,23 @@ class SQLCompiler(engine.Compiled): elif \ asfrom and \ - isinstance(column, sql.ColumnClause) and \ + isinstance(column, elements.ColumnClause) and \ not column.is_literal and \ column.table is not None and \ - not isinstance(column.table, sql.Select): + not isinstance(column.table, selectable.Select): result_expr = _CompileLabel(col_expr, - sql._as_truncated(column.name), + elements._as_truncated(column.name), alt_names=(column.key,)) elif not isinstance(column, - (sql.UnaryExpression, sql.TextClause)) \ + (elements.UnaryExpression, elements.TextClause)) \ and (not hasattr(column, 'name') or \ - isinstance(column, sql.Function)): + isinstance(column, functions.Function)): result_expr = _CompileLabel(col_expr, column.anon_label) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? - # assert isinstance(column, sql.ColumnClause) + # assert isinstance(column, elements.ColumnClause) result_expr = _CompileLabel(col_expr, - sql._as_truncated(column.name), + elements._as_truncated(column.name), alt_names=(column.key,)) else: result_expr = col_expr @@ -1137,8 +1286,8 @@ class SQLCompiler(engine.Compiled): # as this whole system won't work for custom Join/Select # subclasses where compilation routines # call down to compiler.visit_join(), compiler.visit_select() - join_name = sql.Join.__visit_name__ - select_name = sql.Select.__visit_name__ + join_name = selectable.Join.__visit_name__ + select_name = selectable.Select.__visit_name__ def visit(element, **kw): if element in column_translate[-1]: @@ -1150,24 +1299,27 @@ class SQLCompiler(engine.Compiled): newelem = cloned[element] = element._clone() if newelem.__visit_name__ is join_name and \ - isinstance(newelem.right, sql.FromGrouping): + isinstance(newelem.right, selectable.FromGrouping): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) right = visit(newelem.right, **kw) - selectable = sql.select( + selectable_ = selectable.Select( [right.element], use_labels=True).alias() - for c in selectable.c: - c._label = c._key_label = c.name + for c in selectable_.c: + c._key_label = c.key + c._label = c.name + translate_dict = dict( - zip(right.element.c, selectable.c) - ) - translate_dict[right.element.left] = selectable - translate_dict[right.element.right] = selectable + zip(newelem.right.element.c, selectable_.c) + ) + + translate_dict[right.element.left] = selectable_ + translate_dict[right.element.right] = selectable_ # propagate translations that we've gained # from nested visit(newelem.right) outwards @@ -1183,7 +1335,8 @@ class SQLCompiler(engine.Compiled): column_translate[-1].update(translate_dict) - newelem.right = selectable + newelem.right = selectable_ + newelem.onclause = visit(newelem.onclause, **kw) elif newelem.__visit_name__ is select_name: column_translate.append({}) @@ -1199,6 +1352,7 @@ class SQLCompiler(engine.Compiled): def _transform_result_map_for_nested_joins(self, select, transformed_select): inner_col = dict((c._key_label, c) for c in transformed_select.inner_columns) + d = dict( (inner_col[c._key_label], c) for c in select.inner_columns @@ -1291,7 +1445,7 @@ class SQLCompiler(engine.Compiled): explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms) - new_correlate_froms = set(sql._from_objects(*froms)) + new_correlate_froms = set(selectable._from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) new_entry = { @@ -1382,9 +1536,11 @@ class SQLCompiler(engine.Compiled): text += self.order_by_clause(select, order_by_select=order_by_select, **kwargs) + if select._limit is not None or select._offset is not None: text += self.limit_clause(select) - if select.for_update: + + if select._for_update_arg is not None: text += self.for_update_clause(select) if self.ctes and \ @@ -1440,10 +1596,7 @@ class SQLCompiler(engine.Compiled): return "" def for_update_clause(self, select): - if select.for_update: - return " FOR UPDATE" - else: - return "" + return " FOR UPDATE" def returning_clause(self, stmt, returning_cols): raise exc.CompileError( @@ -1453,23 +1606,21 @@ class SQLCompiler(engine.Compiled): def limit_clause(self, select): text = "" if select._limit is not None: - text += "\n LIMIT " + self.process(sql.literal(select._limit)) + text += "\n LIMIT " + self.process(elements.literal(select._limit)) if select._offset is not None: if select._limit is None: text += "\n LIMIT -1" - text += " OFFSET " + self.process(sql.literal(select._offset)) + text += " OFFSET " + self.process(elements.literal(select._offset)) return text def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: if getattr(table, "schema", None): - ret = self.preparer.quote_schema(table.schema, - table.quote_schema) + \ - "." + self.preparer.quote(table.name, - table.quote) + ret = self.preparer.quote_schema(table.schema) + \ + "." + self.preparer.quote(table.name) else: - ret = self.preparer.quote(table.name, table.quote) + ret = self.preparer.quote(table.name) if fromhints and table in fromhints: ret = self.format_from_hint_text(ret, table, fromhints[table], iscrud) @@ -1488,7 +1639,7 @@ class SQLCompiler(engine.Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - colparams = self._get_colparams(insert_stmt) + colparams = self._get_colparams(insert_stmt, **kw) if not colparams and \ not self.dialect.supports_default_values and \ @@ -1621,7 +1772,7 @@ class SQLCompiler(engine.Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, extra_froms) + colparams = self._get_colparams(update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1651,11 +1802,12 @@ class SQLCompiler(engine.Compiled): '=' + c[1] for c in colparams ) - if update_stmt._returning: - self.returning = update_stmt._returning + if self.returning or update_stmt._returning: + if not self.returning: + self.returning = update_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, update_stmt._returning) + update_stmt, self.returning) if extra_froms: extra_from_text = self.update_from_clause( @@ -1675,7 +1827,7 @@ class SQLCompiler(engine.Compiled): if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, update_stmt._returning) + update_stmt, self.returning) self.stack.pop(-1) @@ -1684,13 +1836,45 @@ class SQLCompiler(engine.Compiled): def _create_crud_bind_param(self, col, value, required=False, name=None): if name is None: name = col.key - bindparam = sql.bindparam(name, value, - type_=col.type, required=required, - quote=col.quote) + bindparam = elements.BindParameter(name, value, + type_=col.type, required=required) bindparam._is_crud = True return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt, extra_tables=None): + @util.memoized_property + def _key_getters_for_crud_column(self): + if self.isupdate and self.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(self.statement._extra_froms) + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + def _get_colparams(self, stmt, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1719,12 +1903,18 @@ class SQLCompiler(engine.Compiled): else: stmt_parameters = stmt.parameters + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + self._key_getters_for_crud_column + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: - parameters = dict((sql._column_as_key(key), REQUIRED) + parameters = dict((_column_as_key(key), REQUIRED) for key in self.column_keys if not stmt_parameters or key not in stmt_parameters) @@ -1734,17 +1924,19 @@ class SQLCompiler(engine.Compiled): if stmt_parameters is not None: for k, v in stmt_parameters.items(): - colkey = sql._column_as_key(k) + colkey = _column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) else: # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param - if sql._is_literal(v): - v = self.process(sql.bindparam(None, v, type_=k.type)) + if elements._is_literal(v): + v = self.process( + elements.BindParameter(None, v, type_=k.type), + **kw) else: - v = self.process(v.self_group()) + v = self.process(v.self_group(), **kw) values.append((k, v)) @@ -1756,30 +1948,44 @@ class SQLCompiler(engine.Compiled): self.dialect.implicit_returning and \ stmt.table.implicit_returning + if self.isinsert: + implicit_return_defaults = implicit_returning and stmt._return_defaults + elif self.isupdate: + implicit_return_defaults = self.dialect.implicit_returning and \ + stmt.table.implicit_returning and \ + stmt._return_defaults + + if implicit_return_defaults: + if stmt._return_defaults is True: + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults) + postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} + # special logic that only occurs for multi-table UPDATE # statements - if extra_tables and stmt_parameters: + if self.isupdate and stmt._extra_froms and stmt_parameters: normalized_params = dict( - (sql._clause_element_as_expr(c), param) + (elements._clause_element_as_expr(c), param) for c, param in stmt_parameters.items() ) - assert self.isupdate affected_tables = set() - for t in extra_tables: + for t in stmt._extra_froms: for c in t.c: if c in normalized_params: affected_tables.add(t) - check_columns[c.key] = c + check_columns[_getattr_col_key(c)] = c value = normalized_params[c] - if sql._is_literal(value): + if elements._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is REQUIRED) + c, value, required=value is REQUIRED, + name=_col_bind_name(c)) else: self.postfetch.append(c) - value = self.process(value.self_group()) + value = self.process(value.self_group(), **kw) values.append((c, value)) # determine tables which are actually # to be updated - process onupdate and @@ -1791,36 +1997,60 @@ class SQLCompiler(engine.Compiled): elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group())) + (c, self.process( + c.onupdate.arg.self_group(), + **kw) + ) ) self.postfetch.append(c) else: values.append( - (c, self._create_crud_bind_param(c, None)) + (c, self._create_crud_bind_param( + c, None, name=_col_bind_name(c) + ) + ) ) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) - # iterating through columns at the top to maintain ordering. - # otherwise we might iterate through individual sets of - # "defaults", "primary key cols", etc. - for c in stmt.table.columns: - if c.key in parameters and c.key not in check_columns: - value = parameters.pop(c.key) - if sql._is_literal(value): + if self.isinsert and stmt.select_names: + # for an insert from select, we can only use names that + # are given, so only select for those names. + cols = (stmt.table.c[_column_as_key(name)] + for name in stmt.select_names) + else: + # iterate through all table columns to maintain + # ordering, even for those cols that aren't included + cols = stmt.table.columns + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + value = parameters.pop(col_key) + if elements._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is REQUIRED, - name=c.key + name=_col_bind_name(c) if not stmt._has_multi_parameters - else "%s_0" % c.key + else "%s_0" % _col_bind_name(c) ) - elif c.primary_key and implicit_returning: - self.returning.append(c) - value = self.process(value.self_group()) else: - self.postfetch.append(c) - value = self.process(value.self_group()) + if isinstance(value, elements.BindParameter) and \ + value.type._isnull: + value = value._clone() + value.type = c.type + + if c.primary_key and implicit_returning: + self.returning.append(c) + value = self.process(value.self_group(), **kw) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + value = self.process(value.self_group(), **kw) + else: + self.postfetch.append(c) + value = self.process(value.self_group(), **kw) values.append((c, value)) elif self.isinsert: @@ -1838,13 +2068,13 @@ class SQLCompiler(engine.Compiled): if self.dialect.supports_sequences and \ (not c.default.optional or \ not self.dialect.sequences_optional): - proc = self.process(c.default) + proc = self.process(c.default, **kw) values.append((c, proc)) self.returning.append(c) elif c.default.is_clause_element: values.append( (c, - self.process(c.default.arg.self_group())) + self.process(c.default.arg.self_group(), **kw)) ) self.returning.append(c) else: @@ -1855,7 +2085,13 @@ class SQLCompiler(engine.Compiled): else: self.returning.append(c) else: - if c.default is not None or \ + if ( + c.default is not None and + ( + not c.default.is_sequence or + self.dialect.supports_sequences + ) + ) or \ c is stmt.table._autoincrement_column and ( self.dialect.supports_sequences or self.dialect.preexecute_autoincrement_sequences @@ -1872,16 +2108,22 @@ class SQLCompiler(engine.Compiled): if self.dialect.supports_sequences and \ (not c.default.optional or \ not self.dialect.sequences_optional): - proc = self.process(c.default) + proc = self.process(c.default, **kw) values.append((c, proc)) - if not c.primary_key: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: self.postfetch.append(c) elif c.default.is_clause_element: values.append( - (c, self.process(c.default.arg.self_group())) + (c, self.process(c.default.arg.self_group(), **kw)) ) - if not c.primary_key: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) else: @@ -1890,32 +2132,49 @@ class SQLCompiler(engine.Compiled): ) self.prefetch.append(c) elif c.server_default is not None: - if not c.primary_key: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: self.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group())) + (c, self.process(c.onupdate.arg.self_group(), **kw)) ) - self.postfetch.append(c) + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + else: + self.postfetch.append(c) else: values.append( (c, self._create_crud_bind_param(c, None)) ) self.prefetch.append(c) elif c.server_onupdate is not None: - self.postfetch.append(c) + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + else: + self.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) if parameters and stmt_parameters: check = set(parameters).intersection( - sql._column_as_key(k) for k in stmt.parameters + _column_as_key(k) for k in stmt.parameters ).difference(check_columns) if check: raise exc.CompileError( "Unconsumed column names: %s" % - (", ".join(check)) + (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -1924,17 +2183,17 @@ class SQLCompiler(engine.Compiled): values.extend( [ - ( - c, - self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) ) return values @@ -2005,7 +2264,7 @@ class SQLCompiler(engine.Compiled): self.preparer.format_savepoint(savepoint_stmt) -class DDLCompiler(engine.Compiled): +class DDLCompiler(Compiled): @util.memoized_property def sql_compiler(self): @@ -2042,11 +2301,11 @@ class DDLCompiler(engine.Compiled): return self.sql_compiler.post_process_text(ddl.statement % context) def visit_create_schema(self, create): - schema = self.preparer.format_schema(create.element, create.quote) + schema = self.preparer.format_schema(create.element) return "CREATE SCHEMA " + schema def visit_drop_schema(self, drop): - schema = self.preparer.format_schema(drop.element, drop.quote) + schema = self.preparer.format_schema(drop.element) text = "DROP SCHEMA " + schema if drop.cascade: text += " CASCADE" @@ -2068,11 +2327,13 @@ class DDLCompiler(engine.Compiled): for create_column in create.columns: column = create_column.element try: - text += separator - separator = ", \n" - text += "\t" + self.process(create_column, + processed = self.process(create_column, first_pk=column.primary_key and not first_pk) + if processed is not None: + text += separator + separator = ", \n" + text += "\t" + processed if column.primary_key: first_pk = True except exc.CompileError as ce: @@ -2093,6 +2354,9 @@ class DDLCompiler(engine.Compiled): def visit_create_column(self, create, first_pk=False): column = create.element + if column.system: + return None + text = self.get_column_specification( column, first_pk=first_pk @@ -2156,7 +2420,7 @@ class DDLCompiler(engine.Compiled): use_schema=include_table_schema), ', '.join( self.sql_compiler.process(expr, - include_table=False) for + include_table=False, literal_binds=True) for expr in index.expressions) ) return text @@ -2169,13 +2433,12 @@ class DDLCompiler(engine.Compiled): def _prepared_index_name(self, index, include_schema=False): if include_schema and index.table is not None and index.table.schema: schema = index.table.schema - schema_name = self.preparer.quote_schema(schema, - index.table.quote_schema) + schema_name = self.preparer.quote_schema(schema) else: schema_name = None ident = index.name - if isinstance(ident, sql._truncated_label): + if isinstance(ident, elements._truncated_label): max_ = self.dialect.max_index_name_length or \ self.dialect.max_identifier_length if len(ident) > max_: @@ -2184,9 +2447,7 @@ class DDLCompiler(engine.Compiled): else: self.dialect.validate_identifier(ident) - index_name = self.preparer.quote( - ident, - index.quote) + index_name = self.preparer.quote(ident) if schema_name: index_name = schema_name + "." + index_name @@ -2246,8 +2507,9 @@ class DDLCompiler(engine.Compiled): if constraint.name is not None: text += "CONSTRAINT %s " % \ self.preparer.format_constraint(constraint) - sqltext = sql_util.expression_as_ddl(constraint.sqltext) - text += "CHECK (%s)" % self.sql_compiler.process(sqltext) + text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, + include_table=False, + literal_binds=True) text += self.define_constraint_deferrability(constraint) return text @@ -2268,7 +2530,7 @@ class DDLCompiler(engine.Compiled): text += "CONSTRAINT %s " % \ self.preparer.format_constraint(constraint) text += "PRIMARY KEY " - text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + text += "(%s)" % ', '.join(self.preparer.quote(c.name) for c in constraint) text += self.define_constraint_deferrability(constraint) return text @@ -2281,11 +2543,11 @@ class DDLCompiler(engine.Compiled): preparer.format_constraint(constraint) remote_table = list(constraint._elements.values())[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join(preparer.quote(f.parent.name, f.parent.quote) + ', '.join(preparer.quote(f.parent.name) for f in constraint._elements.values()), self.define_constraint_remote_table( constraint, remote_table, preparer), - ', '.join(preparer.quote(f.column.name, f.column.quote) + ', '.join(preparer.quote(f.column.name) for f in constraint._elements.values()) ) text += self.define_constraint_match(constraint) @@ -2299,12 +2561,14 @@ class DDLCompiler(engine.Compiled): return preparer.format_table(table) def visit_unique_constraint(self, constraint): + if len(constraint) == 0: + return '' text = "" if constraint.name is not None: text += "CONSTRAINT %s " % \ self.preparer.format_constraint(constraint) text += "UNIQUE (%s)" % ( - ', '.join(self.preparer.quote(c.name, c.quote) + ', '.join(self.preparer.quote(c.name) for c in constraint)) text += self.define_constraint_deferrability(constraint) return text @@ -2335,7 +2599,7 @@ class DDLCompiler(engine.Compiled): return text -class GenericTypeCompiler(engine.TypeCompiler): +class GenericTypeCompiler(TypeCompiler): def visit_FLOAT(self, type_): return "FLOAT" @@ -2558,15 +2822,25 @@ class IdentifierPreparer(object): or not self.legal_characters.match(util.text_type(value)) or (lc_value != value)) - def quote_schema(self, schema, force): - """Quote a schema. + def quote_schema(self, schema, force=None): + """Conditionally quote a schema. + + Subclasses can override this to provide database-dependent + quoting behavior for schema names. + + the 'force' flag should be considered deprecated. - Subclasses should override this to provide database-dependent - quoting behavior. """ return self.quote(schema, force) - def quote(self, ident, force): + def quote(self, ident, force=None): + """Conditionally quote an identifier. + + the 'force' flag should be considered deprecated. + """ + + force = getattr(ident, "quote", None) + if force is None: if ident in self._strings: return self._strings[ident] @@ -2582,38 +2856,35 @@ class IdentifierPreparer(object): return ident def format_sequence(self, sequence, use_schema=True): - name = self.quote(sequence.name, sequence.quote) - if not self.omit_schema and use_schema and \ - sequence.schema is not None: - name = self.quote_schema(sequence.schema, sequence.quote) + \ - "." + name + name = self.quote(sequence.name) + if not self.omit_schema and use_schema and sequence.schema is not None: + name = self.quote_schema(sequence.schema) + "." + name return name def format_label(self, label, name=None): - return self.quote(name or label.name, label.quote) + return self.quote(name or label.name) def format_alias(self, alias, name=None): - return self.quote(name or alias.name, alias.quote) + return self.quote(name or alias.name) def format_savepoint(self, savepoint, name=None): - return self.quote(name or savepoint.ident, savepoint.quote) + return self.quote(name or savepoint.ident) def format_constraint(self, constraint): - return self.quote(constraint.name, constraint.quote) + return self.quote(constraint.name) def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name - result = self.quote(name, table.quote) + result = self.quote(name) if not self.omit_schema and use_schema \ and getattr(table, "schema", None): - result = self.quote_schema(table.schema, table.quote_schema) + \ - "." + result + result = self.quote_schema(table.schema) + "." + result return result - def format_schema(self, name, quote): + def format_schema(self, name, quote=None): """Prepare a quoted schema name.""" return self.quote(name, quote) @@ -2628,10 +2899,9 @@ class IdentifierPreparer(object): if use_table: return self.format_table( column.table, use_schema=False, - name=table_name) + "." + \ - self.quote(name, column.quote) + name=table_name) + "." + self.quote(name) else: - return self.quote(name, column.quote) + return self.quote(name) else: # literal textual elements get stuck into ColumnClause a lot, # which shouldn't get quoted @@ -2651,7 +2921,7 @@ class IdentifierPreparer(object): if not self.omit_schema and use_schema and \ getattr(table, 'schema', None): - return (self.quote_schema(table.schema, table.quote_schema), + return (self.quote_schema(table.schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) |