summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py714
1 files changed, 492 insertions, 222 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index a5f545de9..4448f7c7b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1,5 +1,5 @@
# sql/compiler.py
-# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -23,13 +23,12 @@ To generate user-defined SQL strings, see
"""
import re
-import sys
-from .. import schema, engine, util, exc, types
-from . import (
- operators, functions, util as sql_util, visitors, expression as sql
-)
+from . import schema, sqltypes, operators, functions, \
+ util as sql_util, visitors, elements, selectable, base
+from .. import util, exc
import decimal
import itertools
+import operator
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -115,6 +114,7 @@ OPERATORS = {
operators.asc_op: ' ASC',
operators.nullsfirst_op: ' NULLS FIRST',
operators.nullslast_op: ' NULLS LAST',
+
}
FUNCTIONS = {
@@ -150,14 +150,122 @@ EXTRACT_MAP = {
}
COMPOUND_KEYWORDS = {
- sql.CompoundSelect.UNION: 'UNION',
- sql.CompoundSelect.UNION_ALL: 'UNION ALL',
- sql.CompoundSelect.EXCEPT: 'EXCEPT',
- sql.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
- sql.CompoundSelect.INTERSECT: 'INTERSECT',
- sql.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
+ selectable.CompoundSelect.UNION: 'UNION',
+ selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
+ selectable.CompoundSelect.EXCEPT: 'EXCEPT',
+ selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
+ selectable.CompoundSelect.INTERSECT: 'INTERSECT',
+ selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
}
+class Compiled(object):
+ """Represent a compiled SQL or DDL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to their underlying database dialect, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
+
+ def __init__(self, dialect, statement, bind=None,
+ compile_kwargs=util.immutabledict()):
+ """Construct a new ``Compiled`` object.
+
+ :param dialect: ``Dialect`` to compile against.
+
+ :param statement: ``ClauseElement`` to be compiled.
+
+ :param bind: Optional Engine or Connection to compile this
+ statement against.
+
+ :param compile_kwargs: additional kwargs that will be
+ passed to the initial call to :meth:`.Compiled.process`.
+
+ .. versionadded:: 0.8
+
+ """
+
+ self.dialect = dialect
+ self.bind = bind
+ if statement is not None:
+ self.statement = statement
+ self.can_execute = statement.supports_execution
+ self.string = self.process(self.statement, **compile_kwargs)
+
+ @util.deprecated("0.7", ":class:`.Compiled` objects now compile "
+ "within the constructor.")
+ def compile(self):
+ """Produce the internal string representation of this element.
+ """
+ pass
+
+ def _execute_on_connection(self, connection, multiparams, params):
+ return connection._execute_compiled(self, multiparams, params)
+
+ @property
+ def sql_compiler(self):
+ """Return a Compiled that is capable of processing SQL expressions.
+
+ If this compiler is one, it would likely just return 'self'.
+
+ """
+
+ raise NotImplementedError()
+
+ def process(self, obj, **kwargs):
+ return obj._compiler_dispatch(self, **kwargs)
+
+ def __str__(self):
+ """Return the string text of the generated SQL or DDL."""
+
+ return self.string or ''
+
+ def construct_params(self, params=None):
+ """Return the bind params for this compiled object.
+
+ :param params: a dict of string/object pairs whose values will
+ override bind values compiled in to the
+ statement.
+ """
+
+ raise NotImplementedError()
+
+ @property
+ def params(self):
+ """Return the bind params for this compiled object."""
+ return self.construct_params()
+
+ def execute(self, *multiparams, **params):
+ """Execute this compiled object."""
+
+ e = self.bind
+ if e is None:
+ raise exc.UnboundExecutionError(
+ "This Compiled object is not bound to any Engine "
+ "or Connection.")
+ return e._execute_compiled(self, multiparams, params)
+
+ def scalar(self, *multiparams, **params):
+ """Execute this compiled object and return the result's
+ scalar value."""
+
+ return self.execute(*multiparams, **params).scalar()
+
+
+class TypeCompiler(object):
+ """Produces DDL specification for TypeEngine objects."""
+
+ def __init__(self, dialect):
+ self.dialect = dialect
+
+ def process(self, type_):
+ return type_._compiler_dispatch(self)
+
+
class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression.Label."""
@@ -178,12 +286,8 @@ class _CompileLabel(visitors.Visitable):
def type(self):
return self.element.type
- @property
- def quote(self):
- return self.element.quote
-
-class SQLCompiler(engine.Compiled):
+class SQLCompiler(Compiled):
"""Default implementation of Compiled.
Compiles ClauseElements into SQL strings. Uses a similar visit
@@ -284,7 +388,7 @@ class SQLCompiler(engine.Compiled):
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
self.truncated_names = {}
- engine.Compiled.__init__(self, dialect, statement, **kwargs)
+ Compiled.__init__(self, dialect, statement, **kwargs)
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@@ -397,7 +501,7 @@ class SQLCompiler(engine.Compiled):
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
- if isinstance(label.name, sql._truncated_label):
+ if isinstance(label.name, elements._truncated_label):
labelname = self._truncated_identifier("colident", label.name)
else:
labelname = label.name
@@ -432,7 +536,7 @@ class SQLCompiler(engine.Compiled):
"its 'name' is assigned.")
is_literal = column.is_literal
- if not is_literal and isinstance(name, sql._truncated_label):
+ if not is_literal and isinstance(name, elements._truncated_label):
name = self._truncated_identifier("colident", name)
if add_to_result_map is not None:
@@ -446,24 +550,22 @@ class SQLCompiler(engine.Compiled):
if is_literal:
name = self.escape_literal_column(name)
else:
- name = self.preparer.quote(name, column.quote)
+ name = self.preparer.quote(name)
table = column.table
if table is None or not include_table or not table.named_with_column:
return name
else:
if table.schema:
- schema_prefix = self.preparer.quote_schema(
- table.schema,
- table.quote_schema) + '.'
+ schema_prefix = self.preparer.quote_schema(table.schema) + '.'
else:
schema_prefix = ''
tablename = table.name
- if isinstance(tablename, sql._truncated_label):
+ if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + \
- self.preparer.quote(tablename, table.quote) + \
+ self.preparer.quote(tablename) + \
"." + name
def escape_literal_column(self, text):
@@ -484,20 +586,13 @@ class SQLCompiler(engine.Compiled):
def post_process_text(self, text):
return text
- def visit_textclause(self, textclause, **kwargs):
- if textclause.typemap is not None:
- for colname, type_ in textclause.typemap.items():
- self.result_map[colname
- if self.dialect.case_sensitive
- else colname.lower()] = \
- (colname, None, type_)
-
+ def visit_textclause(self, textclause, **kw):
def do_bindparam(m):
name = m.group(1)
- if name in textclause.bindparams:
- return self.process(textclause.bindparams[name])
+ if name in textclause._bindparams:
+ return self.process(textclause._bindparams[name], **kw)
else:
- return self.bindparam_string(name, **kwargs)
+ return self.bindparam_string(name, **kw)
# un-escape any \:params
return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
@@ -505,14 +600,47 @@ class SQLCompiler(engine.Compiled):
self.post_process_text(textclause.text))
)
+ def visit_text_as_from(self, taf, iswrapper=False,
+ compound_index=0, force_result_map=False,
+ asfrom=False,
+ parens=True, **kw):
+
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ populate_result_map = force_result_map or (
+ compound_index == 0 and (
+ toplevel or \
+ entry['iswrapper']
+ )
+ )
+
+ if populate_result_map:
+ for c in taf.c:
+ self._add_to_result_map(
+ c.key, c.key, (c,), c.type
+ )
+
+ text = self.process(taf.element, **kw)
+ if asfrom and parens:
+ text = "(%s)" % text
+ return text
+
+
def visit_null(self, expr, **kw):
return 'NULL'
def visit_true(self, expr, **kw):
- return 'true'
+ if self.dialect.supports_native_boolean:
+ return 'true'
+ else:
+ return "1"
def visit_false(self, expr, **kw):
- return 'false'
+ if self.dialect.supports_native_boolean:
+ return 'false'
+ else:
+ return "0"
def visit_clauselist(self, clauselist, order_by_select=None, **kw):
if order_by_select is not None:
@@ -619,6 +747,7 @@ class SQLCompiler(engine.Compiled):
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
+
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=0, **kwargs):
toplevel = not self.stack
@@ -684,11 +813,23 @@ class SQLCompiler(engine.Compiled):
raise exc.CompileError(
"Unary expression has no operator or modifier")
+ def visit_istrue_unary_operator(self, element, operator, **kw):
+ if self.dialect.supports_native_boolean:
+ return self.process(element.element, **kw)
+ else:
+ return "%s = 1" % self.process(element.element, **kw)
+
+ def visit_isfalse_unary_operator(self, element, operator, **kw):
+ if self.dialect.supports_native_boolean:
+ return "NOT %s" % self.process(element.element, **kw)
+ else:
+ return "%s = 0" % self.process(element.element, **kw)
+
def visit_binary(self, binary, **kw):
# don't allow "? = ?" to render
if self.ansi_bind_rules and \
- isinstance(binary.left, sql.BindParameter) and \
- isinstance(binary.right, sql.BindParameter):
+ isinstance(binary.left, elements.BindParameter) and \
+ isinstance(binary.right, elements.BindParameter):
kw['literal_binds'] = True
operator = binary.operator
@@ -728,7 +869,7 @@ class SQLCompiler(engine.Compiled):
@util.memoized_property
def _like_percent_literal(self):
- return sql.literal_column("'%'", type_=types.String())
+ return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
def visit_contains_op_binary(self, binary, operator, **kw):
binary = binary._clone()
@@ -772,39 +913,49 @@ class SQLCompiler(engine.Compiled):
def visit_like_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
+
+ # TODO: use ternary here, not "and"/ "or"
return '%s LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return '%s NOT LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_bindparam(self, bindparam, within_columns_clause=False,
literal_binds=False,
@@ -820,8 +971,9 @@ class SQLCompiler(engine.Compiled):
(within_columns_clause and \
self.ansi_bind_rules):
if bindparam.value is None:
- raise exc.CompileError("Bind parameter without a "
- "renderable value not allowed here.")
+ raise exc.CompileError("Bind parameter '%s' without a "
+ "renderable value not allowed here."
+ % bindparam.key)
return self.render_literal_bindparam(bindparam,
within_columns_clause=True, **kwargs)
@@ -851,13 +1003,10 @@ class SQLCompiler(engine.Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
- return self.bindparam_string(name, quote=bindparam.quote, **kwargs)
+ return self.bindparam_string(name, **kwargs)
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.value
- processor = bindparam.type._cached_bind_processor(self.dialect)
- if processor:
- value = processor(value)
return self.render_literal_value(value, bindparam.type)
def render_literal_value(self, value, type_):
@@ -870,15 +1019,10 @@ class SQLCompiler(engine.Compiled):
of the DBAPI.
"""
- if isinstance(value, util.string_types):
- value = value.replace("'", "''")
- return "'%s'" % value
- elif value is None:
- return "NULL"
- elif isinstance(value, (float, ) + util.int_types):
- return repr(value)
- elif isinstance(value, decimal.Decimal):
- return str(value)
+
+ processor = type_._cached_literal_processor(self.dialect)
+ if processor:
+ return processor(value)
else:
raise NotImplementedError(
"Don't know how to literal-quote value %r" % value)
@@ -888,7 +1032,7 @@ class SQLCompiler(engine.Compiled):
return self.bind_names[bindparam]
bind_name = bindparam.key
- if isinstance(bind_name, sql._truncated_label):
+ if isinstance(bind_name, elements._truncated_label):
bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation
@@ -921,8 +1065,7 @@ class SQLCompiler(engine.Compiled):
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
- def bindparam_string(self, name, quote=None,
- positional_names=None, **kw):
+ def bindparam_string(self, name, positional_names=None, **kw):
if self.positional:
if positional_names is not None:
positional_names.append(name)
@@ -937,7 +1080,7 @@ class SQLCompiler(engine.Compiled):
if self.positional:
kwargs['positional_names'] = self.cte_positional
- if isinstance(cte.name, sql._truncated_label):
+ if isinstance(cte.name, elements._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
cte_name = cte.name
@@ -947,7 +1090,7 @@ class SQLCompiler(engine.Compiled):
# we've generated a same-named CTE that we are enclosed in,
# or this is the same CTE. just return the name.
if cte in existing_cte._restates or cte is existing_cte:
- return cte_name
+ return self.preparer.format_alias(cte, cte_name)
elif existing_cte in cte._restates:
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
@@ -961,19 +1104,24 @@ class SQLCompiler(engine.Compiled):
self.ctes_by_name[cte_name] = cte
- if cte.cte_alias:
- if isinstance(cte.cte_alias, sql._truncated_label):
- cte_alias = self._truncated_identifier("alias", cte.cte_alias)
- else:
- cte_alias = cte.cte_alias
- if not cte.cte_alias and cte not in self.ctes:
+ if cte._cte_alias is not None:
+ orig_cte = cte._cte_alias
+ if orig_cte not in self.ctes:
+ self.visit_cte(orig_cte)
+ cte_alias_name = cte._cte_alias.name
+ if isinstance(cte_alias_name, elements._truncated_label):
+ cte_alias_name = self._truncated_identifier("alias", cte_alias_name)
+ else:
+ orig_cte = cte
+ cte_alias_name = None
+ if not cte_alias_name and cte not in self.ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
if cte.recursive:
- if isinstance(cte.original, sql.Select):
+ if isinstance(cte.original, selectable.Select):
col_source = cte.original
- elif isinstance(cte.original, sql.CompoundSelect):
+ elif isinstance(cte.original, selectable.CompoundSelect):
col_source = cte.original.selects[0]
else:
assert False
@@ -989,9 +1137,10 @@ class SQLCompiler(engine.Compiled):
self, asfrom=True, **kwargs
)
self.ctes[cte] = text
+
if asfrom:
- if cte.cte_alias:
- text = self.preparer.format_alias(cte, cte_alias)
+ if cte_alias_name:
+ text = self.preparer.format_alias(cte, cte_alias_name)
text += " AS " + cte_name
else:
return self.preparer.format_alias(cte, cte_name)
@@ -1001,7 +1150,7 @@ class SQLCompiler(engine.Compiled):
iscrud=False,
fromhints=None, **kwargs):
if asfrom or ashint:
- if isinstance(alias.name, sql._truncated_label):
+ if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
else:
alias_name = alias.name
@@ -1059,7 +1208,7 @@ class SQLCompiler(engine.Compiled):
if not within_columns_clause:
result_expr = col_expr
- elif isinstance(column, sql.Label):
+ elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
col_expr,
@@ -1078,23 +1227,23 @@ class SQLCompiler(engine.Compiled):
elif \
asfrom and \
- isinstance(column, sql.ColumnClause) and \
+ isinstance(column, elements.ColumnClause) and \
not column.is_literal and \
column.table is not None and \
- not isinstance(column.table, sql.Select):
+ not isinstance(column.table, selectable.Select):
result_expr = _CompileLabel(col_expr,
- sql._as_truncated(column.name),
+ elements._as_truncated(column.name),
alt_names=(column.key,))
elif not isinstance(column,
- (sql.UnaryExpression, sql.TextClause)) \
+ (elements.UnaryExpression, elements.TextClause)) \
and (not hasattr(column, 'name') or \
- isinstance(column, sql.Function)):
+ isinstance(column, functions.Function)):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
- # assert isinstance(column, sql.ColumnClause)
+ # assert isinstance(column, elements.ColumnClause)
result_expr = _CompileLabel(col_expr,
- sql._as_truncated(column.name),
+ elements._as_truncated(column.name),
alt_names=(column.key,))
else:
result_expr = col_expr
@@ -1137,8 +1286,8 @@ class SQLCompiler(engine.Compiled):
# as this whole system won't work for custom Join/Select
# subclasses where compilation routines
# call down to compiler.visit_join(), compiler.visit_select()
- join_name = sql.Join.__visit_name__
- select_name = sql.Select.__visit_name__
+ join_name = selectable.Join.__visit_name__
+ select_name = selectable.Select.__visit_name__
def visit(element, **kw):
if element in column_translate[-1]:
@@ -1150,24 +1299,27 @@ class SQLCompiler(engine.Compiled):
newelem = cloned[element] = element._clone()
if newelem.__visit_name__ is join_name and \
- isinstance(newelem.right, sql.FromGrouping):
+ isinstance(newelem.right, selectable.FromGrouping):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
right = visit(newelem.right, **kw)
- selectable = sql.select(
+ selectable_ = selectable.Select(
[right.element],
use_labels=True).alias()
- for c in selectable.c:
- c._label = c._key_label = c.name
+ for c in selectable_.c:
+ c._key_label = c.key
+ c._label = c.name
+
translate_dict = dict(
- zip(right.element.c, selectable.c)
- )
- translate_dict[right.element.left] = selectable
- translate_dict[right.element.right] = selectable
+ zip(newelem.right.element.c, selectable_.c)
+ )
+
+ translate_dict[right.element.left] = selectable_
+ translate_dict[right.element.right] = selectable_
# propagate translations that we've gained
# from nested visit(newelem.right) outwards
@@ -1183,7 +1335,8 @@ class SQLCompiler(engine.Compiled):
column_translate[-1].update(translate_dict)
- newelem.right = selectable
+ newelem.right = selectable_
+
newelem.onclause = visit(newelem.onclause, **kw)
elif newelem.__visit_name__ is select_name:
column_translate.append({})
@@ -1199,6 +1352,7 @@ class SQLCompiler(engine.Compiled):
def _transform_result_map_for_nested_joins(self, select, transformed_select):
inner_col = dict((c._key_label, c) for
c in transformed_select.inner_columns)
+
d = dict(
(inner_col[c._key_label], c)
for c in select.inner_columns
@@ -1291,7 +1445,7 @@ class SQLCompiler(engine.Compiled):
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms)
- new_correlate_froms = set(sql._from_objects(*froms))
+ new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
@@ -1382,9 +1536,11 @@ class SQLCompiler(engine.Compiled):
text += self.order_by_clause(select,
order_by_select=order_by_select, **kwargs)
+
if select._limit is not None or select._offset is not None:
text += self.limit_clause(select)
- if select.for_update:
+
+ if select._for_update_arg is not None:
text += self.for_update_clause(select)
if self.ctes and \
@@ -1440,10 +1596,7 @@ class SQLCompiler(engine.Compiled):
return ""
def for_update_clause(self, select):
- if select.for_update:
- return " FOR UPDATE"
- else:
- return ""
+ return " FOR UPDATE"
def returning_clause(self, stmt, returning_cols):
raise exc.CompileError(
@@ -1453,23 +1606,21 @@ class SQLCompiler(engine.Compiled):
def limit_clause(self, select):
text = ""
if select._limit is not None:
- text += "\n LIMIT " + self.process(sql.literal(select._limit))
+ text += "\n LIMIT " + self.process(elements.literal(select._limit))
if select._offset is not None:
if select._limit is None:
text += "\n LIMIT -1"
- text += " OFFSET " + self.process(sql.literal(select._offset))
+ text += " OFFSET " + self.process(elements.literal(select._offset))
return text
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if getattr(table, "schema", None):
- ret = self.preparer.quote_schema(table.schema,
- table.quote_schema) + \
- "." + self.preparer.quote(table.name,
- table.quote)
+ ret = self.preparer.quote_schema(table.schema) + \
+ "." + self.preparer.quote(table.name)
else:
- ret = self.preparer.quote(table.name, table.quote)
+ ret = self.preparer.quote(table.name)
if fromhints and table in fromhints:
ret = self.format_from_hint_text(ret, table,
fromhints[table], iscrud)
@@ -1488,7 +1639,7 @@ class SQLCompiler(engine.Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- colparams = self._get_colparams(insert_stmt)
+ colparams = self._get_colparams(insert_stmt, **kw)
if not colparams and \
not self.dialect.supports_default_values and \
@@ -1621,7 +1772,7 @@ class SQLCompiler(engine.Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- colparams = self._get_colparams(update_stmt, extra_froms)
+ colparams = self._get_colparams(update_stmt, **kw)
if update_stmt._hints:
dialect_hints = dict([
@@ -1651,11 +1802,12 @@ class SQLCompiler(engine.Compiled):
'=' + c[1] for c in colparams
)
- if update_stmt._returning:
- self.returning = update_stmt._returning
+ if self.returning or update_stmt._returning:
+ if not self.returning:
+ self.returning = update_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, update_stmt._returning)
+ update_stmt, self.returning)
if extra_froms:
extra_from_text = self.update_from_clause(
@@ -1675,7 +1827,7 @@ class SQLCompiler(engine.Compiled):
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, update_stmt._returning)
+ update_stmt, self.returning)
self.stack.pop(-1)
@@ -1684,13 +1836,45 @@ class SQLCompiler(engine.Compiled):
def _create_crud_bind_param(self, col, value, required=False, name=None):
if name is None:
name = col.key
- bindparam = sql.bindparam(name, value,
- type_=col.type, required=required,
- quote=col.quote)
+ bindparam = elements.BindParameter(name, value,
+ type_=col.type, required=required)
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
- def _get_colparams(self, stmt, extra_tables=None):
+ @util.memoized_property
+ def _key_getters_for_crud_column(self):
+ if self.isupdate and self.statement._extra_froms:
+ # when extra tables are present, refer to the columns
+ # in those extra tables as table-qualified, including in
+ # dictionaries and when rendering bind param names.
+ # the "main" table of the statement remains unqualified,
+ # allowing the most compatibility with a non-multi-table
+ # statement.
+ _et = set(self.statement._extra_froms)
+ def _column_as_key(key):
+ str_key = elements._column_as_key(key)
+ if hasattr(key, 'table') and key.table in _et:
+ return (key.table.name, str_key)
+ else:
+ return str_key
+ def _getattr_col_key(col):
+ if col.table in _et:
+ return (col.table.name, col.key)
+ else:
+ return col.key
+ def _col_bind_name(col):
+ if col.table in _et:
+ return "%s_%s" % (col.table.name, col.key)
+ else:
+ return col.key
+
+ else:
+ _column_as_key = elements._column_as_key
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+ return _column_as_key, _getattr_col_key, _col_bind_name
+
+ def _get_colparams(self, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -1719,12 +1903,18 @@ class SQLCompiler(engine.Compiled):
else:
stmt_parameters = stmt.parameters
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ _column_as_key, _getattr_col_key, _col_bind_name = \
+ self._key_getters_for_crud_column
+
# if we have statement parameters - set defaults in the
# compiled params
if self.column_keys is None:
parameters = {}
else:
- parameters = dict((sql._column_as_key(key), REQUIRED)
+ parameters = dict((_column_as_key(key), REQUIRED)
for key in self.column_keys
if not stmt_parameters or
key not in stmt_parameters)
@@ -1734,17 +1924,19 @@ class SQLCompiler(engine.Compiled):
if stmt_parameters is not None:
for k, v in stmt_parameters.items():
- colkey = sql._column_as_key(k)
+ colkey = _column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
else:
# a non-Column expression on the left side;
# add it to values() in an "as-is" state,
# coercing right side to bound param
- if sql._is_literal(v):
- v = self.process(sql.bindparam(None, v, type_=k.type))
+ if elements._is_literal(v):
+ v = self.process(
+ elements.BindParameter(None, v, type_=k.type),
+ **kw)
else:
- v = self.process(v.self_group())
+ v = self.process(v.self_group(), **kw)
values.append((k, v))
@@ -1756,30 +1948,44 @@ class SQLCompiler(engine.Compiled):
self.dialect.implicit_returning and \
stmt.table.implicit_returning
+ if self.isinsert:
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
+ elif self.isupdate:
+ implicit_return_defaults = self.dialect.implicit_returning and \
+ stmt.table.implicit_returning and \
+ stmt._return_defaults
+
+ if implicit_return_defaults:
+ if stmt._return_defaults is True:
+ implicit_return_defaults = set(stmt.table.c)
+ else:
+ implicit_return_defaults = set(stmt._return_defaults)
+
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
+
# special logic that only occurs for multi-table UPDATE
# statements
- if extra_tables and stmt_parameters:
+ if self.isupdate and stmt._extra_froms and stmt_parameters:
normalized_params = dict(
- (sql._clause_element_as_expr(c), param)
+ (elements._clause_element_as_expr(c), param)
for c, param in stmt_parameters.items()
)
- assert self.isupdate
affected_tables = set()
- for t in extra_tables:
+ for t in stmt._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
- check_columns[c.key] = c
+ check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
- if sql._is_literal(value):
+ if elements._is_literal(value):
value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED)
+ c, value, required=value is REQUIRED,
+ name=_col_bind_name(c))
else:
self.postfetch.append(c)
- value = self.process(value.self_group())
+ value = self.process(value.self_group(), **kw)
values.append((c, value))
# determine tables which are actually
# to be updated - process onupdate and
@@ -1791,36 +1997,60 @@ class SQLCompiler(engine.Compiled):
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, self.process(c.onupdate.arg.self_group()))
+ (c, self.process(
+ c.onupdate.arg.self_group(),
+ **kw)
+ )
)
self.postfetch.append(c)
else:
values.append(
- (c, self._create_crud_bind_param(c, None))
+ (c, self._create_crud_bind_param(
+ c, None, name=_col_bind_name(c)
+ )
+ )
)
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
- # iterating through columns at the top to maintain ordering.
- # otherwise we might iterate through individual sets of
- # "defaults", "primary key cols", etc.
- for c in stmt.table.columns:
- if c.key in parameters and c.key not in check_columns:
- value = parameters.pop(c.key)
- if sql._is_literal(value):
+ if self.isinsert and stmt.select_names:
+ # for an insert from select, we can only use names that
+ # are given, so only select for those names.
+ cols = (stmt.table.c[_column_as_key(name)]
+ for name in stmt.select_names)
+ else:
+ # iterate through all table columns to maintain
+ # ordering, even for those cols that aren't included
+ cols = stmt.table.columns
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ value = parameters.pop(col_key)
+ if elements._is_literal(value):
value = self._create_crud_bind_param(
c, value, required=value is REQUIRED,
- name=c.key
+ name=_col_bind_name(c)
if not stmt._has_multi_parameters
- else "%s_0" % c.key
+ else "%s_0" % _col_bind_name(c)
)
- elif c.primary_key and implicit_returning:
- self.returning.append(c)
- value = self.process(value.self_group())
else:
- self.postfetch.append(c)
- value = self.process(value.self_group())
+ if isinstance(value, elements.BindParameter) and \
+ value.type._isnull:
+ value = value._clone()
+ value.type = c.type
+
+ if c.primary_key and implicit_returning:
+ self.returning.append(c)
+ value = self.process(value.self_group(), **kw)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ value = self.process(value.self_group(), **kw)
+ else:
+ self.postfetch.append(c)
+ value = self.process(value.self_group(), **kw)
values.append((c, value))
elif self.isinsert:
@@ -1838,13 +2068,13 @@ class SQLCompiler(engine.Compiled):
if self.dialect.supports_sequences and \
(not c.default.optional or \
not self.dialect.sequences_optional):
- proc = self.process(c.default)
+ proc = self.process(c.default, **kw)
values.append((c, proc))
self.returning.append(c)
elif c.default.is_clause_element:
values.append(
(c,
- self.process(c.default.arg.self_group()))
+ self.process(c.default.arg.self_group(), **kw))
)
self.returning.append(c)
else:
@@ -1855,7 +2085,13 @@ class SQLCompiler(engine.Compiled):
else:
self.returning.append(c)
else:
- if c.default is not None or \
+ if (
+ c.default is not None and
+ (
+ not c.default.is_sequence or
+ self.dialect.supports_sequences
+ )
+ ) or \
c is stmt.table._autoincrement_column and (
self.dialect.supports_sequences or
self.dialect.preexecute_autoincrement_sequences
@@ -1872,16 +2108,22 @@ class SQLCompiler(engine.Compiled):
if self.dialect.supports_sequences and \
(not c.default.optional or \
not self.dialect.sequences_optional):
- proc = self.process(c.default)
+ proc = self.process(c.default, **kw)
values.append((c, proc))
- if not c.primary_key:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ elif not c.primary_key:
self.postfetch.append(c)
elif c.default.is_clause_element:
values.append(
- (c, self.process(c.default.arg.self_group()))
+ (c, self.process(c.default.arg.self_group(), **kw))
)
- if not c.primary_key:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ elif not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
else:
@@ -1890,32 +2132,49 @@ class SQLCompiler(engine.Compiled):
)
self.prefetch.append(c)
elif c.server_default is not None:
- if not c.primary_key:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ elif not c.primary_key:
self.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
elif self.isupdate:
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, self.process(c.onupdate.arg.self_group()))
+ (c, self.process(c.onupdate.arg.self_group(), **kw))
)
- self.postfetch.append(c)
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ else:
+ self.postfetch.append(c)
else:
values.append(
(c, self._create_crud_bind_param(c, None))
)
self.prefetch.append(c)
elif c.server_onupdate is not None:
- self.postfetch.append(c)
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ else:
+ self.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
if parameters and stmt_parameters:
check = set(parameters).intersection(
- sql._column_as_key(k) for k in stmt.parameters
+ _column_as_key(k) for k in stmt.parameters
).difference(check_columns)
if check:
raise exc.CompileError(
"Unconsumed column names: %s" %
- (", ".join(check))
+ (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
@@ -1924,17 +2183,17 @@ class SQLCompiler(engine.Compiled):
values.extend(
[
- (
- c,
- self._create_crud_bind_param(
- c, row[c.key],
- name="%s_%d" % (c.key, i + 1)
- )
- if c.key in row else param
- )
- for (c, param) in values_0
- ]
- for i, row in enumerate(stmt.parameters[1:])
+ (
+ c,
+ self._create_crud_bind_param(
+ c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ )
+ if c.key in row else param
+ )
+ for (c, param) in values_0
+ ]
+ for i, row in enumerate(stmt.parameters[1:])
)
return values
@@ -2005,7 +2264,7 @@ class SQLCompiler(engine.Compiled):
self.preparer.format_savepoint(savepoint_stmt)
-class DDLCompiler(engine.Compiled):
+class DDLCompiler(Compiled):
@util.memoized_property
def sql_compiler(self):
@@ -2042,11 +2301,11 @@ class DDLCompiler(engine.Compiled):
return self.sql_compiler.post_process_text(ddl.statement % context)
def visit_create_schema(self, create):
- schema = self.preparer.format_schema(create.element, create.quote)
+ schema = self.preparer.format_schema(create.element)
return "CREATE SCHEMA " + schema
def visit_drop_schema(self, drop):
- schema = self.preparer.format_schema(drop.element, drop.quote)
+ schema = self.preparer.format_schema(drop.element)
text = "DROP SCHEMA " + schema
if drop.cascade:
text += " CASCADE"
@@ -2068,11 +2327,13 @@ class DDLCompiler(engine.Compiled):
for create_column in create.columns:
column = create_column.element
try:
- text += separator
- separator = ", \n"
- text += "\t" + self.process(create_column,
+ processed = self.process(create_column,
first_pk=column.primary_key
and not first_pk)
+ if processed is not None:
+ text += separator
+ separator = ", \n"
+ text += "\t" + processed
if column.primary_key:
first_pk = True
except exc.CompileError as ce:
@@ -2093,6 +2354,9 @@ class DDLCompiler(engine.Compiled):
def visit_create_column(self, create, first_pk=False):
column = create.element
+ if column.system:
+ return None
+
text = self.get_column_specification(
column,
first_pk=first_pk
@@ -2156,7 +2420,7 @@ class DDLCompiler(engine.Compiled):
use_schema=include_table_schema),
', '.join(
self.sql_compiler.process(expr,
- include_table=False) for
+ include_table=False, literal_binds=True) for
expr in index.expressions)
)
return text
@@ -2169,13 +2433,12 @@ class DDLCompiler(engine.Compiled):
def _prepared_index_name(self, index, include_schema=False):
if include_schema and index.table is not None and index.table.schema:
schema = index.table.schema
- schema_name = self.preparer.quote_schema(schema,
- index.table.quote_schema)
+ schema_name = self.preparer.quote_schema(schema)
else:
schema_name = None
ident = index.name
- if isinstance(ident, sql._truncated_label):
+ if isinstance(ident, elements._truncated_label):
max_ = self.dialect.max_index_name_length or \
self.dialect.max_identifier_length
if len(ident) > max_:
@@ -2184,9 +2447,7 @@ class DDLCompiler(engine.Compiled):
else:
self.dialect.validate_identifier(ident)
- index_name = self.preparer.quote(
- ident,
- index.quote)
+ index_name = self.preparer.quote(ident)
if schema_name:
index_name = schema_name + "." + index_name
@@ -2246,8 +2507,9 @@ class DDLCompiler(engine.Compiled):
if constraint.name is not None:
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
- sqltext = sql_util.expression_as_ddl(constraint.sqltext)
- text += "CHECK (%s)" % self.sql_compiler.process(sqltext)
+ text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
+ include_table=False,
+ literal_binds=True)
text += self.define_constraint_deferrability(constraint)
return text
@@ -2268,7 +2530,7 @@ class DDLCompiler(engine.Compiled):
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
text += "PRIMARY KEY "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+ text += "(%s)" % ', '.join(self.preparer.quote(c.name)
for c in constraint)
text += self.define_constraint_deferrability(constraint)
return text
@@ -2281,11 +2543,11 @@ class DDLCompiler(engine.Compiled):
preparer.format_constraint(constraint)
remote_table = list(constraint._elements.values())[0].column.table
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join(preparer.quote(f.parent.name, f.parent.quote)
+ ', '.join(preparer.quote(f.parent.name)
for f in constraint._elements.values()),
self.define_constraint_remote_table(
constraint, remote_table, preparer),
- ', '.join(preparer.quote(f.column.name, f.column.quote)
+ ', '.join(preparer.quote(f.column.name)
for f in constraint._elements.values())
)
text += self.define_constraint_match(constraint)
@@ -2299,12 +2561,14 @@ class DDLCompiler(engine.Compiled):
return preparer.format_table(table)
def visit_unique_constraint(self, constraint):
+ if len(constraint) == 0:
+ return ''
text = ""
if constraint.name is not None:
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
text += "UNIQUE (%s)" % (
- ', '.join(self.preparer.quote(c.name, c.quote)
+ ', '.join(self.preparer.quote(c.name)
for c in constraint))
text += self.define_constraint_deferrability(constraint)
return text
@@ -2335,7 +2599,7 @@ class DDLCompiler(engine.Compiled):
return text
-class GenericTypeCompiler(engine.TypeCompiler):
+class GenericTypeCompiler(TypeCompiler):
def visit_FLOAT(self, type_):
return "FLOAT"
@@ -2558,15 +2822,25 @@ class IdentifierPreparer(object):
or not self.legal_characters.match(util.text_type(value))
or (lc_value != value))
- def quote_schema(self, schema, force):
- """Quote a schema.
+ def quote_schema(self, schema, force=None):
+ """Conditionally quote a schema.
+
+ Subclasses can override this to provide database-dependent
+ quoting behavior for schema names.
+
+ the 'force' flag should be considered deprecated.
- Subclasses should override this to provide database-dependent
- quoting behavior.
"""
return self.quote(schema, force)
- def quote(self, ident, force):
+ def quote(self, ident, force=None):
+ """Conditionally quote an identifier.
+
+ the 'force' flag should be considered deprecated.
+ """
+
+ force = getattr(ident, "quote", None)
+
if force is None:
if ident in self._strings:
return self._strings[ident]
@@ -2582,38 +2856,35 @@ class IdentifierPreparer(object):
return ident
def format_sequence(self, sequence, use_schema=True):
- name = self.quote(sequence.name, sequence.quote)
- if not self.omit_schema and use_schema and \
- sequence.schema is not None:
- name = self.quote_schema(sequence.schema, sequence.quote) + \
- "." + name
+ name = self.quote(sequence.name)
+ if not self.omit_schema and use_schema and sequence.schema is not None:
+ name = self.quote_schema(sequence.schema) + "." + name
return name
def format_label(self, label, name=None):
- return self.quote(name or label.name, label.quote)
+ return self.quote(name or label.name)
def format_alias(self, alias, name=None):
- return self.quote(name or alias.name, alias.quote)
+ return self.quote(name or alias.name)
def format_savepoint(self, savepoint, name=None):
- return self.quote(name or savepoint.ident, savepoint.quote)
+ return self.quote(name or savepoint.ident)
def format_constraint(self, constraint):
- return self.quote(constraint.name, constraint.quote)
+ return self.quote(constraint.name)
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
- result = self.quote(name, table.quote)
+ result = self.quote(name)
if not self.omit_schema and use_schema \
and getattr(table, "schema", None):
- result = self.quote_schema(table.schema, table.quote_schema) + \
- "." + result
+ result = self.quote_schema(table.schema) + "." + result
return result
- def format_schema(self, name, quote):
+ def format_schema(self, name, quote=None):
"""Prepare a quoted schema name."""
return self.quote(name, quote)
@@ -2628,10 +2899,9 @@ class IdentifierPreparer(object):
if use_table:
return self.format_table(
column.table, use_schema=False,
- name=table_name) + "." + \
- self.quote(name, column.quote)
+ name=table_name) + "." + self.quote(name)
else:
- return self.quote(name, column.quote)
+ return self.quote(name)
else:
# literal textual elements get stuck into ColumnClause a lot,
# which shouldn't get quoted
@@ -2651,7 +2921,7 @@ class IdentifierPreparer(object):
if not self.omit_schema and use_schema and \
getattr(table, 'schema', None):
- return (self.quote_schema(table.schema, table.quote_schema),
+ return (self.quote_schema(table.schema),
self.format_table(table, use_schema=False))
else:
return (self.format_table(table, use_schema=False), )