summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py108
-rw-r--r--lib/sqlalchemy/sql/expression.py71
2 files changed, 89 insertions, 90 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 6aab22a79..e8cc3378e 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -90,7 +90,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
operators = OPERATORS
- def __init__(self, dialect, statement, parameters=None, **kwargs):
+ def __init__(self, dialect, statement, parameters=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
dialect
@@ -113,6 +113,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
# if we are insert/update. set to true when we visit an INSERT or UPDATE
self.isinsert = self.isupdate = False
+ # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
+ self.inline = inline or getattr(statement, 'inline', False)
+
# a dictionary of bind parameter keys to _BindParamClause instances.
self.binds = {}
@@ -151,12 +154,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = self.dialect.identifier_preparer
- # for UPDATE and INSERT statements, a set of columns whos values are being set
- # from a SQL expression (i.e., not one of the bind parameter values). if present,
- # default-value logic in the Dialect knows not to fire off column defaults
- # and also knows postfetching will be needed to get the values represented by these
- # parameters.
- self.inline_params = None
def after_compile(self):
# this re will search for params like :param
@@ -615,26 +612,14 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
def uses_sequences_for_inserts(self):
return False
-
- def visit_insert(self, insert_stmt):
- # search for columns who will be required to have an explicit bound value.
- # for inserts, this includes Python-side defaults, columns with sequences for dialects
- # that support sequences, and primary key columns for dialects that explicitly insert
- # pre-generated primary key values
- required_cols = [
- c for c in insert_stmt.table.c
- if \
- isinstance(c, schema.SchemaItem) and \
- (self.parameters is None or self.parameters.get(c.key, None) is None) and \
- (
- ((c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts()) or
- isinstance(c.default, schema.ColumnDefault)
- )
- ]
+ def visit_sequence(self, seq):
+ raise NotImplementedError()
+
+ def visit_insert(self, insert_stmt):
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, required_cols)
+ colparams = self._get_colparams(insert_stmt)
return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
" VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
@@ -642,17 +627,8 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
def visit_update(self, update_stmt):
self.stack.append({'from':util.Set([update_stmt.table])})
- # search for columns who will be required to have an explicit bound value.
- # for updates, this includes Python-side "onupdate" defaults.
- required_cols = [c for c in update_stmt.table.c
- if
- isinstance(c, schema.SchemaItem) and \
- (self.parameters is None or self.parameters.get(c.key, None) is None) and
- isinstance(c.onupdate, schema.ColumnDefault)
- ]
-
self.isupdate = True
- colparams = self._get_colparams(update_stmt, required_cols)
+ colparams = self._get_colparams(update_stmt)
text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
@@ -663,13 +639,10 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
return text
- def _get_colparams(self, stmt, required_cols):
+ def _get_colparams(self, stmt):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
- This method may generate new bind params within this compiled
- based on the given set of "required columns", which are required
- to have a value set in the statement.
"""
def create_bind_param(col, value):
@@ -677,8 +650,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
self.binds[col.key] = bindparam
return self.bindparam_string(self._truncate_bindparam(bindparam))
- self.inline_params = util.Set()
-
+ self.postfetch = util.Set()
+ self.prefetch = util.Set()
+
def to_col(key):
if not isinstance(key, sql._ColumnClause):
return stmt.table.columns.get(unicode(key), key)
@@ -701,23 +675,53 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(getattr(k, 'key', k), v)
- for col in required_cols:
- parameters.setdefault(col.key, None)
-
# create a list of column assignment clauses as tuples
values = []
for c in stmt.table.columns:
if c.key in parameters:
value = parameters[c.key]
- else:
- continue
- if sql._is_literal(value):
- value = create_bind_param(c, value)
- else:
- self.inline_params.add(c)
- value = self.process(value)
- values.append((c, value))
-
+ if sql._is_literal(value):
+ value = create_bind_param(c, value)
+ else:
+ self.postfetch.add(c)
+ value = self.process(value.self_group())
+ values.append((c, value))
+ elif isinstance(c, schema.Column):
+ if self.isinsert:
+ if isinstance(c.default, schema.ColumnDefault):
+ if self.inline and isinstance(c.default.arg, sql.ClauseElement):
+ values.append((c, self.process(c.default.arg)))
+ self.postfetch.add(c)
+ else:
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.add(c)
+ elif isinstance(c.default, schema.PassiveDefault):
+ if c.primary_key and self.uses_sequences_for_inserts() and not self.inline:
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.add(c)
+ else:
+ self.postfetch.add(c)
+ elif (c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts():
+ if self.inline:
+ if c.default is not None:
+ proc = self.process(c.default)
+ if proc is not None:
+ values.append((c, proc))
+ self.postfetch.add(c)
+ else:
+ print "ISINSERT, HAS A SEQUENCE, IS PRIMARY KEY, ADDING PREFETCH:", c.key
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.add(c)
+ elif self.isupdate:
+ if isinstance(c.onupdate, schema.ColumnDefault):
+ if self.inline and isinstance(c.onupdate.arg, sql.ClauseElement):
+ values.append((c, self.process(c.onupdate.arg)))
+ self.postfetch.add(c)
+ else:
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.add(c)
+ elif isinstance(c.onupdate, schema.PassiveDefault):
+ self.postfetch.add(c)
return values
def visit_delete(self, delete_stmt):
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b31ccbe44..ac56289e8 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -249,7 +249,7 @@ def subquery(alias, *args, **kwargs):
return Select(*args, **kwargs).alias(alias)
-def insert(table, values = None, **kwargs):
+def insert(table, values=None, inline=False):
"""Return an [sqlalchemy.sql#Insert] clause element.
Similar functionality is available via the ``insert()`` method on
@@ -266,6 +266,10 @@ def insert(table, values = None, **kwargs):
bind parameters also are None during the compile phase, then the
column specifications will be generated from the full list of
table columns.
+
+ inline
+ if True, SQL defaults will be compiled 'inline' into the statement
+ and not pre-executed.
If both `values` and compile-time bind parameters are present, the
compile-time bind parameters override the information specified
@@ -283,9 +287,9 @@ def insert(table, values = None, **kwargs):
against the ``INSERT`` statement.
"""
- return Insert(table, values, **kwargs)
+ return Insert(table, values, inline=inline)
-def update(table, whereclause = None, values = None, **kwargs):
+def update(table, whereclause=None, values=None, inline=False):
"""Return an [sqlalchemy.sql#Update] clause element.
Similar functionality is available via the ``update()`` method on
@@ -307,6 +311,11 @@ def update(table, whereclause = None, values = None, **kwargs):
``SET`` conditions will be generated from the full list of table
columns.
+ inline
+ if True, SQL defaults will be compiled 'inline' into the statement
+ and not pre-executed.
+
+
If both `values` and compile-time bind parameters are present, the
compile-time bind parameters override the information specified
within `values` on a per-key basis.
@@ -323,7 +332,7 @@ def update(table, whereclause = None, values = None, **kwargs):
against the ``UPDATE`` statement.
"""
- return Update(table, whereclause, values, **kwargs)
+ return Update(table, whereclause=whereclause, values=values, inline=inline)
def delete(table, whereclause = None, **kwargs):
"""Return a [sqlalchemy.sql#Delete] clause element.
@@ -959,14 +968,14 @@ class ClauseElement(object):
compile_params = multiparams[0]
else:
compile_params = params
- return self.compile(bind=self.bind, parameters=compile_params).execute(*multiparams, **params)
+ return self.compile(bind=self.bind, parameters=compile_params, inline=(len(multiparams) > 1)).execute(*multiparams, **params)
def scalar(self, *multiparams, **params):
"""Compile and execute this ``ClauseElement``, returning the result's scalar representation."""
return self.execute(*multiparams, **params).scalar()
- def compile(self, bind=None, parameters=None, compiler=None, dialect=None):
+ def compile(self, bind=None, parameters=None, compiler=None, dialect=None, inline=False):
"""Compile this SQL expression.
Uses the given ``Compiler``, or the given ``AbstractDialect``
@@ -995,16 +1004,16 @@ class ClauseElement(object):
if compiler is None:
if dialect is not None:
- compiler = dialect.statement_compiler(dialect, self, parameters)
+ compiler = dialect.statement_compiler(dialect, self, parameters, inline=inline)
elif bind is not None:
- compiler = bind.statement_compiler(self, parameters)
+ compiler = bind.statement_compiler(self, parameters, inline=inline)
elif self.bind is not None:
- compiler = self.bind.statement_compiler(self, parameters)
+ compiler = self.bind.statement_compiler(self, parameters, inline=inline)
if compiler is None:
from sqlalchemy.engine.default import DefaultDialect
dialect = DefaultDialect()
- compiler = dialect.statement_compiler(dialect, self, parameters=parameters)
+ compiler = dialect.statement_compiler(dialect, self, parameters=parameters, inline=inline)
compiler.compile()
return compiler
@@ -2705,13 +2714,13 @@ class TableClause(FromClause):
def select(self, whereclause = None, **params):
return select([self], whereclause, **params)
- def insert(self, values = None):
- return insert(self, values=values)
+ def insert(self, values=None, inline=False):
+ return insert(self, values=values, inline=inline)
- def update(self, whereclause = None, values = None):
- return update(self, whereclause, values)
+ def update(self, whereclause=None, values=None, inline=False):
+ return update(self, whereclause=whereclause, values=values, inline=inline)
- def delete(self, whereclause = None):
+ def delete(self, whereclause=None):
return delete(self, whereclause)
def _get_from_objects(self, **modifiers):
@@ -3213,41 +3222,26 @@ class _UpdateBase(ClauseElement):
return iter([self.table])
def _process_colparams(self, parameters):
- """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters."""
if parameters is None:
return None
if isinstance(parameters, (list, tuple)):
pp = {}
- i = 0
- for c in self.table.c:
+ for i, c in enumerate(self.table.c):
pp[c.key] = parameters[i]
- i +=1
- parameters = pp
-
- for key in parameters.keys():
- value = parameters[key]
- if isinstance(value, ClauseElement):
- parameters[key] = value.self_group()
- elif _is_literal(value):
- if _is_literal(key):
- col = self.table.c[key]
- else:
- col = key
- try:
- parameters[key] = bindparam(col, value, unique=True)
- except KeyError:
- del parameters[key]
- return parameters
-
+ return pp
+ else:
+ return parameters
+
def _find_engine(self):
return self.table.bind
class Insert(_UpdateBase):
- def __init__(self, table, values=None):
+ def __init__(self, table, values=None, inline=False):
self.table = table
self.select = None
+ self.inline=inline
self.parameters = self._process_colparams(values)
def get_children(self, **kwargs):
@@ -3271,9 +3265,10 @@ class Insert(_UpdateBase):
return u
class Update(_UpdateBase):
- def __init__(self, table, whereclause, values=None):
+ def __init__(self, table, whereclause, values=None, inline=False):
self.table = table
self._whereclause = whereclause
+ self.inline = inline
self.parameters = self._process_colparams(values)
def get_children(self, **kwargs):