summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
commited4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch)
treec1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy
parent3a8e235af64e36b3b711df1f069d32359fe6c967 (diff)
downloadsqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/__init__.py8
-rw-r--r--lib/sqlalchemy/ansisql.py749
-rw-r--r--lib/sqlalchemy/databases/firebird.py57
-rw-r--r--lib/sqlalchemy/databases/information_schema.py13
-rw-r--r--lib/sqlalchemy/databases/informix.py63
-rw-r--r--lib/sqlalchemy/databases/mssql.py113
-rw-r--r--lib/sqlalchemy/databases/mysql.py107
-rw-r--r--lib/sqlalchemy/databases/oracle.py255
-rw-r--r--lib/sqlalchemy/databases/postgres.py293
-rw-r--r--lib/sqlalchemy/databases/sqlite.py41
-rw-r--r--lib/sqlalchemy/engine/__init__.py1
-rw-r--r--lib/sqlalchemy/engine/base.py633
-rw-r--r--lib/sqlalchemy/engine/default.py147
-rw-r--r--lib/sqlalchemy/engine/strategies.py21
-rw-r--r--lib/sqlalchemy/engine/threadlocal.py23
-rw-r--r--lib/sqlalchemy/engine/url.py8
-rw-r--r--lib/sqlalchemy/ext/activemapper.py11
-rw-r--r--lib/sqlalchemy/ext/assignmapper.py59
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py88
-rw-r--r--lib/sqlalchemy/ext/proxy.py113
-rw-r--r--lib/sqlalchemy/ext/selectresults.py218
-rw-r--r--lib/sqlalchemy/ext/sessioncontext.py28
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py13
-rw-r--r--lib/sqlalchemy/mods/legacy_session.py176
-rw-r--r--lib/sqlalchemy/mods/selectresults.py2
-rw-r--r--lib/sqlalchemy/mods/threadlocal.py53
-rw-r--r--lib/sqlalchemy/orm/__init__.py486
-rw-r--r--lib/sqlalchemy/orm/attributes.py836
-rw-r--r--lib/sqlalchemy/orm/collections.py1182
-rw-r--r--lib/sqlalchemy/orm/dependency.py6
-rw-r--r--lib/sqlalchemy/orm/interfaces.py496
-rw-r--r--lib/sqlalchemy/orm/mapper.py921
-rw-r--r--lib/sqlalchemy/orm/properties.py269
-rw-r--r--lib/sqlalchemy/orm/query.py1236
-rw-r--r--lib/sqlalchemy/orm/session.py224
-rw-r--r--lib/sqlalchemy/orm/shard.py112
-rw-r--r--lib/sqlalchemy/orm/strategies.py671
-rw-r--r--lib/sqlalchemy/orm/sync.py13
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py61
-rw-r--r--lib/sqlalchemy/orm/util.py187
-rw-r--r--lib/sqlalchemy/pool.py22
-rw-r--r--lib/sqlalchemy/schema.py375
-rw-r--r--lib/sqlalchemy/sql.py1946
-rw-r--r--lib/sqlalchemy/sql_util.py42
-rw-r--r--lib/sqlalchemy/topological.py3
-rw-r--r--lib/sqlalchemy/types.py98
-rw-r--r--lib/sqlalchemy/util.py93
47 files changed, 7159 insertions, 5413 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index ad2615131..6e95fd7e1 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -7,10 +7,8 @@
from sqlalchemy.types import *
from sqlalchemy.sql import *
from sqlalchemy.schema import *
-from sqlalchemy.orm import *
from sqlalchemy.engine import create_engine
-from sqlalchemy.schema import default_metadata
def __figure_version():
try:
@@ -25,8 +23,6 @@ def __figure_version():
return '(not installed)'
except:
return '(not installed)'
-
+
__version__ = __figure_version()
-
-def global_connect(*args, **kwargs):
- default_metadata.connect(*args, **kwargs)
+
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 9994d5288..22227d56a 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -10,9 +10,10 @@ Contains default implementations for the abstract objects in the sql
module.
"""
-from sqlalchemy import schema, sql, engine, util, sql_util, exceptions
+import string, re, sets, operator
+
+from sqlalchemy import schema, sql, engine, util, exceptions
from sqlalchemy.engine import default
-import string, re, sets, weakref, random
ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
@@ -40,6 +41,41 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array',
LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$')
ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE)
+
+OPERATORS = {
+ operator.and_ : 'AND',
+ operator.or_ : 'OR',
+ operator.inv : 'NOT',
+ operator.add : '+',
+ operator.mul : '*',
+ operator.sub : '-',
+ operator.div : '/',
+ operator.mod : '%',
+ operator.truediv : '/',
+ operator.lt : '<',
+ operator.le : '<=',
+ operator.ne : '!=',
+ operator.gt : '>',
+ operator.ge : '>=',
+ operator.eq : '=',
+ sql.ColumnOperators.concat_op : '||',
+ sql.ColumnOperators.like_op : 'LIKE',
+ sql.ColumnOperators.notlike_op : 'NOT LIKE',
+ sql.ColumnOperators.ilike_op : 'ILIKE',
+ sql.ColumnOperators.notilike_op : 'NOT ILIKE',
+ sql.ColumnOperators.between_op : 'BETWEEN',
+ sql.ColumnOperators.in_op : 'IN',
+ sql.ColumnOperators.notin_op : 'NOT IN',
+ sql.ColumnOperators.comma_op : ', ',
+ sql.Operators.from_ : 'FROM',
+ sql.Operators.as_ : 'AS',
+ sql.Operators.exists : 'EXISTS',
+ sql.Operators.is_ : 'IS',
+ sql.Operators.isnot : 'IS NOT'
+}
+
class ANSIDialect(default.DefaultDialect):
def __init__(self, cache_identifiers=True, **kwargs):
super(ANSIDialect,self).__init__(**kwargs)
@@ -66,14 +102,16 @@ class ANSIDialect(default.DefaultDialect):
"""
return ANSIIdentifierPreparer(self)
-class ANSICompiler(sql.Compiled):
+class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
"""Default implementation of Compiled.
Compiles ClauseElements into ANSI-compliant SQL strings.
"""
- __traverse_options__ = {'column_collections':False}
+ __traverse_options__ = {'column_collections':False, 'entry':True}
+ operators = OPERATORS
+
def __init__(self, dialect, statement, parameters=None, **kwargs):
"""Construct a new ``ANSICompiler`` object.
@@ -92,7 +130,7 @@ class ANSICompiler(sql.Compiled):
correspond to the keys present in the parameters.
"""
- sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+ super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
# if we are insert/update. set to true when we visit an INSERT or UPDATE
self.isinsert = self.isupdate = False
@@ -104,21 +142,6 @@ class ANSICompiler(sql.Compiled):
# actually present in the generated SQL
self.bind_names = {}
- # a dictionary which stores the string representation for every ClauseElement
- # processed by this compiler.
- self.strings = {}
-
- # a dictionary which stores the string representation for ClauseElements
- # processed by this compiler, which are to be used in the FROM clause
- # of a select. items are often placed in "froms" as well as "strings"
- # and sometimes with different representations.
- self.froms = {}
-
- # slightly hacky. maps FROM clauses to WHERE clauses, and used in select
- # generation to modify the WHERE clause of the select. currently a hack
- # used by the oracle module.
- self.wheres = {}
-
# when the compiler visits a SELECT statement, the clause object is appended
# to this stack. various visit operations will check this stack to determine
# additional choices (TODO: it seems to be all typemap stuff. shouldnt this only
@@ -137,12 +160,6 @@ class ANSICompiler(sql.Compiled):
# for aliases
self.generated_ids = {}
- # True if this compiled represents an INSERT
- self.isinsert = False
-
- # True if this compiled represents an UPDATE
- self.isupdate = False
-
# default formatting style for bind parameters
self.bindtemplate = ":%s"
@@ -158,64 +175,76 @@ class ANSICompiler(sql.Compiled):
# an ANSIIdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
-
+
+ # a dictionary containing attributes about all select()
+ # elements located within the clause, regarding which are subqueries, which are
+ # selected from, and which elements should be correlated to an enclosing select.
+ # used mostly to determine the list of FROM elements for each select statement, as well
+ # as some dialect-specific rules regarding subqueries.
+ self.correlate_state = {}
+
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
# default-value logic in the Dialect knows not to fire off column defaults
# and also knows postfetching will be needed to get the values represented by these
# parameters.
self.inline_params = None
-
+
def after_compile(self):
# this re will search for params like :param
# it has a negative lookbehind for an extra ':' so that it doesnt match
# postgres '::text' tokens
- match = re.compile(r'(?<!:):([\w_]+)', re.UNICODE)
+ text = self.string
+ if ':' not in text:
+ return
+
if self.paramstyle=='pyformat':
- self.strings[self.statement] = match.sub(lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
+ text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text)
elif self.positional:
- params = match.finditer(self.strings[self.statement])
+ params = BIND_PARAMS.finditer(text)
for p in params:
self.positiontup.append(p.group(1))
if self.paramstyle=='qmark':
- self.strings[self.statement] = match.sub('?', self.strings[self.statement])
+ text = BIND_PARAMS.sub('?', text)
elif self.paramstyle=='format':
- self.strings[self.statement] = match.sub('%s', self.strings[self.statement])
+ text = BIND_PARAMS.sub('%s', text)
elif self.paramstyle=='numeric':
i = [0]
def getnum(x):
i[0] += 1
return str(i[0])
- self.strings[self.statement] = match.sub(getnum, self.strings[self.statement])
-
- def get_from_text(self, obj):
- return self.froms.get(obj, None)
-
- def get_str(self, obj):
- return self.strings[obj]
-
+ text = BIND_PARAMS.sub(getnum, text)
+ # un-escape any \:params
+ text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
+ self.string = text
+
+ def compile(self):
+ self.string = self.process(self.statement)
+ self.after_compile()
+
+ def process(self, obj, **kwargs):
+ return self.traverse_single(obj, **kwargs)
+
+ def is_subquery(self, select):
+ return self.correlate_state[select].get('is_subquery', False)
+
def get_whereclause(self, obj):
- return self.wheres.get(obj, None)
+ """given a FROM clause, return an additional WHERE condition that should be
+ applied to a SELECT.
+
+ Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN
+ constructs in non-ansi mode.
+ """
+
+ return None
def construct_params(self, params):
- """Return a structure of bind parameters for this compiled object.
-
- This includes bind parameters that might be compiled in via
- the `values` argument of an ``Insert`` or ``Update`` statement
- object, and also the given `**params`. The keys inside of
- `**params` can be any key that matches the
- ``BindParameterClause`` objects compiled within this object.
-
- The output is dependent on the paramstyle of the DBAPI being
- used; if a named style, the return result will be a dictionary
- with keynames matching the compiled statement. If a
- positional style, the output will be a list, with an iterator
- that will return parameter values in an order corresponding to
- the bind positions in the compiled statement.
-
- For an executemany style of call, this method should be called
- for each element in the list of parameter groups that will
- ultimately be executed.
+ """Return a sql.ClauseParameters object.
+
+ Combines the given bind parameter dictionary (string keys to object values)
+ with the _BindParamClause objects stored within this Compiled object
+ to produce a ClauseParameters structure, representing the bind arguments
+ for a single statement execution, or one element of an executemany execution.
"""
if self.parameters is not None:
@@ -225,7 +254,7 @@ class ANSICompiler(sql.Compiled):
bindparams.update(params)
d = sql.ClauseParameters(self.dialect, self.positiontup)
for b in self.binds.values():
- name = self.bind_names.get(b, b.key)
+ name = self.bind_names[b]
d.set_parameter(b, b.value, name)
for key, value in bindparams.iteritems():
@@ -233,7 +262,7 @@ class ANSICompiler(sql.Compiled):
b = self.binds[key]
except KeyError:
continue
- name = self.bind_names.get(b, b.key)
+ name = self.bind_names[b]
d.set_parameter(b, value, name)
return d
@@ -246,8 +275,8 @@ class ANSICompiler(sql.Compiled):
return ""
- def visit_grouping(self, grouping):
- self.strings[grouping] = "(" + self.strings[grouping.elem] + ")"
+ def visit_grouping(self, grouping, **kwargs):
+ return "(" + self.process(grouping.elem) + ")"
def visit_label(self, label):
labelname = self._truncated_identifier("colident", label.name)
@@ -256,9 +285,10 @@ class ANSICompiler(sql.Compiled):
self.typemap.setdefault(labelname.lower(), label.obj.type)
if isinstance(label.obj, sql._ColumnClause):
self.column_labels[label.obj._label] = labelname
- self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname)
+ self.column_labels[label.name] = labelname
+ return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
- def visit_column(self, column):
+ def visit_column(self, column, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
# want to truncate a column identifier, if its mapped to the name of a
# physical column. but thats very hard to identify at this point, and
@@ -269,107 +299,110 @@ class ANSICompiler(sql.Compiled):
else:
name = column.name
+ if len(self.select_stack):
+ # if we are within a visit to a Select, set up the "typemap"
+ # for this column which is used to translate result set values
+ self.typemap.setdefault(name.lower(), column.type)
+ self.column_labels.setdefault(column._label, name.lower())
+
if column.table is None or not column.table.named_with_column():
- self.strings[column] = self.preparer.format_column(column, name=name)
+ return self.preparer.format_column(column, name=name)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name(column)
if n is not None:
- self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
+ return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
elif len(column.table.primary_key) != 0:
pk = list(column.table.primary_key)[0]
pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
- self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname)
+ return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
else:
- self.strings[column] = None
+ return None
else:
- self.strings[column] = self.preparer.format_column_with_table(column, column_name=name)
+ return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
- if len(self.select_stack):
- # if we are within a visit to a Select, set up the "typemap"
- # for this column which is used to translate result set values
- self.typemap.setdefault(name.lower(), column.type)
- self.column_labels.setdefault(column._label, name.lower())
- def visit_fromclause(self, fromclause):
- self.froms[fromclause] = fromclause.name
+ def visit_fromclause(self, fromclause, **kwargs):
+ return fromclause.name
- def visit_index(self, index):
- self.strings[index] = index.name
+ def visit_index(self, index, **kwargs):
+ return index.name
- def visit_typeclause(self, typeclause):
- self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
+ def visit_typeclause(self, typeclause, **kwargs):
+ return typeclause.type.dialect_impl(self.dialect).get_col_spec()
- def visit_textclause(self, textclause):
- self.strings[textclause] = textclause.text
- self.froms[textclause] = textclause.text
+ def visit_textclause(self, textclause, **kwargs):
+ for bind in textclause.bindparams.values():
+ self.process(bind)
if textclause.typemap is not None:
self.typemap.update(textclause.typemap)
+ return textclause.text
- def visit_null(self, null):
- self.strings[null] = 'NULL'
+ def visit_null(self, null, **kwargs):
+ return 'NULL'
- def visit_clauselist(self, list):
- sep = list.operator
- if sep == ',':
- sep = ', '
- elif sep is None or sep == " ":
+ def visit_clauselist(self, clauselist, **kwargs):
+ sep = clauselist.operator
+ if sep is None:
sep = " "
+ elif sep == sql.ColumnOperators.comma_op:
+ sep = ', '
else:
- sep = " " + sep + " "
- self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep)
+ sep = " " + self.operator_string(clauselist.operator) + " "
+ return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
- def visit_calculatedclause(self, clause):
- self.strings[clause] = self.get_str(clause.clause_expr)
+ def visit_calculatedclause(self, clause, **kwargs):
+ return self.process(clause.clause_expr)
- def visit_cast(self, cast):
+ def visit_cast(self, cast, **kwargs):
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
- self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause])
+ return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
- def visit_function(self, func):
+ def visit_function(self, func, **kwargs):
if len(self.select_stack):
self.typemap.setdefault(func.name, func.type)
if not self.apply_function_parens(func):
- self.strings[func] = ".".join(func.packagenames + [func.name])
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name])
else:
- self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr)
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
- def visit_compound_select(self, cs):
- text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
- group_by = self.get_str(cs.group_by_clause)
+ def visit_compound_select(self, cs, asfrom=False, **kwargs):
+ text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
+ group_by = self.process(cs._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
- text += self.visit_select_postclauses(cs)
- self.strings[cs] = text
- self.froms[cs] = "(" + text + ")"
+ text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
+
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_unary(self, unary):
- s = self.get_str(unary.element)
+ def visit_unary(self, unary, **kwargs):
+ s = self.process(unary.element)
if unary.operator:
- s = unary.operator + " " + s
+ s = self.operator_string(unary.operator) + " " + s
if unary.modifier:
s = s + " " + unary.modifier
- self.strings[unary] = s
+ return s
- def visit_binary(self, binary):
- result = self.get_str(binary.left)
- if binary.operator is not None:
- result += " " + self.binary_operator_string(binary)
- result += " " + self.get_str(binary.right)
- self.strings[binary] = result
-
- def binary_operator_string(self, binary):
- return binary.operator
+ def visit_binary(self, binary, **kwargs):
+ op = self.operator_string(binary.operator)
+ if callable(op):
+ return op(self.process(binary.left), self.process(binary.right))
+ else:
+ return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+
+ def operator_string(self, operator):
+ return self.operators.get(operator, str(operator))
- def visit_bindparam(self, bindparam):
+ def visit_bindparam(self, bindparam, **kwargs):
# apply truncation to the ultimate generated name
if bindparam.shortname != bindparam.key:
@@ -378,7 +411,6 @@ class ANSICompiler(sql.Compiled):
if bindparam.unique:
count = 1
key = bindparam.key
-
# redefine the generated name of the bind param in the case
# that we have multiple conflicting bind parameters.
while self.binds.setdefault(key, bindparam) is not bindparam:
@@ -386,164 +418,167 @@ class ANSICompiler(sql.Compiled):
key = bindparam.key + tag
count += 1
bindparam.key = key
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
else:
existing = self.binds.get(bindparam.key)
if existing is not None and existing.unique:
raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
self.binds[bindparam.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
bind_name = bindparam.key
- if len(bind_name) > self.dialect.max_identifier_length():
- bind_name = self._truncated_identifier("bindparam", bind_name)
- # add to bind_names for translation
- self.bind_names[bindparam] = bind_name
+ bind_name = self._truncated_identifier("bindparam", bind_name)
+ # add to bind_names for translation
+ self.bind_names[bindparam] = bind_name
+
return bind_name
def _truncated_identifier(self, ident_class, name):
if (ident_class, name) in self.generated_ids:
return self.generated_ids[(ident_class, name)]
- if len(name) > self.dialect.max_identifier_length():
+
+ anonname = self._anonymize(name)
+ if len(anonname) > self.dialect.max_identifier_length():
counter = self.generated_ids.get(ident_class, 1)
truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
self.generated_ids[ident_class] = counter + 1
else:
- truncname = name
+ truncname = anonname
self.generated_ids[(ident_class, name)] = truncname
return truncname
+
+ def _anonymize(self, name):
+ def anon(match):
+ (ident, derived) = match.group(1,2)
+ if ('anonymous', ident) in self.generated_ids:
+ return self.generated_ids[('anonymous', ident)]
+ else:
+ anonymous_counter = self.generated_ids.get('anonymous', 1)
+ newname = derived + "_" + str(anonymous_counter)
+ self.generated_ids['anonymous'] = anonymous_counter + 1
+ self.generated_ids[('anonymous', ident)] = newname
+ return newname
+ return re.sub(r'{ANON (-?\d+) (.*)}', anon, name)
def bindparam_string(self, name):
return self.bindtemplate % name
- def visit_alias(self, alias):
- self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
- self.strings[alias] = self.get_str(alias.original)
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ if asfrom:
+ return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
+ else:
+ return self.process(alias.original, **kwargs)
- def visit_select(self, select):
- # the actual list of columns to print in the SELECT column list.
- inner_columns = util.OrderedDict()
+ def label_select_column(self, select, column):
+ """convert a column from a select's "columns" clause.
+
+ given a select() and a column element from its inner_columns collection, return a
+ Label object if this column should be labeled in the columns clause. Otherwise,
+ return None and the column will be used as-is.
+
+ The calling method will traverse the returned label to acquire its string
+ representation.
+ """
+
+ # SQLite doesnt like selecting from a subquery where the column
+ # names look like table.colname. so if column is in a "selected from"
+ # subquery, label it synoymously with its column name
+ if \
+ self.correlate_state[select].get('is_selected_from', False) and \
+ isinstance(column, sql._ColumnClause) and \
+ not column.is_literal and \
+ column.table is not None and \
+ not isinstance(column.table, sql.Select):
+ return column.label(column.name)
+ else:
+ return None
+
+ def visit_select(self, select, asfrom=False, **kwargs):
+ select._calculate_correlations(self.correlate_state)
self.select_stack.append(select)
- for c in select._raw_columns:
- if hasattr(c, '_selectable'):
- s = c._selectable()
- else:
- self.traverse(c)
- inner_columns[self.get_str(c)] = c
- continue
- for co in s.columns:
- if select.use_labels:
- labelname = co._label
- if labelname is not None:
- l = co.label(labelname)
- self.traverse(l)
- inner_columns[labelname] = l
- else:
- self.traverse(co)
- inner_columns[self.get_str(co)] = co
- # TODO: figure this out, a ColumnClause with a select as a parent
- # is different from any other kind of parent
- elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select):
- # SQLite doesnt like selecting from a subquery where the column
- # names look like table.colname, so add a label synonomous with
- # the column name
- l = co.label(co.name)
- self.traverse(l)
- inner_columns[self.get_str(l.obj)] = l
+
+ # the actual list of columns to print in the SELECT column list.
+ inner_columns = util.OrderedSet()
+
+ froms = select._get_display_froms(self.correlate_state)
+
+ for co in select.inner_columns:
+ if select.use_labels:
+ labelname = co._label
+ if labelname is not None:
+ l = co.label(labelname)
+ inner_columns.add(self.process(l))
else:
self.traverse(co)
- inner_columns[self.get_str(co)] = co
+ inner_columns.add(self.process(co))
+ else:
+ l = self.label_select_column(select, co)
+ if l is not None:
+ inner_columns.add(self.process(l))
+ else:
+ inner_columns.add(self.process(co))
+
self.select_stack.pop(-1)
- collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
+ collist = string.join(inner_columns.difference(util.Set([None])), ', ')
- text = "SELECT "
- text += self.visit_select_precolumns(select)
+ text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " "
+ text += self.get_select_precolumns(select)
text += collist
- whereclause = select.whereclause
-
- froms = []
- for f in select.froms:
-
- if self.parameters is not None:
- # TODO: whack this feature in 0.4
- # look at our own parameters, see if they
- # are all present in the form of BindParamClauses. if
- # not, then append to the above whereclause column conditions
- # matching those keys
- for c in f.columns:
- if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
- value = self.parameters[c.key]
- else:
- continue
- clause = c==value
- if whereclause is not None:
- whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause]))
- else:
- whereclause = clause
- self.traverse(whereclause)
-
- # special thingy used by oracle to redefine a join
+ whereclause = select._whereclause
+
+ from_strings = []
+ for f in froms:
+ from_strings.append(self.process(f, asfrom=True))
+
w = self.get_whereclause(f)
if w is not None:
- # TODO: move this more into the oracle module
if whereclause is not None:
- whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w]))
+ whereclause = sql.and_(w, whereclause)
else:
whereclause = w
- t = self.get_from_text(f)
- if t is not None:
- froms.append(t)
-
if len(froms):
text += " \nFROM "
- text += string.join(froms, ', ')
+ text += string.join(from_strings, ', ')
else:
text += self.default_from()
if whereclause is not None:
- t = self.get_str(whereclause)
+ t = self.process(whereclause)
if t:
text += " \nWHERE " + t
- group_by = self.get_str(select.group_by_clause)
+ group_by = self.process(select._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
- if select.having is not None:
- t = self.get_str(select.having)
+ if select._having is not None:
+ t = self.process(select._having)
if t:
text += " \nHAVING " + t
text += self.order_by_clause(select)
- text += self.visit_select_postclauses(select)
+ text += (select._limit or select._offset) and self.limit_clause(select) or ""
text += self.for_update_clause(select)
- self.strings[select] = text
- self.froms[select] = "(" + text + ")"
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just before column list."""
-
- return select.distinct and "DISTINCT " or ""
-
- def visit_select_postclauses(self, select):
- """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
-
- Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
- """
-
- return (select.limit or select.offset) and self.limit_clause(select) or ""
+ return select._distinct and "DISTINCT " or ""
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.process(select._order_by_clause)
if order_by:
return " ORDER BY " + order_by
else:
@@ -557,175 +592,103 @@ class ANSICompiler(sql.Compiled):
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def visit_table(self, table):
- self.froms[table] = self.preparer.format_table(table)
- self.strings[table] = ""
-
- def visit_join(self, join):
- righttext = self.get_from_text(join.right)
- if join.right._group_parenthesized():
- righttext = "(" + righttext + ")"
- if join.isouter:
- self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext +
- " ON " + self.get_str(join.onclause))
+ def visit_table(self, table, asfrom=False, **kwargs):
+ if asfrom:
+ return self.preparer.format_table(table)
else:
- self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
- " ON " + self.get_str(join.onclause))
- self.strings[join] = self.froms[join]
-
- def visit_insert_column_default(self, column, default, parameters):
- """Called when visiting an ``Insert`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object, add a blank *placeholder* parameter so the ``Insert``
- gets compiled with this column's name in its column and
- ``VALUES`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_update_column_default(self, column, default, parameters):
- """Called when visiting an ``Update`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object as an onupdate, add a blank *placeholder* parameter so
- the ``Update`` gets compiled with this column's name as one of
- its ``SET`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_insert_sequence(self, column, sequence, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden compilers that support sequences to
- place a blank *placeholder* parameter for each column in the
- table that contains a Sequence object, so the Insert gets
- compiled with this column's name in its column and ``VALUES``
- clauses.
- """
-
- pass
-
- def visit_insert_column(self, column, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden by compilers who disallow NULL columns
- being set in an ``Insert`` where there is a default value on
- the column (i.e. postgres), to remove the column for which
- there is a NULL insert from the parameter list.
- """
+ return ""
- pass
+ def visit_join(self, join, asfrom=False, **kwargs):
+ return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
+ def uses_sequences_for_inserts(self):
+ return False
+
def visit_insert(self, insert_stmt):
- # scan the table's columns for defaults that have to be pre-set for an INSERT
- # add these columns to the parameter list via visit_insert_XXX methods
- default_params = {}
+
+ # search for columns who will be required to have an explicit bound value.
+ # for inserts, this includes Python-side defaults, columns with sequences for dialects
+ # that support sequences, and primary key columns for dialects that explicitly insert
+ # pre-generated primary key values
+ required_cols = util.Set()
class DefaultVisitor(schema.SchemaVisitor):
- def visit_column(s, c):
- self.visit_insert_column(c, default_params)
+ def visit_column(s, cd):
+ if c.primary_key and self.uses_sequences_for_inserts():
+ required_cols.add(c)
def visit_column_default(s, cd):
- self.visit_insert_column_default(c, cd, default_params)
+ required_cols.add(c)
def visit_sequence(s, seq):
- self.visit_insert_sequence(c, seq, default_params)
+ if self.uses_sequences_for_inserts():
+ required_cols.add(c)
vis = DefaultVisitor()
for c in insert_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, default_params)
+ colparams = self._get_colparams(insert_stmt, required_cols)
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- if p.shortname is not None:
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.inline_params.add(col)
- self.traverse(p)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.get_str(p) + ")"
- else:
- return self.get_str(p)
-
- text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
- " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")")
-
- self.strings[insert_stmt] = text
+ return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
+ " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
def visit_update(self, update_stmt):
- # scan the table's columns for onupdates that have to be pre-set for an UPDATE
- # add these columns to the parameter list via visit_update_XXX methods
- default_params = {}
+ update_stmt._calculate_correlations(self.correlate_state)
+
+ # search for columns who will be required to have an explicit bound value.
+ # for updates, this includes Python-side "onupdate" defaults.
+ required_cols = util.Set()
class OnUpdateVisitor(schema.SchemaVisitor):
def visit_column_onupdate(s, cd):
- self.visit_update_column_default(c, cd, default_params)
+ required_cols.add(c)
vis = OnUpdateVisitor()
for c in update_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isupdate = True
- colparams = self._get_colparams(update_stmt, default_params)
-
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.traverse(p)
- self.inline_params.add(col)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.get_str(p) + ")"
- else:
- return self.get_str(p)
-
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
-
- if update_stmt.whereclause:
- text += " WHERE " + self.get_str(update_stmt.whereclause)
+ colparams = self._get_colparams(update_stmt, required_cols)
- self.strings[update_stmt] = text
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
+ if update_stmt._whereclause:
+ text += " WHERE " + self.process(update_stmt._whereclause)
- def _get_colparams(self, stmt, default_params):
- """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples.
-
- Each tuple will contain the ``Column`` and a ``ClauseElement``
- representing the value to be set (usually a ``_BindParamClause``,
- but could also be other SQL expressions.)
-
- The list of tuples will determine the columns that are
- actually rendered into the ``SET``/``VALUES`` clause of the
- rendered ``UPDATE``/``INSERT`` statement. It will also
- determine how to generate the list/dictionary of bind
- parameters at execution time (i.e. ``get_params()``).
+ return text
- This list takes into account the `values` keyword specified
- to the statement, the parameters sent to this Compiled
- instance, and the default bind parameter values corresponding
- to the dialect's behavior for otherwise unspecified primary
- key columns.
+ def _get_colparams(self, stmt, required_cols):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ This method may generate new bind params within this compiled
+ based on the given set of "required columns", which are required
+ to have a value set in the statement.
"""
+ def create_bind_param(col, value):
+ bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True)
+ self.binds[col.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.parameters is None and stmt.parameters is None:
- return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns]
+ return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
+
+ def create_clause_param(col, value):
+ self.traverse(value)
+ self.inline_params.add(col)
+ return self.process(value)
+
+ self.inline_params = util.Set()
def to_col(key):
if not isinstance(key, sql._ColumnClause):
@@ -744,29 +707,43 @@ class ANSICompiler(sql.Compiled):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(to_col(k), v)
- for k, v in default_params.iteritems():
- parameters.setdefault(to_col(k), v)
+ for col in required_cols:
+ parameters.setdefault(col, None)
# create a list of column assignment clauses as tuples
values = []
for c in stmt.table.columns:
- if parameters.has_key(c):
+ if c in parameters:
value = parameters[c]
if sql._is_literal(value):
- value = sql.bindparam(c.key, value, type=c.type, unique=True)
+ value = create_bind_param(c, value)
+ else:
+ value = create_clause_param(c, value)
values.append((c, value))
+
return values
def visit_delete(self, delete_stmt):
+ delete_stmt._calculate_correlations(self.correlate_state)
+
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
- if delete_stmt.whereclause:
- text += " WHERE " + self.get_str(delete_stmt.whereclause)
+ if delete_stmt._whereclause:
+ text += " WHERE " + self.process(delete_stmt._whereclause)
- self.strings[delete_stmt] = text
+ return text
+
+ def visit_savepoint(self, savepoint_stmt):
+ return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+
+ def visit_release_savepoint(self, savepoint_stmt):
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+
def __str__(self):
- return self.get_str(self.statement)
+ return self.string
class ANSISchemaBase(engine.SchemaIterator):
def find_alterables(self, tables):
@@ -795,7 +772,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
@@ -803,9 +780,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
- #if column.onupdate is not None:
- # column.onupdate.accept_visitor(visitor)
+ self.traverse_single(column.default)
self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
@@ -820,20 +795,20 @@ class ANSISchemaGenerator(ANSISchemaBase):
if column.primary_key:
first_pk = True
for constraint in column.constraints:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
if len(table.primary_key):
- table.primary_key.accept_visitor(self)
+ self.traverse_single(table.primary_key)
for constraint in [c for c in table.constraints if c is not table.primary_key]:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
self.append("\n)%s\n\n" % self.post_create_table(table))
self.execute()
if hasattr(table, 'indexes'):
for index in table.indexes:
- index.accept_visitor(self)
+ self.traverse_single(index)
def post_create_table(self, table):
return ''
@@ -870,7 +845,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+ self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
def visit_foreign_key_constraint(self, constraint):
if constraint.use_alter and self.dialect.supports_alter():
@@ -889,9 +864,9 @@ class ANSISchemaGenerator(ANSISchemaBase):
self.append("CONSTRAINT %s " %
preparer.format_constraint(constraint))
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- string.join([preparer.format_column(f.parent) for f in constraint.elements], ', '),
+ ', '.join([preparer.format_column(f.parent) for f in constraint.elements]),
preparer.format_table(list(constraint.elements)[0].column.table),
- string.join([preparer.format_column(f.column) for f in constraint.elements], ', ')
+ ', '.join([preparer.format_column(f.column) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -903,17 +878,17 @@ class ANSISchemaGenerator(ANSISchemaBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint])))
def visit_column(self, column):
pass
def visit_index(self, index):
- preparer = self.preparer
- self.append('CREATE ')
+ preparer = self.preparer
+ self.append("CREATE ")
if index.unique:
- self.append('UNIQUE ')
- self.append('INDEX %s ON %s (%s)' \
+ self.append("UNIQUE ")
+ self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
string.join([preparer.format_column(c) for c in index.columns], ', ')))
@@ -933,7 +908,7 @@ class ANSISchemaDropper(ANSISchemaBase):
for alterable in self.find_alterables(collection):
self.drop_foreignkey(alterable)
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
def visit_index(self, index):
self.append("\nDROP INDEX " + self.preparer.format_index(index))
@@ -948,7 +923,7 @@ class ANSISchemaDropper(ANSISchemaBase):
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
+ self.traverse_single(column.default)
self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
@@ -1048,17 +1023,17 @@ class ANSIIdentifierPreparer(object):
def should_quote(self, object):
return object.quote or self._requires_quotes(object.name, object.case_sensitive)
- def is_natural_case(self, object):
- return object.quote or self._requires_quotes(object.name, object.case_sensitive)
-
def format_sequence(self, sequence):
return self.__generic_obj_format(sequence, sequence.name)
def format_label(self, label, name=None):
return self.__generic_obj_format(label, name or label.name)
- def format_alias(self, alias):
- return self.__generic_obj_format(alias, alias.name)
+ def format_alias(self, alias, name=None):
+ return self.__generic_obj_format(alias, name or alias.name)
+
+ def format_savepoint(self, savepoint):
+ return self.__generic_obj_format(savepoint, savepoint)
def format_constraint(self, constraint):
return self.__generic_obj_format(constraint, constraint.name)
@@ -1076,25 +1051,25 @@ class ANSIIdentifierPreparer(object):
result = self.__generic_obj_format(table, table.schema) + "." + result
return result
- def format_column(self, column, use_table=False, name=None):
+ def format_column(self, column, use_table=False, name=None, table_name=None):
"""Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name)
else:
return self.__generic_obj_format(column, name)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + name
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + name
else:
return name
- def format_column_with_table(self, column, column_name=None):
+ def format_column_with_table(self, column, column_name=None, table_name=None):
"""Prepare a quoted column name with table name."""
- return self.format_column(column, use_table=True, name=column_name)
+ return self.format_column(column, use_table=True, name=column_name, table_name=table_name)
dialect = ANSIDialect
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index a02781c84..07f07644f 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -5,15 +5,11 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, types
+import warnings
-from sqlalchemy import util
+from sqlalchemy import util, sql, schema, ansisql, exceptions
import sqlalchemy.engine.default as default
-import sqlalchemy.sql as sql
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
_initialized_kb = False
@@ -176,7 +172,7 @@ class FBDialect(ansisql.ANSIDialect):
else:
return False
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
#TODO: map these better
column_func = {
14 : lambda r: sqltypes.String(r['FLEN']), # TEXT
@@ -254,11 +250,20 @@ class FBDialect(ansisql.ANSIDialect):
while row:
name = row['FNAME']
- args = [lower_if_possible(name)]
+ python_name = lower_if_possible(name)
+ if include_columns and python_name not in include_columns:
+ continue
+ args = [python_name]
kw = {}
# get the data types and lengths
- args.append(column_func[row['FTYPE']](row))
+ coltype = column_func.get(row['FTYPE'], None)
+ if coltype is None:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (str(row['FTYPE']), name)))
+ coltype = sqltypes.NULLTYPE
+ else:
+ coltype = coltype(row)
+ args.append(coltype)
# is it a primary key?
kw['primary_key'] = name in pkfields
@@ -301,39 +306,39 @@ class FBDialect(ansisql.ANSIDialect):
class FBCompiler(ansisql.ANSICompiler):
"""Firebird specific idiosincrasies"""
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
# Override to not use the AS keyword which FB 1.5 does not like
- self.froms[alias] = self.get_from_text(alias.original) + " " + self.preparer.format_alias(alias)
- self.strings[alias] = self.get_str(alias.original)
+ if asfrom:
+ return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias)
+ else:
+ return self.process(alias.original, asfrom=True)
def visit_function(self, func):
if len(func.clauses):
- super(FBCompiler, self).visit_function(func)
+ return super(FBCompiler, self).visit_function(func)
else:
- self.strings[func] = func.name
+ return func.name
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ def uses_sequences_for_inserts(self):
+ return True
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right
after the ``SELECT``...
"""
result = ""
- if select.limit:
- result += " FIRST %d " % select.limit
- if select.offset:
- result +=" SKIP %d " % select.offset
- if select.distinct:
+ if select._limit:
+ result += " FIRST %d " % select._limit
+ if select._offset:
+ result +=" SKIP %d " % select._offset
+ if select._distinct:
result += " DISTINCT "
return result
def limit_clause(self, select):
- """Already taken care of in the `visit_select_precolumns` method."""
+ """Already taken care of in the `get_select_precolumns` method."""
return ""
@@ -364,7 +369,7 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper):
class FBDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.connection)
+ c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection)
return self.connection.execute_compiled(c).scalar()
def visit_sequence(self, seq):
diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py
index 81c44dcaa..93f47de15 100644
--- a/lib/sqlalchemy/databases/information_schema.py
+++ b/lib/sqlalchemy/databases/information_schema.py
@@ -1,4 +1,6 @@
-from sqlalchemy import sql, schema, exceptions, select, MetaData, Table, Column, String, Integer
+import sqlalchemy.sql as sql
+import sqlalchemy.exceptions as exceptions
+from sqlalchemy import select, MetaData, Table, Column, String, Integer
from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint
ischema = MetaData()
@@ -96,8 +98,7 @@ class ISchema(object):
return self.cache[name]
-def reflecttable(connection, table, ischema_names):
-
+def reflecttable(connection, table, include_columns, ischema_names):
key_constraints = pg_key_constraints
if table.schema is not None:
@@ -128,7 +129,9 @@ def reflecttable(connection, table, ischema_names):
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
-
+ if include_columns and name not in include_columns:
+ continue
+
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
@@ -139,7 +142,7 @@ def reflecttable(connection, table, ischema_names):
colargs= []
if default is not None:
colargs.append(PassiveDefault(sql.text(default)))
- table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+ table.append_column(Column(name, coltype, nullable=nullable, *colargs))
if not found_table:
raise exceptions.NoSuchTableError(table.name)
diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py
index 2fb508280..f3a6cf60e 100644
--- a/lib/sqlalchemy/databases/informix.py
+++ b/lib/sqlalchemy/databases/informix.py
@@ -5,20 +5,11 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import datetime, warnings
-import sys, StringIO, string , random
-import datetime
-from decimal import Decimal
-
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
+from sqlalchemy import sql, schema, ansisql, exceptions, pool
import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
-import sqlalchemy.pool as pool
# for offset
@@ -128,7 +119,7 @@ class InfoBoolean(sqltypes.Boolean):
elif value is None:
return None
else:
- return value and True or False
+ return value and True or False
colspecs = {
@@ -262,7 +253,7 @@ class InfoDialect(ansisql.ANSIDialect):
cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() )
return bool( cursor.fetchone() is not None )
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() )
rows = c.fetchall()
if not rows :
@@ -289,6 +280,10 @@ class InfoDialect(ansisql.ANSIDialect):
raise exceptions.NoSuchTableError(table.name)
for name , colattr , collength , default , colno in rows:
+ name = name.lower()
+ if include_columns and name not in include_columns:
+ continue
+
# in 7.31, coltype = 0x000
# ^^-- column type
# ^-- 1 not null , 0 null
@@ -306,14 +301,16 @@ class InfoDialect(ansisql.ANSIDialect):
scale = 0
coltype = InfoNumeric(precision, scale)
else:
- coltype = ischema_names.get(coltype)
+ try:
+ coltype = ischema_names[coltype]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
+ coltype = sqltypes.NULLTYPE
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
- name = name.lower()
-
table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs))
# FK
@@ -372,20 +369,20 @@ class InfoCompiler(ansisql.ANSICompiler):
def default_from(self):
return " from systables where tabname = 'systables' "
- def visit_select_precolumns( self , select ):
- s = select.distinct and "DISTINCT " or ""
+ def get_select_precolumns( self , select ):
+ s = select._distinct and "DISTINCT " or ""
# only has limit
- if select.limit:
- off = select.offset or 0
- s += " FIRST %s " % ( select.limit + off )
+ if select._limit:
+ off = select._offset or 0
+ s += " FIRST %s " % ( select._limit + off )
else:
s += ""
return s
def visit_select(self, select):
- if select.offset:
- self.offset = select.offset
- self.limit = select.limit or 0
+ if select._offset:
+ self.offset = select._offset
+ self.limit = select._limit or 0
# the column in order by clause must in select too
def __label( c ):
@@ -393,13 +390,14 @@ class InfoCompiler(ansisql.ANSICompiler):
return c._label.lower()
except:
return ''
-
+
+ # TODO: dont modify the original select, generate a new one
a = [ __label(c) for c in select._raw_columns ]
for c in select.order_by_clause.clauses:
if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
select.append_column( c )
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select)
def limit_clause(self, select):
return ""
@@ -414,23 +412,20 @@ class InfoCompiler(ansisql.ANSICompiler):
def visit_function( self , func ):
if func.name.lower() == 'current_date':
- self.strings[func] = "today"
+ return "today"
elif func.name.lower() == 'current_time':
- self.strings[func] = "CURRENT HOUR TO SECOND"
+ return "CURRENT HOUR TO SECOND"
elif func.name.lower() in ( 'current_timestamp' , 'now' ):
- self.strings[func] = "CURRENT YEAR TO SECOND"
+ return "CURRENT YEAR TO SECOND"
else:
- ansisql.ANSICompiler.visit_function( self , func )
+ return ansisql.ANSICompiler.visit_function( self , func )
def visit_clauselist(self, list):
try:
li = [ c for c in list.clauses if c.name != 'oid' ]
except:
li = [ c for c in list.clauses ]
- if list.parens:
- self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in li] if s is not None ], ', ') + ")"
- else:
- self.strings[list] = string.join([s for s in [self.get_str(c) for c in li] if s is not None], ', ')
+ return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, first_pk=False):
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index ba1c0fd9d..206291404 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -25,7 +25,7 @@
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
-* ``select.limit`` implemented as ``SELECT TOP n``
+* ``select._limit`` implemented as ``SELECT TOP n``
Known issues / TODO:
@@ -39,16 +39,11 @@ Known issues / TODO:
"""
-import sys, StringIO, string, types, re, datetime, random
+import datetime, random, warnings
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+from sqlalchemy import sql, schema, ansisql, exceptions
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
-
+from sqlalchemy.engine import default
class MSNumeric(sqltypes.Numeric):
def convert_result_value(self, value, dialect):
@@ -500,7 +495,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
row = c.fetchone()
return row is not None
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
import sqlalchemy.databases.information_schema as ischema
# Get base columns
@@ -532,16 +527,22 @@ class MSSQLDialect(ansisql.ANSIDialect):
row[columns.c.numeric_scale],
row[columns.c.column_default]
)
+ if include_columns and name not in include_columns:
+ continue
args = []
for a in (charlen, numericprec, numericscale):
if a is not None:
args.append(a)
- coltype = self.ischema_names[type]
+ coltype = self.ischema_names.get(type, None)
if coltype == MSString and charlen == -1:
coltype = MSText()
else:
- if coltype == MSNVarchar and charlen == -1:
+ if coltype is None:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name)))
+ coltype = sqltypes.NULLTYPE
+
+ elif coltype == MSNVarchar and charlen == -1:
charlen = None
coltype = coltype(*args)
colargs= []
@@ -812,12 +813,12 @@ class MSSQLCompiler(ansisql.ANSICompiler):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
- s = select.distinct and "DISTINCT " or ""
- if select.limit:
- s += "TOP %s " % (select.limit,)
- if select.offset:
+ s = select._distinct and "DISTINCT " or ""
+ if select._limit:
+ s += "TOP %s " % (select._limit,)
+ if select._offset:
raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
@@ -825,49 +826,50 @@ class MSSQLCompiler(ansisql.ANSICompiler):
# Limit in mssql is after the select keyword
return ""
- def visit_table(self, table):
+ def _schema_aliased_table(self, table):
+ if getattr(table, 'schema', None) is not None:
+ if not self.tablealiases.has_key(table):
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_table(self, table, mssql_aliased=False, **kwargs):
+ if mssql_aliased:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
# alias schema-qualified tables
- if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
- alias = table.alias()
- self.tablealiases[table] = alias
- self.traverse(alias)
- self.froms[('alias', table)] = self.froms[table]
- for c in alias.c:
- self.traverse(c)
- self.traverse(alias.oid_column)
- self.tablealiases[alias] = self.froms[table]
- self.froms[table] = self.froms[alias]
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=True, **kwargs)
else:
- super(MSSQLCompiler, self).visit_table(table)
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
- def visit_alias(self, alias):
+ def visit_alias(self, alias, **kwargs):
# translate for schema-qualified table aliases
- if self.froms.has_key(('alias', alias.original)):
- self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
- self.strings[alias] = ""
- else:
- super(MSSQLCompiler, self).visit_alias(alias)
+ self.tablealiases[alias.original] = alias
+ return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
def visit_column(self, column):
- # translate for schema-qualified table aliases
- super(MSSQLCompiler, self).visit_column(column)
- if column.table is not None and self.tablealiases.has_key(column.table):
- self.strings[column] = \
- self.strings[self.tablealiases[column.table].corresponding_column(column)]
+ if column.table is not None:
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ return self.process(t.corresponding_column(column))
+ return super(MSSQLCompiler, self).visit_column(column)
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=':
- binary.left, binary.right = binary.right, binary.left
- super(MSSQLCompiler, self).visit_binary(binary)
-
- def visit_select(self, select):
- # label function calls, so they return a name in cursor.description
- for i,c in enumerate(select._raw_columns):
- if isinstance(c, sql._Function):
- select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
+ if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
+ return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+ else:
+ return super(MSSQLCompiler, self).visit_binary(binary)
- super(MSSQLCompiler, self).visit_select(select)
+ def label_select_column(self, select, column):
+ if isinstance(column, sql._Function):
+ return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
+ else:
+ return super(MSSQLCompiler, self).label_select_column(select, column)
function_rewrites = {'current_date': 'getdate',
'length': 'len',
@@ -881,10 +883,10 @@ class MSSQLCompiler(ansisql.ANSICompiler):
return ''
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.process(select._order_by_clause)
# MSSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not select.is_subquery or select.limit):
+ if order_by and (not self.is_subquery(select) or select._limit):
return " ORDER BY " + order_by
else:
return ""
@@ -916,10 +918,12 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote_identifier(index.name)))
+ self.preparer.quote_identifier(index.table.name),
+ self.preparer.quote_identifier(index.name)
+ ))
self.execute()
+
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
# TODO: does ms-sql have standalone sequences ?
pass
@@ -940,4 +944,3 @@ dialect = MSSQLDialect
-
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index bac0e5e12..26800e32b 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, datetime, inspect, warnings, weakref
+import re, datetime, inspect, warnings, weakref, operator
from sqlalchemy import sql, schema, ansisql
from sqlalchemy.engine import default
@@ -12,13 +12,13 @@ import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
import sqlalchemy.util as util
from array import array as _array
+from decimal import Decimal
try:
from threading import Lock
except ImportError:
from dummy_threading import Lock
-
RESERVED_WORDS = util.Set(
['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
@@ -60,7 +60,6 @@ RESERVED_WORDS = util.Set(
'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
'read_only', 'read_write', # 5.1
])
-
_per_connection_mutex = Lock()
class _NumericType(object):
@@ -137,7 +136,7 @@ class _StringType(object):
class MSNumeric(sqltypes.Numeric, _NumericType):
"""MySQL NUMERIC type"""
- def __init__(self, precision = 10, length = 2, **kw):
+ def __init__(self, precision = 10, length = 2, asdecimal=True, **kw):
"""Construct a NUMERIC.
precision
@@ -157,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
"""
_NumericType.__init__(self, **kw)
- sqltypes.Numeric.__init__(self, precision, length)
-
+ sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal)
+
def get_col_spec(self):
if self.precision is None:
return self._extend("NUMERIC")
else:
return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class MSDecimal(MSNumeric):
"""MySQL DECIMAL type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DECIMAL.
precision
@@ -187,7 +195,7 @@ class MSDecimal(MSNumeric):
underlying database API, which continue to be numeric.
"""
- super(MSDecimal, self).__init__(precision, length, **kw)
+ super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is None:
@@ -200,7 +208,7 @@ class MSDecimal(MSNumeric):
class MSDouble(MSNumeric):
"""MySQL DOUBLE type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DOUBLE.
precision
@@ -222,7 +230,7 @@ class MSDouble(MSNumeric):
if ((precision is None and length is not None) or
(precision is not None and length is None)):
raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.")
- super(MSDouble, self).__init__(precision, length, **kw)
+ super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is not None and self.length is not None:
@@ -235,7 +243,7 @@ class MSDouble(MSNumeric):
class MSFloat(sqltypes.Float, _NumericType):
"""MySQL FLOAT type"""
- def __init__(self, precision=10, length=None, **kw):
+ def __init__(self, precision=10, length=None, asdecimal=False, **kw):
"""Construct a FLOAT.
precision
@@ -257,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType):
if length is not None:
self.length=length
_NumericType.__init__(self, **kw)
- sqltypes.Float.__init__(self, precision)
+ sqltypes.Float.__init__(self, precision, asdecimal=asdecimal)
def get_col_spec(self):
if hasattr(self, 'length') and self.length is not None:
@@ -267,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType):
else:
return self._extend("FLOAT")
+ def convert_bind_param(self, value, dialect):
+ return value
+
+
class MSInteger(sqltypes.Integer, _NumericType):
"""MySQL INTEGER type"""
@@ -955,7 +967,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
+
+ def is_select(self):
+ return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None
+
class MySQLDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
@@ -1044,6 +1059,27 @@ class MySQLDialect(ansisql.ANSIDialect):
except:
pass
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("XA RECOVER"))
+ return [row['data'][0:row['gtrid_length']] for row in resultset]
+
def is_disconnect(self, e):
return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
@@ -1088,7 +1124,7 @@ class MySQLDialect(ansisql.ANSIDialect):
version.append(n)
return tuple(version)
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
"""Load column definitions from the server."""
decode_from = self._detect_charset(connection)
@@ -1111,6 +1147,9 @@ class MySQLDialect(ansisql.ANSIDialect):
# leave column names as unicode
name = name.decode(decode_from)
+
+ if include_columns and name not in include_columns:
+ continue
match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
col_type = match.group(1)
@@ -1118,7 +1157,11 @@ class MySQLDialect(ansisql.ANSIDialect):
extra_1 = match.group(3)
extra_2 = match.group(4)
- coltype = ischema_names.get(col_type, MSString)
+ try:
+ coltype = ischema_names[col_type]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name)))
+ coltype = sqltypes.NULLTYPE
kw = {}
if extra_1 is not None:
@@ -1156,7 +1199,6 @@ class MySQLDialect(ansisql.ANSIDialect):
if not row:
raise exceptions.NoSuchTableError(table.fullname)
desc = row[1].strip()
- row.close()
tabletype = ''
lastparen = re.search(r'\)[^\)]*\Z', desc)
@@ -1223,7 +1265,6 @@ class MySQLDialect(ansisql.ANSIDialect):
cs = True
else:
cs = row[1] in ('0', 'OFF' 'off')
- row.close()
cache['lower_case_table_names'] = cs
self.per_connection[raw_connection] = cache
return cache.get('lower_case_table_names')
@@ -1266,14 +1307,21 @@ class _MySQLPythonRowProxy(object):
class MySQLCompiler(ansisql.ANSICompiler):
- def visit_cast(self, cast):
-
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
+ operator.mod : '%%'
+ }
+ )
+
+ def visit_cast(self, cast, **kwargs):
if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
- return super(MySQLCompiler, self).visit_cast(cast)
+ return super(MySQLCompiler, self).visit_cast(cast, **kwargs)
else:
# so just skip the CAST altogether for now.
# TODO: put whatever MySQL does for CAST here.
- self.strings[cast] = self.strings[cast.clause]
+ return self.process(cast.clause)
def for_update_clause(self, select):
if select.for_update == 'read':
@@ -1283,20 +1331,15 @@ class MySQLCompiler(ansisql.ANSICompiler):
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
- # striaght from the MySQL docs, I kid you not
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ # straight from the MySQL docs, I kid you not
text += " \n LIMIT 18446744073709551615"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def binary_operator_string(self, binary):
- if binary.operator == '%':
- return '%%'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index 9d7d6a112..d3aa2e268 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -5,9 +5,9 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, re, warnings
+import re, warnings, operator
-from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
+from sqlalchemy import util, sql, schema, ansisql, exceptions, logging
from sqlalchemy.engine import default, base
import sqlalchemy.types as sqltypes
@@ -88,8 +88,11 @@ class OracleText(sqltypes.TEXT):
def convert_result_value(self, value, dialect):
if value is None:
return None
- else:
+ elif hasattr(value, 'read'):
+ # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str
return super(OracleText, self).convert_result_value(value.read(), dialect)
+ else:
+ return super(OracleText, self).convert_result_value(value, dialect)
class OracleRaw(sqltypes.Binary):
@@ -178,25 +181,31 @@ class OracleExecutionContext(default.DefaultExecutionContext):
super(OracleExecutionContext, self).pre_exec()
if self.dialect.auto_setinputsizes:
self.set_input_sizes()
+ if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list):
+ for key in self.compiled_parameters:
+ (bindparam, name, value) = self.compiled_parameters.get_parameter(key)
+ if bindparam.isoutparam:
+ dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+ if not hasattr(self, 'out_parameters'):
+ self.out_parameters = {}
+ self.out_parameters[name] = self.cursor.var(dbtype)
+ self.parameters[name] = self.out_parameters[name]
def get_result_proxy(self):
+ if hasattr(self, 'out_parameters'):
+ if self.compiled_parameters is not None:
+ for k in self.out_parameters:
+ type = self.compiled_parameters.get_type(k)
+ self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect)
+ else:
+ for k in self.out_parameters:
+ self.out_parameters[k] = self.out_parameters[k].getvalue()
+
if self.cursor.description is not None:
- if self.dialect.auto_convert_lobs and self.typemap is None:
- typemap = {}
- binary = False
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- binary = True
- typemap[column[0].lower()] = OracleBinary()
- self.typemap = typemap
- if binary:
+ for column in self.cursor.description:
+ type_code = column[1]
+ if type_code in self.dialect.ORACLE_BINARY_TYPES:
return base.BufferedColumnResultProxy(self)
- else:
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- return base.BufferedColumnResultProxy(self)
return base.ResultProxy(self)
@@ -208,11 +217,26 @@ class OracleDialect(ansisql.ANSIDialect):
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
self.auto_convert_lobs = auto_convert_lobs
+
if self.dbapi is not None:
self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
else:
self.ORACLE_BINARY_TYPES = []
+ def dbapi_type_map(self):
+ if self.dbapi is None or not self.auto_convert_lobs:
+ return {}
+ else:
+ return {
+ self.dbapi.NUMBER: OracleInteger(),
+ self.dbapi.CLOB: OracleText(),
+ self.dbapi.BLOB: OracleBinary(),
+ self.dbapi.STRING: OracleString(),
+ self.dbapi.TIMESTAMP: OracleTimestamp(),
+ self.dbapi.BINARY: OracleRaw(),
+ datetime.datetime: OracleDate()
+ }
+
def dbapi(cls):
import cx_Oracle
return cx_Oracle
@@ -251,7 +275,7 @@ class OracleDialect(ansisql.ANSIDialect):
return 30
def oid_column_name(self, column):
- if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select):
+ if not isinstance(column.table, (sql.TableClause, sql.Select)):
return None
else:
return "rowid"
@@ -341,7 +365,7 @@ class OracleDialect(ansisql.ANSIDialect):
return name, owner, dblink
raise
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
preparer = self.identifier_preparer
if not preparer.should_quote(table):
name = table.name.upper()
@@ -363,6 +387,13 @@ class OracleDialect(ansisql.ANSIDialect):
#print "ROW:" , row
(colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+ # if name comes back as all upper, assume its case folded
+ if (colname.upper() == colname):
+ colname = colname.lower()
+
+ if include_columns and colname not in include_columns:
+ continue
+
# INTEGER if the scale is 0 and precision is null
# NUMBER if the scale and precision are both null
# NUMBER(9,2) if the precision is 9 and the scale is 2
@@ -382,16 +413,13 @@ class OracleDialect(ansisql.ANSIDialect):
try:
coltype = ischema_names[coltype]
except KeyError:
- raise exceptions.AssertionError("Can't get coltype for type '%s' on colname '%s'" % (coltype, colname))
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname)))
+ coltype = sqltypes.NULLTYPE
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
- # if name comes back as all upper, assume its case folded
- if (colname.upper() == colname):
- colname = colname.lower()
-
table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
if not len(table.columns):
@@ -458,16 +486,27 @@ class OracleDialect(ansisql.ANSIDialect):
OracleDialect.logger = logging.class_logger(OracleDialect)
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = 'outer_join_column'
+ def __init__(self, column):
+ self.column = column
+
class OracleCompiler(ansisql.ANSICompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
the use_ansi flag is False.
"""
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+ }
+ )
+
def __init__(self, *args, **kwargs):
super(OracleCompiler, self).__init__(*args, **kwargs)
- # we have to modify SELECT objects a little bit, so store state here
- self._select_state = {}
+ self.__wheres = {}
def default_from(self):
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
@@ -480,49 +519,46 @@ class OracleCompiler(ansisql.ANSICompiler):
def apply_function_parens(self, func):
return len(func.clauses) > 0
- def visit_join(self, join):
+ def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return ansisql.ANSICompiler.visit_join(self, join)
-
- self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
- where = self.wheres.get(join.left, None)
+ return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+
+ (where, parentjoin) = self.__wheres.get(join, (None, None))
+
+ class VisitOn(sql.ClauseVisitor):
+ def visit_binary(s, binary):
+ if binary.operator == operator.eq:
+ if binary.left.table is join.right:
+ binary.left = _OuterJoinColumn(binary.left)
+ elif binary.right.table is join.right:
+ binary.right = _OuterJoinColumn(binary.right)
+
if where is not None:
- self.wheres[join] = sql.and_(where, join.onclause)
+ self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
else:
- self.wheres[join] = join.onclause
-# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
- self.strings[join] = self.froms[join]
-
- if join.isouter:
- # if outer join, push on the right side table as the current "outertable"
- self._outertable = join.right
-
- # now re-visit the onclause, which will be used as a where clause
- # (the first visit occured via the Join object itself right before it called visit_join())
- self.traverse(join.onclause)
-
- self._outertable = None
-
- self.wheres[join].accept_visitor(self)
+ self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join)
- def visit_insert_sequence(self, column, sequence, parameters):
- """This is the `sequence` equivalent to ``ANSICompiler``'s
- `visit_insert_column_default` which ensures that the column is
- present in the generated column list.
- """
-
- parameters.setdefault(column.key, None)
+ return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+ def get_whereclause(self, f):
+ if f in self.__wheres:
+ return self.__wheres[f][0]
+ else:
+ return None
+
+ def visit_outer_join_column(self, vc):
+ return self.process(vc.column) + "(+)"
+
+ def uses_sequences_for_inserts(self):
+ return True
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
-
- self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
- self.strings[alias] = self.get_str(alias.original)
-
- def visit_column(self, column):
- ansisql.ANSICompiler.visit_column(self, column)
- if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
- self.strings[column] = self.strings[column] + "(+)"
+
+ if asfrom:
+ return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name
+ else:
+ return self.process(alias.original, **kwargs)
def visit_insert(self, insert):
"""``INSERT`` s are required to have the primary keys be explicitly present.
@@ -539,76 +575,35 @@ class OracleCompiler(ansisql.ANSICompiler):
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+ pass
- if getattr(select, '_oracle_visit', False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_compound_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- select._oracle_visit = True
- # to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
- if not orderby:
- orderby = select.oid_column
- self.traverse(orderby)
- orderby = self.strings[orderby]
- class SelectVisitor(sql.NoColumnVisitor):
- def visit_select(self, select):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- SelectVisitor().traverse(select)
- limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
- else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
- else:
- ansisql.ANSICompiler.visit_compound_select(self, select)
-
- def visit_select(self, select):
+ def visit_select(self, select, **kwargs):
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
- # TODO: put a real copy-container on Select and copy, or somehow make this
- # not modify the Select statement
- if self._select_state.get((select, 'visit'), False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- self._select_state[(select, 'visit')] = True
+ if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
# to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
+ orderby = self.process(select._order_by_clause)
if not orderby:
orderby = select.oid_column
self.traverse(orderby)
- orderby = self.strings[orderby]
- if not hasattr(select, '_oracle_visit'):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- select._oracle_visit = True
+ orderby = self.process(orderby)
+
+ oldselect = select
+ select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
+ select._oracle_visit = True
+
limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
+ if select._offset is not None:
+ limitselect.append_whereclause("ora_rn>%d" % select._offset)
+ if select._limit is not None:
+ limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
+ limitselect.append_whereclause("ora_rn<=%d" % select._limit)
+ return self.process(limitselect)
else:
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""
@@ -619,12 +614,6 @@ class OracleCompiler(ansisql.ANSICompiler):
else:
return super(OracleCompiler, self).for_update_clause(select)
- def visit_binary(self, binary):
- if binary.operator == '%':
- self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right)))
- else:
- return ansisql.ANSICompiler.visit_binary(self, binary)
-
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -639,22 +628,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class OracleSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["DUAL"]).compile(engine=self.connection)
- return self.connection.execute_compiled(c).scalar()
+ c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
+ return self.connection.execute(c).scalar()
def visit_sequence(self, seq):
- return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
+ return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
dialect = OracleDialect
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index d3726fc1f..b192c4778 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,12 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, string, types, re, random, warnings
+import re, random, warnings, operator
-from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy import sql, schema, ansisql, exceptions
from sqlalchemy.engine import base, default
import sqlalchemy.types as sqltypes
from sqlalchemy.databases import information_schema as ischema
+from decimal import Decimal
try:
import mx.DateTime.DateTime as mxDateTime
@@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric):
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class PGFloat(sqltypes.Float):
def get_col_spec(self):
if not self.precision:
@@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float):
else:
return "FLOAT(%(precision)s)" % {'precision': self.precision}
+
class PGInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
@@ -47,74 +58,15 @@ class PGBigInteger(PGInteger):
def get_col_spec(self):
return "BIGINT"
-class PG2DateTime(sqltypes.DateTime):
+class PGDateTime(sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-class PG1DateTime(sqltypes.DateTime):
- def convert_bind_param(self, value, dialect):
- if value is not None:
- if isinstance(value, datetime.datetime):
- seconds = float(str(value.second) + "."
- + str(value.microsecond))
- mx_datetime = mxDateTime(value.year, value.month, value.day,
- value.hour, value.minute,
- seconds)
- return dialect.dbapi.TimestampFromMx(mx_datetime)
- return dialect.dbapi.TimestampFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- if value is None:
- return None
- second_parts = str(value.second).split(".")
- seconds = int(second_parts[0])
- microseconds = int(float(second_parts[1]))
- return datetime.datetime(value.year, value.month, value.day,
- value.hour, value.minute, seconds,
- microseconds)
-
- def get_col_spec(self):
- return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG2Date(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
-class PG1Date(sqltypes.Date):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return dialect.dbapi.DateFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGDate(sqltypes.Date):
def get_col_spec(self):
return "DATE"
-class PG2Time(sqltypes.Time):
- def get_col_spec(self):
- return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG1Time(sqltypes.Time):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return psycopg.TimeFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGTime(sqltypes.Time):
def get_col_spec(self):
return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
@@ -142,28 +94,55 @@ class PGBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
-pg2_colspecs = {
+class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
+ def __init__(self, item_type):
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+
+ def dialect_impl(self, dialect):
+ impl = self.__class__.__new__(self.__class__)
+ impl.__dict__.update(self.__dict__)
+ impl.item_type = self.item_type.dialect_impl(dialect)
+ return impl
+ def convert_bind_param(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, (list,tuple)):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_bind_param(item, dialect)
+ return [convert_item(item) for item in value]
+ def convert_result_value(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, list):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_result_value(item, dialect)
+ # Could specialcase when item_type.convert_result_value is the default identity func
+ return [convert_item(item) for item in value]
+ def get_col_spec(self):
+ return self.item_type.get_col_spec() + '[]'
+
+colspecs = {
sqltypes.Integer : PGInteger,
sqltypes.Smallinteger : PGSmallInteger,
sqltypes.Numeric : PGNumeric,
sqltypes.Float : PGFloat,
- sqltypes.DateTime : PG2DateTime,
- sqltypes.Date : PG2Date,
- sqltypes.Time : PG2Time,
+ sqltypes.DateTime : PGDateTime,
+ sqltypes.Date : PGDate,
+ sqltypes.Time : PGTime,
sqltypes.String : PGString,
sqltypes.Binary : PGBinary,
sqltypes.Boolean : PGBoolean,
sqltypes.TEXT : PGText,
sqltypes.CHAR: PGChar,
}
-pg1_colspecs = pg2_colspecs.copy()
-pg1_colspecs.update({
- sqltypes.DateTime : PG1DateTime,
- sqltypes.Date : PG1Date,
- sqltypes.Time : PG1Time
- })
-
-pg2_ischema_names = {
+
+ischema_names = {
'integer' : PGInteger,
'bigint' : PGBigInteger,
'smallint' : PGSmallInteger,
@@ -175,24 +154,17 @@ pg2_ischema_names = {
'real' : PGFloat,
'inet': PGInet,
'double precision' : PGFloat,
- 'timestamp' : PG2DateTime,
- 'timestamp with time zone' : PG2DateTime,
- 'timestamp without time zone' : PG2DateTime,
- 'time with time zone' : PG2Time,
- 'time without time zone' : PG2Time,
- 'date' : PG2Date,
- 'time': PG2Time,
+ 'timestamp' : PGDateTime,
+ 'timestamp with time zone' : PGDateTime,
+ 'timestamp without time zone' : PGDateTime,
+ 'time with time zone' : PGTime,
+ 'time without time zone' : PGTime,
+ 'date' : PGDate,
+ 'time': PGTime,
'bytea' : PGBinary,
'boolean' : PGBoolean,
'interval':PGInterval,
}
-pg1_ischema_names = pg2_ischema_names.copy()
-pg1_ischema_names.update({
- 'timestamp with time zone' : PG1DateTime,
- 'timestamp without time zone' : PG1DateTime,
- 'date' : PG1Date,
- 'time' : PG1Time
- })
def descriptor():
return {'name':'postgres',
@@ -206,11 +178,11 @@ def descriptor():
class PGExecutionContext(default.DefaultExecutionContext):
- def is_select(self):
- return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
-
+ def _is_server_side(self):
+ return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I)
+
def create_cursor(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c" + hex(random.randint(0, 65535))[2:]
@@ -219,7 +191,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return self.connection.connection.cursor()
def get_result_proxy(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
return base.BufferedRowResultProxy(self)
else:
return base.ResultProxy(self)
@@ -242,31 +214,18 @@ class PGDialect(ansisql.ANSIDialect):
ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
- if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
- self.version = 2
- else:
- self.version = 1
self.use_information_schema = use_information_schema
self.paramstyle = 'pyformat'
def dbapi(cls):
- try:
- import psycopg2 as psycopg
- except ImportError, e:
- try:
- import psycopg
- except ImportError, e2:
- raise e
+ import psycopg2 as psycopg
return psycopg
dbapi = classmethod(dbapi)
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
if opts.has_key('port'):
- if self.version == 2:
- opts['port'] = int(opts['port'])
- else:
- opts['port'] = str(opts['port'])
+ opts['port'] = int(opts['port'])
opts.update(url.query)
return ([], opts)
@@ -278,10 +237,7 @@ class PGDialect(ansisql.ANSIDialect):
return 63
def type_descriptor(self, typeobj):
- if self.version == 2:
- return sqltypes.adapt_type(typeobj, pg2_colspecs)
- else:
- return sqltypes.adapt_type(typeobj, pg1_colspecs)
+ return sqltypes.adapt_type(typeobj, colspecs)
def compiler(self, statement, bindparams, **kwargs):
return PGCompiler(self, statement, bindparams, **kwargs)
@@ -292,8 +248,36 @@ class PGDialect(ansisql.ANSIDialect):
def schemadropper(self, *args, **kwargs):
return PGSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, connection, **kwargs):
- return PGDefaultRunner(connection, **kwargs)
+ def do_begin_twophase(self, connection, xid):
+ self.do_begin(connection.connection)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
+ # Must find out a way how to make the dbapi not open a transaction.
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
+ return [row[0] for row in resultset]
+
+ def defaultrunner(self, context, **kwargs):
+ return PGDefaultRunner(context, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
@@ -351,14 +335,9 @@ class PGDialect(ansisql.ANSIDialect):
else:
return False
- def reflecttable(self, connection, table):
- if self.version == 2:
- ischema_names = pg2_ischema_names
- else:
- ischema_names = pg1_ischema_names
-
+ def reflecttable(self, connection, table, include_columns):
if self.use_information_schema:
- ischema.reflecttable(connection, table, ischema_names)
+ ischema.reflecttable(connection, table, include_columns, ischema_names)
else:
preparer = self.identifier_preparer
if table.schema is not None:
@@ -387,7 +366,7 @@ class PGDialect(ansisql.ANSIDialect):
ORDER BY a.attnum
""" % schema_where_clause
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
+ s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
c = connection.execute(s, table_name=table.name,
schema=table.schema)
rows = c.fetchall()
@@ -398,9 +377,13 @@ class PGDialect(ansisql.ANSIDialect):
domains = self._load_domains(connection)
for name, format_type, default, notnull, attnum, table_oid in rows:
+ if include_columns and name not in include_columns:
+ continue
+
## strip (30) from character varying(30)
- attype = re.search('([^\(]+)', format_type).group(1)
+ attype = re.search('([^\([]+)', format_type).group(1)
nullable = not notnull
+ is_array = format_type.endswith('[]')
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
@@ -453,6 +436,8 @@ class PGDialect(ansisql.ANSIDialect):
if coltype:
coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = PGArray(coltype)
else:
warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
coltype = sqltypes.NULLTYPE
@@ -517,7 +502,6 @@ class PGDialect(ansisql.ANSIDialect):
table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
def _load_domains(self, connection):
-
## Load data types for domains:
SQL_DOMAINS = """
SELECT t.typname as "name",
@@ -554,49 +538,46 @@ class PGDialect(ansisql.ANSIDialect):
-
class PGCompiler(ansisql.ANSICompiler):
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : '%%'
+ }
+ )
- def visit_insert_sequence(self, column, sequence, parameters):
- """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures
- that the column is present in the generated column list"""
- parameters.setdefault(column.key, None)
+ def uses_sequences_for_inserts(self):
+ return True
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT ALL"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def visit_select_precolumns(self, select):
- if select.distinct:
- if type(select.distinct) == bool:
+ def get_select_precolumns(self, select):
+ if select._distinct:
+ if type(select._distinct) == bool:
return "DISTINCT "
- if type(select.distinct) == list:
+ if type(select._distinct) == list:
dist_set = "DISTINCT ON ("
- for col in select.distinct:
+ for col in select._distinct:
dist_set += self.strings[col] + ", "
dist_set = dist_set[:-2] + ") "
return dist_set
- return "DISTINCT ON (" + str(select.distinct) + ") "
+ return "DISTINCT ON (" + str(select._distinct) + ") "
else:
return ""
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- elif binary.operator == '%':
- return '%%'
+ def for_update_clause(self, select):
+ if select.for_update == 'nowait':
+ return " FOR UPDATE NOWAIT"
else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
+ return super(PGCompiler, self).for_update_clause(select)
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -617,13 +598,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class PGSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
@@ -632,7 +613,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
- return self.connection.execute_text("select %s" % column.default.arg).scalar()
+ return self.connection.execute("select %s" % column.default.arg).scalar()
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
@@ -641,7 +622,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.connection.execute_text(exc).scalar()
+ return self.connection.execute(exc).scalar()
return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 816b1b76a..725ea23e2 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -5,9 +5,9 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, types, re
+import re
-from sqlalchemy import sql, engine, schema, ansisql, exceptions, pool, PassiveDefault
+from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault
import sqlalchemy.engine.default as default
import sqlalchemy.types as sqltypes
import datetime,time, warnings
@@ -126,6 +126,7 @@ colspecs = {
pragma_names = {
'INTEGER' : SLInteger,
+ 'INT' : SLInteger,
'SMALLINT' : SLSmallInteger,
'VARCHAR' : SLString,
'CHAR' : SLChar,
@@ -150,8 +151,9 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
- super(SQLiteExecutionContext, self).post_exec()
+
+ def is_select(self):
+ return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None
class SQLiteDialect(ansisql.ANSIDialect):
@@ -233,7 +235,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
return (row is not None)
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {})
found_table = False
while True:
@@ -244,6 +246,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
found_table = True
(name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5])
name = re.sub(r'^\"|\"$', '', name)
+ if include_columns and name not in include_columns:
+ continue
match = re.match(r'(\w+)(\(.*?\))?', type)
if match:
coltype = match.group(1)
@@ -253,7 +257,12 @@ class SQLiteDialect(ansisql.ANSIDialect):
args = ''
#print "coltype: " + repr(coltype) + " args: " + repr(args)
- coltype = pragma_names.get(coltype, SLString)
+ try:
+ coltype = pragma_names[coltype]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
+ coltype = sqltypes.NULLTYPE
+
if args is not None:
args = re.findall(r'(\d+)', args)
#print "args! " +repr(args)
@@ -318,21 +327,21 @@ class SQLiteDialect(ansisql.ANSIDialect):
class SQLiteCompiler(ansisql.ANSICompiler):
def visit_cast(self, cast):
if self.dialect.supports_cast:
- super(SQLiteCompiler, self).visit_cast(cast)
+ return super(SQLiteCompiler, self).visit_cast(cast)
else:
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
- self.strings[cast] = self.strings[cast.clause]
+ return self.process(cast.clause)
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
else:
text += " OFFSET 0"
return text
@@ -341,12 +350,6 @@ class SQLiteCompiler(ansisql.ANSICompiler):
# sqlite has no "FOR UPDATE" AFAICT
return ''
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 5a4865de0..50d03ea91 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -48,7 +48,6 @@ The package is represented among several individual modules, including:
from sqlalchemy import databases
from sqlalchemy.engine.base import *
from sqlalchemy.engine import strategies
-import re
def engine_descriptors():
"""Provide a listing of all the database implementations supported.
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index d0ca36515..fc4433a47 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -9,33 +9,10 @@ higher-level statement-construction, connection-management,
execution and result contexts."""
from sqlalchemy import exceptions, sql, schema, util, types, logging
-import StringIO, sys, re
+import StringIO, sys, re, random
-class ConnectionProvider(object):
- """Define an interface that returns raw Connection objects (or compatible)."""
-
- def get_connection(self):
- """Return a Connection or compatible object from a DBAPI which also contains a close() method.
-
- It is not defined what context this connection belongs to. It
- may be newly connected, returned from a pool, part of some
- other kind of context such as thread-local, or can be a fixed
- member of this object.
- """
-
- raise NotImplementedError()
-
- def dispose(self):
- """Release all resources corresponding to this ConnectionProvider.
-
- This includes any underlying connection pools.
- """
-
- raise NotImplementedError()
-
-
-class Dialect(sql.AbstractDialect):
+class Dialect(object):
"""Define the behavior of a specific database/DBAPI.
Any aspect of metadata definition, SQL query generation, execution,
@@ -70,11 +47,14 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def convert_compiled_params(self, parameters):
- """Build DBAPI execute arguments from a [sqlalchemy.sql#ClauseParameters] instance.
-
- Returns an array or dictionary suitable to pass directly to this ``Dialect`` instance's DBAPI's
- execute method.
+ def dbapi_type_map(self):
+ """return a mapping of DBAPI type objects present in this Dialect's DBAPI
+ mapped to TypeEngine implementations used by the dialect.
+
+ This is used to apply types to result sets based on the DBAPI types
+ present in cursor.description; it only takes effect for result sets against
+ textual statements where no explicit typemap was present. Constructed SQL statements
+ always have type information explicitly embedded.
"""
raise NotImplementedError()
@@ -149,11 +129,11 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def defaultrunner(self, connection, **kwargs):
+ def defaultrunner(self, execution_context):
"""Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
+ execution_context
+ a [sqlalchemy.engine#ExecutionContext] to use for statement execution
"""
@@ -168,11 +148,12 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns=None):
"""Load table description from the database.
Given a [sqlalchemy.engine#Connection] and a [sqlalchemy.schema#Table] object, reflect its
- columns and properties from the database.
+ columns and properties from the database. If include_columns (a list or set) is specified, limit the autoload
+ to the given column names.
"""
raise NotImplementedError()
@@ -222,6 +203,46 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
+ def do_savepoint(self, connection, name):
+ """Create a savepoint with the given name on a SQLAlchemy connection."""
+
+ raise NotImplementedError()
+
+ def do_rollback_to_savepoint(self, connection, name):
+ """Rollback a SQL Alchemy connection to the named savepoint."""
+
+ raise NotImplementedError()
+
+ def do_release_savepoint(self, connection, name):
+ """Release the named savepoint on a SQL Alchemy connection."""
+
+ raise NotImplementedError()
+
+ def do_begin_twophase(self, connection, xid):
+ """Begin a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_prepare_twophase(self, connection, xid):
+ """Prepare a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ """Rollback a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ """Commit a two phase transaction on the given connection."""
+
+ raise NotImplementedError()
+
+ def do_recover_twophase(self, connection):
+ """Recover list of uncommited prepared two phase transaction identifiers on the given connection."""
+
+ raise NotImplementedError()
+
def do_executemany(self, cursor, statement, parameters):
"""Provide an implementation of *cursor.executemany(statement, parameters)*."""
@@ -266,19 +287,18 @@ class ExecutionContext(object):
compiled
if passed to constructor, sql.Compiled object being executed
- compiled_parameters
- if passed to constructor, sql.ClauseParameters object
-
statement
string version of the statement to be executed. Is either
passed to the constructor, or must be created from the
sql.Compiled object by the time pre_exec() has completed.
parameters
- "raw" parameters suitable for direct execution by the
- dialect. Either passed to the constructor, or must be
- created from the sql.ClauseParameters object by the time
- pre_exec() has completed.
+ bind parameters passed to the execute() method. for
+ compiled statements, this is a dictionary or list
+ of dictionaries. for textual statements, it should
+ be in a format suitable for the dialect's paramstyle
+ (i.e. dict or list of dicts for non positional,
+ list or list of lists/tuples for positional).
The Dialect should provide an ExecutionContext via the
@@ -288,24 +308,28 @@ class ExecutionContext(object):
"""
def create_cursor(self):
- """Return a new cursor generated this ExecutionContext's connection."""
+ """Return a new cursor generated from this ExecutionContext's connection.
+
+ Some dialects may wish to change the behavior of connection.cursor(),
+ such as postgres which may return a PG "server side" cursor.
+ """
raise NotImplementedError()
- def pre_exec(self):
+ def pre_execution(self):
"""Called before an execution of a compiled statement.
- If compiled and compiled_parameters were passed to this
+ If a compiled statement was passed to this
ExecutionContext, the `statement` and `parameters` datamembers
must be initialized after this statement is complete.
"""
raise NotImplementedError()
- def post_exec(self):
+ def post_execution(self):
"""Called after the execution of a compiled statement.
- If compiled was passed to this ExecutionContext,
+ If a compiled statement was passed to this ExecutionContext,
the `last_insert_ids`, `last_inserted_params`, etc.
datamembers should be available after this method
completes.
@@ -313,8 +337,11 @@ class ExecutionContext(object):
raise NotImplementedError()
- def get_result_proxy(self):
- """return a ResultProxy corresponding to this ExecutionContext."""
+ def result(self):
+ """return a result object corresponding to this ExecutionContext.
+
+ Returns a ResultProxy."""
+
raise NotImplementedError()
def get_rowcount(self):
@@ -361,8 +388,88 @@ class ExecutionContext(object):
raise NotImplementedError()
+class Compiled(object):
+ """Represent a compiled SQL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to their underlying database dialect, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
+
+ def __init__(self, dialect, statement, parameters, bind=None):
+ """Construct a new ``Compiled`` object.
+
+ statement
+ ``ClauseElement`` to be compiled.
+
+ parameters
+ Optional dictionary indicating a set of bind parameters
+ specified with this ``Compiled`` object. These parameters
+ are the *default* values corresponding to the
+ ``ClauseElement``'s ``_BindParamClauses`` when the
+ ``Compiled`` is executed. In the case of an ``INSERT`` or
+ ``UPDATE`` statement, these parameters will also result in
+ the creation of new ``_BindParamClause`` objects for each
+ key and will also affect the generated column list in an
+ ``INSERT`` statement and the ``SET`` clauses of an
+ ``UPDATE`` statement. The keys of the parameter dictionary
+ can either be the string names of columns or
+ ``_ColumnClause`` objects.
+
+ bind
+ Optional Engine or Connection to compile this statement against.
+ """
+ self.dialect = dialect
+ self.statement = statement
+ self.parameters = parameters
+ self.bind = bind
+ self.can_execute = statement.supports_execution()
+
+ def compile(self):
+ """Produce the internal string representation of this element."""
+
+ raise NotImplementedError()
+
+ def __str__(self):
+ """Return the string text of the generated SQL statement."""
+
+ raise NotImplementedError()
+
+ def get_params(self, **params):
+ """Deprecated. use construct_params(). (supports unicode names)
+ """
+
+ return self.construct_params(params)
+
+ def construct_params(self, params):
+ """Return the bind params for this compiled object.
+
+ params is a dict of string/object pairs whos
+ values will override bind values compiled in
+ to the statement.
+ """
+ raise NotImplementedError()
+
+ def execute(self, *multiparams, **params):
+ """Execute this compiled object."""
+
+ e = self.bind
+ if e is None:
+ raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
+ return e._execute_compiled(self, multiparams, params)
+
+ def scalar(self, *multiparams, **params):
+ """Execute this compiled object and return the result's scalar value."""
+
+ return self.execute(*multiparams, **params).scalar()
+
-class Connectable(sql.Executor):
+class Connectable(object):
"""Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
def contextual_connect(self):
@@ -401,6 +508,7 @@ class Connection(Connectable):
self.__connection = connection or engine.raw_connection()
self.__transaction = None
self.__close_with_result = close_with_result
+ self.__savepoint_seq = 0
def _get_connection(self):
try:
@@ -408,13 +516,18 @@ class Connection(Connectable):
except AttributeError:
raise exceptions.InvalidRequestError("This Connection is closed")
+ def _branch(self):
+ """return a new Connection which references this Connection's
+ engine and connection; but does not have close_with_result enabled."""
+
+ return Connection(self.__engine, self.__connection)
+
engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
-
- def _create_transaction(self, parent):
- return Transaction(self, parent)
+ properties = property(lambda s: s._get_connection().properties,
+ doc="A set of per-DBAPI connection properties.")
def connect(self):
"""connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly."""
@@ -448,12 +561,34 @@ class Connection(Connectable):
self.__connection.detach()
- def begin(self):
+ def begin(self, nested=False):
if self.__transaction is None:
- self.__transaction = self._create_transaction(None)
- return self.__transaction
+ self.__transaction = RootTransaction(self)
+ elif nested:
+ self.__transaction = NestedTransaction(self, self.__transaction)
else:
- return self._create_transaction(self.__transaction)
+ return Transaction(self, self.__transaction)
+ return self.__transaction
+
+ def begin_nested(self):
+ return self.begin(nested=True)
+
+ def begin_twophase(self, xid=None):
+ if self.__transaction is not None:
+ raise exceptions.InvalidRequestError("Cannot start a two phase transaction when a transaction is already started.")
+ if xid is None:
+ xid = "_sa_%032x" % random.randint(0,2**128)
+ self.__transaction = TwoPhaseTransaction(self, xid)
+ return self.__transaction
+
+ def recover_twophase(self):
+ return self.__engine.dialect.do_recover_twophase(self)
+
+ def rollback_prepared(self, xid, recover=False):
+ self.__engine.dialect.do_rollback_twophase(self, xid, recover=recover)
+
+ def commit_prepared(self, xid, recover=False):
+ self.__engine.dialect.do_commit_twophase(self, xid, recover=recover)
def in_transaction(self):
return self.__transaction is not None
@@ -485,6 +620,45 @@ class Connection(Connectable):
raise exceptions.SQLError(None, None, e)
self.__transaction = None
+ def _savepoint_impl(self, name=None):
+ if name is None:
+ self.__savepoint_seq += 1
+ name = '__sa_savepoint_%s' % self.__savepoint_seq
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_savepoint(self, name)
+ return name
+
+ def _rollback_to_savepoint_impl(self, name, context):
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_rollback_to_savepoint(self, name)
+ self.__transaction = context
+
+ def _release_savepoint_impl(self, name, context):
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_release_savepoint(self, name)
+ self.__transaction = context
+
+ def _begin_twophase_impl(self, xid):
+ if self.__connection.is_valid:
+ self.__engine.dialect.do_begin_twophase(self, xid)
+
+ def _prepare_twophase_impl(self, xid):
+ if self.__connection.is_valid:
+ assert isinstance(self.__transaction, TwoPhaseTransaction)
+ self.__engine.dialect.do_prepare_twophase(self, xid)
+
+ def _rollback_twophase_impl(self, xid, is_prepared):
+ if self.__connection.is_valid:
+ assert isinstance(self.__transaction, TwoPhaseTransaction)
+ self.__engine.dialect.do_rollback_twophase(self, xid, is_prepared)
+ self.__transaction = None
+
+ def _commit_twophase_impl(self, xid, is_prepared):
+ if self.__connection.is_valid:
+ assert isinstance(self.__transaction, TwoPhaseTransaction)
+ self.__engine.dialect.do_commit_twophase(self, xid, is_prepared)
+ self.__transaction = None
+
def _autocommit(self, statement):
"""When no Transaction is present, this is called after executions to provide "autocommit" behavior."""
# TODO: have the dialect determine if autocommit can be set on the connection directly without this
@@ -495,7 +669,7 @@ class Connection(Connectable):
def _autorollback(self):
if not self.in_transaction():
self._rollback_impl()
-
+
def close(self):
try:
c = self.__connection
@@ -514,74 +688,66 @@ class Connection(Connectable):
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
if c in Connection.executors:
- return Connection.executors[c](self, object, *multiparams, **params)
+ return Connection.executors[c](self, object, multiparams, params)
else:
raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
- def execute_default(self, default, **kwargs):
- return default.accept_visitor(self.__engine.dialect.defaultrunner(self))
+ def _execute_default(self, default, multiparams=None, params=None):
+ return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
+
+ def _execute_text(self, statement, multiparams, params):
+ parameters = self.__distill_params(multiparams, params)
+ context = self.__create_execution_context(statement=statement, parameters=parameters)
+ self.__execute_raw(context)
+ return context.result()
- def execute_text(self, statement, *multiparams, **params):
- if len(multiparams) == 0:
+ def __distill_params(self, multiparams, params):
+ if multiparams is None or len(multiparams) == 0:
parameters = params or None
- elif len(multiparams) == 1 and (isinstance(multiparams[0], list) or isinstance(multiparams[0], tuple) or isinstance(multiparams[0], dict)):
+ elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)):
parameters = multiparams[0]
else:
parameters = list(multiparams)
- context = self._create_execution_context(statement=statement, parameters=parameters)
- self._execute_raw(context)
- return context.get_result_proxy()
-
- def _params_to_listofdicts(self, *multiparams, **params):
- if len(multiparams) == 0:
- return [params]
- elif len(multiparams) == 1:
- if multiparams[0] == None:
- return [{}]
- elif isinstance (multiparams[0], list) or isinstance (multiparams[0], tuple):
- return multiparams[0]
- else:
- return [multiparams[0]]
- else:
- return multiparams
-
- def execute_function(self, func, *multiparams, **params):
- return self.execute_clauseelement(func.select(), *multiparams, **params)
+ return parameters
+
+ def _execute_function(self, func, multiparams, params):
+ return self._execute_clauseelement(func.select(), multiparams, params)
- def execute_clauseelement(self, elem, *multiparams, **params):
- executemany = len(multiparams) > 0
+ def _execute_clauseelement(self, elem, multiparams=None, params=None):
+ executemany = multiparams is not None and len(multiparams) > 0
if executemany:
param = multiparams[0]
else:
param = params
- return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params)
+ return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param), multiparams, params)
- def execute_compiled(self, compiled, *multiparams, **params):
+ def _execute_compiled(self, compiled, multiparams=None, params=None):
"""Execute a sql.Compiled object."""
if not compiled.can_execute:
raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
- parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
- if len(parameters) == 1:
- parameters = parameters[0]
- context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters)
- context.pre_exec()
- self._execute_raw(context)
- context.post_exec()
- return context.get_result_proxy()
-
- def _create_execution_context(self, **kwargs):
+
+ params = self.__distill_params(multiparams, params)
+ context = self.__create_execution_context(compiled=compiled, parameters=params)
+
+ context.pre_execution()
+ self.__execute_raw(context)
+ context.post_execution()
+ return context.result()
+
+ def __create_execution_context(self, **kwargs):
return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
- def _execute_raw(self, context):
- self.__engine.logger.info(context.statement)
- self.__engine.logger.info(repr(context.parameters))
- if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], tuple) or isinstance(context.parameters[0], dict)):
- self._executemany(context)
+ def __execute_raw(self, context):
+ if logging.is_info_enabled(self.__engine.logger):
+ self.__engine.logger.info(context.statement)
+ self.__engine.logger.info(repr(context.parameters))
+ if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
+ self.__executemany(context)
else:
- self._execute(context)
+ self.__execute(context)
self._autocommit(context.statement)
- def _execute(self, context):
+ def __execute(self, context):
if context.parameters is None:
if context.dialect.positional:
context.parameters = ()
@@ -592,19 +758,19 @@ class Connection(Connectable):
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
- self.engine.connection_provider.dispose()
+ self.engine.dispose()
self._autorollback()
if self.__close_with_result:
self.close()
raise exceptions.SQLError(context.statement, context.parameters, e)
- def _executemany(self, context):
+ def __executemany(self, context):
try:
context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
- self.engine.connection_provider.dispose()
+ self.engine.dispose()
self._autorollback()
if self.__close_with_result:
self.close()
@@ -612,11 +778,11 @@ class Connection(Connectable):
# poor man's multimethod/generic function thingy
executors = {
- sql._Function : execute_function,
- sql.ClauseElement : execute_clauseelement,
- sql.ClauseVisitor : execute_compiled,
- schema.SchemaItem:execute_default,
- str.__mro__[-2] : execute_text
+ sql._Function : _execute_function,
+ sql.ClauseElement : _execute_clauseelement,
+ sql.ClauseVisitor : _execute_compiled,
+ schema.SchemaItem:_execute_default,
+ str.__mro__[-2] : _execute_text
}
def create(self, entity, **kwargs):
@@ -629,10 +795,10 @@ class Connection(Connectable):
return self.__engine.drop(entity, connection=self, **kwargs)
- def reflecttable(self, table, **kwargs):
+ def reflecttable(self, table, include_columns=None):
"""Reflect the columns in the given string table name from the database."""
- return self.__engine.reflecttable(table, connection=self, **kwargs)
+ return self.__engine.reflecttable(table, self, include_columns)
def default_schema_name(self):
return self.__engine.dialect.get_default_schema_name(self)
@@ -647,39 +813,90 @@ class Transaction(object):
"""
def __init__(self, connection, parent):
- self.__connection = connection
- self.__parent = parent or self
- self.__is_active = True
- if self.__parent is self:
- self.__connection._begin_impl()
+ self._connection = connection
+ self._parent = parent or self
+ self._is_active = True
- connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction")
- is_active = property(lambda s:s.__is_active)
+ connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction")
+ is_active = property(lambda s:s._is_active)
def rollback(self):
- if not self.__parent.__is_active:
+ if not self._parent._is_active:
return
- if self.__parent is self:
- self.__connection._rollback_impl()
- self.__is_active = False
- else:
- self.__parent.rollback()
+ self._is_active = False
+ self._do_rollback()
+
+ def _do_rollback(self):
+ self._parent.rollback()
def commit(self):
- if not self.__parent.__is_active:
+ if not self._parent._is_active:
raise exceptions.InvalidRequestError("This transaction is inactive")
- if self.__parent is self:
- self.__connection._commit_impl()
- self.__is_active = False
+ self._is_active = False
+ self._do_commit()
+
+ def _do_commit(self):
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ if type is None and self._is_active:
+ self.commit()
+ else:
+ self.rollback()
+
+class RootTransaction(Transaction):
+ def __init__(self, connection):
+ super(RootTransaction, self).__init__(connection, None)
+ self._connection._begin_impl()
+
+ def _do_rollback(self):
+ self._connection._rollback_impl()
+
+ def _do_commit(self):
+ self._connection._commit_impl()
+
+class NestedTransaction(Transaction):
+ def __init__(self, connection, parent):
+ super(NestedTransaction, self).__init__(connection, parent)
+ self._savepoint = self._connection._savepoint_impl()
+
+ def _do_rollback(self):
+ self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent)
+
+ def _do_commit(self):
+ self._connection._release_savepoint_impl(self._savepoint, self._parent)
+
+class TwoPhaseTransaction(Transaction):
+ def __init__(self, connection, xid):
+ super(TwoPhaseTransaction, self).__init__(connection, None)
+ self._is_prepared = False
+ self.xid = xid
+ self._connection._begin_twophase_impl(self.xid)
+
+ def prepare(self):
+ if not self._parent._is_active:
+ raise exceptions.InvalidRequestError("This transaction is inactive")
+ self._connection._prepare_twophase_impl(self.xid)
+ self._is_prepared = True
+
+ def _do_rollback(self):
+ self._connection._rollback_twophase_impl(self.xid, self._is_prepared)
+
+ def commit(self):
+ self._connection._commit_twophase_impl(self.xid, self._is_prepared)
class Engine(Connectable):
"""
- Connects a ConnectionProvider, a Dialect and a CompilerFactory together to
+ Connects a Pool, a Dialect and a CompilerFactory together to
provide a default implementation of SchemaEngine.
"""
- def __init__(self, connection_provider, dialect, echo=None):
- self.connection_provider = connection_provider
+ def __init__(self, pool, dialect, url, echo=None):
+ self.pool = pool
+ self.url = url
self._dialect=dialect
self.echo = echo
self.logger = logging.instance_logger(self)
@@ -688,10 +905,13 @@ class Engine(Connectable):
engine = property(lambda s:s)
dialect = property(lambda s:s._dialect, doc="the [sqlalchemy.engine#Dialect] in use by this engine.")
echo = logging.echo_property()
- url = property(lambda s:s.connection_provider.url, doc="The [sqlalchemy.engine.url#URL] object representing this ``Engine`` object's datasource.")
+
+ def __repr__(self):
+ return 'Engine(%s)' % str(self.url)
def dispose(self):
- self.connection_provider.dispose()
+ self.pool.dispose()
+ self.pool = self.pool.recreate()
def create(self, entity, connection=None, **kwargs):
"""Create a table or index within this engine's database connection given a schema.Table object."""
@@ -703,22 +923,22 @@ class Engine(Connectable):
self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
- def execute_default(self, default, **kwargs):
+ def _execute_default(self, default):
connection = self.contextual_connect()
try:
- return connection.execute_default(default, **kwargs)
+ return connection._execute_default(default)
finally:
connection.close()
def _func(self):
- return sql._FunctionGenerator(engine=self)
+ return sql._FunctionGenerator(bind=self)
func = property(_func)
def text(self, text, *args, **kwargs):
"""Return a sql.text() object for performing literal queries."""
- return sql.text(text, engine=self, *args, **kwargs)
+ return sql.text(text, bind=self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
if connection is None:
@@ -726,7 +946,7 @@ class Engine(Connectable):
else:
conn = connection
try:
- element.accept_visitor(visitorcallable(conn, **kwargs))
+ visitorcallable(conn, **kwargs).traverse(element)
finally:
if connection is None:
conn.close()
@@ -775,12 +995,12 @@ class Engine(Connectable):
def scalar(self, statement, *multiparams, **params):
return self.execute(statement, *multiparams, **params).scalar()
- def execute_compiled(self, compiled, *multiparams, **params):
+ def _execute_compiled(self, compiled, multiparams, params):
connection = self.contextual_connect(close_with_result=True)
- return connection.execute_compiled(compiled, *multiparams, **params)
+ return connection._execute_compiled(compiled, multiparams, params)
def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, engine=self, **kwargs)
+ return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
@@ -795,7 +1015,7 @@ class Engine(Connectable):
return Connection(self, close_with_result=close_with_result, **kwargs)
- def reflecttable(self, table, connection=None):
+ def reflecttable(self, table, connection=None, include_columns=None):
"""Given a Table object, reflects its columns and properties from the database."""
if connection is None:
@@ -803,7 +1023,7 @@ class Engine(Connectable):
else:
conn = connection
try:
- self.dialect.reflecttable(conn, table)
+ self.dialect.reflecttable(conn, table, include_columns)
finally:
if connection is None:
conn.close()
@@ -814,7 +1034,7 @@ class Engine(Connectable):
def raw_connection(self):
"""Return a DBAPI connection."""
- return self.connection_provider.get_connection()
+ return self.pool.connect()
def log(self, msg):
"""Log a message using this SQLEngine's logger stream."""
@@ -858,28 +1078,42 @@ class ResultProxy(object):
self.closed = False
self.cursor = context.cursor
self.__echo = logging.is_debug_enabled(context.engine.logger)
- self._init_metadata()
-
- rowcount = property(lambda s:s.context.get_rowcount())
- connection = property(lambda s:s.context.connection)
+ if context.is_select():
+ self._init_metadata()
+ self._rowcount = None
+ else:
+ self._rowcount = context.get_rowcount()
+ self.close()
+
+ connection = property(lambda self:self.context.connection)
+ def _get_rowcount(self):
+ if self._rowcount is not None:
+ return self._rowcount
+ else:
+ return self.context.get_rowcount()
+ rowcount = property(_get_rowcount)
lastrowid = property(lambda s:s.cursor.lastrowid)
+ out_parameters = property(lambda s:s.context.out_parameters)
def _init_metadata(self):
if hasattr(self, '_ResultProxy__props'):
return
- self.__key_cache = {}
self.__props = {}
+ self._key_cache = self._create_key_cache()
self.__keys = []
metadata = self.cursor.description
if metadata is not None:
+ typemap = self.dialect.dbapi_type_map()
+
for i, item in enumerate(metadata):
# sqlite possibly prepending table name to colnames so strip
- colname = item[0].split('.')[-1]
+ colname = self.dialect.decode_result_columnname(item[0].split('.')[-1])
if self.context.typemap is not None:
- type = self.context.typemap.get(colname.lower(), types.NULLTYPE)
+ type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
else:
- type = types.NULLTYPE
+ type = typemap.get(item[1], types.NULLTYPE)
+
rec = (type, type.dialect_impl(self.dialect), i)
if rec[0] is None:
@@ -889,6 +1123,33 @@ class ResultProxy(object):
self.__keys.append(colname)
self.__props[i] = rec
+ if self.__echo:
+ self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata])))
+
+ def _create_key_cache(self):
+ # local copies to avoid circular ref against 'self'
+ props = self.__props
+ context = self.context
+ def lookup_key(key):
+ """Given a key, which could be a ColumnElement, string, etc.,
+ matches it to the appropriate key we got from the result set's
+ metadata; then cache it locally for quick re-access."""
+
+ if isinstance(key, int) and key in props:
+ rec = props[key]
+ elif isinstance(key, basestring) and key.lower() in props:
+ rec = props[key.lower()]
+ elif isinstance(key, sql.ColumnElement):
+ label = context.column_labels.get(key._label, key.name).lower()
+ if label in props:
+ rec = props[label]
+
+ if not "rec" in locals():
+ raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
+
+ return rec
+ return util.PopulateDict(lookup_key)
+
def close(self):
"""Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
@@ -904,38 +1165,12 @@ class ResultProxy(object):
self.cursor.close()
if self.connection.should_close_with_result:
self.connection.close()
-
- def _convert_key(self, key):
- """Convert and cache a key.
-
- Given a key, which could be a ColumnElement, string, etc.,
- matches it to the appropriate key we got from the result set's
- metadata; then cache it locally for quick re-access.
- """
-
- if key in self.__key_cache:
- return self.__key_cache[key]
- else:
- if isinstance(key, int) and key in self.__props:
- rec = self.__props[key]
- elif isinstance(key, basestring) and key.lower() in self.__props:
- rec = self.__props[key.lower()]
- elif isinstance(key, sql.ColumnElement):
- label = self.context.column_labels.get(key._label, key.name).lower()
- if label in self.__props:
- rec = self.__props[label]
-
- if not "rec" in locals():
- raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
-
- self.__key_cache[key] = rec
- return rec
keys = property(lambda s:s.__keys)
def _has_key(self, row, key):
try:
- self._convert_key(key)
+ self._key_cache[key]
return True
except KeyError:
return False
@@ -989,7 +1224,7 @@ class ResultProxy(object):
return self.context.supports_sane_rowcount()
def _get_col(self, row, key):
- rec = self._convert_key(key)
+ rec = self._key_cache[key]
return rec[1].convert_result_value(row[rec[2]], self.dialect)
def _fetchone_impl(self):
@@ -1101,7 +1336,7 @@ class BufferedColumnResultProxy(ResultProxy):
"""
def _get_col(self, row, key):
- rec = self._convert_key(key)
+ rec = self._key_cache[key]
return row[rec[2]]
def _process_row(self, row):
@@ -1152,6 +1387,9 @@ class RowProxy(object):
self.__parent.close()
+ def __contains__(self, key):
+ return self.__parent._has_key(self.__row, key)
+
def __iter__(self):
for i in range(0, len(self.__row)):
yield self.__parent._get_col(self.__row, i)
@@ -1168,7 +1406,11 @@ class RowProxy(object):
return self.__parent._has_key(self.__row, key)
def __getitem__(self, key):
- return self.__parent._get_col(self.__row, key)
+ if isinstance(key, slice):
+ indices = key.indices(len(self))
+ return tuple([self.__parent._get_col(self.__row, i) for i in range(*indices)])
+ else:
+ return self.__parent._get_col(self.__row, key)
def __getattr__(self, name):
try:
@@ -1226,19 +1468,22 @@ class DefaultRunner(schema.SchemaVisitor):
DefaultRunner to allow database-specific behavior.
"""
- def __init__(self, connection):
- self.connection = connection
- self.dialect = connection.dialect
+ def __init__(self, context):
+ self.context = context
+ # branch the connection so it doesnt close after result
+ self.connection = context.connection._branch()
+ dialect = property(lambda self:self.context.dialect)
+
def get_column_default(self, column):
if column.default is not None:
- return column.default.accept_visitor(self)
+ return self.traverse_single(column.default)
else:
return None
def get_column_onupdate(self, column):
if column.onupdate is not None:
- return column.onupdate.accept_visitor(self)
+ return self.traverse_single(column.onupdate)
else:
return None
@@ -1260,14 +1505,14 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def exec_default_sql(self, default):
- c = sql.select([default.arg]).compile(engine=self.connection)
- return self.connection.execute_compiled(c).scalar()
+ c = sql.select([default.arg]).compile(bind=self.connection)
+ return self.connection._execute_compiled(c).scalar()
def visit_column_onupdate(self, onupdate):
if isinstance(onupdate.arg, sql.ClauseElement):
return self.exec_default_sql(onupdate)
elif callable(onupdate.arg):
- return onupdate.arg()
+ return onupdate.arg(self.context)
else:
return onupdate.arg
@@ -1275,6 +1520,6 @@ class DefaultRunner(schema.SchemaVisitor):
if isinstance(default.arg, sql.ClauseElement):
return self.exec_default_sql(default)
elif callable(default.arg):
- return default.arg()
+ return default.arg(self.context)
else:
return default.arg
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 95f6566e3..962e2ab60 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -4,25 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Provide default implementations of per-dialect sqlalchemy.engine classes"""
-from sqlalchemy import schema, exceptions, util, sql, types
-import StringIO, sys, re
+from sqlalchemy import schema, exceptions, sql, types
+import sys, re
from sqlalchemy.engine import base
-"""Provide default implementations of the engine interfaces"""
-class PoolConnectionProvider(base.ConnectionProvider):
- def __init__(self, url, pool):
- self.url = url
- self._pool = pool
-
- def get_connection(self):
- return self._pool.connect()
-
- def dispose(self):
- self._pool.dispose()
- self._pool = self._pool.recreate()
-
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
@@ -33,7 +21,18 @@ class DefaultDialect(base.Dialect):
self._ischema = None
self.dbapi = dbapi
self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
-
+
+ def decode_result_columnname(self, name):
+ """decode a name found in cursor.description to a unicode object."""
+
+ return name.decode(self.encoding)
+
+ def dbapi_type_map(self):
+ # most DBAPIs have problems with this (such as, psycocpg2 types
+ # are unhashable). So far Oracle can return it.
+
+ return {}
+
def create_execution_context(self, **kwargs):
return DefaultExecutionContext(self, **kwargs)
@@ -88,6 +87,15 @@ class DefaultDialect(base.Dialect):
#print "ENGINE COMMIT ON ", connection.connection
connection.commit()
+
+ def do_savepoint(self, connection, name):
+ connection.execute(sql.SavepointClause(name))
+
+ def do_rollback_to_savepoint(self, connection, name):
+ connection.execute(sql.RollbackToSavepointClause(name))
+
+ def do_release_savepoint(self, connection, name):
+ connection.execute(sql.ReleaseSavepointClause(name))
def do_executemany(self, cursor, statement, parameters, **kwargs):
cursor.executemany(statement, parameters)
@@ -95,8 +103,8 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, connection):
- return base.DefaultRunner(connection)
+ def defaultrunner(self, context):
+ return base.DefaultRunner(context)
def is_disconnect(self, e):
return False
@@ -107,23 +115,6 @@ class DefaultDialect(base.Dialect):
paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
- def convert_compiled_params(self, parameters):
- executemany = parameters is not None and isinstance(parameters, list)
- # the bind params are a CompiledParams object. but all the DBAPI's hate
- # that object (or similar). so convert it to a clean
- # dictionary/list/tuple of dictionary/tuple of list
- if parameters is not None:
- if self.positional:
- if executemany:
- parameters = [p.get_raw_list() for p in parameters]
- else:
- parameters = parameters.get_raw_list()
- else:
- if executemany:
- parameters = [p.get_raw_dict() for p in parameters]
- else:
- parameters = parameters.get_raw_dict()
- return parameters
def _figure_paramstyle(self, paramstyle=None, default='named'):
if paramstyle is not None:
@@ -152,29 +143,38 @@ class DefaultDialect(base.Dialect):
ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
+ def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
self.dialect = dialect
self.connection = connection
self.compiled = compiled
- self.compiled_parameters = compiled_parameters
if compiled is not None:
self.typemap = compiled.typemap
self.column_labels = compiled.column_labels
self.statement = unicode(compiled)
- else:
+ if parameters is None:
+ self.compiled_parameters = compiled.construct_params({})
+ elif not isinstance(parameters, (list, tuple)):
+ self.compiled_parameters = compiled.construct_params(parameters)
+ else:
+ self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters]
+ if len(self.compiled_parameters) == 1:
+ self.compiled_parameters = self.compiled_parameters[0]
+ elif statement is not None:
self.typemap = self.column_labels = None
- self.parameters = self._encode_param_keys(parameters)
+ self.parameters = self.__encode_param_keys(parameters)
self.statement = statement
-
- if not dialect.supports_unicode_statements():
+ else:
+ self.statement = None
+
+ if self.statement is not None and not dialect.supports_unicode_statements():
self.statement = self.statement.encode(self.dialect.encoding)
self.cursor = self.create_cursor()
engine = property(lambda s:s.connection.engine)
- def _encode_param_keys(self, params):
+ def __encode_param_keys(self, params):
"""apply string encoding to the keys of dictionary-based bind parameters"""
if self.dialect.positional or self.dialect.supports_unicode_statements():
return params
@@ -189,16 +189,46 @@ class DefaultExecutionContext(base.ExecutionContext):
return [proc(d) for d in params]
else:
return proc(params)
+
+ def __convert_compiled_params(self, parameters):
+ executemany = parameters is not None and isinstance(parameters, list)
+ encode = not self.dialect.supports_unicode_statements()
+ # the bind params are a CompiledParams object. but all the DBAPI's hate
+ # that object (or similar). so convert it to a clean
+ # dictionary/list/tuple of dictionary/tuple of list
+ if parameters is not None:
+ if self.dialect.positional:
+ if executemany:
+ parameters = [p.get_raw_list() for p in parameters]
+ else:
+ parameters = parameters.get_raw_list()
+ else:
+ if executemany:
+ parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters]
+ else:
+ parameters = parameters.get_raw_dict(encode_keys=encode)
+ return parameters
def is_select(self):
- return re.match(r'SELECT', self.statement.lstrip(), re.I)
+ """return TRUE if the statement is expected to have result rows."""
+
+ return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None
def create_cursor(self):
return self.connection.connection.cursor()
-
+
+ def pre_execution(self):
+ self.pre_exec()
+
+ def post_execution(self):
+ self.post_exec()
+
+ def result(self):
+ return self.get_result_proxy()
+
def pre_exec(self):
self._process_defaults()
- self.parameters = self._encode_param_keys(self.dialect.convert_compiled_params(self.compiled_parameters))
+ self.parameters = self.__convert_compiled_params(self.compiled_parameters)
def post_exec(self):
pass
@@ -241,7 +271,7 @@ class DefaultExecutionContext(base.ExecutionContext):
inputsizes = []
for params in plist[0:1]:
for key in params.positional:
- typeengine = params.binds[key].type
+ typeengine = params.get_type(key)
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes.append(dbtype)
@@ -250,36 +280,23 @@ class DefaultExecutionContext(base.ExecutionContext):
inputsizes = {}
for params in plist[0:1]:
for key in params.keys():
- typeengine = params.binds[key].type
+ typeengine = params.get_type(key)
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes[key] = dbtype
self.cursor.setinputsizes(**inputsizes)
def _process_defaults(self):
- """``INSERT`` and ``UPDATE`` statements, when compiled, may
- have additional columns added to their ``VALUES`` and ``SET``
- lists corresponding to column defaults/onupdates that are
- present on the ``Table`` object (i.e. ``ColumnDefault``,
- ``Sequence``, ``PassiveDefault``). This method pre-execs
- those ``DefaultGenerator`` objects that require pre-execution
- and sets their values within the parameter list, and flags this
- ExecutionContext about ``PassiveDefault`` objects that may
- require post-fetching the row after it is inserted/updated.
-
- This method relies upon logic within the ``ANSISQLCompiler``
- in its `visit_insert` and `visit_update` methods that add the
- appropriate column clauses to the statement when its being
- compiled, so that these parameters can be bound to the
- statement.
- """
+ """generate default values for compiled insert/update statements,
+ and generate last_inserted_ids() collection."""
+ # TODO: cleanup
if self.compiled.isinsert:
if isinstance(self.compiled_parameters, list):
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
@@ -319,7 +336,7 @@ class DefaultExecutionContext(base.ExecutionContext):
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
# check the "onupdate" status of each column in the table
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index 7d85de9ad..0c59ee8eb 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -4,13 +4,11 @@ By default there are two, one which is the "thread-local" strategy,
one which is the "plain" strategy.
New strategies can be added via constructing a new EngineStrategy
-object which will add itself to the list of available strategies here,
-or replace one of the existing name. this can be accomplished via a
-mod; see the sqlalchemy/mods package for details.
+object which will add itself to the list of available strategies.
"""
-from sqlalchemy.engine import base, default, threadlocal, url
+from sqlalchemy.engine import base, threadlocal, url
from sqlalchemy import util, exceptions
from sqlalchemy import pool as poollib
@@ -92,8 +90,6 @@ class DefaultEngineStrategy(EngineStrategy):
else:
pool = pool
- provider = self.get_pool_provider(u, pool)
-
# create engine.
engineclass = self.get_engine_cls()
engine_args = {}
@@ -105,14 +101,11 @@ class DefaultEngineStrategy(EngineStrategy):
if len(kwargs):
raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__))
- return engineclass(provider, dialect, **engine_args)
+ return engineclass(pool, dialect, u, **engine_args)
def pool_threadlocal(self):
raise NotImplementedError()
- def get_pool_provider(self, url, pool):
- raise NotImplementedError()
-
def get_engine_cls(self):
raise NotImplementedError()
@@ -123,9 +116,6 @@ class PlainEngineStrategy(DefaultEngineStrategy):
def pool_threadlocal(self):
return False
- def get_pool_provider(self, url, pool):
- return default.PoolConnectionProvider(url, pool)
-
def get_engine_cls(self):
return base.Engine
@@ -138,9 +128,6 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy):
def pool_threadlocal(self):
return True
- def get_pool_provider(self, url, pool):
- return threadlocal.TLocalConnectionProvider(url, pool)
-
def get_engine_cls(self):
return threadlocal.TLEngine
@@ -195,4 +182,4 @@ class MockEngineStrategy(EngineStrategy):
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
-MockEngineStrategy() \ No newline at end of file
+MockEngineStrategy()
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py
index 2bbb1ed43..b6ba54ea5 100644
--- a/lib/sqlalchemy/engine/threadlocal.py
+++ b/lib/sqlalchemy/engine/threadlocal.py
@@ -1,8 +1,7 @@
-from sqlalchemy import schema, exceptions, util, sql, types
-import StringIO, sys, re
-from sqlalchemy.engine import base, default
+from sqlalchemy import util
+from sqlalchemy.engine import base
-"""Provide a thread-local transactional wrapper around the basic ComposedSQLEngine.
+"""Provide a thread-local transactional wrapper around the root Engine class.
Multiple calls to engine.connect() will return the same connection for
the same thread. also provides begin/commit methods on the engine
@@ -70,11 +69,8 @@ class TLConnection(base.Connection):
self.__opencount += 1
return self
- def _create_transaction(self, parent):
- return TLTransaction(self, parent)
-
def _begin(self):
- return base.Connection.begin(self)
+ return TLTransaction(self)
def in_transaction(self):
return self.session.in_transaction()
@@ -91,7 +87,7 @@ class TLConnection(base.Connection):
self.__opencount = 0
base.Connection.close(self)
-class TLTransaction(base.Transaction):
+class TLTransaction(base.RootTransaction):
def _commit_impl(self):
base.Transaction.commit(self)
@@ -112,7 +108,7 @@ class TLEngine(base.Engine):
"""
def __init__(self, *args, **kwargs):
- """The TLEngine relies upon the ConnectionProvider having
+ """The TLEngine relies upon the Pool having
"threadlocal" behavior, so that once a connection is checked out
for the current thread, you get that same connection
repeatedly.
@@ -124,7 +120,7 @@ class TLEngine(base.Engine):
def raw_connection(self):
"""Return a DBAPI connection."""
- return self.connection_provider.get_connection()
+ return self.pool.connect()
def connect(self, **kwargs):
"""Return a Connection that is not thread-locally scoped.
@@ -133,7 +129,7 @@ class TLEngine(base.Engine):
ComposedSQLEngine.
"""
- return base.Connection(self, self.connection_provider.unique_connection())
+ return base.Connection(self, self.pool.unique_connection())
def _session(self):
if not hasattr(self.context, 'session'):
@@ -156,6 +152,3 @@ class TLEngine(base.Engine):
def rollback(self):
self.session.rollback()
-class TLocalConnectionProvider(default.PoolConnectionProvider):
- def unique_connection(self):
- return self._pool.unique_connection()
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index c5ad90ee9..1da76d7b2 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -1,10 +1,8 @@
-import re
-import cgi
-import sys
-import urllib
+"""Provide the URL object as well as the make_url parsing function."""
+
+import re, cgi, sys, urllib
from sqlalchemy import exceptions
-"""Provide the URL object as well as the make_url parsing function."""
class URL(object):
"""Represent the components of a URL used to connect to a database.
diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py
index 2fcf44f61..fa32a5fc3 100644
--- a/lib/sqlalchemy/ext/activemapper.py
+++ b/lib/sqlalchemy/ext/activemapper.py
@@ -1,11 +1,10 @@
-from sqlalchemy import create_session, relation, mapper, \
- join, ThreadLocalMetaData, class_mapper, \
- util, Integer
-from sqlalchemy import and_, or_
+from sqlalchemy import ThreadLocalMetaData, util, Integer
from sqlalchemy import Table, Column, ForeignKey
+from sqlalchemy.orm import class_mapper, relation, create_session
+
from sqlalchemy.ext.sessioncontext import SessionContext
from sqlalchemy.ext.assignmapper import assign_mapper
-from sqlalchemy import backref as create_backref
+from sqlalchemy.orm import backref as create_backref
import sqlalchemy
import inspect
@@ -14,7 +13,7 @@ import sys
#
# the "proxy" to the database engine... this can be swapped out at runtime
#
-metadata = ThreadLocalMetaData("activemapper")
+metadata = ThreadLocalMetaData()
try:
objectstore = sqlalchemy.objectstore
diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py
index 4708afd8d..238041702 100644
--- a/lib/sqlalchemy/ext/assignmapper.py
+++ b/lib/sqlalchemy/ext/assignmapper.py
@@ -1,51 +1,50 @@
-from sqlalchemy import mapper, util, Query, exceptions
+from sqlalchemy import util, exceptions
import types
-
-def monkeypatch_query_method(ctx, class_, name):
- def do(self, *args, **kwargs):
- query = Query(class_, session=ctx.current)
- return getattr(query, name)(*args, **kwargs)
- try:
- do.__name__ = name
- except:
- pass
- setattr(class_, name, classmethod(do))
-
-def monkeypatch_objectstore_method(ctx, class_, name):
+from sqlalchemy.orm import mapper
+
+def _monkeypatch_session_method(name, ctx, class_):
def do(self, *args, **kwargs):
session = ctx.current
- if name == "flush":
- # flush expects a list of objects
- self = [self]
return getattr(session, name)(self, *args, **kwargs)
try:
do.__name__ = name
except:
pass
- setattr(class_, name, do)
-
+ if not hasattr(class_, name):
+ setattr(class_, name, do)
+
def assign_mapper(ctx, class_, *args, **kwargs):
+ extension = kwargs.pop('extension', None)
+ if extension is not None:
+ extension = util.to_list(extension)
+ extension.append(ctx.mapper_extension)
+ else:
+ extension = ctx.mapper_extension
+
validate = kwargs.pop('validate', False)
+
if not isinstance(getattr(class_, '__init__'), types.MethodType):
def __init__(self, **kwargs):
for key, value in kwargs.items():
if validate:
- if not key in self.mapper.props:
+ if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
setattr(self, key, value)
class_.__init__ = __init__
- extension = kwargs.pop('extension', None)
- if extension is not None:
- extension = util.to_list(extension)
- extension.append(ctx.mapper_extension)
- else:
- extension = ctx.mapper_extension
+
+ class query(object):
+ def __getattr__(self, key):
+ return getattr(ctx.current.query(class_), key)
+ def __call__(self):
+ return ctx.current.query(class_)
+
+ if not hasattr(class_, 'query'):
+ class_.query = query()
+
+ for name in ['refresh', 'expire', 'delete', 'expunge', 'update']:
+ _monkeypatch_session_method(name, ctx, class_)
+
m = mapper(class_, extension=extension, *args, **kwargs)
class_.mapper = m
- class_.query = classmethod(lambda cls: Query(class_, session=ctx.current))
- for name in ['get', 'filter', 'filter_by', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']:
- monkeypatch_query_method(ctx, class_, name)
- for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']:
- monkeypatch_objectstore_method(ctx, class_, name)
return m
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index cdb814702..2dd807222 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -6,11 +6,10 @@ transparent proxied access to the endpoint of an association object.
See the example ``examples/association/proxied_association.py``.
"""
-from sqlalchemy.orm.attributes import InstrumentedList
+import weakref, itertools
import sqlalchemy.exceptions as exceptions
import sqlalchemy.orm as orm
import sqlalchemy.util as util
-import weakref
def association_proxy(targetcollection, attr, **kw):
"""Convenience function for use in mapped classes. Implements a Python
@@ -109,7 +108,7 @@ class AssociationProxy(object):
self.collection_class = None
def _get_property(self):
- return orm.class_mapper(self.owning_class).props[self.target_collection]
+ return orm.class_mapper(self.owning_class).get_property(self.target_collection)
def _target_class(self):
return self._get_property().mapper.class_
@@ -168,15 +167,7 @@ class AssociationProxy(object):
def _new(self, lazy_collection):
creator = self.creator and self.creator or self.target_class
-
- # Prefer class typing here to spot dicts with the required append()
- # method.
- collection = lazy_collection()
- if isinstance(collection.data, dict):
- self.collection_class = dict
- else:
- self.collection_class = util.duck_type_collection(collection.data)
- del collection
+ self.collection_class = util.duck_type_collection(lazy_collection())
if self.proxy_factory:
return self.proxy_factory(lazy_collection, creator, self.value_attr)
@@ -269,7 +260,33 @@ class _AssociationList(object):
return self._get(self.col[index])
def __setitem__(self, index, value):
- self._set(self.col[index], value)
+ if not isinstance(index, slice):
+ self._set(self.col[index], value)
+ else:
+ if index.stop is None:
+ stop = len(self)
+ elif index.stop < 0:
+ stop = len(self) + index.stop
+ else:
+ stop = index.stop
+ step = index.step or 1
+
+ rng = range(index.start or 0, stop, step)
+ if step == 1:
+ for i in rng:
+ del self[index.start]
+ i = index.start
+ for item in value:
+ self.insert(i, item)
+ i += 1
+ else:
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value),
+ len(rng)))
+ for i, item in zip(rng, value):
+ self._set(self.col[i], item)
def __delitem__(self, index):
del self.col[index]
@@ -291,9 +308,13 @@ class _AssociationList(object):
del self.col[start:end]
def __iter__(self):
- """Iterate over proxied values. For the actual domain objects,
- iterate over .col instead or just use the underlying collection
- directly from its property on the parent."""
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or
+ just use the underlying collection directly from its property
+ on the parent.
+ """
+
for member in self.col:
yield self._get(member)
raise StopIteration
@@ -304,6 +325,10 @@ class _AssociationList(object):
item = self._create(value, **kw)
self.col.append(item)
+ def count(self, value):
+ return sum([1 for _ in
+ itertools.ifilter(lambda v: v == value, iter(self))])
+
def extend(self, values):
for v in values:
self.append(v)
@@ -311,6 +336,26 @@ class _AssociationList(object):
def insert(self, index, value):
self.col[index:index] = [self._create(value)]
+ def pop(self, index=-1):
+ return self.getter(self.col.pop(index))
+
+ def remove(self, value):
+ for i, val in enumerate(self):
+ if val == value:
+ del self.col[i]
+ return
+ raise ValueError("value not in list")
+
+ def reverse(self):
+ """Not supported, use reversed(mylist)"""
+
+ raise NotImplementedError
+
+ def sort(self):
+ """Not supported, use sorted(mylist)"""
+
+ raise NotImplementedError
+
def clear(self):
del self.col[0:len(self.col)]
@@ -545,9 +590,7 @@ class _AssociationSet(object):
def add(self, value):
if value not in self:
- # must shove this through InstrumentedList.append() which will
- # eventually call the collection_class .add()
- self.col.append(self._create(value))
+ self.col.add(self._create(value))
# for discard and remove, choosing a more expensive check strategy rather
# than call self.creator()
@@ -567,12 +610,7 @@ class _AssociationSet(object):
def pop(self):
if not self.col:
raise KeyError('pop from an empty set')
- # grumble, pop() is borked on InstrumentedList (#548)
- if isinstance(self.col, InstrumentedList):
- member = list(self.col)[0]
- self.col.remove(member)
- else:
- member = self.col.pop()
+ member = self.col.pop()
return self._get(member)
def update(self, other):
diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py
deleted file mode 100644
index b81702fc4..000000000
--- a/lib/sqlalchemy/ext/proxy.py
+++ /dev/null
@@ -1,113 +0,0 @@
-try:
- from threading import local
-except ImportError:
- from sqlalchemy.util import ThreadLocal as local
-
-from sqlalchemy import sql
-from sqlalchemy.engine import create_engine, Engine
-
-__all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine']
-
-class BaseProxyEngine(sql.Executor):
- """Basis for all proxy engines."""
-
- def get_engine(self):
- raise NotImplementedError
-
- def set_engine(self, engine):
- raise NotImplementedError
-
- engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e))
-
- def execute_compiled(self, *args, **kwargs):
- """Override superclass behaviour.
-
- This method is required to be present as it overrides the
- `execute_compiled` present in ``sql.Engine``.
- """
-
- return self.get_engine().execute_compiled(*args, **kwargs)
-
- def compiler(self, *args, **kwargs):
- """Override superclass behaviour.
-
- This method is required to be present as it overrides the
- `compiler` method present in ``sql.Engine``.
- """
-
- return self.get_engine().compiler(*args, **kwargs)
-
- def __getattr__(self, attr):
- """Provide proxying for methods that are not otherwise present on this ``BaseProxyEngine``.
-
- Note that methods which are present on the base class
- ``sql.Engine`` will **not** be proxied through this, and must
- be explicit on this class.
- """
-
- # call get_engine() to give subclasses a chance to change
- # connection establishment behavior
- e = self.get_engine()
- if e is not None:
- return getattr(e, attr)
- raise AttributeError("No connection established in ProxyEngine: "
- " no access to %s" % attr)
-
-class AutoConnectEngine(BaseProxyEngine):
- """An SQLEngine proxy that automatically connects when necessary."""
-
- def __init__(self, dburi, **kwargs):
- BaseProxyEngine.__init__(self)
- self.dburi = dburi
- self.kwargs = kwargs
- self._engine = None
-
- def get_engine(self):
- if self._engine is None:
- if callable(self.dburi):
- dburi = self.dburi()
- else:
- dburi = self.dburi
- self._engine = create_engine(dburi, **self.kwargs)
- return self._engine
-
-
-class ProxyEngine(BaseProxyEngine):
- """Engine proxy for lazy and late initialization.
-
- This engine will delegate access to a real engine set with connect().
- """
-
- def __init__(self, **kwargs):
- BaseProxyEngine.__init__(self)
- # create the local storage for uri->engine map and current engine
- self.storage = local()
- self.kwargs = kwargs
-
- def connect(self, *args, **kwargs):
- """Establish connection to a real engine."""
-
- kwargs.update(self.kwargs)
- if not kwargs:
- key = repr(args)
- else:
- key = "%s, %s" % (repr(args), repr(sorted(kwargs.items())))
- try:
- map = self.storage.connection
- except AttributeError:
- self.storage.connection = {}
- self.storage.engine = None
- map = self.storage.connection
- try:
- self.storage.engine = map[key]
- except KeyError:
- map[key] = create_engine(*args, **kwargs)
- self.storage.engine = map[key]
-
- def get_engine(self):
- if not hasattr(self.storage, 'engine') or self.storage.engine is None:
- raise AttributeError("No connection established")
- return self.storage.engine
-
- def set_engine(self, engine):
- self.storage.engine = engine
diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py
index 68538f3cb..1920b6f92 100644
--- a/lib/sqlalchemy/ext/selectresults.py
+++ b/lib/sqlalchemy/ext/selectresults.py
@@ -1,212 +1,28 @@
+"""SelectResults has been rolled into Query. This class is now just a placeholder."""
+
import sqlalchemy.sql as sql
import sqlalchemy.orm as orm
class SelectResultsExt(orm.MapperExtension):
"""a MapperExtension that provides SelectResults functionality for the
results of query.select_by() and query.select()"""
+
def select_by(self, query, *args, **params):
- return SelectResults(query, query.join_by(*args, **params))
+ q = query
+ for a in args:
+ q = q.filter(a)
+ return q.filter_by(**params)
+
def select(self, query, arg=None, **kwargs):
if isinstance(arg, sql.FromClause) and arg.supports_execution():
return orm.EXT_PASS
else:
- return SelectResults(query, arg, ops=kwargs)
-
-class SelectResults(object):
- """Build a query one component at a time via separate method
- calls, each call transforming the previous ``SelectResults``
- instance into a new ``SelectResults`` instance with further
- limiting criterion added. When interpreted in an iterator context
- (such as via calling ``list(selectresults)``), executes the query.
- """
-
- def __init__(self, query, clause=None, ops={}, joinpoint=None):
- """Construct a new ``SelectResults`` using the given ``Query``
- object and optional ``WHERE`` clause. `ops` is an optional
- dictionary of bind parameter values.
- """
-
- self._query = query
- self._clause = clause
- self._ops = {}
- self._ops.update(ops)
- self._joinpoint = joinpoint or (self._query.table, self._query.mapper)
-
- def options(self,*args, **kwargs):
- """Apply mapper options to the underlying query.
-
- See also ``Query.options``.
- """
-
- new = self.clone()
- new._query = new._query.options(*args, **kwargs)
- return new
-
- def count(self):
- """Execute the SQL ``count()`` function against the ``SelectResults`` criterion."""
-
- return self._query.count(self._clause, **self._ops)
-
- def _col_aggregate(self, col, func):
- """Execute ``func()`` function against the given column.
-
- For performance, only use subselect if `order_by` attribute is set.
- """
-
- if self._ops.get('order_by'):
- s1 = sql.select([col], self._clause, **self._ops).alias('u')
- return sql.select([func(s1.corresponding_column(col))]).scalar()
- else:
- return sql.select([func(col)], self._clause, **self._ops).scalar()
-
- def min(self, col):
- """Execute the SQL ``min()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.min)
-
- def max(self, col):
- """Execute the SQL ``max()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.max)
-
- def sum(self, col):
- """Execute the SQL ``sum()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.sum)
-
- def avg(self, col):
- """Execute the SQL ``avg()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.avg)
-
- def clone(self):
- """Create a copy of this ``SelectResults``."""
-
- return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint)
-
- def filter(self, clause):
- """Apply an additional ``WHERE`` clause against the query."""
-
- new = self.clone()
- new._clause = sql.and_(self._clause, clause)
- return new
-
- def select(self, clause):
- return self.filter(clause)
-
- def select_by(self, *args, **kwargs):
- return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1]))
-
- def order_by(self, order_by):
- """Apply an ``ORDER BY`` to the query."""
-
- new = self.clone()
- new._ops['order_by'] = order_by
- return new
-
- def limit(self, limit):
- """Apply a ``LIMIT`` to the query."""
-
- return self[:limit]
-
- def offset(self, offset):
- """Apply an ``OFFSET`` to the query."""
-
- return self[offset:]
-
- def distinct(self):
- """Apply a ``DISTINCT`` to the query."""
-
- new = self.clone()
- new._ops['distinct'] = True
- return new
-
- def list(self):
- """Return the results represented by this ``SelectResults`` as a list.
-
- This results in an execution of the underlying query.
- """
-
- return list(self)
-
- def select_from(self, from_obj):
- """Set the `from_obj` parameter of the query.
-
- `from_obj` is a list of one or more tables.
- """
-
- new = self.clone()
- new._ops['from_obj'] = from_obj
- return new
-
- def join_to(self, prop):
- """Join the table of this ``SelectResults`` to the table located against the given property name.
-
- Subsequent calls to join_to or outerjoin_to will join against
- the rightmost table located from the previous `join_to` or
- `outerjoin_to` call, searching for the property starting with
- the rightmost mapper last located.
- """
-
- new = self.clone()
- (clause, mapper) = self._join_to(prop, outerjoin=False)
- new._ops['from_obj'] = [clause]
- new._joinpoint = (clause, mapper)
- return new
-
- def outerjoin_to(self, prop):
- """Outer join the table of this ``SelectResults`` to the
- table located against the given property name.
-
- Subsequent calls to join_to or outerjoin_to will join against
- the rightmost table located from the previous ``join_to`` or
- ``outerjoin_to`` call, searching for the property starting with
- the rightmost mapper last located.
- """
-
- new = self.clone()
- (clause, mapper) = self._join_to(prop, outerjoin=True)
- new._ops['from_obj'] = [clause]
- new._joinpoint = (clause, mapper)
- return new
-
- def _join_to(self, prop, outerjoin=False):
- [keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1])
- clause = self._joinpoint[0]
- mapper = self._joinpoint[1]
- for key in keys:
- prop = mapper.props[key]
- if outerjoin:
- clause = clause.outerjoin(prop.select_table, prop.get_join(mapper))
- else:
- clause = clause.join(prop.select_table, prop.get_join(mapper))
- mapper = prop.mapper
- return (clause, mapper)
-
- def compile(self):
- return self._query.compile(self._clause, **self._ops)
-
- def __getitem__(self, item):
- if isinstance(item, slice):
- start = item.start
- stop = item.stop
- if (isinstance(start, int) and start < 0) or \
- (isinstance(stop, int) and stop < 0):
- return list(self)[item]
- else:
- res = self.clone()
- if start is not None and stop is not None:
- res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start))
- elif start is None and stop is not None:
- res._ops.update(dict(limit=stop))
- elif start is not None and stop is None:
- res._ops.update(dict(offset=self._ops.get('offset', 0)+start))
- if item.step is not None:
- return list(res)[None:None:item.step]
- else:
- return res
- else:
- return list(self[item:item+1])[0]
-
- def __iter__(self):
- return iter(self._query.select_whereclause(self._clause, **self._ops))
+ if arg is not None:
+ query = query.filter(arg)
+ return query._legacy_select_kwargs(**kwargs)
+
+def SelectResults(query, clause=None, ops={}):
+ if clause is not None:
+ query = query.filter(clause)
+ query = query.options(orm.extension(SelectResultsExt()))
+ return query._legacy_select_kwargs(**ops)
diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py
index 2f81e55d2..fcbf29c3f 100644
--- a/lib/sqlalchemy/ext/sessioncontext.py
+++ b/lib/sqlalchemy/ext/sessioncontext.py
@@ -1,5 +1,5 @@
from sqlalchemy.util import ScopedRegistry
-from sqlalchemy.orm.mapper import MapperExtension
+from sqlalchemy.orm import create_session, object_session, MapperExtension, EXT_PASS
__all__ = ['SessionContext', 'SessionContextExt']
@@ -15,16 +15,18 @@ class SessionContext(object):
engine = create_engine(...)
def session_factory():
- return Session(bind_to=engine)
+ return Session(bind=engine)
context = SessionContext(session_factory)
s = context.current # get thread-local session
- context.current = Session(bind_to=other_engine) # set current session
+ context.current = Session(bind=other_engine) # set current session
del context.current # discard the thread-local session (a new one will
# be created on the next call to context.current)
"""
- def __init__(self, session_factory, scopefunc=None):
+ def __init__(self, session_factory=None, scopefunc=None):
+ if session_factory is None:
+ session_factory = create_session
self.registry = ScopedRegistry(session_factory, scopefunc)
super(SessionContext, self).__init__()
@@ -60,3 +62,21 @@ class SessionContextExt(MapperExtension):
def get_session(self):
return self.context.current
+
+ def init_instance(self, mapper, class_, instance, args, kwargs):
+ session = kwargs.pop('_sa_session', self.context.current)
+ session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+ return EXT_PASS
+
+ def init_failed(self, mapper, class_, instance, args, kwargs):
+ object_session(instance).expunge(instance)
+ return EXT_PASS
+
+ def dispose_class(self, mapper, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+
+
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index 04e5b49f7..756b5e1e7 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -310,8 +310,8 @@ Boring tests here. Nothing of real expository value.
"""
from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
-from sqlalchemy.ext.assignmapper import assign_mapper
from sqlalchemy.exceptions import *
@@ -392,7 +392,7 @@ class SelectableClassType(type):
def update(cls, whereclause=None, values=None, **kwargs):
_ddl_error(cls)
- def _selectable(cls):
+ def __selectable__(cls):
return cls._table
def __getattr__(cls, attr):
@@ -434,9 +434,7 @@ def _selectable_name(selectable):
return x
def class_for_table(selectable, **mapper_kwargs):
- if not hasattr(selectable, '_selectable') \
- or selectable._selectable() != selectable:
- raise ArgumentError('class_for_table requires a selectable as its argument')
+ selectable = sql._selectable(selectable)
mapname = 'Mapped' + _selectable_name(selectable)
if isinstance(selectable, Table):
klass = TableClassType(mapname, (object,), {})
@@ -520,7 +518,7 @@ class SqlSoup:
def with_labels(self, item):
# TODO give meaningful aliases
- return self.map(item._selectable().select(use_labels=True).alias('foo'))
+ return self.map(sql._selectable(item).select(use_labels=True).alias('foo'))
def join(self, *args, **kwargs):
j = join(*args, **kwargs)
@@ -539,6 +537,9 @@ class SqlSoup:
t = None
self._cache[attr] = t
return t
+
+ def __repr__(self):
+ return 'SqlSoup(%r)' % self._metadata
if __name__ == '__main__':
import doctest
diff --git a/lib/sqlalchemy/mods/legacy_session.py b/lib/sqlalchemy/mods/legacy_session.py
deleted file mode 100644
index e21a5634b..000000000
--- a/lib/sqlalchemy/mods/legacy_session.py
+++ /dev/null
@@ -1,176 +0,0 @@
-"""A plugin that emulates 0.1 Session behavior."""
-
-import sqlalchemy.orm.objectstore as objectstore
-import sqlalchemy.orm.unitofwork as unitofwork
-import sqlalchemy.util as util
-import sqlalchemy
-
-import sqlalchemy.mods.threadlocal
-
-class LegacySession(objectstore.Session):
- def __init__(self, nest_on=None, hash_key=None, **kwargs):
- super(LegacySession, self).__init__(**kwargs)
- self.parent_uow = None
- self.begin_count = 0
- self.nest_on = util.to_list(nest_on)
- self.__pushed_count = 0
-
- def was_pushed(self):
- if self.nest_on is None:
- return
- self.__pushed_count += 1
- if self.__pushed_count == 1:
- for n in self.nest_on:
- n.push_session()
-
- def was_popped(self):
- if self.nest_on is None or self.__pushed_count == 0:
- return
- self.__pushed_count -= 1
- if self.__pushed_count == 0:
- for n in self.nest_on:
- n.pop_session()
-
- class SessionTrans(object):
- """Returned by ``Session.begin()``, denotes a
- transactionalized UnitOfWork instance. Call ``commit()`
- on this to commit the transaction.
- """
-
- def __init__(self, parent, uow, isactive):
- self.__parent = parent
- self.__isactive = isactive
- self.__uow = uow
-
- isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.")
- parent = property(lambda s:s.__parent, doc="The parent Session of this SessionTrans object.")
- uow = property(lambda s:s.__uow, doc="The parent UnitOfWork corresponding to this transaction.")
-
- def begin(self):
- """Call ``begin()`` on the underlying ``Session`` object,
- returning a new no-op ``SessionTrans`` object.
- """
-
- if self.parent.uow is not self.uow:
- raise InvalidRequestError("This SessionTrans is no longer valid")
- return self.parent.begin()
-
- def commit(self):
- """Commit the transaction noted by this ``SessionTrans`` object."""
-
- self.__parent._trans_commit(self)
- self.__isactive = False
-
- def rollback(self):
- """Roll back the current UnitOfWork transaction, in the
- case that ``begin()`` has been called.
-
- The changes logged since the begin() call are discarded.
- """
-
- self.__parent._trans_rollback(self)
- self.__isactive = False
-
- def begin(self):
- """Begin a new UnitOfWork transaction and return a
- transaction-holding object.
-
- ``commit()`` or ``rollback()`` should be called on the returned object.
-
- ``commit()`` on the ``Session`` will do nothing while a
- transaction is pending, and further calls to ``begin()`` will
- return no-op transactional objects.
- """
-
- if self.parent_uow is not None:
- return LegacySession.SessionTrans(self, self.uow, False)
- self.parent_uow = self.uow
- self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map)
- return LegacySession.SessionTrans(self, self.uow, True)
-
- def commit(self, *objects):
- """Commit the current UnitOfWork transaction.
-
- Called with no arguments, this is only used for *implicit*
- transactions when there was no ``begin()``.
-
- If individual objects are submitted, then only those objects
- are committed, and the begin/commit cycle is not affected.
- """
-
- # if an object list is given, commit just those but dont
- # change begin/commit status
- if len(objects):
- self._commit_uow(*objects)
- self.uow.flush(self, *objects)
- return
- if self.parent_uow is None:
- self._commit_uow()
-
- def _trans_commit(self, trans):
- if trans.uow is self.uow and trans.isactive:
- try:
- self._commit_uow()
- finally:
- self.uow = self.parent_uow
- self.parent_uow = None
-
- def _trans_rollback(self, trans):
- if trans.uow is self.uow:
- self.uow = self.parent_uow
- self.parent_uow = None
-
- def _commit_uow(self, *obj):
- self.was_pushed()
- try:
- self.uow.flush(self, *obj)
- finally:
- self.was_popped()
-
-def begin():
- """Deprecated. Use ``s = Session(new_imap=False)``."""
-
- return objectstore.get_session().begin()
-
-def commit(*obj):
- """Deprecated. Use ``flush(*obj)``."""
-
- objectstore.get_session().flush(*obj)
-
-def uow():
- return objectstore.get_session()
-
-def push_session(sess):
- old = get_session()
- if getattr(sess, '_previous', None) is not None:
- raise InvalidRequestError("Given Session is already pushed onto some thread's stack")
- sess._previous = old
- session_registry.set(sess)
- sess.was_pushed()
-
-def pop_session():
- sess = get_session()
- old = sess._previous
- sess._previous = None
- session_registry.set(old)
- sess.was_popped()
- return old
-
-def using_session(sess, func):
- push_session(sess)
- try:
- return func()
- finally:
- pop_session()
-
-def install_plugin():
- objectstore.Session = LegacySession
- objectstore.session_registry = util.ScopedRegistry(objectstore.Session)
- objectstore.begin = begin
- objectstore.commit = commit
- objectstore.uow = uow
- objectstore.push_session = push_session
- objectstore.pop_session = pop_session
- objectstore.using_session = using_session
-
-install_plugin()
diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py
index ac8de9b06..25bfa2840 100644
--- a/lib/sqlalchemy/mods/selectresults.py
+++ b/lib/sqlalchemy/mods/selectresults.py
@@ -1,4 +1,4 @@
-from sqlalchemy.ext.selectresults import *
+from sqlalchemy.ext.selectresults import SelectResultsExt
from sqlalchemy.orm.mapper import global_extensions
def install_plugin():
diff --git a/lib/sqlalchemy/mods/threadlocal.py b/lib/sqlalchemy/mods/threadlocal.py
deleted file mode 100644
index c8043bc62..000000000
--- a/lib/sqlalchemy/mods/threadlocal.py
+++ /dev/null
@@ -1,53 +0,0 @@
-"""This plugin installs thread-local behavior at the ``Engine`` and ``Session`` level.
-
-The default ``Engine`` strategy will be *threadlocal*, producing
-``TLocalEngine`` instances for create_engine by default.
-
-With this engine, ``connect()`` method will return the same connection
-on the same thread, if it is already checked out from the pool. This
-greatly helps functions that call multiple statements to be able to
-easily use just one connection without explicit ``close`` statements
-on result handles.
-
-On the ``Session`` side, module-level methods will be installed within
-the objectstore module, such as ``flush()``, ``delete()``, etc. which
-call this method on the thread-local session.
-
-Note: this mod creates a global, thread-local session context named
-``sqlalchemy.objectstore``. All mappers created while this mod is
-installed will reference this global context when creating new mapped
-object instances.
-"""
-
-from sqlalchemy import util, engine, mapper
-from sqlalchemy.ext.sessioncontext import SessionContext
-import sqlalchemy.ext.assignmapper as assignmapper
-from sqlalchemy.orm.mapper import global_extensions
-from sqlalchemy.orm.session import Session
-import sqlalchemy
-import sys, types
-
-__all__ = ['Objectstore', 'assign_mapper']
-
-class Objectstore(object):
- def __init__(self, *args, **kwargs):
- self.context = SessionContext(*args, **kwargs)
- def __getattr__(self, name):
- return getattr(self.context.current, name)
- session = property(lambda s:s.context.current)
-
-def assign_mapper(class_, *args, **kwargs):
- assignmapper.assign_mapper(objectstore.context, class_, *args, **kwargs)
-
-objectstore = Objectstore(Session)
-def install_plugin():
- sqlalchemy.objectstore = objectstore
- global_extensions.append(objectstore.context.mapper_extension)
- engine.default_strategy = 'threadlocal'
- sqlalchemy.assign_mapper = assign_mapper
-
-def uninstall_plugin():
- engine.default_strategy = 'plain'
- global_extensions.remove(objectstore.context.mapper_extension)
-
-install_plugin()
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 7ef2da897..1982a94f7 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -11,58 +11,229 @@ packages and tying operations to class properties and constructors.
from sqlalchemy import exceptions
from sqlalchemy import util as sautil
-from sqlalchemy.orm.mapper import *
+from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, mapper_registry
+from sqlalchemy.orm.interfaces import SynonymProperty, MapperExtension, EXT_PASS, ExtensionOption, PropComparator
+from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty, CompositeProperty, BackRef
from sqlalchemy.orm import mapper as mapperlib
+from sqlalchemy.orm import collections, strategies
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.util import polymorphic_union
-from sqlalchemy.orm import properties, strategies, interfaces
from sqlalchemy.orm.session import Session as create_session
from sqlalchemy.orm.session import object_session, attribute_manager
-__all__ = ['relation', 'column_property', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', 'extension',
- 'mapper', 'clear_mappers', 'compile_mappers', 'clear_mapper', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query',
- 'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS', 'object_session'
- ]
+__all__ = ['relation', 'column_property', 'composite', 'backref', 'eagerload',
+ 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', 'undefer',
+ 'undefer_group', 'extension', 'mapper', 'clear_mappers',
+ 'compile_mappers', 'class_mapper', 'object_mapper',
+ 'MapperExtension', 'Query', 'polymorphic_union', 'create_session',
+ 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS',
+ 'object_session', 'PropComparator'
+ ]
-def relation(*args, **kwargs):
+def relation(argument, secondary=None, **kwargs):
"""Provide a relationship of a primary Mapper to a secondary Mapper.
- This corresponds to a parent-child or associative table relationship.
+ This corresponds to a parent-child or associative table relationship.
+ The constructed class is an instance of [sqlalchemy.orm.properties#PropertyLoader].
+
+ argument
+ a class or Mapper instance, representing the target of the relation.
+
+ secondary
+ for a many-to-many relationship, specifies the intermediary table. The
+ ``secondary`` keyword argument should generally only be used for a table
+ that is not otherwise expressed in any class mapping. In particular,
+ using the Association Object Pattern is
+ generally mutually exclusive against using the ``secondary`` keyword
+ argument.
+
+ \**kwargs follow:
+
+ association
+ Deprecated; as of version 0.3.0 the association keyword is synonomous
+ with applying the "all, delete-orphan" cascade to a "one-to-many"
+ relationship. SA can now automatically reconcile a "delete" and
+ "insert" operation of two objects with the same "identity" in a flush()
+ operation into a single "update" statement, which is the pattern that
+ "association" used to indicate.
+
+ backref
+ indicates the name of a property to be placed on the related mapper's
+ class that will handle this relationship in the other direction,
+ including synchronizing the object attributes on both sides of the
+ relation. Can also point to a ``backref()`` construct for more
+ configurability.
+
+ cascade
+ a string list of cascade rules which determines how persistence
+ operations should be "cascaded" from parent to child.
+
+ collection_class
+ a class or function that returns a new list-holding object. will be
+ used in place of a plain list for storing elements.
+
+ foreign_keys
+ a list of columns which are to be used as "foreign key" columns.
+ this parameter should be used in conjunction with explicit
+ ``primaryjoin`` and ``secondaryjoin`` (if needed) arguments, and the
+ columns within the ``foreign_keys`` list should be present within
+ those join conditions. Normally, ``relation()`` will inspect the
+ columns within the join conditions to determine which columns are
+ the "foreign key" columns, based on information in the ``Table``
+ metadata. Use this argument when no ForeignKey's are present in the
+ join condition, or to override the table-defined foreign keys.
+
+ foreignkey
+ deprecated. use the ``foreign_keys`` argument for foreign key
+ specification, or ``remote_side`` for "directional" logic.
+
+ lazy=True
+ specifies how the related items should be loaded. a value of True
+ indicates they should be loaded lazily when the property is first
+ accessed. A value of False indicates they should be loaded by joining
+ against the parent object query, so parent and child are loaded in one
+ round trip (i.e. eagerly). A value of None indicates the related items
+ are not loaded by the mapper in any case; the application will manually
+ insert items into the list in some other way. In all cases, items added
+ or removed to the parent object's collection (or scalar attribute) will
+ cause the appropriate updates and deletes upon flush(), i.e. this
+ option only affects load operations, not save operations.
+
+ order_by
+ indicates the ordering that should be applied when loading these items.
+
+ passive_deletes=False
+ Indicates if lazy-loaders should not be executed during the ``flush()``
+ process, which normally occurs in order to locate all existing child
+ items when a parent item is to be deleted. Setting this flag to True is
+ appropriate when ``ON DELETE CASCADE`` rules have been set up on the
+ actual tables so that the database may handle cascading deletes
+ automatically. This strategy is useful particularly for handling the
+ deletion of objects that have very large (and/or deep) child-object
+ collections.
+
+ post_update
+ this indicates that the relationship should be handled by a second
+ UPDATE statement after an INSERT or before a DELETE. Currently, it also
+ will issue an UPDATE after the instance was UPDATEd as well, although
+ this technically should be improved. This flag is used to handle saving
+ bi-directional dependencies between two individual rows (i.e. each row
+ references the other), where it would otherwise be impossible to INSERT
+ or DELETE both rows fully since one row exists before the other. Use
+ this flag when a particular mapping arrangement will incur two rows
+ that are dependent on each other, such as a table that has a
+ one-to-many relationship to a set of child rows, and also has a column
+ that references a single child row within that list (i.e. both tables
+ contain a foreign key to each other). If a ``flush()`` operation returns
+ an error that a "cyclical dependency" was detected, this is a cue that
+ you might want to use ``post_update`` to "break" the cycle.
+
+ primaryjoin
+ a ClauseElement that will be used as the primary join of this child
+ object against the parent object, or in a many-to-many relationship the
+ join of the primary object to the association table. By default, this
+ value is computed based on the foreign key relationships of the parent
+ and child tables (or association table).
+
+ private=False
+ deprecated. setting ``private=True`` is the equivalent of setting
+ ``cascade="all, delete-orphan"``, and indicates the lifecycle of child
+ objects should be contained within that of the parent.
+
+ remote_side
+ used for self-referential relationships, indicates the column or list
+ of columns that form the "remote side" of the relationship.
+
+ secondaryjoin
+ a ClauseElement that will be used as the join of an association table
+ to the child object. By default, this value is computed based on the
+ foreign key relationships of the association and child tables.
+
+ uselist=(True|False)
+ a boolean that indicates if this property should be loaded as a list or
+ a scalar. In most cases, this value is determined automatically by
+ ``relation()``, based on the type and direction of the relationship - one
+ to many forms a list, many to one forms a scalar, many to many is a
+ list. If a scalar is desired where normally a list would be present,
+ such as a bi-directional one-to-one relationship, set uselist to False.
+
+ viewonly=False
+ when set to True, the relation is used only for loading objects within
+ the relationship, and has no effect on the unit-of-work flush process.
+ Relations with viewonly can specify any kind of join conditions to
+ provide additional views of related objects onto a parent object. Note
+ that the functionality of a viewonly relationship has its limits -
+ complicated join conditions may not compile into eager or lazy loaders
+ properly. If this is the case, use an alternative method.
+
"""
- if len(args) > 1 and isinstance(args[0], type):
- raise exceptions.ArgumentError("relation(class, table, **kwargs) is deprecated. Please use relation(class, **kwargs) or relation(mapper, **kwargs).")
- return _relation_loader(*args, **kwargs)
+ return PropertyLoader(argument, secondary=secondary, **kwargs)
+
+# return _relation_loader(argument, secondary=secondary, **kwargs)
+
+#def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs):
def column_property(*args, **kwargs):
"""Provide a column-level property for use with a Mapper.
+
+ Column-based properties can normally be applied to the mapper's
+ ``properties`` dictionary using the ``schema.Column`` element directly.
+ Use this function when the given column is not directly present within
+ the mapper's selectable; examples include SQL expressions, functions,
+ and scalar SELECT queries.
+
+ Columns that arent present in the mapper's selectable won't be persisted
+ by the mapper and are effectively "read-only" attributes.
+
+ \*cols
+ list of Column objects to be mapped.
- Normally, custom column-level properties that represent columns
- directly or indirectly present within the mapped selectable
- can just be added to the ``properties`` dictionary directly,
- in which case this function's usage is not necessary.
-
- In the case of a ``ColumnElement`` directly present within the
- ``properties`` dictionary, the given column is converted to be the exact column
- located within the mapped selectable, in the case that the mapped selectable
- is not the exact parent selectable of the given column, but shares a common
- base table relationship with that column.
+ group
+ a group name for this property when marked as deferred.
+
+ deferred
+ when True, the column property is "deferred", meaning that
+ it does not load immediately, and is instead loaded when the
+ attribute is first accessed on an instance. See also
+ [sqlalchemy.orm#deferred()].
+
+ """
- Use this function when the column expression being added does not
- correspond to any single column within the mapped selectable,
- such as a labeled function or scalar-returning subquery, to force the element
- to become a mapped property regardless of it not being present within the
- mapped selectable.
+ return ColumnProperty(*args, **kwargs)
+
+def composite(class_, *cols, **kwargs):
+ """Return a composite column-based property for use with a Mapper.
+
+ This is very much like a column-based property except the given class
+ is used to construct values composed of one or more columns. The class must
+ implement a constructor with positional arguments matching the order of
+ columns given, as well as a __colset__() method which returns its attributes
+ in column order.
- Note that persistence of instances is driven from the collection of columns
- within the mapped selectable, so column properties attached to a Mapper which have
- no direct correspondence to the mapped selectable will effectively be non-persisted
- attributes.
+ class\_
+ the "composite type" class.
+
+ \*cols
+ list of Column objects to be mapped.
+
+ group
+ a group name for this property when marked as deferred.
+
+ deferred
+ when True, the column property is "deferred", meaning that
+ it does not load immediately, and is instead loaded when the
+ attribute is first accessed on an instance. See also
+ [sqlalchemy.orm#deferred()].
+
+ comparator
+ an optional instance of [sqlalchemy.orm#PropComparator] which
+ provides SQL expression generation functions for this composite
+ type.
"""
- return properties.ColumnProperty(*args, **kwargs)
-def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs):
- return properties.PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, lazy=lazy, **kwargs)
+ return CompositeProperty(class_, *cols, **kwargs)
+
def backref(name, **kwargs):
"""Create a BackRef object with explicit arguments, which are the same arguments one
@@ -72,7 +243,7 @@ def backref(name, **kwargs):
place of a string argument.
"""
- return properties.BackRef(name, **kwargs)
+ return BackRef(name, **kwargs)
def deferred(*columns, **kwargs):
"""Return a ``DeferredColumnProperty``, which indicates this
@@ -82,15 +253,141 @@ def deferred(*columns, **kwargs):
Used with the `properties` dictionary sent to ``mapper()``.
"""
- return properties.ColumnProperty(deferred=True, *columns, **kwargs)
-
-def mapper(class_, table=None, *args, **params):
- """Return a new ``Mapper`` object.
-
- See the ``Mapper`` class for a description of arguments.
+ return ColumnProperty(deferred=True, *columns, **kwargs)
+
+def mapper(class_, local_table=None, *args, **params):
+ """Return a new [sqlalchemy.orm#Mapper] object.
+
+ class\_
+ The class to be mapped.
+
+ local_table
+ The table to which the class is mapped, or None if this
+ mapper inherits from another mapper using concrete table
+ inheritance.
+
+ entity_name
+ A name to be associated with the `class`, to allow alternate
+ mappings for a single class.
+
+ always_refresh
+ If True, all query operations for this mapped class will
+ overwrite all data within object instances that already
+ exist within the session, erasing any in-memory changes with
+ whatever information was loaded from the database. Usage
+ of this flag is highly discouraged; as an alternative,
+ see the method `populate_existing()` on [sqlalchemy.orm.query#Query].
+
+ allow_column_override
+ If True, allows the usage of a ``relation()`` which has the
+ same name as a column in the mapped table. The table column
+ will no longer be mapped.
+
+ allow_null_pks
+ Indicates that composite primary keys where one or more (but
+ not all) columns contain NULL is a valid primary key.
+ Primary keys which contain NULL values usually indicate that
+ a result row does not contain an entity and should be
+ skipped.
+
+ batch
+ Indicates that save operations of multiple entities can be
+ batched together for efficiency. setting to False indicates
+ that an instance will be fully saved before saving the next
+ instance, which includes inserting/updating all table rows
+ corresponding to the entity as well as calling all
+ ``MapperExtension`` methods corresponding to the save
+ operation.
+
+ column_prefix
+ A string which will be prepended to the `key` name of all
+ Columns when creating column-based properties from the given
+ Table. Does not affect explicitly specified column-based
+ properties
+
+ concrete
+ If True, indicates this mapper should use concrete table
+ inheritance with its parent mapper.
+
+ extension
+ A [sqlalchemy.orm#MapperExtension] instance or list of
+ ``MapperExtension`` instances which will be applied to all
+ operations by this ``Mapper``.
+
+ inherits
+ Another ``Mapper`` for which this ``Mapper`` will have an
+ inheritance relationship with.
+
+ inherit_condition
+ For joined table inheritance, a SQL expression (constructed
+ ``ClauseElement``) which will define how the two tables are
+ joined; defaults to a natural join between the two tables.
+
+ order_by
+ A single ``Column`` or list of ``Columns`` for which
+ selection operations should use as the default ordering for
+ entities. Defaults to the OID/ROWID of the table if any, or
+ the first primary key column of the table.
+
+ non_primary
+ Construct a ``Mapper`` that will define only the selection
+ of instances, not their persistence. Any number of non_primary
+ mappers may be created for a particular class.
+
+ polymorphic_on
+ Used with mappers in an inheritance relationship, a ``Column``
+ which will identify the class/mapper combination to be used
+ with a particular row. requires the polymorphic_identity
+ value to be set for all mappers in the inheritance
+ hierarchy.
+
+ _polymorphic_map
+ Used internally to propigate the full map of polymorphic
+ identifiers to surrogate mappers.
+
+ polymorphic_identity
+ A value which will be stored in the Column denoted by
+ polymorphic_on, corresponding to the *class identity* of
+ this mapper.
+
+ polymorphic_fetch
+ specifies how subclasses mapped through joined-table
+ inheritance will be fetched. options are 'union',
+ 'select', and 'deferred'. if the select_table argument
+ is present, defaults to 'union', otherwise defaults to
+ 'select'.
+
+ properties
+ A dictionary mapping the string names of object attributes
+ to ``MapperProperty`` instances, which define the
+ persistence behavior of that attribute. Note that the
+ columns in the mapped table are automatically converted into
+ ``ColumnProperty`` instances based on the `key` property of
+ each ``Column`` (although they can be overridden using this
+ dictionary).
+
+ primary_key
+ A list of ``Column`` objects which define the *primary key*
+ to be used against this mapper's selectable unit. This is
+ normally simply the primary key of the `local_table`, but
+ can be overridden here.
+
+ select_table
+ A [sqlalchemy.schema#Table] or any [sqlalchemy.sql#Selectable]
+ which will be used to select instances of this mapper's class.
+ usually used to provide polymorphic loading among several
+ classes in an inheritance hierarchy.
+
+ version_id_col
+ A ``Column`` which must have an integer type that will be
+ used to keep a running *version id* of mapped entities in
+ the database. this is used during save operations to ensure
+ that no other thread or process has updated the instance
+ during the lifetime of the entity, else a
+ ``ConcurrentModificationError`` exception is thrown.
"""
- return Mapper(class_, table, *args, **params)
+ return Mapper(class_, local_table, *args, **params)
def synonym(name, proxy=False):
"""Set up `name` as a synonym to another ``MapperProperty``.
@@ -98,7 +395,7 @@ def synonym(name, proxy=False):
Used with the `properties` dictionary sent to ``mapper()``.
"""
- return interfaces.SynonymProperty(name, proxy=proxy)
+ return SynonymProperty(name, proxy=proxy)
def compile_mappers():
"""Compile all mappers that have been defined.
@@ -120,32 +417,13 @@ def clear_mappers():
mapperlib._COMPILE_MUTEX.acquire()
try:
for mapper in mapper_registry.values():
- attribute_manager.reset_class_managed(mapper.class_)
- if hasattr(mapper.class_, 'c'):
- del mapper.class_.c
+ mapper.dispose()
mapper_registry.clear()
# TODO: either dont use ArgSingleton, or
# find a way to clear only ClassKey instances from it
sautil.ArgSingleton.instances.clear()
finally:
mapperlib._COMPILE_MUTEX.release()
-
-def clear_mapper(m):
- """Remove the given mapper from the storage of mappers.
-
- When a new mapper is created for the previous mapper's class, it
- will be used as that classes' new primary mapper.
- """
-
- mapperlib._COMPILE_MUTEX.acquire()
- try:
- del mapper_registry[m.class_key]
- attribute_manager.reset_class_managed(m.class_)
- if hasattr(m.class_, 'c'):
- del m.class_.c
- m.class_key.dispose()
- finally:
- mapperlib._COMPILE_MUTEX.release()
def extension(ext):
"""Return a ``MapperOption`` that will insert the given
@@ -166,6 +444,22 @@ def eagerload(name):
return strategies.EagerLazyOption(name, lazy=False)
+def eagerload_all(name):
+ """Return a ``MapperOption`` that will convert all
+ properties along the given dot-separated path into an
+ eager load.
+
+ e.g::
+ query.options(eagerload_all('orders.items.keywords'))...
+
+ will set all of 'orders', 'orders.items', and 'orders.items.keywords'
+ to load in one eager load.
+
+ Used with ``query.options()``.
+ """
+
+ return strategies.EagerLazyOption(name, lazy=False, chained=True)
+
def lazyload(name):
"""Return a ``MapperOption`` that will convert the property of the
given name into a lazy load.
@@ -175,6 +469,9 @@ def lazyload(name):
return strategies.EagerLazyOption(name, lazy=True)
+def fetchmode(name, type):
+ return strategies.FetchModeOption(name, type)
+
def noload(name):
"""Return a ``MapperOption`` that will convert the property of the
given name into a non-load.
@@ -250,64 +547,11 @@ def undefer(name):
return strategies.DeferredOption(name, defer=False)
+def undefer_group(name):
+ """Return a ``MapperOption`` that will convert the given
+ group of deferred column properties into a non-deferred (regular column) load.
-def cascade_mappers(*classes_or_mappers):
- """Attempt to create a series of ``relations()`` between mappers
- automatically, via introspecting the foreign key relationships of
- the underlying tables.
-
- Given a list of classes and/or mappers, identify the foreign key
- relationships between the given mappers or corresponding class
- mappers, and create ``relation()`` objects representing those
- relationships, including a backreference. Attempt to find the
- *secondary* table in a many-to-many relationship as well.
-
- The names of the relations will be a lowercase version of the
- related class. In the case of one-to-many or many-to-many, the
- name will be *pluralized*, which currently is based on the English
- language (i.e. an 's' or 'es' added to it).
-
- NOTE: this method usually works poorly, and its usage is generally
- not advised.
+ Used with ``query.options()``.
"""
-
- table_to_mapper = {}
- for item in classes_or_mappers:
- if isinstance(item, Mapper):
- m = item
- else:
- klass = item
- m = class_mapper(klass)
- table_to_mapper[m.mapped_table] = m
-
- def pluralize(name):
- # oh crap, do we need locale stuff now
- if name[-1] == 's':
- return name + "es"
- else:
- return name + "s"
-
- for table,mapper in table_to_mapper.iteritems():
- for fk in table.foreign_keys:
- if fk.column.table is table:
- continue
- secondary = None
- try:
- m2 = table_to_mapper[fk.column.table]
- except KeyError:
- if len(fk.column.table.primary_key):
- continue
- for sfk in fk.column.table.foreign_keys:
- if sfk.column.table is table:
- continue
- m2 = table_to_mapper.get(sfk.column.table)
- secondary = fk.column.table
- if m2 is None:
- continue
- if secondary:
- propname = pluralize(m2.class_.__name__.lower())
- propname2 = pluralize(mapper.class_.__name__.lower())
- else:
- propname = m2.class_.__name__.lower()
- propname2 = pluralize(mapper.class_.__name__.lower())
- mapper.add_property(propname, relation(m2, secondary=secondary, backref=propname2))
+ return strategies.UndeferGroupOption(name)
+
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9f8a04db8..47ff26085 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -4,37 +4,67 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util
-from sqlalchemy.orm import util as orm_util
-from sqlalchemy import logging, exceptions
import weakref
-class InstrumentedAttribute(object):
- """A property object that instruments attribute access on object instances.
+from sqlalchemy import util
+from sqlalchemy.orm import util as orm_util, interfaces, collections
+from sqlalchemy.orm.mapper import class_mapper
+from sqlalchemy import logging, exceptions
- All methods correspond to a single attribute on a particular
- class.
- """
- PASSIVE_NORESULT = object()
+PASSIVE_NORESULT = object()
+ATTR_WAS_SET = object()
- def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
+class InstrumentedAttribute(interfaces.PropComparator):
+ """attribute access for instrumented classes."""
+
+ def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs):
+ """Construct an InstrumentedAttribute.
+
+ class_
+ the class to be instrumented.
+
+ manager
+ AttributeManager managing this class
+
+ key
+ string name of the attribute
+
+ callable_
+ optional function which generates a callable based on a parent
+ instance, which produces the "default" values for a scalar or
+ collection attribute when it's first accessed, if not present already.
+
+ trackparent
+ if True, attempt to track if an instance has a parent attached to it
+ via this attribute
+
+ extension
+ an AttributeExtension object which will receive
+ set/delete/append/remove/etc. events
+
+ compare_function
+ a function that compares two values which are normally assignable to this
+ attribute
+
+ mutable_scalars
+ if True, the values which are normally assignable to this attribute can mutate,
+ and need to be compared against a copy of their original contents in order to
+ detect changes on the parent instance
+
+ comparator
+ a sql.Comparator to which class-level compare/math events will be sent
+
+ """
+
+ self.class_ = class_
self.manager = manager
self.key = key
- self.uselist = uselist
self.callable_ = callable_
- self.typecallable= typecallable
self.trackparent = trackparent
self.mutable_scalars = mutable_scalars
- if copy_function is None:
- if uselist:
- self.copy = lambda x:[y for y in x]
- else:
- # scalar values are assumed to be immutable unless a copy function
- # is passed
- self.copy = lambda x:x
- else:
- self.copy = lambda x:copy_function(x)
+ self.comparator = comparator
+ self.copy = None
if compare_function is None:
self.is_equal = lambda x,y: x == y
else:
@@ -42,7 +72,7 @@ class InstrumentedAttribute(object):
self.extensions = util.to_list(extension or [])
def __set__(self, obj, value):
- self.set(None, obj, value)
+ self.set(obj, value, None)
def __delete__(self, obj):
self.delete(None, obj)
@@ -52,17 +82,18 @@ class InstrumentedAttribute(object):
return self
return self.get(obj)
- def check_mutable_modified(self, obj):
- if self.mutable_scalars:
- h = self.get_history(obj, passive=True)
- if h is not None and h.is_modified():
- obj._state['modified'] = True
- return True
- else:
- return False
- else:
- return False
+ def clause_element(self):
+ return self.comparator.clause_element()
+
+ def expression_element(self):
+ return self.comparator.expression_element()
+
+ def operate(self, op, other, **kwargs):
+ return op(self.comparator, other, **kwargs)
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.comparator, **kwargs)
+
def hasparent(self, item, optimistic=False):
"""Return the boolean value of a `hasparent` flag attached to the given item.
@@ -98,8 +129,8 @@ class InstrumentedAttribute(object):
# get the current state. this may trigger a lazy load if
# passive is False.
- current = self.get(obj, passive=passive, raiseerr=False)
- if current is InstrumentedAttribute.PASSIVE_NORESULT:
+ current = self.get(obj, passive=passive)
+ if current is PASSIVE_NORESULT:
return None
return AttributeHistory(self, obj, current, passive=passive)
@@ -123,6 +154,14 @@ class InstrumentedAttribute(object):
else:
obj._state[('callable', self)] = callable_
+ def _get_callable(self, obj):
+ if ('callable', self) in obj._state:
+ return obj._state[('callable', self)]
+ elif self.callable_ is not None:
+ return self.callable_(obj)
+ else:
+ return None
+
def reset(self, obj):
"""Remove any per-instance callable functions corresponding to
this ``InstrumentedAttribute``'s attribute from the given
@@ -148,43 +187,21 @@ class InstrumentedAttribute(object):
except KeyError:
pass
- def _get_callable(self, obj):
- if obj._state.has_key(('callable', self)):
- return obj._state[('callable', self)]
- elif self.callable_ is not None:
- return self.callable_(obj)
- else:
- return None
-
- def _blank_list(self):
- if self.typecallable is not None:
- return self.typecallable()
- else:
- return []
+ def check_mutable_modified(self, obj):
+ return False
def initialize(self, obj):
- """Initialize this attribute on the given object instance.
+ """Initialize this attribute on the given object instance with an empty value."""
- If this is a list-based attribute, a new, blank list will be
- created. if a scalar attribute, the value will be initialized
- to None.
- """
-
- if self.uselist:
- l = InstrumentedList(self, obj, self._blank_list())
- obj.__dict__[self.key] = l
- return l
- else:
- obj.__dict__[self.key] = None
- return None
+ obj.__dict__[self.key] = None
+ return None
- def get(self, obj, passive=False, raiseerr=True):
+ def get(self, obj, passive=False):
"""Retrieve a value from the given object.
If a callable is assembled on this object's attribute, and
passive is False, the callable will be executed and the
- resulting value will be set as the new value for this
- attribute.
+ resulting value will be set as the new value for this attribute.
"""
try:
@@ -193,441 +210,301 @@ class InstrumentedAttribute(object):
state = obj._state
# if an instance-wide "trigger" was set, call that
# and start again
- if state.has_key('trigger'):
+ if 'trigger' in state:
trig = state['trigger']
del state['trigger']
trig()
- return self.get(obj, passive=passive, raiseerr=raiseerr)
-
- if self.uselist:
- callable_ = self._get_callable(obj)
- if callable_ is not None:
- if passive:
- return InstrumentedAttribute.PASSIVE_NORESULT
- self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
- values = callable_()
- l = InstrumentedList(self, obj, values, init=False)
-
- # if a callable was executed, then its part of the "committed state"
- # if any, so commit the newly loaded data
- orig = state.get('original', None)
- if orig is not None:
- orig.commit_attribute(self, obj, l)
-
+ return self.get(obj, passive=passive)
+
+ callable_ = self._get_callable(obj)
+ if callable_ is not None:
+ if passive:
+ return PASSIVE_NORESULT
+ self.logger.debug("Executing lazy callable on %s.%s" %
+ (orm_util.instance_str(obj), self.key))
+ value = callable_()
+ if value is not ATTR_WAS_SET:
+ return self.set_committed_value(obj, value)
else:
- # note that we arent raising AttributeErrors, just creating a new
- # blank list and setting it.
- # this might be a good thing to be changeable by options.
- l = InstrumentedList(self, obj, self._blank_list(), init=False)
- obj.__dict__[self.key] = l
- return l
- else:
- callable_ = self._get_callable(obj)
- if callable_ is not None:
- if passive:
- return InstrumentedAttribute.PASSIVE_NORESULT
- self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
- value = callable_()
- obj.__dict__[self.key] = value
-
- # if a callable was executed, then its part of the "committed state"
- # if any, so commit the newly loaded data
- orig = state.get('original', None)
- if orig is not None:
- orig.commit_attribute(self, obj)
- return value
- else:
- # note that we arent raising AttributeErrors, just returning None.
- # this might be a good thing to be changeable by options.
- return None
-
- def set(self, event, obj, value):
- """Set a value on the given object.
-
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``set()` operation and is used to control the depth of a
- circular setter operation.
- """
-
- if event is not self:
- state = obj._state
- # if an instance-wide "trigger" was set, call that
- if state.has_key('trigger'):
- trig = state['trigger']
- del state['trigger']
- trig()
- if self.uselist:
- value = InstrumentedList(self, obj, value)
- old = self.get(obj)
- obj.__dict__[self.key] = value
- state['modified'] = True
- if not self.uselist:
- if self.trackparent:
- if value is not None:
- self.sethasparent(value, True)
- if old is not None:
- self.sethasparent(old, False)
- for ext in self.extensions:
- ext.set(event or self, obj, value, old)
+ return obj.__dict__[self.key]
else:
- # mark all the old elements as detached from the parent
- old.list_replaced()
+ # Return a new, empty value
+ return self.initialize(obj)
- def delete(self, event, obj):
- """Delete a value from the given object.
+ def append(self, obj, value, initiator):
+ self.set(obj, value, initiator)
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``delete()`` operation and is used to control the depth of a
- circular delete operation.
- """
-
- if event is not self:
- try:
- if not self.uselist and (self.trackparent or len(self.extensions)):
- old = self.get(obj)
- del obj.__dict__[self.key]
- except KeyError:
- # TODO: raise this? not consistent with get() ?
- raise AttributeError(self.key)
- obj._state['modified'] = True
- if not self.uselist:
- if self.trackparent:
- if old is not None:
- self.sethasparent(old, False)
- for ext in self.extensions:
- ext.delete(event or self, obj, old)
-
- def append(self, event, obj, value):
- """Append an element to a list based element or sets a scalar
- based element to the given value.
-
- Used by ``GenericBackrefExtension`` to *append* an item
- independent of list/scalar semantics.
-
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``append()`` operation and is used to control the depth of a
- circular append operation.
- """
+ def remove(self, obj, value, initiator):
+ self.set(obj, None, initiator)
- if self.uselist:
- if event is not self:
- self.get(obj).append_with_event(value, event)
- else:
- self.set(event, obj, value)
-
- def remove(self, event, obj, value):
- """Remove an element from a list based element or sets a
- scalar based element to None.
-
- Used by ``GenericBackrefExtension`` to *remove* an item
- independent of list/scalar semantics.
+ def set(self, obj, value, initiator):
+ raise NotImplementedError()
- `event` is the ``InstrumentedAttribute`` that initiated the
- ``remove()`` operation and is used to control the depth of a
- circular remove operation.
+ def set_committed_value(self, obj, value):
+ """set an attribute value on the given instance and 'commit' it.
+
+ this indicates that the given value is the "persisted" value,
+ and history will be logged only if a newly set value is not
+ equal to this value.
+
+ this is typically used by deferred/lazy attribute loaders
+ to set object attributes after the initial load.
"""
- if self.uselist:
- if event is not self:
- self.get(obj).remove_with_event(value, event)
- else:
- self.set(event, obj, None)
+ state = obj._state
+ orig = state.get('original', None)
+ if orig is not None:
+ orig.commit_attribute(self, obj, value)
+ # remove per-instance callable, if any
+ state.pop(('callable', self), None)
+ obj.__dict__[self.key] = value
+ return value
- def append_event(self, event, obj, value):
- """Called by ``InstrumentedList`` when an item is appended."""
+ def set_raw_value(self, obj, value):
+ obj.__dict__[self.key] = value
+ return value
+ def fire_append_event(self, obj, value, initiator):
obj._state['modified'] = True
if self.trackparent and value is not None:
self.sethasparent(value, True)
for ext in self.extensions:
- ext.append(event or self, obj, value)
-
- def remove_event(self, event, obj, value):
- """Called by ``InstrumentedList`` when an item is removed."""
+ ext.append(obj, value, initiator or self)
+ def fire_remove_event(self, obj, value, initiator):
obj._state['modified'] = True
if self.trackparent and value is not None:
self.sethasparent(value, False)
for ext in self.extensions:
- ext.delete(event or self, obj, value)
+ ext.remove(obj, value, initiator or self)
+
+ def fire_replace_event(self, obj, value, previous, initiator):
+ obj._state['modified'] = True
+ if self.trackparent:
+ if value is not None:
+ self.sethasparent(value, True)
+ if previous is not None:
+ self.sethasparent(previous, False)
+ for ext in self.extensions:
+ ext.set(obj, value, previous, initiator or self)
+
+ property = property(lambda s: class_mapper(s.class_).get_property(s.key),
+ doc="the MapperProperty object associated with this attribute")
InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute)
+
+class InstrumentedScalarAttribute(InstrumentedAttribute):
+ """represents a scalar-holding InstrumentedAttribute."""
-class InstrumentedList(object):
- """Instrument a list-based attribute.
-
- All mutator operations (i.e. append, remove, etc.) will fire off
- events to the ``InstrumentedAttribute`` that manages the object's
- attribute. Those events in turn trigger things like backref
- operations and whatever is implemented by
- ``do_list_value_changed`` on ``InstrumentedAttribute``.
-
- Note that this list does a lot less than earlier versions of SA
- list-based attributes, which used ``HistoryArraySet``. This list
- wrapper does **not** maintain setlike semantics, meaning you can add
- as many duplicates as you want (which can break a lot of SQL), and
- also does not do anything related to history tracking.
-
- Please see ticket #213 for information on the future of this
- class, where it will be broken out into more collection-specific
- subtypes.
- """
+ def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
+ super(InstrumentedScalarAttribute, self).__init__(class_, manager, key,
+ callable_, trackparent=trackparent, extension=extension,
+ compare_function=compare_function, **kwargs)
+ self.mutable_scalars = mutable_scalars
- def __init__(self, attr, obj, data, init=True):
- self.attr = attr
- # this weakref is to prevent circular references between the parent object
- # and the list attribute, which interferes with immediate garbage collection.
- self.__obj = weakref.ref(obj)
- self.key = attr.key
-
- # adapt to lists or sets
- # TODO: make three subclasses of InstrumentedList that come off from a
- # metaclass, based on the type of data sent in
- if attr.typecallable is not None:
- self.data = attr.typecallable()
- else:
- self.data = data or attr._blank_list()
-
- if isinstance(self.data, list):
- self._data_appender = self.data.append
- self._clear_data = self._clear_list
- elif isinstance(self.data, util.Set):
- self._data_appender = self.data.add
- self._clear_data = self._clear_set
- elif isinstance(self.data, dict):
- if hasattr(self.data, 'append'):
- self._data_appender = self.data.append
- else:
- raise exceptions.ArgumentError("Dictionary collection class '%s' must implement an append() method" % type(self.data).__name__)
- self._clear_data = self._clear_dict
- else:
- if hasattr(self.data, 'append'):
- self._data_appender = self.data.append
- elif hasattr(self.data, 'add'):
- self._data_appender = self.data.add
- else:
- raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no append() or add() method" % type(self.data).__name__)
+ if copy_function is None:
+ copy_function = self.__copy
+ self.copy = copy_function
- if hasattr(self.data, 'clear'):
- self._clear_data = self._clear_set
- else:
- raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no clear() method" % type(self.data).__name__)
-
- if data is not None and data is not self.data:
- for elem in data:
- self._data_appender(elem)
-
+ def __copy(self, item):
+ # scalar values are assumed to be immutable unless a copy function
+ # is passed
+ return item
- if init:
- for x in self.data:
- self.__setrecord(x)
+ def __delete__(self, obj):
+ old = self.get(obj)
+ del obj.__dict__[self.key]
+ self.fire_remove_event(obj, old, self)
- def list_replaced(self):
- """Fire off delete event handlers for each item in the list
- but doesnt affect the original data list.
- """
+ def check_mutable_modified(self, obj):
+ if self.mutable_scalars:
+ h = self.get_history(obj, passive=True)
+ if h is not None and h.is_modified():
+ obj._state['modified'] = True
+ return True
+ else:
+ return False
+ else:
+ return False
- [self.__delrecord(x) for x in self.data]
+ def set(self, obj, value, initiator):
+ """Set a value on the given object.
- def clear(self):
- """Clear all items in this InstrumentedList and fires off
- delete event handlers for each item.
+ `initiator` is the ``InstrumentedAttribute`` that initiated the
+ ``set()` operation and is used to control the depth of a circular
+ setter operation.
"""
- self._clear_data()
-
- def _clear_dict(self):
- [self.__delrecord(x) for x in self.data.values()]
- self.data.clear()
-
- def _clear_set(self):
- [self.__delrecord(x) for x in self.data]
- self.data.clear()
-
- def _clear_list(self):
- self[:] = []
-
- def __getstate__(self):
- """Implemented to allow pickling, since `__obj` is a weakref,
- also the ``InstrumentedAttribute`` has callables attached to
- it.
- """
+ if initiator is self:
+ return
- return {'key':self.key, 'obj':self.obj, 'data':self.data}
+ state = obj._state
+ # if an instance-wide "trigger" was set, call that
+ if 'trigger' in state:
+ trig = state['trigger']
+ del state['trigger']
+ trig()
- def __setstate__(self, d):
- """Implemented to allow pickling, since `__obj` is a weakref,
- also the ``InstrumentedAttribute`` has callables attached to it.
- """
+ old = self.get(obj)
+ obj.__dict__[self.key] = value
+ self.fire_replace_event(obj, value, old, initiator)
- self.key = d['key']
- self.__obj = weakref.ref(d['obj'])
- self.data = d['data']
- self.attr = getattr(d['obj'].__class__, self.key)
+ type = property(lambda self: self.property.columns[0].type)
- obj = property(lambda s:s.__obj())
+
+class InstrumentedCollectionAttribute(InstrumentedAttribute):
+ """A collection-holding attribute that instruments changes in membership.
- def unchanged_items(self):
- """Deprecated."""
+ InstrumentedCollectionAttribute holds an arbitrary, user-specified
+ container object (defaulting to a list) and brokers access to the
+ CollectionAdapter, a "view" onto that object that presents consistent
+ bag semantics to the orm layer independent of the user data implementation.
+ """
+
+ def __init__(self, class_, manager, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
+ super(InstrumentedCollectionAttribute, self).__init__(class_, manager,
+ key, callable_, trackparent=trackparent, extension=extension,
+ compare_function=compare_function, **kwargs)
- return self.attr.get_history(self.obj).unchanged_items
+ if copy_function is None:
+ copy_function = self.__copy
+ self.copy = copy_function
- def added_items(self):
- """Deprecated."""
+ if typecallable is None:
+ typecallable = list
+ self.collection_factory = \
+ collections._prepare_instrumentation(typecallable)
+ self.collection_interface = \
+ util.duck_type_collection(self.collection_factory())
- return self.attr.get_history(self.obj).added_items
+ def __copy(self, item):
+ return [y for y in list(collections.collection_adapter(item))]
- def deleted_items(self):
- """Deprecated."""
+ def __set__(self, obj, value):
+ """Replace the current collection with a new one."""
- return self.attr.get_history(self.obj).deleted_items
+ setting_type = util.duck_type_collection(value)
- def __iter__(self):
- return iter(self.data)
+ if value is None or setting_type != self.collection_interface:
+ raise exceptions.ArgumentError(
+ "Incompatible collection type on assignment: %s is not %s-like" %
+ (type(value).__name__, self.collection_interface.__name__))
- def __repr__(self):
- return repr(self.data)
+ if hasattr(value, '_sa_adapter'):
+ self.set(obj, list(getattr(value, '_sa_adapter')), None)
+ elif setting_type == dict:
+ self.set(obj, value.values(), None)
+ else:
+ self.set(obj, value, None)
- def __getattr__(self, attr):
- """Proxy unknown methods and attributes to the underlying
- data array. This allows custom list classes to be used.
- """
+ def __delete__(self, obj):
+ if self.key not in obj.__dict__:
+ return
- return getattr(self.data, attr)
+ obj._state['modified'] = True
- def __setrecord(self, item, event=None):
- self.attr.append_event(event, self.obj, item)
- return True
+ collection = self._get_collection(obj)
+ collection.clear_with_event()
+ del obj.__dict__[self.key]
- def __delrecord(self, item, event=None):
- self.attr.remove_event(event, self.obj, item)
- return True
+ def initialize(self, obj):
+ """Initialize this attribute on the given object instance with an empty collection."""
- def append_with_event(self, item, event):
- self.__setrecord(item, event)
- self._data_appender(item)
+ _, user_data = self._build_collection(obj)
+ obj.__dict__[self.key] = user_data
+ return user_data
- def append_without_event(self, item):
- self._data_appender(item)
+ def append(self, obj, value, initiator):
+ if initiator is self:
+ return
+ collection = self._get_collection(obj)
+ collection.append_with_event(value, initiator)
- def remove_with_event(self, item, event):
- self.__delrecord(item, event)
- self.data.remove(item)
+ def remove(self, obj, value, initiator):
+ if initiator is self:
+ return
+ collection = self._get_collection(obj)
+ collection.remove_with_event(value, initiator)
- def append(self, item, _mapper_nohistory=False):
- """Fire off dependent events, and appends the given item to the underlying list.
+ def set(self, obj, value, initiator):
+ """Set a value on the given object.
- `_mapper_nohistory` is a backwards compatibility hack; call
- ``append_without_event`` instead.
+ `initiator` is the ``InstrumentedAttribute`` that initiated the
+ ``set()` operation and is used to control the depth of a circular
+ setter operation.
"""
- if _mapper_nohistory:
- self.append_without_event(item)
- else:
- self.__setrecord(item)
- self._data_appender(item)
-
- def __getitem__(self, i):
- return self.data[i]
-
- def __setitem__(self, i, item):
- if isinstance(i, slice):
- self.__setslice__(i.start, i.stop, item)
- else:
- self.__setrecord(item)
- self.data[i] = item
-
- def __delitem__(self, i):
- if isinstance(i, slice):
- self.__delslice__(i.start, i.stop)
- else:
- self.__delrecord(self.data[i], None)
- del self.data[i]
-
- def __lt__(self, other): return self.data < self.__cast(other)
-
- def __le__(self, other): return self.data <= self.__cast(other)
+ if initiator is self:
+ return
- def __eq__(self, other): return self.data == self.__cast(other)
+ state = obj._state
+ # if an instance-wide "trigger" was set, call that
+ if 'trigger' in state:
+ trig = state['trigger']
+ del state['trigger']
+ trig()
- def __ne__(self, other): return self.data != self.__cast(other)
+ old = self.get(obj)
+ old_collection = self._get_collection(obj, old)
- def __gt__(self, other): return self.data > self.__cast(other)
+ new_collection, user_data = self._build_collection(obj)
+ self._load_collection(obj, value or [], emit_events=True,
+ collection=new_collection)
- def __ge__(self, other): return self.data >= self.__cast(other)
+ obj.__dict__[self.key] = user_data
+ state['modified'] = True
- def __cast(self, other):
- if isinstance(other, InstrumentedList): return other.data
- else: return other
+ # mark all the old elements as detached from the parent
+ if old_collection:
+ old_collection.clear_with_event()
+ old_collection.unlink(old)
- def __cmp__(self, other):
- return cmp(self.data, self.__cast(other))
+ def set_committed_value(self, obj, value):
+ """Set an attribute value on the given instance and 'commit' it."""
+
+ state = obj._state
+ orig = state.get('original', None)
- def __contains__(self, item): return item in self.data
+ collection, user_data = self._build_collection(obj)
+ self._load_collection(obj, value or [], emit_events=False,
+ collection=collection)
+ value = user_data
- def __len__(self):
+ if orig is not None:
+ orig.commit_attribute(self, obj, value)
+ # remove per-instance callable, if any
+ state.pop(('callable', self), None)
+ obj.__dict__[self.key] = value
+ return value
+
+ def _build_collection(self, obj):
+ user_data = self.collection_factory()
+ collection = collections.CollectionAdapter(self, obj, user_data)
+ return collection, user_data
+
+ def _load_collection(self, obj, values, emit_events=True, collection=None):
+ collection = collection or self._get_collection(obj)
+ if values is None:
+ return
+ elif emit_events:
+ for item in values:
+ collection.append_with_event(item)
+ else:
+ for item in values:
+ collection.append_without_event(item)
+
+ def _get_collection(self, obj, user_data=None):
+ if user_data is None:
+ user_data = self.get(obj)
try:
- return len(self.data)
- except TypeError:
- return len(list(self.data))
-
- def __setslice__(self, i, j, other):
- [self.__delrecord(x) for x in self.data[i:j]]
- g = [a for a in list(other) if self.__setrecord(a)]
- self.data[i:j] = g
-
- def __delslice__(self, i, j):
- for a in self.data[i:j]:
- self.__delrecord(a)
- del self.data[i:j]
-
- def insert(self, i, item):
- if self.__setrecord(item):
- self.data.insert(i, item)
-
- def pop(self, i=-1):
- item = self.data[i]
- self.__delrecord(item)
- return self.data.pop(i)
-
- def remove(self, item):
- self.__delrecord(item)
- self.data.remove(item)
-
- def discard(self, item):
- if item in self.data:
- self.__delrecord(item)
- self.data.remove(item)
-
- def extend(self, item_list):
- for item in item_list:
- self.append(item)
-
- def __add__(self, other):
- raise NotImplementedError()
-
- def __radd__(self, other):
- raise NotImplementedError()
-
- def __iadd__(self, other):
- raise NotImplementedError()
-
-class AttributeExtension(object):
- """An abstract class which specifies `append`, `delete`, and `set`
- event handlers to be attached to an object property.
- """
-
- def append(self, event, obj, child):
- pass
-
- def delete(self, event, obj, child):
- pass
+ return getattr(user_data, '_sa_adapter')
+ except AttributeError:
+ collections.CollectionAdapter(self, obj, user_data)
+ return getattr(user_data, '_sa_adapter')
- def set(self, event, obj, child, oldchild):
- pass
-class GenericBackrefExtension(AttributeExtension):
+class GenericBackrefExtension(interfaces.AttributeExtension):
"""An extension which synchronizes a two-way relationship.
A typical two-way relationship is a parent object containing a
@@ -639,19 +516,19 @@ class GenericBackrefExtension(AttributeExtension):
def __init__(self, key):
self.key = key
- def set(self, event, obj, child, oldchild):
+ def set(self, obj, child, oldchild, initiator):
if oldchild is child:
return
if oldchild is not None:
- getattr(oldchild.__class__, self.key).remove(event, oldchild, obj)
+ getattr(oldchild.__class__, self.key).remove(oldchild, obj, initiator)
if child is not None:
- getattr(child.__class__, self.key).append(event, child, obj)
+ getattr(child.__class__, self.key).append(child, obj, initiator)
- def append(self, event, obj, child):
- getattr(child.__class__, self.key).append(event, child, obj)
+ def append(self, obj, child, initiator):
+ getattr(child.__class__, self.key).append(child, obj, initiator)
- def delete(self, event, obj, child):
- getattr(child.__class__, self.key).remove(event, child, obj)
+ def remove(self, obj, child, initiator):
+ getattr(child.__class__, self.key).remove(child, obj, initiator)
class CommittedState(object):
"""Store the original state of an object when the ``commit()`
@@ -673,7 +550,7 @@ class CommittedState(object):
"""
if value is CommittedState.NO_VALUE:
- if obj.__dict__.has_key(attr.key):
+ if attr.key in obj.__dict__:
value = obj.__dict__[attr.key]
if value is not CommittedState.NO_VALUE:
self.data[attr.key] = attr.copy(value)
@@ -690,10 +567,13 @@ class CommittedState(object):
def rollback(self, manager, obj):
for attr in manager.managed_attributes(obj.__class__):
if self.data.has_key(attr.key):
- if attr.uselist:
- obj.__dict__[attr.key][:] = self.data[attr.key]
- else:
+ if not isinstance(attr, InstrumentedCollectionAttribute):
obj.__dict__[attr.key] = self.data[attr.key]
+ else:
+ collection = attr._get_collection(obj)
+ collection.clear_without_event()
+ for item in self.data[attr.key]:
+ collection.append_without_event(item)
else:
del obj.__dict__[attr.key]
@@ -718,17 +598,15 @@ class AttributeHistory(object):
else:
original = None
- if attr.uselist:
+ if isinstance(attr, InstrumentedCollectionAttribute):
self._current = current
- else:
- self._current = [current]
- if attr.uselist:
s = util.Set(original or [])
self._added_items = []
self._unchanged_items = []
self._deleted_items = []
if current:
- for a in current:
+ collection = attr._get_collection(obj, current)
+ for a in collection:
if a in s:
self._unchanged_items.append(a)
else:
@@ -737,6 +615,7 @@ class AttributeHistory(object):
if a not in self._unchanged_items:
self._deleted_items.append(a)
else:
+ self._current = [current]
if attr.is_equal(current, original):
self._unchanged_items = [current]
self._added_items = []
@@ -748,7 +627,6 @@ class AttributeHistory(object):
else:
self._deleted_items = []
self._unchanged_items = []
- #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
def __iter__(self):
return iter(self._current)
@@ -766,24 +644,13 @@ class AttributeHistory(object):
return self._deleted_items
def hasparent(self, obj):
- """Deprecated. This should be called directly from the
- appropriate ``InstrumentedAttribute`` object.
+ """Deprecated. This should be called directly from the appropriate ``InstrumentedAttribute`` object.
"""
return self.attr.hasparent(obj)
class AttributeManager(object):
- """Allow the instrumentation of object attributes.
-
- ``AttributeManager`` is stateless, but can be overridden by
- subclasses to redefine some of its factory operations. Also be
- aware ``AttributeManager`` will cache attributes for a given
- class, allowing not to determine those for each objects (used in
- ``managed_attributes()`` and
- ``noninherited_managed_attributes()``). This cache is cleared for
- a given class while calling ``register_attribute()``, and can be
- cleared using ``clear_attribute_cache()``.
- """
+ """Allow the instrumentation of object attributes."""
def __init__(self):
# will cache attributes, indexed by class objects
@@ -827,7 +694,7 @@ class AttributeManager(object):
o._state['modified'] = False
def managed_attributes(self, class_):
- """Return an iterator of all ``InstrumentedAttribute`` objects
+ """Return a list of all ``InstrumentedAttribute`` objects
associated with the given class.
"""
@@ -878,7 +745,7 @@ class AttributeManager(object):
"""Return an attribute of the given name from the given object.
If the attribute is a scalar, return it as a single-item list,
- otherwise return the list based attribute.
+ otherwise return a collection based attribute.
If the attribute's value is to be produced by an unexecuted
callable, the callable will only be executed if the given
@@ -887,10 +754,10 @@ class AttributeManager(object):
attr = getattr(obj.__class__, key)
x = attr.get(obj, passive=passive)
- if x is InstrumentedAttribute.PASSIVE_NORESULT:
+ if x is PASSIVE_NORESULT:
return []
- elif attr.uselist:
- return x
+ elif isinstance(attr, InstrumentedCollectionAttribute):
+ return list(attr._get_collection(obj, x))
else:
return [x]
@@ -921,7 +788,7 @@ class AttributeManager(object):
by ``trigger_history()``.
"""
- return obj._state.has_key('trigger')
+ return 'trigger' in obj._state
def reset_instance_attribute(self, obj, key):
"""Remove any per-instance callable functions corresponding to
@@ -946,10 +813,9 @@ class AttributeManager(object):
"""Return True if the given `key` correponds to an
instrumented property on the given class.
"""
-
return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute)
- def init_instance_attribute(self, obj, key, uselist, callable_=None, **kwargs):
+ def init_instance_attribute(self, obj, key, callable_=None):
"""Initialize an attribute on an instance to either a blank
value, cancelling out any class- or instance-level callables
that were present, or if a `callable` is supplied set the
@@ -964,7 +830,24 @@ class AttributeManager(object):
events back to this ``AttributeManager``.
"""
- return InstrumentedAttribute(self, key, uselist, callable_, typecallable, **kwargs)
+ if uselist:
+ return InstrumentedCollectionAttribute(class_, self, key,
+ callable_,
+ typecallable,
+ **kwargs)
+ else:
+ return InstrumentedScalarAttribute(class_, self, key, callable_,
+ **kwargs)
+
+ def get_attribute(self, obj_or_cls, key):
+ """Register an attribute at the class level to be instrumented
+ for all instances of the class.
+ """
+
+ if isinstance(obj_or_cls, type):
+ return getattr(obj_or_cls, key)
+ else:
+ return getattr(obj_or_cls.__class__, key)
def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
"""Register an attribute at the class level to be instrumented
@@ -973,10 +856,9 @@ class AttributeManager(object):
# firt invalidate the cache for the given class
# (will be reconstituted as needed, while getting managed attributes)
- self._inherited_attribute_cache.pop(class_,None)
- self._noninherited_attribute_cache.pop(class_,None)
+ self._inherited_attribute_cache.pop(class_, None)
+ self._noninherited_attribute_cache.pop(class_, None)
- #print self, "register attribute", key, "for class", class_
if not hasattr(class_, '_state'):
def _get_state(self):
if not hasattr(self, '_sa_attr_state'):
@@ -987,4 +869,12 @@ class AttributeManager(object):
typecallable = kwargs.pop('typecallable', None)
if isinstance(typecallable, InstrumentedAttribute):
typecallable = None
- setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs))
+ setattr(class_, key, self.create_prop(class_, key, uselist, callable_,
+ typecallable=typecallable, **kwargs))
+
+ def init_collection(self, instance, key):
+ """Initialize a collection attribute and return the collection adapter."""
+
+ attr = self.get_attribute(instance, key)
+ user_data = attr.initialize(instance)
+ return attr._get_collection(instance, user_data)
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
new file mode 100644
index 000000000..7ade882f5
--- /dev/null
+++ b/lib/sqlalchemy/orm/collections.py
@@ -0,0 +1,1182 @@
+"""Support for collections of mapped entities.
+
+The collections package supplies the machinery used to inform the ORM of
+collection membership changes. An instrumentation via decoration approach is
+used, allowing arbitrary types (including built-ins) to be used as entity
+collections without requiring inheritance from a base class.
+
+Instrumentation decoration relays membership change events to the
+``InstrumentedCollectionAttribute`` that is currently managing the collection.
+The decorators observe function call arguments and return values, tracking
+entities entering or leaving the collection. Two decorator approaches are
+provided. One is a bundle of generic decorators that map function arguments
+and return values to events::
+
+ from sqlalchemy.orm.collections import collection
+ class MyClass(object):
+ # ...
+
+ @collection.adds(1)
+ def store(self, item):
+ self.data.append(item)
+
+ @collection.removes_return()
+ def pop(self):
+ return self.data.pop()
+
+
+The second approach is a bundle of targeted decorators that wrap appropriate
+append and remove notifiers around the mutation methods present in the
+standard Python ``list``, ``set`` and ``dict`` interfaces. These could be
+specified in terms of generic decorator recipes, but are instead hand-tooled for
+increased efficiency. The targeted decorators occasionally implement
+adapter-like behavior, such as mapping bulk-set methods (``extend``, ``update``,
+``__setslice``, etc.) into the series of atomic mutation events that the ORM
+requires.
+
+The targeted decorators are used internally for automatic instrumentation of
+entity collection classes. Every collection class goes through a
+transformation process roughly like so:
+
+1. If the class is a built-in, substitute a trivial sub-class
+2. Is this class already instrumented?
+3. Add in generic decorators
+4. Sniff out the collection interface through duck-typing
+5. Add targeted decoration to any undecorated interface method
+
+This process modifies the class at runtime, decorating methods and adding some
+bookkeeping properties. This isn't possible (or desirable) for built-in
+classes like ``list``, so trivial sub-classes are substituted to hold
+decoration::
+
+ class InstrumentedList(list):
+ pass
+
+Collection classes can be specified in ``relation(collection_class=)`` as
+types or a function that returns an instance. Collection classes are
+inspected and instrumented during the mapper compilation phase. The
+collection_class callable will be executed once to produce a specimen
+instance, and the type of that specimen will be instrumented. Functions that
+return built-in types like ``lists`` will be adapted to produce instrumented
+instances.
+
+When extending a known type like ``list``, additional decorations are not
+generally not needed. Odds are, the extension method will delegate to a
+method that's already instrumented. For example::
+
+ class QueueIsh(list):
+ def push(self, item):
+ self.append(item)
+ def shift(self):
+ return self.pop(0)
+
+There's no need to decorate these methods. ``append`` and ``pop`` are already
+instrumented as part of the ``list`` interface. Decorating them would fire
+duplicate events, which should be avoided.
+
+The targeted decoration tries not to rely on other methods in the underlying
+collection class, but some are unavoidable. Many depend on 'read' methods
+being present to properly instrument a 'write', for example, ``__setitem__``
+needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also
+reimplemented in terms of atomic appends and removes, so the ``extend``
+decoration will actually perform many ``append`` operations and not call the
+underlying method at all.
+
+Tight control over bulk operation and the firing of events is also possible by
+implementing the instrumentation internally in your methods. The basic
+instrumentation package works under the general assumption that collection
+mutation will not raise unusual exceptions. If you want to closely
+orchestrate append and remove events with exception management, internal
+instrumentation may be the answer. Within your method,
+``collection_adapter(self)`` will retrieve an object that you can use for
+explicit control over triggering append and remove events.
+
+The owning object and InstrumentedCollectionAttribute are also reachable
+through the adapter, allowing for some very sophisticated behavior.
+"""
+
+import copy, inspect, sys, weakref
+
+from sqlalchemy import exceptions, schema, util as sautil
+from sqlalchemy.orm import mapper
+
+try:
+ from threading import Lock
+except:
+ from dummy_threading import Lock
+try:
+ from operator import attrgetter
+except:
+ def attrgetter(attribute):
+ return lambda value: getattr(value, attribute)
+
+
+__all__ = ['collection', 'collection_adapter',
+ 'mapped_collection', 'column_mapped_collection',
+ 'attribute_mapped_collection']
+
+def column_mapped_collection(mapping_spec):
+ """A dictionary-based collection type with column-based keying.
+
+ Returns a MappedCollection factory with a keying function generated
+ from mapping_spec, which may be a Column or a sequence of Columns.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+ """
+
+ if isinstance(mapping_spec, schema.Column):
+ def keyfunc(value):
+ m = mapper.object_mapper(value)
+ return m.get_attr_by_column(value, mapping_spec)
+ else:
+ cols = []
+ for c in mapping_spec:
+ if not isinstance(c, schema.Column):
+ raise exceptions.ArgumentError(
+ "mapping_spec tuple may only contain columns")
+ cols.append(c)
+ mapping_spec = tuple(cols)
+ def keyfunc(value):
+ m = mapper.object_mapper(value)
+ return tuple([m.get_attr_by_column(value, c) for c in mapping_spec])
+ return lambda: MappedCollection(keyfunc)
+
+def attribute_mapped_collection(attr_name):
+ """A dictionary-based collection type with attribute-based keying.
+
+ Returns a MappedCollection factory with a keying based on the
+ 'attr_name' attribute of entities in the collection.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+ """
+
+ return lambda: MappedCollection(attrgetter(attr_name))
+
+
+def mapped_collection(keyfunc):
+ """A dictionary-based collection type with arbitrary keying.
+
+ Returns a MappedCollection factory with a keying function generated
+ from keyfunc, a callable that takes an entity and returns a key value.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+ """
+
+ return lambda: MappedCollection(keyfunc)
+
+class collection(object):
+ """Decorators for entity collection classes.
+
+ The decorators fall into two groups: annotations and interception recipes.
+
+ The annotating decorators (appender, remover, iterator,
+ internally_instrumented, on_link) indicate the method's purpose and take no
+ arguments. They are not written with parens::
+
+ @collection.appender
+ def append(self, append): ...
+
+ The recipe decorators all require parens, even those that take no
+ arguments::
+
+ @collection.adds('entity'):
+ def insert(self, position, entity): ...
+
+ @collection.removes_return()
+ def popitem(self): ...
+
+ Decorators can be specified in long-hand for Python 2.3, or with
+ the class-level dict attribute '__instrumentation__'- see the source
+ for details.
+ """
+
+ # Bundled as a class solely for ease of use: packaging, doc strings,
+ # importability.
+
+ def appender(cls, fn):
+ """Tag the method as the collection appender.
+
+ The appender method is called with one positional argument: the value
+ to append. The method will be automatically decorated with 'adds(1)'
+ if not already decorated::
+
+ @collection.appender
+ def add(self, append): ...
+
+ # or, equivalently
+ @collection.appender
+ @collection.adds(1)
+ def add(self, append): ...
+
+ # for mapping type, an 'append' may kick out a previous value
+ # that occupies that slot. consider d['a'] = 'foo'- any previous
+ # value in d['a'] is discarded.
+ @collection.appender
+ @collection.replaces(1)
+ def add(self, entity):
+ key = some_key_func(entity)
+ previous = None
+ if key in self:
+ previous = self[key]
+ self[key] = entity
+ return previous
+
+ If the value to append is not allowed in the collection, you may
+ raise an exception. Something to remember is that the appender
+ will be called for each object mapped by a database query. If the
+ database contains rows that violate your collection semantics, you
+ will need to get creative to fix the problem, as access via the
+ collection will not work.
+
+ If the appender method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+ """
+
+ setattr(fn, '_sa_instrument_role', 'appender')
+ return fn
+ appender = classmethod(appender)
+
+ def remover(cls, fn):
+ """Tag the method as the collection remover.
+
+ The remover method is called with one positional argument: the value
+ to remove. The method will be automatically decorated with
+ 'removes_return()' if not already decorated::
+
+ @collection.remover
+ def zap(self, entity): ...
+
+ # or, equivalently
+ @collection.remover
+ @collection.removes_return()
+ def zap(self, ): ...
+
+ If the value to remove is not present in the collection, you may
+ raise an exception or return None to ignore the error.
+
+ If the remove method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+ """
+
+ setattr(fn, '_sa_instrument_role', 'remover')
+ return fn
+ remover = classmethod(remover)
+
+ def iterator(cls, fn):
+ """Tag the method as the collection remover.
+
+ The iterator method is called with no arguments. It is expected to
+ return an iterator over all collection members::
+
+ @collection.iterator
+ def __iter__(self): ...
+ """
+
+ setattr(fn, '_sa_instrument_role', 'iterator')
+ return fn
+ iterator = classmethod(iterator)
+
+ def internally_instrumented(cls, fn):
+ """Tag the method as instrumented.
+
+ This tag will prevent any decoration from being applied to the method.
+ Use this if you are orchestrating your own calls to collection_adapter
+ in one of the basic SQLAlchemy interface methods, or to prevent
+ an automatic ABC method decoration from wrapping your implementation::
+
+ # normally an 'extend' method on a list-like class would be
+ # automatically intercepted and re-implemented in terms of
+ # SQLAlchemy events and append(). your implementation will
+ # never be called, unless:
+ @collection.internally_instrumented
+ def extend(self, items): ...
+ """
+
+ setattr(fn, '_sa_instrumented', True)
+ return fn
+ internally_instrumented = classmethod(internally_instrumented)
+
+ def on_link(cls, fn):
+ """Tag the method as a the "linked to attribute" event handler.
+
+ This optional event handler will be called when the collection class
+ is linked to or unlinked from the InstrumentedAttribute. It is
+ invoked immediately after the '_sa_adapter' property is set on
+ the instance. A single argument is passed: the collection adapter
+ that has been linked, or None if unlinking.
+ """
+
+ setattr(fn, '_sa_instrument_role', 'on_link')
+ return fn
+ on_link = classmethod(on_link)
+
+ def adds(cls, arg):
+ """Mark the method as adding an entity to the collection.
+
+ Adds "add to collection" handling to the method. The decorator argument
+ indicates which method argument holds the SQLAlchemy-relevant value.
+ Arguments can be specified positionally (i.e. integer) or by name::
+
+ @collection.adds(1)
+ def push(self, item): ...
+
+ @collection.adds('entity')
+ def do_stuff(self, thing, entity=None): ...
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
+ return fn
+ return decorator
+ adds = classmethod(adds)
+
+ def replaces(cls, arg):
+ """Mark the method as replacing an entity in the collection.
+
+ Adds "add to collection" and "remove from collection" handling to
+ the method. The decorator argument indicates which method argument
+ holds the SQLAlchemy-relevant value to be added, and return value, if
+ any will be considered the value to remove.
+
+ Arguments can be specified positionally (i.e. integer) or by name::
+
+ @collection.replaces(2)
+ def __setitem__(self, index, item): ...
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
+ setattr(fn, '_sa_instrument_after', 'fire_remove_event')
+ return fn
+ return decorator
+ replaces = classmethod(replaces)
+
+ def removes(cls, arg):
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The decorator
+ argument indicates which method argument holds the SQLAlchemy-relevant
+ value to be removed. Arguments can be specified positionally (i.e.
+ integer) or by name::
+
+ @collection.removes(1)
+ def zap(self, item): ...
+
+ For methods where the value to remove is not known at call-time, use
+ collection.removes_return.
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg))
+ return fn
+ return decorator
+ removes = classmethod(removes)
+
+ def removes_return(cls):
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The return value
+ of the method, if any, is considered the value to remove. The method
+ arguments are not inspected::
+
+ @collection.removes_return()
+ def pop(self): ...
+
+ For methods where the value to remove is known at call-time, use
+ collection.remove.
+ """
+
+ def decorator(fn):
+ setattr(fn, '_sa_instrument_after', 'fire_remove_event')
+ return fn
+ return decorator
+ removes_return = classmethod(removes_return)
+
+
+# public instrumentation interface for 'internally instrumented'
+# implementations
+def collection_adapter(collection):
+ """Fetch the CollectionAdapter for a collection."""
+
+ return getattr(collection, '_sa_adapter', None)
+
+class CollectionAdapter(object):
+ """Bridges between the ORM and arbitrary Python collections.
+
+ Proxies base-level collection operations (append, remove, iterate)
+ to the underlying Python collection, and emits add/remove events for
+ entities entering or leaving the collection.
+
+ The ORM uses an CollectionAdapter exclusively for interaction with
+ entity collections.
+ """
+
+ def __init__(self, attr, owner, data):
+ self.attr = attr
+ self._owner = weakref.ref(owner)
+ self._data = weakref.ref(data)
+ self.link_to_self(data)
+
+ owner = property(lambda s: s._owner(),
+ doc="The object that owns the entity collection.")
+ data = property(lambda s: s._data(),
+ doc="The entity collection being adapted.")
+
+ def link_to_self(self, data):
+ """Link a collection to this adapter, and fire a link event."""
+
+ setattr(data, '_sa_adapter', self)
+ if hasattr(data, '_sa_on_link'):
+ getattr(data, '_sa_on_link')(self)
+
+ def unlink(self, data):
+ """Unlink a collection from any adapter, and fire a link event."""
+
+ setattr(data, '_sa_adapter', None)
+ if hasattr(data, '_sa_on_link'):
+ getattr(data, '_sa_on_link')(None)
+
+ def append_with_event(self, item, initiator=None):
+ """Add an entity to the collection, firing mutation events."""
+
+ getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator)
+
+ def append_without_event(self, item):
+ """Add or restore an entity to the collection, firing no events."""
+
+ getattr(self._data(), '_sa_appender')(item, _sa_initiator=False)
+
+ def remove_with_event(self, item, initiator=None):
+ """Remove an entity from the collection, firing mutation events."""
+
+ getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator)
+
+ def remove_without_event(self, item):
+ """Remove an entity from the collection, firing no events."""
+
+ getattr(self._data(), '_sa_remover')(item, _sa_initiator=False)
+
+ def clear_with_event(self, initiator=None):
+ """Empty the collection, firing a mutation event for each entity."""
+
+ for item in list(self):
+ self.remove_with_event(item, initiator)
+
+ def clear_without_event(self):
+ """Empty the collection, firing no events."""
+
+ for item in list(self):
+ self.remove_without_event(item)
+
+ def __iter__(self):
+ """Iterate over entities in the collection."""
+
+ return getattr(self._data(), '_sa_iterator')()
+
+ def __len__(self):
+ """Count entities in the collection."""
+
+ return len(list(getattr(self._data(), '_sa_iterator')()))
+
+ def __nonzero__(self):
+ return True
+
+ def fire_append_event(self, item, initiator=None):
+ """Notify that a entity has entered the collection.
+
+ Initiator is the InstrumentedAttribute that initiated the membership
+ mutation, and should be left as None unless you are passing along
+ an initiator value from a chained operation.
+ """
+
+ if initiator is not False and item is not None:
+ self.attr.fire_append_event(self._owner(), item, initiator)
+
+ def fire_remove_event(self, item, initiator=None):
+ """Notify that a entity has entered the collection.
+
+ Initiator is the InstrumentedAttribute that initiated the membership
+ mutation, and should be left as None unless you are passing along
+ an initiator value from a chained operation.
+ """
+
+ if initiator is not False and item is not None:
+ self.attr.fire_remove_event(self._owner(), item, initiator)
+
+ def __getstate__(self):
+ return { 'key': self.attr.key,
+ 'owner': self.owner,
+ 'data': self.data }
+
+ def __setstate__(self, d):
+ self.attr = getattr(d['owner'].__class__, d['key'])
+ self._owner = weakref.ref(d['owner'])
+ self._data = weakref.ref(d['data'])
+
+
+__instrumentation_mutex = Lock()
+def _prepare_instrumentation(factory):
+ """Prepare a callable for future use as a collection class factory.
+
+ Given a collection class factory (either a type or no-arg callable),
+ return another factory that will produce compatible instances when
+ called.
+
+ This function is responsible for converting collection_class=list
+ into the run-time behavior of collection_class=InstrumentedList.
+ """
+
+ # Convert a builtin to 'Instrumented*'
+ if factory in __canned_instrumentation:
+ factory = __canned_instrumentation[factory]
+
+ # Create a specimen
+ cls = type(factory())
+
+ # Did factory callable return a builtin?
+ if cls in __canned_instrumentation:
+ # Wrap it so that it returns our 'Instrumented*'
+ factory = __converting_factory(factory)
+ cls = factory()
+
+ # Instrument the class if needed.
+ if __instrumentation_mutex.acquire():
+ try:
+ if getattr(cls, '_sa_instrumented', None) != id(cls):
+ _instrument_class(cls)
+ finally:
+ __instrumentation_mutex.release()
+
+ return factory
+
+def __converting_factory(original_factory):
+ """Convert the type returned by collection factories on the fly.
+
+ Given a collection factory that returns a builtin type (e.g. a list),
+ return a wrapped function that converts that type to one of our
+ instrumented types.
+ """
+
+ def wrapper():
+ collection = original_factory()
+ type_ = type(collection)
+ if type_ in __canned_instrumentation:
+ # return an instrumented type initialized from the factory's
+ # collection
+ return __canned_instrumentation[type_](collection)
+ else:
+ raise exceptions.InvalidRequestError(
+ "Collection class factories must produce instances of a "
+ "single class.")
+ try:
+ # often flawed but better than nothing
+ wrapper.__name__ = "%sWrapper" % original_factory.__name__
+ wrapper.__doc__ = original_factory.__doc__
+ except:
+ pass
+ return wrapper
+
+def _instrument_class(cls):
+ """Modify methods in a class and install instrumentation."""
+
+ # FIXME: more formally document this as a decoratorless/Python 2.3
+ # option for specifying instrumentation. (likely doc'd here in code only,
+ # not in online docs.)
+ #
+ # __instrumentation__ = {
+ # 'rolename': 'methodname', # ...
+ # 'methods': {
+ # 'methodname': ('fire_{append,remove}_event', argspec,
+ # 'fire_{append,remove}_event'),
+ # 'append': ('fire_append_event', 1, None),
+ # '__setitem__': ('fire_append_event', 1, 'fire_remove_event'),
+ # 'pop': (None, None, 'fire_remove_event'),
+ # }
+ # }
+
+ # In the normal call flow, a request for any of the 3 basic collection
+ # types is transformed into one of our trivial subclasses
+ # (e.g. InstrumentedList). Catch anything else that sneaks in here...
+ if cls.__module__ == '__builtin__':
+ raise exceptions.ArgumentError(
+ "Can not instrument a built-in type. Use a "
+ "subclass, even a trivial one.")
+
+ collection_type = sautil.duck_type_collection(cls)
+ if collection_type in __interfaces:
+ roles = __interfaces[collection_type].copy()
+ decorators = roles.pop('_decorators', {})
+ else:
+ roles, decorators = {}, {}
+
+ if hasattr(cls, '__instrumentation__'):
+ roles.update(copy.deepcopy(getattr(cls, '__instrumentation__')))
+
+ methods = roles.pop('methods', {})
+
+ for name in dir(cls):
+ method = getattr(cls, name)
+ if not callable(method):
+ continue
+
+ # note role declarations
+ if hasattr(method, '_sa_instrument_role'):
+ role = method._sa_instrument_role
+ assert role in ('appender', 'remover', 'iterator', 'on_link')
+ roles[role] = name
+
+ # transfer instrumentation requests from decorated function
+ # to the combined queue
+ before, after = None, None
+ if hasattr(method, '_sa_instrument_before'):
+ op, argument = method._sa_instrument_before
+ assert op in ('fire_append_event', 'fire_remove_event')
+ before = op, argument
+ if hasattr(method, '_sa_instrument_after'):
+ op = method._sa_instrument_after
+ assert op in ('fire_append_event', 'fire_remove_event')
+ after = op
+ if before:
+ methods[name] = before[0], before[1], after
+ elif after:
+ methods[name] = None, None, after
+
+ # apply ABC auto-decoration to methods that need it
+ for method, decorator in decorators.items():
+ fn = getattr(cls, method, None)
+ if fn and method not in methods and not hasattr(fn, '_sa_instrumented'):
+ setattr(cls, method, decorator(fn))
+
+ # ensure all roles are present, and apply implicit instrumentation if
+ # needed
+ if 'appender' not in roles or not hasattr(cls, roles['appender']):
+ raise exceptions.ArgumentError(
+ "Type %s must elect an appender method to be "
+ "a collection class" % cls.__name__)
+ elif (roles['appender'] not in methods and
+ not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')):
+ methods[roles['appender']] = ('fire_append_event', 1, None)
+
+ if 'remover' not in roles or not hasattr(cls, roles['remover']):
+ raise exceptions.ArgumentError(
+ "Type %s must elect a remover method to be "
+ "a collection class" % cls.__name__)
+ elif (roles['remover'] not in methods and
+ not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')):
+ methods[roles['remover']] = ('fire_remove_event', 1, None)
+
+ if 'iterator' not in roles or not hasattr(cls, roles['iterator']):
+ raise exceptions.ArgumentError(
+ "Type %s must elect an iterator method to be "
+ "a collection class" % cls.__name__)
+
+ # apply ad-hoc instrumentation from decorators, class-level defaults
+ # and implicit role declarations
+ for method, (before, argument, after) in methods.items():
+ setattr(cls, method,
+ _instrument_membership_mutator(getattr(cls, method),
+ before, argument, after))
+ # intern the role map
+ for role, method in roles.items():
+ setattr(cls, '_sa_%s' % role, getattr(cls, method))
+
+ setattr(cls, '_sa_instrumented', id(cls))
+
+def _instrument_membership_mutator(method, before, argument, after):
+ """Route method args and/or return value through the collection adapter."""
+
+ if type(argument) is int:
+ def wrapper(*args, **kw):
+ if before and len(args) < argument:
+ raise exceptions.ArgumentError(
+ 'Missing argument %i' % argument)
+ initiator = kw.pop('_sa_initiator', None)
+ if initiator is False:
+ executor = None
+ else:
+ executor = getattr(args[0], '_sa_adapter', None)
+
+ if before and executor:
+ getattr(executor, before)(args[argument], initiator)
+
+ if not after or not executor:
+ return method(*args, **kw)
+ else:
+ res = method(*args, **kw)
+ if res is not None:
+ getattr(executor, after)(res, initiator)
+ return res
+ else:
+ def wrapper(*args, **kw):
+ if before:
+ vals = inspect.getargvalues(inspect.currentframe())
+ if argument in kw:
+ value = kw[argument]
+ else:
+ positional = inspect.getargspec(method)[0]
+ pos = positional.index(argument)
+ if pos == -1:
+ raise exceptions.ArgumentError('Missing argument %s' %
+ argument)
+ else:
+ value = args[pos]
+
+ initiator = kw.pop('_sa_initiator', None)
+ if initiator is False:
+ executor = None
+ else:
+ executor = getattr(args[0], '_sa_adapter', None)
+
+ if before and executor:
+ getattr(executor, before)(value, initiator)
+
+ if not after or not executor:
+ return method(*args, **kw)
+ else:
+ res = method(*args, **kw)
+ if res is not None:
+ getattr(executor, after)(res, initiator)
+ return res
+ try:
+ wrapper._sa_instrumented = True
+ wrapper.__name__ = method.__name__
+ wrapper.__doc__ = method.__doc__
+ except:
+ pass
+ return wrapper
+
+def __set(collection, item, _sa_initiator=None):
+ """Run set events, may eventually be inlined into decorators."""
+
+ if _sa_initiator is not False and item is not None:
+ executor = getattr(collection, '_sa_adapter', None)
+ if executor:
+ getattr(executor, 'fire_append_event')(item, _sa_initiator)
+
+def __del(collection, item, _sa_initiator=None):
+ """Run del events, may eventually be inlined into decorators."""
+
+ if _sa_initiator is not False and item is not None:
+ executor = getattr(collection, '_sa_adapter', None)
+ if executor:
+ getattr(executor, 'fire_remove_event')(item, _sa_initiator)
+
+def _list_decorators():
+ """Hand-turned instrumentation wrappers that can decorate any list-like
+ class."""
+
+ def _tidy(fn):
+ setattr(fn, '_sa_instrumented', True)
+ fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__')
+
+ def append(fn):
+ def append(self, item, _sa_initiator=None):
+ # FIXME: example of fully inlining __set and adapter.fire
+ # for critical path
+ if _sa_initiator is not False and item is not None:
+ executor = getattr(self, '_sa_adapter', None)
+ if executor:
+ executor.attr.fire_append_event(executor._owner(),
+ item, _sa_initiator)
+ fn(self, item)
+ _tidy(append)
+ return append
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ fn(self, value)
+ __del(self, value, _sa_initiator)
+ _tidy(remove)
+ return remove
+
+ def insert(fn):
+ def insert(self, index, value):
+ __set(self, value)
+ fn(self, index, value)
+ _tidy(insert)
+ return insert
+
+ def __setitem__(fn):
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ existing = self[index]
+ if existing is not None:
+ __del(self, existing)
+ __set(self, value)
+ fn(self, index, value)
+ else:
+ # slice assignment requires __delitem__, insert, __len__
+ if index.stop is None:
+ stop = 0
+ elif index.stop < 0:
+ stop = len(self) + index.stop
+ else:
+ stop = index.stop
+ step = index.step or 1
+ rng = range(index.start or 0, stop, step)
+ if step == 1:
+ for i in rng:
+ del self[index.start]
+ i = index.start
+ for item in value:
+ self.insert(i, item)
+ i += 1
+ else:
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value),
+ len(rng)))
+ for i, item in zip(rng, value):
+ self.__setitem__(i, item)
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, index):
+ if not isinstance(index, slice):
+ item = self[index]
+ __del(self, item)
+ fn(self, index)
+ else:
+ # slice deletion requires __getslice__ and a slice-groking
+ # __getitem__ for stepped deletion
+ # note: not breaking this into atomic dels
+ for item in self[index]:
+ __del(self, item)
+ fn(self, index)
+ _tidy(__delitem__)
+ return __delitem__
+
+ def __setslice__(fn):
+ def __setslice__(self, start, end, values):
+ for value in self[start:end]:
+ __del(self, value)
+ for value in values:
+ __set(self, value)
+ fn(self, start, end, values)
+ _tidy(__setslice__)
+ return __setslice__
+
+ def __delslice__(fn):
+ def __delslice__(self, start, end):
+ for value in self[start:end]:
+ __del(self, value)
+ fn(self, start, end)
+ _tidy(__delslice__)
+ return __delslice__
+
+ def extend(fn):
+ def extend(self, iterable):
+ for value in iterable:
+ self.append(value)
+ _tidy(extend)
+ return extend
+
+ def pop(fn):
+ def pop(self, index=-1):
+ item = fn(self, index)
+ __del(self, item)
+ return item
+ _tidy(pop)
+ return pop
+
+ l = locals().copy()
+ l.pop('_tidy')
+ return l
+
+def _dict_decorators():
+ """Hand-turned instrumentation wrappers that can decorate any dict-like
+ mapping class."""
+
+ def _tidy(fn):
+ setattr(fn, '_sa_instrumented', True)
+ fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__')
+
+ Unspecified=object()
+
+ def __setitem__(fn):
+ def __setitem__(self, key, value, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ __set(self, value, _sa_initiator)
+ fn(self, key, value)
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, key, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ fn(self, key)
+ _tidy(__delitem__)
+ return __delitem__
+
+ def clear(fn):
+ def clear(self):
+ for key in self:
+ __del(self, self[key])
+ fn(self)
+ _tidy(clear)
+ return clear
+
+ def pop(fn):
+ def pop(self, key, default=Unspecified):
+ if key in self:
+ __del(self, self[key])
+ if default is Unspecified:
+ return fn(self, key)
+ else:
+ return fn(self, key, default)
+ _tidy(pop)
+ return pop
+
+ def popitem(fn):
+ def popitem(self):
+ item = fn(self)
+ __del(self, item[1])
+ return item
+ _tidy(popitem)
+ return popitem
+
+ def setdefault(fn):
+ def setdefault(self, key, default=None):
+ if key not in self:
+ self.__setitem__(key, default)
+ return default
+ else:
+ return self.__getitem__(key)
+ _tidy(setdefault)
+ return setdefault
+
+ if sys.version_info < (2, 4):
+ def update(fn):
+ def update(self, other):
+ for key in other.keys():
+ if not self.has_key(key) or self[key] is not other[key]:
+ self[key] = other[key]
+ _tidy(update)
+ return update
+ else:
+ def update(fn):
+ def update(self, __other=Unspecified, **kw):
+ if __other is not Unspecified:
+ if hasattr(__other, 'keys'):
+ for key in __other.keys():
+ if key not in self or self[key] is not __other[key]:
+ self[key] = __other[key]
+ else:
+ for key, value in __other:
+ if key not in self or self[key] is not value:
+ self[key] = value
+ for key in kw:
+ if key not in self or self[key] is not kw[key]:
+ self[key] = kw[key]
+ _tidy(update)
+ return update
+
+ l = locals().copy()
+ l.pop('_tidy')
+ l.pop('Unspecified')
+ return l
+
+def _set_decorators():
+ """Hand-turned instrumentation wrappers that can decorate any set-like
+ sequence class."""
+
+ def _tidy(fn):
+ setattr(fn, '_sa_instrumented', True)
+ fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__')
+
+ Unspecified=object()
+
+ def add(fn):
+ def add(self, value, _sa_initiator=None):
+ __set(self, value, _sa_initiator)
+ fn(self, value)
+ _tidy(add)
+ return add
+
+ def discard(fn):
+ def discard(self, value, _sa_initiator=None):
+ if value in self:
+ __del(self, value, _sa_initiator)
+ fn(self, value)
+ _tidy(discard)
+ return discard
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ if value in self:
+ __del(self, value, _sa_initiator)
+ fn(self, value)
+ _tidy(remove)
+ return remove
+
+ def pop(fn):
+ def pop(self):
+ item = fn(self)
+ __del(self, item)
+ return item
+ _tidy(pop)
+ return pop
+
+ def clear(fn):
+ def clear(self):
+ for item in list(self):
+ self.remove(item)
+ _tidy(clear)
+ return clear
+
+ def update(fn):
+ def update(self, value):
+ for item in value:
+ if item not in self:
+ self.add(item)
+ _tidy(update)
+ return update
+ __ior__ = update
+
+ def difference_update(fn):
+ def difference_update(self, value):
+ for item in value:
+ self.discard(item)
+ _tidy(difference_update)
+ return difference_update
+ __isub__ = difference_update
+
+ def intersection_update(fn):
+ def intersection_update(self, other):
+ want, have = self.intersection(other), sautil.Set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ _tidy(intersection_update)
+ return intersection_update
+ __iand__ = intersection_update
+
+ def symmetric_difference_update(fn):
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), sautil.Set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ _tidy(symmetric_difference_update)
+ return symmetric_difference_update
+ __ixor__ = symmetric_difference_update
+
+ l = locals().copy()
+ l.pop('_tidy')
+ l.pop('Unspecified')
+ return l
+
+
+class InstrumentedList(list):
+ """An instrumented version of the built-in list."""
+
+ __instrumentation__ = {
+ 'appender': 'append',
+ 'remover': 'remove',
+ 'iterator': '__iter__', }
+
+class InstrumentedSet(sautil.Set):
+ """An instrumented version of the built-in set (or Set)."""
+
+ __instrumentation__ = {
+ 'appender': 'add',
+ 'remover': 'remove',
+ 'iterator': '__iter__', }
+
+class InstrumentedDict(dict):
+ """An instrumented version of the built-in dict."""
+
+ __instrumentation__ = {
+ 'iterator': 'itervalues', }
+
+__canned_instrumentation = {
+ list: InstrumentedList,
+ sautil.Set: InstrumentedSet,
+ dict: InstrumentedDict,
+ }
+
+__interfaces = {
+ list: { 'appender': 'append',
+ 'remover': 'remove',
+ 'iterator': '__iter__',
+ '_decorators': _list_decorators(), },
+ sautil.Set: { 'appender': 'add',
+ 'remover': 'remove',
+ 'iterator': '__iter__',
+ '_decorators': _set_decorators(), },
+ # decorators are required for dicts and object collections.
+ dict: { 'iterator': 'itervalues',
+ '_decorators': _dict_decorators(), },
+ # < 0.4 compatible naming, deprecated- use decorators instead.
+ None: { }
+ }
+
+
+class MappedCollection(dict):
+ """A basic dictionary-based collection class.
+
+ Extends dict with the minimal bag semantics that collection classes require.
+ ``set`` and ``remove`` are implemented in terms of a keying function: any
+ callable that takes an object and returns an object for use as a dictionary
+ key.
+ """
+
+ def __init__(self, keyfunc):
+ """Create a new collection with keying provided by keyfunc.
+
+ keyfunc may be any callable any callable that takes an object and
+ returns an object for use as a dictionary key.
+
+ The keyfunc will be called every time the ORM needs to add a member by
+ value-only (such as when loading instances from the database) or remove
+ a member. The usual cautions about dictionary keying apply-
+ ``keyfunc(object)`` should return the same output for the life of the
+ collection. Keying based on mutable properties can result in
+ unreachable instances "lost" in the collection.
+ """
+ self.keyfunc = keyfunc
+
+ def set(self, value, _sa_initiator=None):
+ """Add an item to the collection, with a key provided by this instance's keyfunc."""
+
+ key = self.keyfunc(value)
+ self.__setitem__(key, value, _sa_initiator)
+ set = collection.internally_instrumented(set)
+ set = collection.appender(set)
+
+ def remove(self, value, _sa_initiator=None):
+ """Remove an item from the collection by value, consulting this instance's keyfunc for the key."""
+
+ key = self.keyfunc(value)
+ # Let self[key] raise if key is not in this collection
+ if self[key] != value:
+ raise exceptions.InvalidRequestError(
+ "Can not remove '%s': collection holds '%s' for key '%s'. "
+ "Possible cause: is the MappedCollection key function "
+ "based on mutable properties or properties that only obtain "
+ "values after flush?" %
+ (value, self[key], key))
+ self.__delitem__(key, _sa_initiator)
+ remove = collection.internally_instrumented(remove)
+ remove = collection.remover(remove)
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index 54b043b32..c06db6963 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -6,8 +6,8 @@
"""Bridge the ``PropertyLoader`` (i.e. a ``relation()``) and the
-``UOWTransaction`` together to allow processing of scalar- and
-list-based dependencies at flush time.
+``UOWTransaction`` together to allow processing of relation()-based
+ dependencies at flush time.
"""
from sqlalchemy.orm import sync
@@ -366,7 +366,7 @@ class ManyToManyDP(DependencyProcessor):
if len(secondary_delete):
secondary_delete.sort()
# TODO: precompile the delete/insert queries?
- statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type=c.type) for c in self.secondary.c if c.key in associationrow]))
+ statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
result = connection.execute(statement, secondary_delete)
if result.supports_sane_rowcount() and result.rowcount != len(secondary_delete):
raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (result.rowcount, len(secondary_delete)))
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index a9a26b57f..aeb8a23fa 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -5,7 +5,205 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util, logging
+from sqlalchemy import util, logging, sql
+
+# returned by a MapperExtension method to indicate a "do nothing" response
+EXT_PASS = object()
+
+class MapperExtension(object):
+ """Base implementation for an object that provides overriding
+ behavior to various Mapper functions. For each method in
+ MapperExtension, a result of EXT_PASS indicates the functionality
+ is not overridden.
+ """
+
+
+ def init_instance(self, mapper, class_, instance, args, kwargs):
+ return EXT_PASS
+
+ def init_failed(self, mapper, class_, instance, args, kwargs):
+ return EXT_PASS
+
+ def get_session(self):
+ """Retrieve a contextual Session instance with which to
+ register a new object.
+
+ Note: this is not called if a session is provided with the
+ `__init__` params (i.e. `_sa_session`).
+ """
+
+ return EXT_PASS
+
+ def load(self, query, *args, **kwargs):
+ """Override the `load` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.load()`` if the value is anything other than EXT_PASS.
+ """
+
+ return EXT_PASS
+
+ def get(self, query, *args, **kwargs):
+ """Override the `get` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.get()`` if the value is anything other than EXT_PASS.
+ """
+
+ return EXT_PASS
+
+ def get_by(self, query, *args, **kwargs):
+ """Override the `get_by` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.get_by()`` if the value is anything other than
+ EXT_PASS.
+
+ DEPRECATED.
+ """
+
+ return EXT_PASS
+
+ def select_by(self, query, *args, **kwargs):
+ """Override the `select_by` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.select_by()`` if the value is anything other than
+ EXT_PASS.
+
+ DEPRECATED.
+ """
+
+ return EXT_PASS
+
+ def select(self, query, *args, **kwargs):
+ """Override the `select` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.select()`` if the value is anything other than
+ EXT_PASS.
+
+ DEPRECATED.
+ """
+
+ return EXT_PASS
+
+
+ def translate_row(self, mapper, context, row):
+ """Perform pre-processing on the given result row and return a
+ new row instance.
+
+ This is called as the very first step in the ``_instance()``
+ method.
+ """
+
+ return EXT_PASS
+
+ def create_instance(self, mapper, selectcontext, row, class_):
+ """Receive a row when a new object instance is about to be
+ created from that row.
+
+ The method can choose to create the instance itself, or it can
+ return None to indicate normal object creation should take
+ place.
+
+ mapper
+ The mapper doing the operation
+
+ selectcontext
+ SelectionContext corresponding to the instances() call
+
+ row
+ The result row from the database
+
+ class\_
+ The class we are mapping.
+ """
+
+ return EXT_PASS
+
+ def append_result(self, mapper, selectcontext, row, instance, result, **flags):
+ """Receive an object instance before that instance is appended
+ to a result list.
+
+ If this method returns EXT_PASS, result appending will proceed
+ normally. if this method returns any other value or None,
+ result appending will not proceed for this instance, giving
+ this extension an opportunity to do the appending itself, if
+ desired.
+
+ mapper
+ The mapper doing the operation.
+
+ selectcontext
+ SelectionContext corresponding to the instances() call.
+
+ row
+ The result row from the database.
+
+ instance
+ The object instance to be appended to the result.
+
+ result
+ List to which results are being appended.
+
+ \**flags
+ extra information about the row, same as criterion in
+ `create_row_processor()` method of [sqlalchemy.orm.interfaces#MapperProperty]
+ """
+
+ return EXT_PASS
+
+ def populate_instance(self, mapper, selectcontext, row, instance, **flags):
+ """Receive a newly-created instance before that instance has
+ its attributes populated.
+
+ The normal population of attributes is according to each
+ attribute's corresponding MapperProperty (which includes
+ column-based attributes as well as relationships to other
+ classes). If this method returns EXT_PASS, instance
+ population will proceed normally. If any other value or None
+ is returned, instance population will not proceed, giving this
+ extension an opportunity to populate the instance itself, if
+ desired.
+ """
+
+ return EXT_PASS
+
+ def before_insert(self, mapper, connection, instance):
+ """Receive an object instance before that instance is INSERTed
+ into its table.
+
+ This is a good place to set up primary key values and such
+ that aren't handled otherwise.
+ """
+
+ return EXT_PASS
+
+ def before_update(self, mapper, connection, instance):
+ """Receive an object instance before that instance is UPDATEed."""
+
+ return EXT_PASS
+
+ def after_update(self, mapper, connection, instance):
+ """Receive an object instance after that instance is UPDATEed."""
+
+ return EXT_PASS
+
+ def after_insert(self, mapper, connection, instance):
+ """Receive an object instance after that instance is INSERTed."""
+
+ return EXT_PASS
+
+ def before_delete(self, mapper, connection, instance):
+ """Receive an object instance before that instance is DELETEed."""
+
+ return EXT_PASS
+
+ def after_delete(self, mapper, connection, instance):
+ """Receive an object instance after that instance is DELETEed."""
+
+ return EXT_PASS
class MapperProperty(object):
"""Manage the relationship of a ``Mapper`` to a single class
@@ -15,22 +213,61 @@ class MapperProperty(object):
"""
def setup(self, querycontext, **kwargs):
- """Called when a statement is being constructed."""
+ """Called by Query for the purposes of constructing a SQL statement.
+
+ Each MapperProperty associated with the target mapper processes the
+ statement referenced by the query context, adding columns and/or
+ criterion as appropriate.
+ """
pass
- def execute(self, selectcontext, instance, row, identitykey, isnew):
- """Called when the mapper receives a row.
-
- `instance` is the parent instance corresponding to the `row`.
+ def create_row_processor(self, selectcontext, mapper, row):
+ """return a 2-tuple consiting of a row processing function and an instance post-processing function.
+
+ Input arguments are the query.SelectionContext and the *first*
+ applicable row of a result set obtained within query.Query.instances(), called
+ only the first time a particular mapper.populate_instance() is invoked for the
+ overal result.
+
+ The settings contained within the SelectionContext as well as the columns present
+ in the row (which will be the same columns present in all rows) are used to determine
+ the behavior of the returned callables. The callables will then be used to process
+ all rows and to post-process all instances, respectively.
+
+ callables are of the following form::
+
+ def execute(instance, row, **flags):
+ # process incoming instance and given row.
+ # flags is a dictionary containing at least the following attributes:
+ # isnew - indicates if the instance was newly created as a result of reading this row
+ # instancekey - identity key of the instance
+ # optional attribute:
+ # ispostselect - indicates if this row resulted from a 'post' select of additional tables/columns
+
+ def post_execute(instance, **flags):
+ # process instance after all result rows have been processed. this
+ # function should be used to issue additional selections in order to
+ # eagerly load additional properties.
+
+ return (execute, post_execute)
+
+ either tuple value can also be ``None`` in which case no function is called.
+
"""
-
+
raise NotImplementedError()
-
+
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
+ """return an iterator of objects which are child objects of the given object,
+ as attached to the attribute corresponding to this MapperProperty."""
+
return []
def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
+ """run the given callable across all objects which are child objects of
+ the given object, as attached to the attribute corresponding to this MapperProperty."""
+
return []
def get_criterion(self, query, key, value):
@@ -60,7 +297,11 @@ class MapperProperty(object):
self.do_init()
def do_init(self):
- """Template method for subclasses."""
+ """Perform subclass-specific initialization steps.
+
+ This is a *template* method called by the
+ ``MapperProperty`` object's init() method."""
+
pass
def register_dependencies(self, *args, **kwargs):
@@ -90,59 +331,81 @@ class MapperProperty(object):
raise NotImplementedError()
- def compare(self, value):
+ def compare(self, operator, value):
"""Return a compare operation for the columns represented by
this ``MapperProperty`` to the given value, which may be a
- column value or an instance.
+ column value or an instance. 'operator' is an operator from
+ the operators module, or from sql.Comparator.
+
+ By default uses the PropComparator attached to this MapperProperty
+ under the attribute name "comparator".
"""
- raise NotImplementedError()
+ return operator(self.comparator, value)
-class SynonymProperty(MapperProperty):
- def __init__(self, name, proxy=False):
- self.name = name
- self.proxy = proxy
-
- def setup(self, querycontext, **kwargs):
- pass
-
- def execute(self, selectcontext, instance, row, identitykey, isnew):
- pass
-
- def do_init(self):
- if not self.proxy:
- return
- class SynonymProp(object):
- def __set__(s, obj, value):
- setattr(obj, self.name, value)
- def __delete__(s, obj):
- delattr(obj, self.name)
- def __get__(s, obj, owner):
- if obj is None:
- return s
- return getattr(obj, self.name)
- setattr(self.parent.class_, self.key, SynonymProp())
+class PropComparator(sql.ColumnOperators):
+ """defines comparison operations for MapperProperty objects"""
+
+ def expression_element(self):
+ return self.clause_element()
+
+ def contains_op(a, b):
+ return a.contains(b)
+ contains_op = staticmethod(contains_op)
+
+ def any_op(a, b, **kwargs):
+ return a.any(b, **kwargs)
+ any_op = staticmethod(any_op)
+
+ def has_op(a, b, **kwargs):
+ return a.has(b, **kwargs)
+ has_op = staticmethod(has_op)
+
+ def __init__(self, prop):
+ self.prop = prop
+
+ def contains(self, other):
+ """return true if this collection contains other"""
+ return self.operate(PropComparator.contains_op, other)
+
+ def any(self, criterion=None, **kwargs):
+ """return true if this collection contains any member that meets the given criterion.
+
+ criterion
+ an optional ClauseElement formulated against the member class' table or attributes.
+
+ \**kwargs
+ key/value pairs corresponding to member class attribute names which will be compared
+ via equality to the corresponding values.
+ """
- def merge(self, session, source, dest, _recursive):
- pass
+ return self.operate(PropComparator.any_op, criterion, **kwargs)
+
+ def has(self, criterion=None, **kwargs):
+ """return true if this element references a member which meets the given criterion.
+
+
+ criterion
+ an optional ClauseElement formulated against the member class' table or attributes.
+
+ \**kwargs
+ key/value pairs corresponding to member class attribute names which will be compared
+ via equality to the corresponding values.
+ """
+ return self.operate(PropComparator.has_op, criterion, **kwargs)
+
class StrategizedProperty(MapperProperty):
"""A MapperProperty which uses selectable strategies to affect
loading behavior.
- There is a single default strategy selected, and alternate
- strategies can be selected at selection time through the usage of
- ``StrategizedOption`` objects.
+ There is a single default strategy selected by default. Alternate
+ strategies can be selected at Query time through the usage of
+ ``StrategizedOption`` objects via the Query.options() method.
"""
def _get_context_strategy(self, context):
- try:
- return context.attributes[id(self)]
- except KeyError:
- # cache the located strategy per StrategizedProperty in the given context for faster re-lookup
- ctx_strategy = self._get_strategy(context.attributes.get((LoaderStrategy, self), self.strategy.__class__))
- context.attributes[id(self)] = ctx_strategy
- return ctx_strategy
+ return self._get_strategy(context.attributes.get(("loaderstrategy", self), self.strategy.__class__))
def _get_strategy(self, cls):
try:
@@ -156,11 +419,10 @@ class StrategizedProperty(MapperProperty):
return strategy
def setup(self, querycontext, **kwargs):
-
self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs)
- def execute(self, selectcontext, instance, row, identitykey, isnew):
- self._get_context_strategy(selectcontext).process_row(selectcontext, instance, row, identitykey, isnew)
+ def create_row_processor(self, selectcontext, mapper, row):
+ return self._get_context_strategy(selectcontext).create_row_processor(selectcontext, mapper, row)
def do_init(self):
self._all_strategies = {}
@@ -170,6 +432,31 @@ class StrategizedProperty(MapperProperty):
if self.is_primary():
self.strategy.init_class_attribute()
+class LoaderStack(object):
+ """a stack object used during load operations to track the
+ current position among a chain of mappers to eager loaders."""
+
+ def __init__(self):
+ self.__stack = []
+
+ def push_property(self, key):
+ self.__stack.append(key)
+
+ def push_mapper(self, mapper):
+ self.__stack.append(mapper.base_mapper())
+
+ def pop(self):
+ self.__stack.pop()
+
+ def snapshot(self):
+ """return an 'snapshot' of this stack.
+
+ this is a tuple form of the stack which can be used as a hash key."""
+ return tuple(self.__stack)
+
+ def __str__(self):
+ return "->".join([str(s) for s in self.__stack])
+
class OperationContext(object):
"""Serve as a context during a query construction or instance
loading operation.
@@ -200,6 +487,44 @@ class MapperOption(object):
def process_query(self, query):
pass
+class ExtensionOption(MapperOption):
+ """a MapperOption that applies a MapperExtension to a query operation."""
+
+ def __init__(self, ext):
+ self.ext = ext
+
+ def process_query(self, query):
+ query._extension = query._extension.copy()
+ query._extension.append(self.ext)
+
+class SynonymProperty(MapperProperty):
+ def __init__(self, name, proxy=False):
+ self.name = name
+ self.proxy = proxy
+
+ def setup(self, querycontext, **kwargs):
+ pass
+
+ def create_row_processor(self, selectcontext, mapper, row):
+ return (None, None)
+
+ def do_init(self):
+ if not self.proxy:
+ return
+ class SynonymProp(object):
+ def __set__(s, obj, value):
+ setattr(obj, self.name, value)
+ def __delete__(s, obj):
+ delattr(obj, self.name)
+ def __get__(s, obj, owner):
+ if obj is None:
+ return s
+ return getattr(obj, self.name)
+ setattr(self.parent.class_, self.key, SynonymProp())
+
+ def merge(self, session, source, dest, _recursive):
+ pass
+
class PropertyOption(MapperOption):
"""A MapperOption that is applied to a property off the mapper or
one of its child mappers, identified by a dot-separated key.
@@ -208,45 +533,72 @@ class PropertyOption(MapperOption):
def __init__(self, key):
self.key = key
- def process_query_property(self, context, property):
+ def process_query_property(self, context, properties):
pass
- def process_selection_property(self, context, property):
+ def process_selection_property(self, context, properties):
pass
def process_query_context(self, context):
- self.process_query_property(context, self._get_property(context))
+ self.process_query_property(context, self._get_properties(context))
def process_selection_context(self, context):
- self.process_selection_property(context, self._get_property(context))
+ self.process_selection_property(context, self._get_properties(context))
- def _get_property(self, context):
+ def _get_properties(self, context):
try:
- prop = self.__prop
+ l = self.__prop
except AttributeError:
+ l = []
mapper = context.mapper
for token in self.key.split('.'):
- prop = mapper.props[token]
- if isinstance(prop, SynonymProperty):
- prop = mapper.props[prop.name]
+ prop = mapper.get_property(token, resolve_synonyms=True)
+ l.append(prop)
mapper = getattr(prop, 'mapper', None)
- self.__prop = prop
- return prop
+ self.__prop = l
+ return l
PropertyOption.logger = logging.class_logger(PropertyOption)
+
+class AttributeExtension(object):
+ """An abstract class which specifies `append`, `delete`, and `set`
+ event handlers to be attached to an object property.
+ """
+
+ def append(self, obj, child, initiator):
+ pass
+
+ def remove(self, obj, child, initiator):
+ pass
+
+ def set(self, obj, child, oldchild, initiator):
+ pass
+
+
class StrategizedOption(PropertyOption):
"""A MapperOption that affects which LoaderStrategy will be used
for an operation by a StrategizedProperty.
"""
- def process_query_property(self, context, property):
+ def is_chained(self):
+ return False
+
+ def process_query_property(self, context, properties):
self.logger.debug("applying option to QueryContext, property key '%s'" % self.key)
- context.attributes[(LoaderStrategy, property)] = self.get_strategy_class()
+ if self.is_chained():
+ for prop in properties:
+ context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
+ else:
+ context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
- def process_selection_property(self, context, property):
+ def process_selection_property(self, context, properties):
self.logger.debug("applying option to SelectionContext, property key '%s'" % self.key)
- context.attributes[(LoaderStrategy, property)] = self.get_strategy_class()
+ if self.is_chained():
+ for prop in properties:
+ context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
+ else:
+ context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
def get_strategy_class(self):
raise NotImplementedError()
@@ -291,5 +643,13 @@ class LoaderStrategy(object):
def setup_query(self, context, **kwargs):
pass
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- pass
+ def create_row_processor(self, selectcontext, mapper, row):
+ """return row processing functions which fulfill the contract specified
+ by MapperProperty.create_row_processor.
+
+
+ StrategizedProperty delegates its create_row_processor method
+ directly to this method.
+ """
+
+ raise NotImplementedError()
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 375408926..76cc41289 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -4,14 +4,15 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, schema, util, exceptions, logging
+from sqlalchemy import sql, util, exceptions, logging
from sqlalchemy import sql_util as sqlutil
from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.util import ExtensionCarrier
from sqlalchemy.orm import sync
-from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, SynonymProperty
-import weakref, warnings
+from sqlalchemy.orm.interfaces import MapperProperty, EXT_PASS, MapperExtension, SynonymProperty
+import weakref, warnings, operator
-__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS', 'mapper_registry', 'ExtensionOption']
+__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
# a dictionary mapping classes to their primary mappers
mapper_registry = weakref.WeakKeyDictionary()
@@ -24,12 +25,13 @@ global_extensions = []
# column
NO_ATTRIBUTE = object()
-# returned by a MapperExtension method to indicate a "do nothing" response
-EXT_PASS = object()
-
# lock used to synchronize the "mapper compile" step
_COMPILE_MUTEX = util.threading.Lock()
+# initialize these two lazily
+attribute_manager = None
+ColumnProperty = None
+
class Mapper(object):
"""Define the correlation of class attributes to database table
columns.
@@ -55,6 +57,7 @@ class Mapper(object):
polymorphic_on=None,
_polymorphic_map=None,
polymorphic_identity=None,
+ polymorphic_fetch=None,
concrete=False,
select_table=None,
allow_null_pks=False,
@@ -62,126 +65,8 @@ class Mapper(object):
column_prefix=None):
"""Construct a new mapper.
- All arguments may be sent to the ``sqlalchemy.orm.mapper()``
- function where they are passed through to here.
-
- class\_
- The class to be mapped.
-
- local_table
- The table to which the class is mapped, or None if this
- mapper inherits from another mapper using concrete table
- inheritance.
-
- properties
- A dictionary mapping the string names of object attributes
- to ``MapperProperty`` instances, which define the
- persistence behavior of that attribute. Note that the
- columns in the mapped table are automatically converted into
- ``ColumnProperty`` instances based on the `key` property of
- each ``Column`` (although they can be overridden using this
- dictionary).
-
- primary_key
- A list of ``Column`` objects which define the *primary key*
- to be used against this mapper's selectable unit. This is
- normally simply the primary key of the `local_table`, but
- can be overridden here.
-
- non_primary
- Construct a ``Mapper`` that will define only the selection
- of instances, not their persistence.
-
- inherits
- Another ``Mapper`` for which this ``Mapper`` will have an
- inheritance relationship with.
-
- inherit_condition
- For joined table inheritance, a SQL expression (constructed
- ``ClauseElement``) which will define how the two tables are
- joined; defaults to a natural join between the two tables.
-
- extension
- A ``MapperExtension`` instance or list of
- ``MapperExtension`` instances which will be applied to all
- operations by this ``Mapper``.
-
- order_by
- A single ``Column`` or list of ``Columns`` for which
- selection operations should use as the default ordering for
- entities. Defaults to the OID/ROWID of the table if any, or
- the first primary key column of the table.
-
- allow_column_override
- If True, allows the usage of a ``relation()`` which has the
- same name as a column in the mapped table. The table column
- will no longer be mapped.
-
- entity_name
- A name to be associated with the `class`, to allow alternate
- mappings for a single class.
-
- always_refresh
- If True, all query operations for this mapped class will
- overwrite all data within object instances that already
- exist within the session, erasing any in-memory changes with
- whatever information was loaded from the database.
-
- version_id_col
- A ``Column`` which must have an integer type that will be
- used to keep a running *version id* of mapped entities in
- the database. this is used during save operations to ensure
- that no other thread or process has updated the instance
- during the lifetime of the entity, else a
- ``ConcurrentModificationError`` exception is thrown.
-
- polymorphic_on
- Used with mappers in an inheritance relationship, a ``Column``
- which will identify the class/mapper combination to be used
- with a particular row. requires the polymorphic_identity
- value to be set for all mappers in the inheritance
- hierarchy.
-
- _polymorphic_map
- Used internally to propigate the full map of polymorphic
- identifiers to surrogate mappers.
-
- polymorphic_identity
- A value which will be stored in the Column denoted by
- polymorphic_on, corresponding to the *class identity* of
- this mapper.
-
- concrete
- If True, indicates this mapper should use concrete table
- inheritance with its parent mapper.
-
- select_table
- A ``Table`` or (more commonly) ``Selectable`` which will be
- used to select instances of this mapper's class. usually
- used to provide polymorphic loading among several classes in
- an inheritance hierarchy.
-
- allow_null_pks
- Indicates that composite primary keys where one or more (but
- not all) columns contain NULL is a valid primary key.
- Primary keys which contain NULL values usually indicate that
- a result row does not contain an entity and should be
- skipped.
-
- batch
- Indicates that save operations of multiple entities can be
- batched together for efficiency. setting to False indicates
- that an instance will be fully saved before saving the next
- instance, which includes inserting/updating all table rows
- corresponding to the entity as well as calling all
- ``MapperExtension`` methods corresponding to the save
- operation.
-
- column_prefix
- A string which will be prepended to the `key` name of all
- Columns when creating column-based properties from the given
- Table. Does not affect explicitly specified column-based
- properties
+ Mappers are normally constructed via the [sqlalchemy.orm#mapper()]
+ function. See for details.
"""
if not issubclass(class_, object):
@@ -227,6 +112,13 @@ class Mapper(object):
# indicates this Mapper should be used to construct the object instance for that row.
self.polymorphic_identity = polymorphic_identity
+ if polymorphic_fetch not in (None, 'union', 'select', 'deferred'):
+ raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch)
+ if polymorphic_fetch is None:
+ self.polymorphic_fetch = (self.select_table is None) and 'select' or 'union'
+ else:
+ self.polymorphic_fetch = polymorphic_fetch
+
# a dictionary of 'polymorphic identity' names, associating those names with
# Mappers that will be used to construct object instances upon a select operation.
if _polymorphic_map is None:
@@ -297,20 +189,8 @@ class Mapper(object):
else:
return False
- def _get_props(self):
- self.compile()
- return self.__props
-
- props = property(_get_props, doc="compiles this mapper if needed, and returns the "
- "dictionary of MapperProperty objects associated with this mapper."
- "(Deprecated; use get_property() and iterate_properties)")
-
def get_property(self, key, resolve_synonyms=False, raiseerr=True):
- """return MapperProperty with the given key.
-
- forwards compatible with 0.4.
- """
-
+ """return MapperProperty with the given key."""
self.compile()
prop = self.__props.get(key, None)
if resolve_synonyms:
@@ -319,10 +199,22 @@ class Mapper(object):
if prop is None and raiseerr:
raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key))
return prop
-
- iterate_properties = property(lambda self: self._get_props().itervalues(), doc="returns an iterator of all MapperProperty objects."
- " Forwards compatible with 0.4")
-
+
+ def iterate_properties(self):
+ self.compile()
+ return self.__props.itervalues()
+ iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.")
+
+ def dispose(self):
+ attribute_manager.reset_class_managed(self.class_)
+ if hasattr(self.class_, 'c'):
+ del self.class_.c
+ if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'):
+ if self.class_.__init__._oldinit is not None:
+ self.class_.__init__ = self.class_.__init__._oldinit
+ else:
+ delattr(self.class_, '__init__')
+
def compile(self):
"""Compile this mapper into its final internal format.
@@ -403,7 +295,7 @@ class Mapper(object):
for ext_obj in util.to_list(extension):
extlist.add(ext_obj)
- self.extension = _ExtensionCarrier()
+ self.extension = ExtensionCarrier()
for ext in extlist:
self.extension.append(ext)
@@ -452,8 +344,17 @@ class Mapper(object):
self.mapped_table = self.local_table
if self.polymorphic_identity is not None:
self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self)
- if self.polymorphic_on is None and self.inherits.polymorphic_on is not None:
- self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False)
+ if self.polymorphic_on is None:
+ if self.inherits.polymorphic_on is not None:
+ self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False)
+ else:
+ raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
+
+ if self.polymorphic_identity is not None and not self.concrete:
+ self._identity_class = self.inherits._identity_class
+ else:
+ self._identity_class = self.class_
+
if self.order_by is False:
self.order_by = self.inherits.order_by
self.polymorphic_map = self.inherits.polymorphic_map
@@ -463,8 +364,11 @@ class Mapper(object):
self._synchronizer = None
self.mapped_table = self.local_table
if self.polymorphic_identity is not None:
+ if self.polymorphic_on is None:
+ raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
self._add_polymorphic_mapping(self.polymorphic_identity, self)
-
+ self._identity_class = self.class_
+
if self.mapped_table is None:
raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self))
@@ -503,39 +407,134 @@ class Mapper(object):
# may be a join or other construct
self.tables = sqlutil.TableFinder(self.mapped_table)
- # determine primary key columns, either passed in, or get them from our set of tables
+ # determine primary key columns
self.pks_by_table = {}
+
+ # go through all of our represented tables
+ # and assemble primary key columns
+ for t in self.tables + [self.mapped_table]:
+ try:
+ l = self.pks_by_table[t]
+ except KeyError:
+ l = self.pks_by_table.setdefault(t, util.OrderedSet())
+ for k in t.primary_key:
+ l.add(k)
+
if self.primary_key_argument is not None:
- # determine primary keys using user-given list of primary key columns as a guide
- #
- # TODO: this might not work very well for joined-table and/or polymorphic
- # inheritance mappers since local_table isnt taken into account nor is select_table
- # need to test custom primary key columns used with inheriting mappers
for k in self.primary_key_argument:
self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
- if k.table != self.mapped_table:
- # associate pk cols from subtables to the "main" table
- corr = self.mapped_table.corresponding_column(k, raiseerr=False)
- if corr is not None:
- self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(corr)
- else:
- # no user-defined primary key columns - go through all of our represented tables
- # and assemble primary key columns
- for t in self.tables + [self.mapped_table]:
- try:
- l = self.pks_by_table[t]
- except KeyError:
- l = self.pks_by_table.setdefault(t, util.OrderedSet())
- for k in t.primary_key:
- #if k.key not in t.c and k._label not in t.c:
- # this is a condition that was occurring when table reflection was doubling up primary keys
- # that were overridden in the Table constructor
- # raise exceptions.AssertionError("Column " + str(k) + " not located in the column set of table " + str(t))
- l.add(k)
-
+
if len(self.pks_by_table[self.mapped_table]) == 0:
raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
- self.primary_key = self.pks_by_table[self.mapped_table]
+
+ if self.inherits is not None and not self.concrete and not self.primary_key_argument:
+ self.primary_key = self.inherits.primary_key
+ self._get_clause = self.inherits._get_clause
+ else:
+ # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns
+ # into one column, where "equivalent" means that one column references the other via foreign key, or
+ # multiple columns that all reference a common parent column. it will also resolve the column
+ # against the "mapped_table" of this mapper.
+ equivalent_columns = self._get_equivalent_columns()
+
+ primary_key = sql.ColumnSet()
+
+ for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ c = self.mapped_table.corresponding_column(col, raiseerr=False)
+ if c is None:
+ for cc in equivalent_columns[col]:
+ c = self.mapped_table.corresponding_column(cc, raiseerr=False)
+ if c is not None:
+ break
+ else:
+ raise exceptions.ArgumentError("Cant resolve column " + str(col))
+
+ # this step attempts to resolve the column to an equivalent which is not
+ # a foreign key elsewhere. this helps with joined table inheritance
+ # so that PKs are expressed in terms of the base table which is always
+ # present in the initial select
+ # TODO: this is a little hacky right now, the "tried" list is to prevent
+ # endless loops between cyclical FKs, try to make this cleaner/work better/etc.,
+ # perhaps via topological sort (pick the leftmost item)
+ tried = util.Set()
+ while True:
+ if not len(c.foreign_keys) or c in tried:
+ break
+ for cc in c.foreign_keys:
+ cc = cc.column
+ c2 = self.mapped_table.corresponding_column(cc, raiseerr=False)
+ if c2 is not None:
+ c = c2
+ tried.add(c)
+ break
+ else:
+ break
+ primary_key.add(c)
+
+ if len(primary_key) == 0:
+ raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
+
+ self.primary_key = primary_key
+ self.__log("Identified primary key columns: " + str(primary_key))
+
+ _get_clause = sql.and_()
+ for primary_key in self.primary_key:
+ _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True))
+ self._get_clause = _get_clause
+
+ def _get_equivalent_columns(self):
+ """Create a map of all *equivalent* columns, based on
+ the determination of column pairs that are equated to
+ one another either by an established foreign key relationship
+ or by a joined-table inheritance join.
+
+ This is used to determine the minimal set of primary key
+ columns for the mapper, as well as when relating
+ columns to those of a polymorphic selectable (i.e. a UNION of
+ several mapped tables), as that selectable usually only contains
+ one column in its columns clause out of a group of several which
+ are equated to each other.
+
+ The resulting structure is a dictionary of columns mapped
+ to lists of equivalent columns, i.e.
+
+ {
+ tablea.col1:
+ set([tableb.col1, tablec.col1]),
+ tablea.col2:
+ set([tabled.col2])
+ }
+
+ this method is called repeatedly during the compilation process as
+ the resulting dictionary contains more equivalents as more inheriting
+ mappers are compiled. the repetition process may be open to some optimization.
+ """
+
+ result = {}
+ def visit_binary(binary):
+ if binary.operator == operator.eq:
+ if binary.left in result:
+ result[binary.left].add(binary.right)
+ else:
+ result[binary.left] = util.Set([binary.right])
+ if binary.right in result:
+ result[binary.right].add(binary.left)
+ else:
+ result[binary.right] = util.Set([binary.left])
+ vis = mapperutil.BinaryVisitor(visit_binary)
+
+ for mapper in self.base_mapper().polymorphic_iterator():
+ if mapper.inherit_condition is not None:
+ vis.traverse(mapper.inherit_condition)
+
+ for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ if not len(col.foreign_keys):
+ result.setdefault(col, util.Set()).add(col)
+ else:
+ for fk in col.foreign_keys:
+ result.setdefault(fk.column, util.Set()).add(col)
+
+ return result
def _compile_properties(self):
"""Inspect the properties dictionary sent to the Mapper's
@@ -548,7 +547,7 @@ class Mapper(object):
"""
# object attribute names mapped to MapperProperty objects
- self.__props = {}
+ self.__props = util.OrderedDict()
# table columns mapped to lists of MapperProperty objects
# using a list allows a single column to be defined as
@@ -574,7 +573,7 @@ class Mapper(object):
self.columns[column.key] = self.select_table.corresponding_column(column, keys_ok=True, raiseerr=True)
column_key = (self.column_prefix or '') + column.key
- prop = self.__props.get(column_key, None)
+ prop = self.__props.get(column.key, None)
if prop is None:
prop = ColumnProperty(column)
self.__props[column_key] = prop
@@ -582,7 +581,7 @@ class Mapper(object):
self.__log("adding ColumnProperty %s" % (column_key))
elif isinstance(prop, ColumnProperty):
if prop.parent is not self:
- prop = ColumnProperty(deferred=prop.deferred, group=prop.group, *prop.columns)
+ prop = prop.copy()
prop.set_parent(self)
self.__props[column_key] = prop
if column in self.primary_key and prop.columns[-1] in self.primary_key:
@@ -597,8 +596,7 @@ class Mapper(object):
# its a ColumnProperty - match the ultimate table columns
# back to the property
- proplist = self.columntoproperty.setdefault(column, [])
- proplist.append(prop)
+ self.columntoproperty.setdefault(column, []).append(prop)
def _initialize_properties(self):
@@ -660,66 +658,43 @@ class Mapper(object):
attribute_manager.reset_class_managed(self.class_)
oldinit = self.class_.__init__
- def init(self, *args, **kwargs):
- entity_name = kwargs.pop('_sa_entity_name', None)
- mapper = mapper_registry.get(ClassKey(self.__class__, entity_name))
- if mapper is not None:
- mapper = mapper.compile()
-
- # this gets the AttributeManager to do some pre-initialization,
- # in order to save on KeyErrors later on
- attribute_manager.init_attr(self)
-
- if kwargs.has_key('_sa_session'):
- session = kwargs.pop('_sa_session')
- else:
- # works for whatever mapper the class is associated with
- if mapper is not None:
- session = mapper.extension.get_session()
- if session is EXT_PASS:
- session = None
- else:
- session = None
- # if a session was found, either via _sa_session or via mapper extension,
- # and we have found a mapper, save() this instance to the session, and give it an associated entity_name.
- # otherwise, this instance will not have a session or mapper association until it is
- # save()d to some session.
- if session is not None and mapper is not None:
- self._entity_name = entity_name
- session._register_pending(self)
+ def init(instance, *args, **kwargs):
+ self.compile()
+ self.extension.init_instance(self, self.class_, instance, args, kwargs)
if oldinit is not None:
try:
- oldinit(self, *args, **kwargs)
+ oldinit(instance, *args, **kwargs)
except:
- def go():
- if session is not None:
- session.expunge(self)
- # convert expunge() exceptions to warnings
- util.warn_exception(go)
+ # call init_failed but suppress exceptions into warnings so that original __init__
+ # exception is raised
+ util.warn_exception(self.extension.init_failed, self, self.class_, instance, args, kwargs)
raise
-
- # override oldinit, insuring that its not already a Mapper-decorated init method
- if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'):
- init._sa_mapper_init = True
+
+ # override oldinit, ensuring that its not already a Mapper-decorated init method
+ if oldinit is None or not hasattr(oldinit, '_oldinit'):
try:
init.__name__ = oldinit.__name__
init.__doc__ = oldinit.__doc__
except:
# cant set __name__ in py 2.3 !
pass
+ init._oldinit = oldinit
self.class_.__init__ = init
+
_COMPILE_MUTEX.acquire()
try:
mapper_registry[self.class_key] = self
finally:
_COMPILE_MUTEX.release()
+
if self.entity_name is None:
self.class_.c = self.c
def base_mapper(self):
"""Return the ultimate base mapper in an inheritance chain."""
+ # TODO: calculate this at mapper setup time
if self.inherits is not None:
return self.inherits.base_mapper()
else:
@@ -759,43 +734,6 @@ class Mapper(object):
for m in mapper.polymorphic_iterator():
yield m
- def _get_inherited_column_equivalents(self):
- """Return a map of all *equivalent* columns, based on
- traversing the full set of inherit_conditions across all
- inheriting mappers and determining column pairs that are
- equated to one another.
-
- This is used when relating columns to those of a polymorphic
- selectable, as the selectable usually only contains one of two (or more)
- columns that are equated to one another.
-
- The resulting structure is a dictionary of columns mapped
- to lists of equivalent columns, i.e.
-
- {
- tablea.col1:
- [tableb.col1, tablec.col1],
- tablea.col2:
- [tabled.col2]
- }
- """
-
- result = {}
- def visit_binary(binary):
- if binary.operator == '=':
- if binary.left in result:
- result[binary.left].append(binary.right)
- else:
- result[binary.left] = [binary.right]
- if binary.right in result:
- result[binary.right].append(binary.left)
- else:
- result[binary.right] = [binary.left]
- vis = mapperutil.BinaryVisitor(visit_binary)
- for mapper in self.base_mapper().polymorphic_iterator():
- if mapper.inherit_condition is not None:
- vis.traverse(mapper.inherit_condition)
- return result
def add_properties(self, dict_of_properties):
"""Add the given dictionary of properties to this mapper,
@@ -947,7 +885,7 @@ class Mapper(object):
dictionary corresponding result-set ``ColumnElement``
instances to their values within a row.
"""
- return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name)
+ return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name)
def identity_key_from_primary_key(self, primary_key):
"""Return an identity-map key for use in storing/retrieving an
@@ -956,7 +894,7 @@ class Mapper(object):
primary_key
A list of values indicating the identifier.
"""
- return (self.class_, tuple(util.to_list(primary_key)), self.entity_name)
+ return (self._identity_class, tuple(util.to_list(primary_key)), self.entity_name)
def identity_key_from_instance(self, instance):
"""Return the identity key for the given instance, based on
@@ -972,7 +910,7 @@ class Mapper(object):
instance.
"""
- return [self.get_attr_by_column(instance, column) for column in self.pks_by_table[self.mapped_table]]
+ return [self.get_attr_by_column(instance, column) for column in self.primary_key]
def canload(self, instance):
"""return true if this mapper is capable of loading the given instance"""
@@ -981,21 +919,6 @@ class Mapper(object):
else:
return instance.__class__ is self.class_
- def instance_key(self, instance):
- """Deprecated. A synonym for `identity_key_from_instance`."""
-
- return self.identity_key_from_instance(instance)
-
- def identity_key(self, primary_key):
- """Deprecated. A synonym for `identity_key_from_primary_key`."""
-
- return self.identity_key_from_primary_key(primary_key)
-
- def identity(self, instance):
- """Deprecated. A synoynm for `primary_key_from_instance`."""
-
- return self.primary_key_from_instance(instance)
-
def _getpropbycolumn(self, column, raiseerror=True):
try:
prop = self.columntoproperty[column]
@@ -1017,13 +940,12 @@ class Mapper(object):
prop = self._getpropbycolumn(column, raiseerror)
if prop is None:
return NO_ATTRIBUTE
- #print "get column attribute '%s' from instance %s" % (column.key, mapperutil.instance_str(obj))
- return prop.getattr(obj)
+ return prop.getattr(obj, column)
def set_attr_by_column(self, obj, column, value):
"""Set the value of an instance attribute using a Column as the key."""
- self.columntoproperty[column][0].setattr(obj, value)
+ self.columntoproperty[column][0].setattr(obj, value, column)
def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
@@ -1048,10 +970,15 @@ class Mapper(object):
self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
return
- connection = uowtransaction.transaction.connection(self)
-
+ if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+ tups = [(obj, connection_callable(self, obj)) for obj in objects]
+ else:
+ connection = uowtransaction.transaction.connection(self)
+ tups = [(obj, connection) for obj in objects]
+
if not postupdate:
- for obj in objects:
+ for obj, connection in tups:
if not has_identity(obj):
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_insert(mapper, connection, obj)
@@ -1059,12 +986,12 @@ class Mapper(object):
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_update(mapper, connection, obj)
- for obj in objects:
+ for obj, connection in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
# and another instance with the same identity key already exists as persistent. convert to an
# UPDATE if so.
mapper = object_mapper(obj)
- instance_key = mapper.instance_key(obj)
+ instance_key = mapper.identity_key_from_instance(obj)
is_row_switch = not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map
if is_row_switch:
existing = uowtransaction.uow.identity_map[instance_key]
@@ -1090,11 +1017,11 @@ class Mapper(object):
insert = []
update = []
- for obj in objects:
+ for obj, connection in tups:
mapper = object_mapper(obj)
if table not in mapper.tables or not mapper._has_pks(table):
continue
- instance_key = mapper.instance_key(obj)
+ instance_key = mapper.identity_key_from_instance(obj)
if self.__should_log_debug:
self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key)))
@@ -1149,7 +1076,7 @@ class Mapper(object):
if history:
a = history.added_items()
if len(a):
- params[col.key] = a[0]
+ params[col.key] = prop.get_col_value(col, a[0])
hasdata = True
else:
# doing an INSERT, non primary key col ?
@@ -1168,17 +1095,17 @@ class Mapper(object):
if hasdata:
# if none of the attributes changed, dont even
# add the row to be updated.
- update.append((obj, params, mapper))
+ update.append((obj, params, mapper, connection))
else:
- insert.append((obj, params, mapper))
+ insert.append((obj, params, mapper, connection))
if len(update):
mapper = table_to_mapper[table]
clause = sql.and_()
for col in mapper.pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True))
+ clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
if mapper.version_id_col is not None:
- clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True))
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
statement = table.update(clause)
rows = 0
supports_sane_rowcount = True
@@ -1190,11 +1117,11 @@ class Mapper(object):
return 0
update.sort(comparator)
for rec in update:
- (obj, params, mapper) = rec
+ (obj, params, mapper, connection) = rec
c = connection.execute(statement, params)
mapper._postfetch(connection, table, obj, c, c.last_updated_params())
- updated_objects.add(obj)
+ updated_objects.add((obj, connection))
rows += c.rowcount
if c.supports_sane_rowcount() and rows != len(update):
@@ -1206,7 +1133,7 @@ class Mapper(object):
return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order)
insert.sort(comparator)
for rec in insert:
- (obj, params, mapper) = rec
+ (obj, params, mapper, connection) = rec
c = connection.execute(statement, params)
primary_key = c.last_inserted_ids()
if primary_key is not None:
@@ -1228,12 +1155,12 @@ class Mapper(object):
mapper._synchronizer.execute(obj, obj)
sync(mapper)
- inserted_objects.add(obj)
+ inserted_objects.add((obj, connection))
if not postupdate:
- for obj in inserted_objects:
+ for obj, connection in inserted_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_insert(mapper, connection, obj)
- for obj in updated_objects:
+ for obj, connection in updated_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_update(mapper, connection, obj)
@@ -1273,9 +1200,14 @@ class Mapper(object):
if self.__should_log_debug:
self.__log_debug("delete_obj() start")
- connection = uowtransaction.transaction.connection(self)
+ if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+ tups = [(obj, connection_callable(self, obj)) for obj in objects]
+ else:
+ connection = uowtransaction.transaction.connection(self)
+ tups = [(obj, connection) for obj in objects]
- for obj in objects:
+ for (obj, connection) in tups:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_delete(mapper, connection, obj)
@@ -1286,8 +1218,8 @@ class Mapper(object):
table_to_mapper.setdefault(t, mapper)
for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True):
- delete = []
- for obj in objects:
+ delete = {}
+ for (obj, connection) in tups:
mapper = object_mapper(obj)
if table not in mapper.tables or not mapper._has_pks(table):
continue
@@ -1296,13 +1228,13 @@ class Mapper(object):
if not hasattr(obj, '_instance_key'):
continue
else:
- delete.append(params)
+ delete.setdefault(connection, []).append(params)
for col in mapper.pks_by_table[table]:
params[col.key] = mapper.get_attr_by_column(obj, col)
if mapper.version_id_col is not None:
params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
- deleted_objects.add(obj)
- if len(delete):
+ deleted_objects.add((obj, connection))
+ for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
def comparator(a, b):
for col in mapper.pks_by_table[table]:
@@ -1310,18 +1242,18 @@ class Mapper(object):
if x != 0:
return x
return 0
- delete.sort(comparator)
+ del_objects.sort(comparator)
clause = sql.and_()
for col in mapper.pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True))
+ clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
if mapper.version_id_col is not None:
- clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type=mapper.version_id_col.type, unique=True))
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
statement = table.delete(clause)
- c = connection.execute(statement, delete)
- if c.supports_sane_rowcount() and c.rowcount != len(delete):
+ c = connection.execute(statement, del_objects)
+ if c.supports_sane_rowcount() and c.rowcount != len(del_objects):
raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete)))
- for obj in deleted_objects:
+ for obj, connection in deleted_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_delete(mapper, connection, obj)
@@ -1429,15 +1361,17 @@ class Mapper(object):
if discriminator is not None:
mapper = self.polymorphic_map[discriminator]
if mapper is not self:
+ if ('polymorphic_fetch', mapper) not in context.attributes:
+ context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
row = self.translate_row(mapper, row)
return mapper._instance(context, row, result=result, skip_polymorphic=True)
-
+
# look in main identity map. if its there, we dont do anything to it,
# including modifying any of its related items lists, as its already
# been exposed to being modified by the application.
- populate_existing = context.populate_existing or self.always_refresh
identitykey = self.identity_key_from_row(row)
+ populate_existing = context.populate_existing or self.always_refresh
if context.session.has_key(identitykey):
instance = context.session._get(identitykey)
if self.__should_log_debug:
@@ -1450,32 +1384,31 @@ class Mapper(object):
if not context.identity_map.has_key(identitykey):
context.identity_map[identitykey] = instance
isnew = True
- if extension.populate_instance(self, context, row, instance, identitykey, isnew) is EXT_PASS:
- self.populate_instance(context, instance, row, identitykey, isnew)
- if extension.append_result(self, context, row, instance, identitykey, result, isnew) is EXT_PASS:
+ if extension.populate_instance(self, context, row, instance, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS:
+ self.populate_instance(context, instance, row, **{'instancekey':identitykey, 'isnew':isnew})
+ if extension.append_result(self, context, row, instance, result, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS:
if result is not None:
result.append(instance)
return instance
else:
if self.__should_log_debug:
- self.__log_debug("_instance(): identity key %s not in session" % str(identitykey) + repr([mapperutil.instance_str(x) for x in context.session]))
+ self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
# look in result-local identitymap for it.
- exists = context.identity_map.has_key(identitykey)
+ exists = identitykey in context.identity_map
if not exists:
if self.allow_null_pks:
# check if *all* primary key cols in the result are None - this indicates
# an instance of the object is not present in the row.
- for col in self.pks_by_table[self.mapped_table]:
- if row[col] is not None:
+ for x in identitykey[1]:
+ if x is not None:
break
else:
return None
else:
# otherwise, check if *any* primary key cols in the result are None - this indicates
# an instance of the object is not present in the row.
- for col in self.pks_by_table[self.mapped_table]:
- if row[col] is None:
- return None
+ if None in identitykey[1]:
+ return None
# plugin point
instance = extension.create_instance(self, context, row, self.class_)
@@ -1493,9 +1426,10 @@ class Mapper(object):
# call further mapper properties on the row, to pull further
# instances from the row and possibly populate this item.
- if extension.populate_instance(self, context, row, instance, identitykey, isnew) is EXT_PASS:
- self.populate_instance(context, instance, row, identitykey, isnew)
- if extension.append_result(self, context, row, instance, identitykey, result, isnew) is EXT_PASS:
+ flags = {'instancekey':identitykey, 'isnew':isnew}
+ if extension.populate_instance(self, context, row, instance, **flags) is EXT_PASS:
+ self.populate_instance(context, instance, row, **flags)
+ if extension.append_result(self, context, row, instance, result, **flags) is EXT_PASS:
if result is not None:
result.append(instance)
return instance
@@ -1510,6 +1444,24 @@ class Mapper(object):
return obj
+ def _deferred_inheritance_condition(self, needs_tables):
+ cond = self.inherit_condition
+
+ param_names = []
+ def visit_binary(binary):
+ leftcol = binary.left
+ rightcol = binary.right
+ if leftcol is None or rightcol is None:
+ return
+ if leftcol.table not in needs_tables:
+ binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True)
+ param_names.append(leftcol)
+ elif rightcol not in needs_tables:
+ binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True)
+ param_names.append(rightcol)
+ cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
+ return cond, param_names
+
def translate_row(self, tomapper, row):
"""Translate the column keys of a row into a new or proxied
row that can be understood by another mapper.
@@ -1520,288 +1472,71 @@ class Mapper(object):
newrow = util.DictDecorator(row)
for c in tomapper.mapped_table.c:
- c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True)
- if row.has_key(c2):
+ c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=False)
+ if c2 and row.has_key(c2):
newrow[c] = row[c2]
return newrow
- def populate_instance(self, selectcontext, instance, row, identitykey, isnew):
- """populate an instance from a result row.
-
- This method iterates through the list of MapperProperty objects attached to this Mapper
- and calls each properties execute() method."""
- for prop in self.__props.values():
- prop.execute(selectcontext, instance, row, identitykey, isnew)
-
-Mapper.logger = logging.class_logger(Mapper)
-
-
-class MapperExtension(object):
- """Base implementation for an object that provides overriding
- behavior to various Mapper functions. For each method in
- MapperExtension, a result of EXT_PASS indicates the functionality
- is not overridden.
- """
-
- def get_session(self):
- """Retrieve a contextual Session instance with which to
- register a new object.
-
- Note: this is not called if a session is provided with the
- `__init__` params (i.e. `_sa_session`).
- """
-
- return EXT_PASS
-
- def load(self, query, *args, **kwargs):
- """Override the `load` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.load()`` if the value is anything other than EXT_PASS.
- """
-
- return EXT_PASS
-
- def get(self, query, *args, **kwargs):
- """Override the `get` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.get()`` if the value is anything other than EXT_PASS.
- """
-
- return EXT_PASS
-
- def get_by(self, query, *args, **kwargs):
- """Override the `get_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.get_by()`` if the value is anything other than
- EXT_PASS.
- """
-
- return EXT_PASS
-
- def select_by(self, query, *args, **kwargs):
- """Override the `select_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select_by()`` if the value is anything other than
- EXT_PASS.
- """
-
- return EXT_PASS
-
- def select(self, query, *args, **kwargs):
- """Override the `select` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select()`` if the value is anything other than
- EXT_PASS.
- """
-
- return EXT_PASS
-
-
- def translate_row(self, mapper, context, row):
- """Perform pre-processing on the given result row and return a
- new row instance.
-
- This is called as the very first step in the ``_instance()``
- method.
- """
-
- return EXT_PASS
-
- def create_instance(self, mapper, selectcontext, row, class_):
- """Receive a row when a new object instance is about to be
- created from that row.
-
- The method can choose to create the instance itself, or it can
- return None to indicate normal object creation should take
- place.
-
- mapper
- The mapper doing the operation
-
- selectcontext
- SelectionContext corresponding to the instances() call
-
- row
- The result row from the database
-
- class\_
- The class we are mapping.
- """
-
- return EXT_PASS
-
- def append_result(self, mapper, selectcontext, row, instance, identitykey, result, isnew):
- """Receive an object instance before that instance is appended
- to a result list.
-
- If this method returns EXT_PASS, result appending will proceed
- normally. if this method returns any other value or None,
- result appending will not proceed for this instance, giving
- this extension an opportunity to do the appending itself, if
- desired.
-
- mapper
- The mapper doing the operation.
-
- selectcontext
- SelectionContext corresponding to the instances() call.
-
- row
- The result row from the database.
-
- instance
- The object instance to be appended to the result.
-
- identitykey
- The identity key of the instance.
-
- result
- List to which results are being appended.
-
- isnew
- Indicates if this is the first time we have seen this object
- instance in the current result set. if you are selecting
- from a join, such as an eager load, you might see the same
- object instance many times in the same result set.
- """
-
- return EXT_PASS
-
- def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
- """Receive a newly-created instance before that instance has
- its attributes populated.
-
- The normal population of attributes is according to each
- attribute's corresponding MapperProperty (which includes
- column-based attributes as well as relationships to other
- classes). If this method returns EXT_PASS, instance
- population will proceed normally. If any other value or None
- is returned, instance population will not proceed, giving this
- extension an opportunity to populate the instance itself, if
- desired.
- """
-
- return EXT_PASS
-
- def before_insert(self, mapper, connection, instance):
- """Receive an object instance before that instance is INSERTed
- into its table.
-
- This is a good place to set up primary key values and such
- that aren't handled otherwise.
- """
-
- return EXT_PASS
-
- def before_update(self, mapper, connection, instance):
- """Receive an object instance before that instance is UPDATEed."""
-
- return EXT_PASS
-
- def after_update(self, mapper, connection, instance):
- """Receive an object instance after that instance is UPDATEed."""
-
- return EXT_PASS
-
- def after_insert(self, mapper, connection, instance):
- """Receive an object instance after that instance is INSERTed."""
-
- return EXT_PASS
-
- def before_delete(self, mapper, connection, instance):
- """Receive an object instance before that instance is DELETEed."""
-
- return EXT_PASS
-
- def after_delete(self, mapper, connection, instance):
- """Receive an object instance after that instance is DELETEed."""
-
- return EXT_PASS
-
-class _ExtensionCarrier(MapperExtension):
- def __init__(self):
- self.__elements = []
-
- def __iter__(self):
- return iter(self.__elements)
+ def populate_instance(self, selectcontext, instance, row, ispostselect=None, **flags):
+ """populate an instance from a result row."""
+
+ selectcontext.stack.push_mapper(self)
+ populators = selectcontext.attributes.get(('instance_populators', self, selectcontext.stack.snapshot(), ispostselect), None)
+ if populators is None:
+ populators = []
+ post_processors = []
+ for prop in self.__props.values():
+ (pop, post_proc) = prop.create_row_processor(selectcontext, self, row)
+ if pop is not None:
+ populators.append(pop)
+ if post_proc is not None:
+ post_processors.append(post_proc)
+
+ poly_select_loader = self._get_poly_select_loader(selectcontext, row)
+ if poly_select_loader is not None:
+ post_processors.append(poly_select_loader)
+
+ selectcontext.attributes[('instance_populators', self, selectcontext.stack.snapshot(), ispostselect)] = populators
+ selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors
+
+ for p in populators:
+ p(instance, row, ispostselect=ispostselect, **flags)
- def insert(self, extension):
- """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
-
- self.__elements.insert(0, extension)
-
- def append(self, extension):
- """Append a MapperExtension at the end of this ExtensionCarrier's list."""
-
- self.__elements.append(extension)
-
- def get_session(self, *args, **kwargs):
- return self._do('get_session', *args, **kwargs)
-
- def load(self, *args, **kwargs):
- return self._do('load', *args, **kwargs)
-
- def get(self, *args, **kwargs):
- return self._do('get', *args, **kwargs)
-
- def get_by(self, *args, **kwargs):
- return self._do('get_by', *args, **kwargs)
-
- def select_by(self, *args, **kwargs):
- return self._do('select_by', *args, **kwargs)
-
- def select(self, *args, **kwargs):
- return self._do('select', *args, **kwargs)
-
- def translate_row(self, *args, **kwargs):
- return self._do('translate_row', *args, **kwargs)
-
- def create_instance(self, *args, **kwargs):
- return self._do('create_instance', *args, **kwargs)
-
- def append_result(self, *args, **kwargs):
- return self._do('append_result', *args, **kwargs)
-
- def populate_instance(self, *args, **kwargs):
- return self._do('populate_instance', *args, **kwargs)
-
- def before_insert(self, *args, **kwargs):
- return self._do('before_insert', *args, **kwargs)
-
- def before_update(self, *args, **kwargs):
- return self._do('before_update', *args, **kwargs)
-
- def after_update(self, *args, **kwargs):
- return self._do('after_update', *args, **kwargs)
+ selectcontext.stack.pop()
+
+ if self.non_primary:
+ selectcontext.attributes[('populating_mapper', instance)] = self
+
+ def _post_instance(self, selectcontext, instance):
+ post_processors = selectcontext.attributes[('post_processors', self, None)]
+ for p in post_processors:
+ p(instance)
+
+ def _get_poly_select_loader(self, selectcontext, row):
+ # 'select' or 'union'+col not present
+ (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
+ if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred':
+ return
+
+ cond, param_names = self._deferred_inheritance_condition(needs_tables)
+ statement = sql.select(needs_tables, cond, use_labels=True)
+ def post_execute(instance, **flags):
+ self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance))
- def after_insert(self, *args, **kwargs):
- return self._do('after_insert', *args, **kwargs)
+ identitykey = self.identity_key_from_instance(instance)
- def before_delete(self, *args, **kwargs):
- return self._do('before_delete', *args, **kwargs)
+ params = {}
+ for c in param_names:
+ params[c.name] = self.get_attr_by_column(instance, c)
+ row = selectcontext.session.connection(self).execute(statement, **params).fetchone()
+ self.populate_instance(selectcontext, instance, row, **{'isnew':False, 'instancekey':identitykey, 'ispostselect':True})
- def after_delete(self, *args, **kwargs):
- return self._do('after_delete', *args, **kwargs)
+ return post_execute
+
+Mapper.logger = logging.class_logger(Mapper)
- def _do(self, funcname, *args, **kwargs):
- for elem in self.__elements:
- ret = getattr(elem, funcname)(*args, **kwargs)
- if ret is not EXT_PASS:
- return ret
- else:
- return EXT_PASS
-class ExtensionOption(MapperOption):
- def __init__(self, ext):
- self.ext = ext
- def process_query(self, query):
- query.extension.append(self.ext)
class ClassKey(object):
"""Key a class and an entity name to a mapper, via the mapper_registry."""
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index a00a35ab6..6ce9fd706 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -15,8 +15,11 @@ from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import sets, random
-from sqlalchemy.orm.interfaces import *
+import operator
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator
+from sqlalchemy.exceptions import ArgumentError
+
+__all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef']
class ColumnProperty(StrategizedProperty):
"""Describes an object attribute that corresponds to a table column."""
@@ -31,17 +34,27 @@ class ColumnProperty(StrategizedProperty):
self.columns = list(columns)
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
-
+ self.comparator = ColumnProperty.ColumnComparator(self)
+ # sanity check
+ for col in columns:
+ if not hasattr(col, 'name'):
+ if hasattr(col, 'label'):
+ raise ArgumentError('ColumnProperties must be named for the mapper to work with them. Try .label() to fix this')
+ raise ArgumentError('%r is not a valid candidate for ColumnProperty' % col)
+
def create_strategy(self):
if self.deferred:
return strategies.DeferredColumnLoader(self)
else:
return strategies.ColumnLoader(self)
-
- def getattr(self, object):
+
+ def copy(self):
+ return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
+
+ def getattr(self, object, column):
return getattr(object, self.key)
- def setattr(self, object, value):
+ def setattr(self, object, value, column):
setattr(object, self.key, value)
def get_history(self, obj, passive=False):
@@ -50,19 +63,69 @@ class ColumnProperty(StrategizedProperty):
def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
- def compare(self, value):
- return self.columns[0] == value
+ def get_col_value(self, column, value):
+ return value
+
+ class ColumnComparator(PropComparator):
+ def clause_element(self):
+ return self.prop.columns[0]
+
+ def operate(self, op, other):
+ return op(self.prop.columns[0], other)
+
+ def reverse_operate(self, op, other):
+ col = self.prop.columns[0]
+ return op(col._bind_param(other), col)
+
ColumnProperty.logger = logging.class_logger(ColumnProperty)
mapper.ColumnProperty = ColumnProperty
+class CompositeProperty(ColumnProperty):
+ """subclasses ColumnProperty to provide composite type support."""
+
+ def __init__(self, class_, *columns, **kwargs):
+ super(CompositeProperty, self).__init__(*columns, **kwargs)
+ self.composite_class = class_
+ self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator(self))
+
+ def copy(self):
+ return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
+
+ def getattr(self, object, column):
+ obj = getattr(object, self.key)
+ return self.get_col_value(column, obj)
+
+ def setattr(self, object, value, column):
+ obj = getattr(object, self.key, None)
+ if obj is None:
+ obj = self.composite_class(*[None for c in self.columns])
+ for a, b in zip(self.columns, value.__colset__()):
+ if a is column:
+ setattr(obj, b, value)
+
+ def get_col_value(self, column, value):
+ for a, b in zip(self.columns, value.__colset__()):
+ if a is column:
+ return b
+
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return sql.and_(*[a==None for a in self.prop.columns])
+ else:
+ return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())])
+
+ def __ne__(self, other):
+ return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())])
+
class PropertyLoader(StrategizedProperty):
"""Describes an object property that holds a single item or list
of items that correspond to a related database table.
"""
- def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True):
+ def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None):
self.uselist = uselist
self.argument = argument
self.entity_name = entity_name
@@ -80,7 +143,9 @@ class PropertyLoader(StrategizedProperty):
self.remote_side = util.to_set(remote_side)
self.enable_typechecks = enable_typechecks
self._parent_join_cache = {}
-
+ self.comparator = PropertyLoader.Comparator(self)
+ self.join_depth = join_depth
+
if cascade is not None:
self.cascade = mapperutil.CascadeOptions(cascade)
else:
@@ -91,7 +156,7 @@ class PropertyLoader(StrategizedProperty):
self.association = association
self.order_by = order_by
- self.attributeext = attributeext
+ self.attributeext=attributeext
if isinstance(backref, str):
# propigate explicitly sent primary/secondary join conditions to the BackRef object if
# just a string was sent
@@ -104,9 +169,96 @@ class PropertyLoader(StrategizedProperty):
self.backref = backref
self.is_backref = is_backref
- def compare(self, value):
- return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
-
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return ~sql.exists([1], self.prop.primaryjoin)
+ elif self.prop.uselist:
+ if not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.")
+ else:
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ clauses = []
+ for o in other:
+ clauses.append(
+ sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))]))
+ )
+ return sql.and_(*clauses)
+ else:
+ return self.prop._optimized_compare(other)
+
+ def any(self, criterion=None, **kwargs):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ for k in kwargs:
+ crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+ return sql.exists([1], j & criterion)
+
+ def has(self, criterion=None, **kwargs):
+ if self.prop.uselist:
+ raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ for k in kwargs:
+ crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+ return sql.exists([1], j & criterion)
+
+ def contains(self, other):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==")
+ clause = self.prop._optimized_compare(other)
+
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+
+ clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+ return clause
+
+ def __ne__(self, other):
+ if self.prop.uselist and not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+
+ def compare(self, op, value, value_is_parent=False):
+ if op == operator.eq:
+ if value is None:
+ return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
+ else:
+ return self._optimized_compare(value, value_is_parent=value_is_parent)
+ else:
+ return op(self.comparator, value)
+
+ def _optimized_compare(self, value, value_is_parent=False):
+ # optimized operation for ==, uses a lazy clause.
+ (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
+ bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+ class Visitor(sql.ClauseVisitor):
+ def visit_bindparam(s, bindparam):
+ mapper = value_is_parent and self.parent or self.mapper
+ bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
+ Visitor().traverse(criterion)
+ return criterion
+
private = property(lambda s:s.cascade.delete_orphan)
def create_strategy(self):
@@ -127,12 +279,13 @@ class PropertyLoader(StrategizedProperty):
if childlist is None:
return
if self.uselist:
- # sets a blank list according to the correct list class
- dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+ # sets a blank collection according to the correct list class
+ dest_list = sessionlib.attribute_manager.init_collection(dest, self.key)
for current in list(childlist):
obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive)
if obj is not None:
- dest_list.append(obj)
+ #dest_list.append_without_event(obj)
+ dest_list.append_with_event(obj)
else:
current = list(childlist)[0]
if current is not None:
@@ -267,7 +420,7 @@ class PropertyLoader(StrategizedProperty):
if len(self.foreign_keys):
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
if binary.left in self.foreign_keys:
self._opposite_side.add(binary.right)
@@ -280,7 +433,7 @@ class PropertyLoader(StrategizedProperty):
self.foreign_keys = util.Set()
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
# this check is for when the user put the "view_only" flag on and has tables that have nothing
@@ -362,16 +515,13 @@ class PropertyLoader(StrategizedProperty):
"argument." % (str(self)))
def _determine_remote_side(self):
- if len(self.remote_side):
- return
- self.remote_side = util.Set()
+ if not len(self.remote_side):
+ if self.direction is sync.MANYTOONE:
+ self.remote_side = util.Set(self._opposite_side)
+ elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
+ self.remote_side = util.Set(self.foreign_keys)
- if self.direction is sync.MANYTOONE:
- for c in self._opposite_side:
- self.remote_side.add(c)
- elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
- for c in self.foreign_keys:
- self.remote_side.add(c)
+ self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side)
def _create_polymorphic_joins(self):
# get ready to create "polymorphic" primary/secondary join clauses.
@@ -383,27 +533,26 @@ class PropertyLoader(StrategizedProperty):
# as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
# first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
# several "equivalent" columns (such as parent/child fk cols) into just one column.
- target_equivalents = self.mapper._get_inherited_column_equivalents()
+ target_equivalents = self.mapper._get_equivalent_columns()
+
# if the target mapper loads polymorphically, adapt the clauses to the target's selectable
if self.loads_polymorphic:
if self.secondaryjoin:
- self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container()
- sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin)
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
+ self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
+ self.polymorphic_primaryjoin = self.primaryjoin
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
self.polymorphic_secondaryjoin = None
# load "polymorphic" versions of the columns present in "remote_side" - this is
# important for lazy-clause generation which goes off the polymorphic target selectable
for c in list(self.remote_side):
- if self.secondary and c in self.secondary.columns:
+ if self.secondary and self.secondary.columns.contains_column(c):
continue
- for equiv in [c] + (c in target_equivalents and target_equivalents[c] or []):
+ for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []):
corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False)
if corr:
self.remote_side.add(corr)
@@ -411,8 +560,8 @@ class PropertyLoader(StrategizedProperty):
else:
raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table))
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
- self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None
+ self.polymorphic_primaryjoin = self.primaryjoin
+ self.polymorphic_secondaryjoin = self.secondaryjoin
def _post_init(self):
if logging.is_info_enabled(self.logger):
@@ -450,22 +599,20 @@ class PropertyLoader(StrategizedProperty):
def _is_self_referential(self):
return self.parent.mapped_table is self.target or self.parent.select_table is self.target
- def get_join(self, parent, primary=True, secondary=True):
+ def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True):
try:
- return self._parent_join_cache[(parent, primary, secondary)]
+ return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)]
except KeyError:
- parent_equivalents = parent._get_inherited_column_equivalents()
- primaryjoin = self.polymorphic_primaryjoin.copy_container()
- if self.secondaryjoin is not None:
- secondaryjoin = self.polymorphic_secondaryjoin.copy_container()
- else:
- secondaryjoin = None
- if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
- elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
- elif self.secondaryjoin:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+ parent_equivalents = parent._get_equivalent_columns()
+ secondaryjoin = self.polymorphic_secondaryjoin
+ if polymorphic_parent:
+ # adapt the "parent" side of our join condition to the "polymorphic" select of the parent
+ if self.direction is sync.ONETOMANY:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+ elif self.direction is sync.MANYTOONE:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+ elif self.secondaryjoin:
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
if secondaryjoin is not None:
if secondary and not primary:
@@ -476,7 +623,7 @@ class PropertyLoader(StrategizedProperty):
j = primaryjoin
else:
j = primaryjoin
- self._parent_join_cache[(parent, primary, secondary)] = j
+ self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j
return j
def register_dependencies(self, uowcommit):
@@ -501,7 +648,7 @@ class BackRef(object):
# try to set a LazyLoader on our mapper referencing the parent mapper
mapper = prop.mapper.primary_mapper()
- if not mapper.props.has_key(self.key):
+ if not mapper.get_property(self.key, raiseerr=False) is not None:
pj = self.kwargs.pop('primaryjoin', None)
sj = self.kwargs.pop('secondaryjoin', None)
# the backref property is set on the primary mapper
@@ -512,26 +659,26 @@ class BackRef(object):
backref=prop.key, is_backref=True,
**self.kwargs)
mapper._compile_property(self.key, relation);
- elif not isinstance(mapper.props[self.key], PropertyLoader):
+ elif not isinstance(mapper.get_property(self.key), PropertyLoader):
raise exceptions.ArgumentError(
"Can't create backref '%s' on mapper '%s'; an incompatible "
"property of that name already exists" % (self.key, str(mapper)))
else:
# else set one of us as the "backreference"
parent = prop.parent.primary_mapper()
- if parent.class_ is not mapper.props[self.key]._get_target_class():
+ if parent.class_ is not mapper.get_property(self.key)._get_target_class():
raise exceptions.ArgumentError(
"Backrefs do not match: backref '%s' expects to connect to %s, "
"but found a backref already connected to %s" %
- (self.key, str(parent.class_), str(mapper.props[self.key].mapper.class_)))
- if not mapper.props[self.key].is_backref:
+ (self.key, str(parent.class_), str(mapper.get_property(self.key).mapper.class_)))
+ if not mapper.get_property(self.key).is_backref:
prop.is_backref=True
if not prop.viewonly:
prop._dependency_processor.is_backref=True
# reverse_property used by dependencies.ManyToManyDP to check
# association table operations
- prop.reverse_property = mapper.props[self.key]
- mapper.props[self.key].reverse_property = prop
+ prop.reverse_property = mapper.get_property(self.key)
+ mapper.get_property(self.key).reverse_property = prop
def get_extension(self):
"""Return an attribute extension to use with this backreference."""
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index d51fd75c3..284653b5c 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -4,89 +4,49 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
-from sqlalchemy.orm import mapper, class_mapper, object_mapper
-from sqlalchemy.orm.interfaces import OperationContext, SynonymProperty
-import random
+from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy.orm import mapper, object_mapper
+from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.interfaces import OperationContext, LoaderStack
+import operator
__all__ = ['Query', 'QueryContext', 'SelectionContext']
class Query(object):
- """Encapsulates the object-fetching operations provided by Mappers.
+ """Encapsulates the object-fetching operations provided by Mappers."""
- Note that this particular version of Query contains the 0.3 API as well as most of the
- 0.4 API for forwards compatibility. A large part of the API here is deprecated (but still present)
- in the 0.4 series.
- """
-
- def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, with_options=None, extension=None, **kwargs):
+ def __init__(self, class_or_mapper, session=None, entity_name=None):
if isinstance(class_or_mapper, type):
self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
else:
self.mapper = class_or_mapper.compile()
- self.with_options = with_options or []
self.select_mapper = self.mapper.get_select_mapper().compile()
- self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
- self.lockmode = lockmode
- self.extension = mapper._ExtensionCarrier()
- if extension is not None:
- self.extension.append(extension)
- self.extension.append(self.mapper.extension)
- self.is_polymorphic = self.mapper is not self.select_mapper
+
self._session = session
- if not hasattr(self.mapper, '_get_clause'):
- _get_clause = sql.and_()
- for primary_key in self.primary_key_columns:
- _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
- self.mapper._get_clause = _get_clause
+ self._with_options = []
+ self._lockmode = None
+ self._extension = self.mapper.extension.copy()
self._entities = []
- self._get_clause = self.mapper._get_clause
-
- self._order_by = kwargs.pop('order_by', False)
- self._group_by = kwargs.pop('group_by', False)
- self._distinct = kwargs.pop('distinct', False)
- self._offset = kwargs.pop('offset', None)
- self._limit = kwargs.pop('limit', None)
- self._criterion = None
+ self._order_by = False
+ self._group_by = False
+ self._distinct = False
+ self._offset = None
+ self._limit = None
+ self._statement = None
self._params = {}
- self._col = None
- self._func = None
+ self._criterion = None
+ self._column_aggregate = None
self._joinpoint = self.mapper
+ self._aliases = None
+ self._alias_ids = {}
self._from_obj = [self.table]
- self._statement = None
-
- for opt in util.flatten_iterator(self.with_options):
- opt.process_query(self)
+ self._populate_existing = False
+ self._version_check = False
def _clone(self):
- # yes, a little embarassing here.
- # go look at 0.4 for the simple version.
q = Query.__new__(Query)
- q.mapper = self.mapper
- q.select_mapper = self.select_mapper
- q._order_by = self._order_by
- q._distinct = self._distinct
- q._entities = list(self._entities)
- q.always_refresh = self.always_refresh
- q.with_options = list(self.with_options)
- q._session = self.session
- q.is_polymorphic = self.is_polymorphic
- q.lockmode = self.lockmode
- q.extension = mapper._ExtensionCarrier()
- for ext in self.extension:
- q.extension.append(ext)
- q._offset = self._offset
- q._limit = self._limit
- q._params = self._params
- q._group_by = self._group_by
- q._get_clause = self._get_clause
- q._from_obj = list(self._from_obj)
- q._joinpoint = self._joinpoint
- q._criterion = self._criterion
- q._statement = self._statement
- q._col = self._col
- q._func = self._func
+ q.__dict__ = self.__dict__.copy()
return q
def _get_session(self):
@@ -96,7 +56,7 @@ class Query(object):
return self._session
table = property(lambda s:s.select_mapper.mapped_table)
- primary_key_columns = property(lambda s:s.select_mapper.pks_by_table[s.select_mapper.mapped_table])
+ primary_key_columns = property(lambda s:s.select_mapper.primary_key)
session = property(_get_session)
def get(self, ident, **kwargs):
@@ -108,13 +68,20 @@ class Query(object):
columns.
"""
- ret = self.extension.get(self, ident, **kwargs)
+ ret = self._extension.get(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
- key = self.mapper.identity_key(ident)
+
+ # convert composite types to individual args
+ # TODO: account for the order of columns in the
+ # ColumnProperty it corresponds to
+ if hasattr(ident, '__colset__'):
+ ident = ident.__colset__()
+
+ key = self.mapper.identity_key_from_primary_key(ident)
return self._get(key, ident, **kwargs)
- def load(self, ident, **kwargs):
+ def load(self, ident, raiseerr=True, **kwargs):
"""Return an instance of the object based on the given
identifier.
@@ -125,304 +92,14 @@ class Query(object):
columns.
"""
- ret = self.extension.load(self, ident, **kwargs)
+ ret = self._extension.load(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
- key = self.mapper.identity_key(ident)
+ key = self.mapper.identity_key_from_primary_key(ident)
instance = self._get(key, ident, reload=True, **kwargs)
- if instance is None:
+ if instance is None and raiseerr:
raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
return instance
-
- def get_by(self, *args, **params):
- """Like ``select_by()``, but only return the first
- as a scalar, or None if no object found.
- Synonymous with ``selectfirst_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- ret = self.extension.get_by(self, *args, **params)
- if ret is not mapper.EXT_PASS:
- return ret
- x = self.select_whereclause(self.join_by(*args, **params), limit=1)
- if x:
- return x[0]
- else:
- return None
-
- def select_by(self, *args, **params):
- """Return an array of object instances based on the given
- clauses and key/value criterion.
-
- \*args
- a list of zero or more ``ClauseElements`` which will be
- connected by ``AND`` operators.
-
- \**params
- a set of zero or more key/value parameters which
- are converted into ``ClauseElements``. the keys are mapped to
- property or column names mapped by this mapper's Table, and
- the values are coerced into a ``WHERE`` clause separated by
- ``AND`` operators. If the local property/column names dont
- contain the key, a search will be performed against this
- mapper's immediate list of relations as well, forming the
- appropriate join conditions if a matching property is located.
-
- if the located property is a column-based property, the comparison
- value should be a scalar with an appropriate type. If the
- property is a relationship-bound property, the comparison value
- should be an instance of the related class.
-
- E.g.::
-
- result = usermapper.select_by(user_name = 'fred')
-
- this method is deprecated in 0.4.
- """
-
- ret = self.extension.select_by(self, *args, **params)
- if ret is not mapper.EXT_PASS:
- return ret
- return self.select_whereclause(self.join_by(*args, **params))
-
- def join_by(self, *args, **params):
- """Return a ``ClauseElement`` representing the ``WHERE``
- clause that would normally be sent to ``select_whereclause()``
- by ``select_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- return self._join_by(args, params)
-
-
- def join_to(self, key):
- """Given the key name of a property, will recursively descend
- through all child properties from this Query's mapper to
- locate the property, and will return a ClauseElement
- representing a join from this Query's mapper to the endmost
- mapper.
-
- this method is deprecated in 0.4.
- """
-
- [keys, p] = self._locate_prop(key)
- return self.join_via(keys)
-
- def join_via(self, keys):
- """Given a list of keys that represents a path from this
- Query's mapper to a related mapper based on names of relations
- from one mapper to the next, return a ClauseElement
- representing a join from this Query's mapper to the endmost
- mapper.
-
- this method is deprecated in 0.4.
- """
-
- mapper = self.mapper
- clause = None
- for key in keys:
- prop = mapper.get_property(key, resolve_synonyms=True)
- if clause is None:
- clause = prop.get_join(mapper)
- else:
- clause &= prop.get_join(mapper)
- mapper = prop.mapper
-
- return clause
-
- def selectfirst_by(self, *args, **params):
- """Like ``select_by()``, but only return the first
- as a scalar, or None if no object found.
- Synonymous with ``get_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- return self.get_by(*args, **params)
-
- def selectone_by(self, *args, **params):
- """Like ``selectfirst_by()``, but throws an error if not
- exactly one result was returned.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- ret = self.select_whereclause(self.join_by(*args, **params), limit=2)
- if len(ret) == 1:
- return ret[0]
- elif len(ret) == 0:
- raise exceptions.InvalidRequestError('No rows returned for selectone_by')
- else:
- raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by')
-
- def count_by(self, *args, **params):
- """Return the count of instances based on the given clauses
- and key/value criterion.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
-
- this method is deprecated in 0.4.
- """
-
- return self.count(self.join_by(*args, **params))
-
- def selectfirst(self, arg=None, **kwargs):
- """Query for a single instance using the given criterion.
-
- Arguments are the same as ``select()``. In the case that
- the given criterion represents ``WHERE`` criterion only,
- LIMIT 1 is applied to the fully generated statement.
-
- this method is deprecated in 0.4.
- """
-
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- ret = self.select_statement(arg, **kwargs)
- else:
- kwargs['limit'] = 1
- ret = self.select_whereclause(whereclause=arg, **kwargs)
- if ret:
- return ret[0]
- else:
- return None
-
- def selectone(self, arg=None, **kwargs):
- """Query for a single instance using the given criterion.
-
- Unlike ``selectfirst``, this method asserts that only one
- row exists. In the case that the given criterion represents
- ``WHERE`` criterion only, LIMIT 2 is applied to the fully
- generated statement.
-
- this method is deprecated in 0.4.
- """
-
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- ret = self.select_statement(arg, **kwargs)
- else:
- kwargs['limit'] = 2
- ret = self.select_whereclause(whereclause=arg, **kwargs)
- if len(ret) == 1:
- return ret[0]
- elif len(ret) == 0:
- raise exceptions.InvalidRequestError('No rows returned for selectone_by')
- else:
- raise exceptions.InvalidRequestError('Multiple rows returned for selectone')
-
- def select(self, arg=None, **kwargs):
- """Select instances of the object from the database.
-
- `arg` can be any ClauseElement, which will form the criterion
- with which to load the objects.
-
- For more advanced usage, arg can also be a Select statement
- object, which will be executed and its resulting rowset used
- to build new object instances.
-
- In this case, the developer must ensure that an adequate set
- of columns exists in the rowset with which to build new object
- instances.
-
- this method is deprecated in 0.4.
- """
-
- ret = self.extension.select(self, arg=arg, **kwargs)
- if ret is not mapper.EXT_PASS:
- return ret
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- return self.select_statement(arg, **kwargs)
- else:
- return self.select_whereclause(whereclause=arg, **kwargs)
-
- def select_whereclause(self, whereclause=None, params=None, **kwargs):
- """Given a ``WHERE`` criterion, create a ``SELECT`` statement,
- execute and return the resulting instances.
-
- this method is deprecated in 0.4.
- """
- statement = self.compile(whereclause, **kwargs)
- return self._select_statement(statement, params=params)
-
- def count(self, whereclause=None, params=None, **kwargs):
- """Given a ``WHERE`` criterion, create a ``SELECT COUNT``
- statement, execute and return the resulting count value.
-
- the additional arguments to this method are is deprecated in 0.4.
-
- """
- if self._criterion:
- if whereclause is not None:
- whereclause = sql.and_(self._criterion, whereclause)
- else:
- whereclause = self._criterion
- from_obj = kwargs.pop('from_obj', self._from_obj)
- kwargs.setdefault('distinct', self._distinct)
-
- alltables = []
- for l in [sql_util.TableFinder(x) for x in from_obj]:
- alltables += l
-
- if self.table not in alltables:
- from_obj.append(self.table)
- if self._nestable(**kwargs):
- s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count()
- else:
- primary_key = self.primary_key_columns
- s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs)
- return self.session.scalar(self.mapper, s, params=params)
-
- def select_statement(self, statement, **params):
- """Given a ``ClauseElement``-based statement, execute and
- return the resulting instances.
-
- this method is deprecated in 0.4.
- """
-
- return self._select_statement(statement, params=params)
-
- def select_text(self, text, **params):
- """Given a literal string-based statement, execute and return
- the resulting instances.
-
- this method is deprecated in 0.4. use from_statement() instead.
- """
-
- t = sql.text(text)
- return self.execute(t, params=params)
-
- def _with_lazy_criterion(cls, instance, prop, reverse=False):
- """extract query criterion from a LazyLoader strategy given a Mapper,
- source persisted/detached instance and PropertyLoader.
-
- """
-
- from sqlalchemy.orm import strategies
- (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(prop, reverse_direction=reverse)
- bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
-
- class Visitor(sql.ClauseVisitor):
- def visit_bindparam(self, bindparam):
- mapper = reverse and prop.mapper or prop.parent
- bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
- Visitor().traverse(criterion)
- return criterion
- _with_lazy_criterion = classmethod(_with_lazy_criterion)
-
def query_from_parent(cls, instance, property, **kwargs):
"""return a newly constructed Query object, with criterion corresponding to
@@ -445,9 +122,25 @@ class Query(object):
mapper = object_mapper(instance)
prop = mapper.get_property(property, resolve_synonyms=True)
target = prop.mapper
- criterion = cls._with_lazy_criterion(instance, prop)
+ criterion = prop.compare(operator.eq, instance, value_is_parent=True)
return Query(target, **kwargs).filter(criterion)
query_from_parent = classmethod(query_from_parent)
+
+ def populate_existing(self):
+ """return a Query that will refresh all instances loaded.
+
+ this includes all entities accessed from the database, including
+ secondary entities, eagerly-loaded collection items.
+
+ All changes present on entities which are already present in the session will
+ be reset and the entities will all be marked "clean".
+
+ This is essentially the en-masse version of load().
+ """
+
+ q = self._clone()
+ q._populate_existing = True
+ return q
def with_parent(self, instance, property=None):
"""add a join criterion corresponding to a relationship to the given parent instance.
@@ -474,9 +167,9 @@ class Query(object):
raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
else:
prop = mapper.get_property(property, resolve_synonyms=True)
- return self.filter(Query._with_lazy_criterion(instance, prop))
+ return self.filter(prop.compare(operator.eq, instance, value_is_parent=True))
- def add_entity(self, entity):
+ def add_entity(self, entity, alias=None, id=None):
"""add a mapped entity to the list of result columns to be returned.
This will have the effect of all result-returning methods returning a tuple
@@ -492,12 +185,25 @@ class Query(object):
entity
a class or mapper which will be added to the results.
+ alias
+ a sqlalchemy.sql.Alias object which will be used to select rows. this
+ will match the usage of the given Alias in filter(), order_by(), etc. expressions
+
+ id
+ a string ID matching that given to query.join() or query.outerjoin(); rows will be
+ selected from the aliased join created via those methods.
"""
q = self._clone()
- q._entities.append(entity)
+
+ if isinstance(entity, type):
+ entity = mapper.class_mapper(entity)
+ if alias is not None:
+ alias = mapperutil.AliasedClauses(entity.mapped_table, alias=alias)
+
+ q._entities = q._entities + [(entity, alias, id)]
return q
- def add_column(self, column):
+ def add_column(self, column, id=None):
"""add a SQL ColumnElement to the list of result columns to be returned.
This will have the effect of all result-returning methods returning a tuple
@@ -517,51 +223,56 @@ class Query(object):
"""
q = self._clone()
-
+
# alias non-labeled column elements.
- # TODO: make the generation deterministic
if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
- column = column.label("anon_" + hex(random.randint(0, 65535))[2:])
-
- q._entities.append(column)
+ column = column.label(None)
+
+ q._entities = q._entities + [(column, None, id)]
return q
- def options(self, *args, **kwargs):
+ def options(self, *args):
"""Return a new Query object, applying the given list of
MapperOptions.
"""
+
q = self._clone()
- for opt in util.flatten_iterator(args):
- q.with_options.append(opt)
+ opts = [o for o in util.flatten_iterator(args)]
+ q._with_options = q._with_options + opts
+ for opt in opts:
opt.process_query(q)
return q
def with_lockmode(self, mode):
"""Return a new Query object with the specified locking mode."""
q = self._clone()
- q.lockmode = mode
+ q._lockmode = mode
return q
def params(self, **kwargs):
"""add values for bind parameters which may have been specified in filter()."""
-
+
q = self._clone()
q._params = q._params.copy()
q._params.update(kwargs)
return q
-
+
def filter(self, criterion):
"""apply the given filtering criterion to the query and return the newly resulting ``Query``
the criterion is any sql.ClauseElement applicable to the WHERE clause of a select.
"""
-
+
if isinstance(criterion, basestring):
criterion = sql.text(criterion)
-
+
if criterion is not None and not isinstance(criterion, sql.ClauseElement):
raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
-
+
+
+ if self._aliases is not None:
+ criterion = self._aliases.adapt_clause(criterion)
+
q = self._clone()
if q._criterion is not None:
q._criterion = q._criterion & criterion
@@ -569,23 +280,36 @@ class Query(object):
q._criterion = criterion
return q
- def filter_by(self, *args, **kwargs):
- """apply the given filtering criterion to the query and return the newly resulting ``Query``
+ def filter_by(self, **kwargs):
+ """apply the given filtering criterion to the query and return the newly resulting ``Query``."""
- The criterion is constructed in the same way as the
- ``select_by()`` method.
- """
- return self.filter(self._join_by(args, kwargs, start=self._joinpoint))
+ #import properties
+
+ alias = None
+ join = None
+ clause = None
+ joinpoint = self._joinpoint
- def _join_to(self, prop, outerjoin=False, start=None):
- if start is None:
- start = self._joinpoint
+ for key, value in kwargs.iteritems():
+ prop = joinpoint.get_property(key, resolve_synonyms=True)
+ c = prop.compare(operator.eq, value)
- if isinstance(prop, list):
- keys = prop
+ if alias is not None:
+ sql_util.ClauseAdapter(alias).traverse(c)
+ if clause is None:
+ clause = c
+ else:
+ clause &= c
+
+ if join is not None:
+ return self.select_from(join).filter(clause)
else:
- [keys,p] = self._locate_prop(prop, start=start)
+ return self.filter(clause)
+ def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
+ if start is None:
+ start = self._joinpoint
+
clause = self._from_obj[-1]
currenttables = [clause]
@@ -594,101 +318,56 @@ class Query(object):
currenttables.append(join.left)
currenttables.append(join.right)
FindJoinedTables().traverse(clause)
-
+
mapper = start
- for key in keys:
+ alias = self._aliases
+ for key in util.to_list(keys):
prop = mapper.get_property(key, resolve_synonyms=True)
- if prop._is_self_referential():
- raise exceptions.InvalidRequestError("Self-referential query on '%s' property must be constructed manually using an Alias object for the related table." % str(prop))
- # dont re-join to a table already in our from objects
- if prop.select_table not in currenttables:
- if outerjoin:
- if prop.secondary:
- clause = clause.outerjoin(prop.secondary, prop.get_join(mapper, primary=True, secondary=False))
- clause = clause.outerjoin(prop.select_table, prop.get_join(mapper, primary=False))
+ if prop._is_self_referential() and not create_aliases:
+ raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop))
+
+ if prop.select_table not in currenttables or create_aliases:
+ if prop.secondary:
+ if create_aliases:
+ alias = mapperutil.PropertyAliasedClauses(prop,
+ prop.get_join(mapper, primary=True, secondary=False),
+ prop.get_join(mapper, primary=False, secondary=True),
+ alias
+ )
+ clause = clause.join(alias.secondary, alias.primaryjoin, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
else:
- clause = clause.outerjoin(prop.select_table, prop.get_join(mapper))
+ clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False), isouter=outerjoin)
+ clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False), isouter=outerjoin)
else:
- if prop.secondary:
- clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False))
- clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False))
+ if create_aliases:
+ alias = mapperutil.PropertyAliasedClauses(prop,
+ prop.get_join(mapper, primary=True, secondary=False),
+ None,
+ alias
+ )
+ clause = clause.join(alias.alias, alias.primaryjoin, isouter=outerjoin)
else:
- clause = clause.join(prop.select_table, prop.get_join(mapper))
- elif prop.secondary is not None and prop.secondary not in currenttables:
+ clause = clause.join(prop.select_table, prop.get_join(mapper), isouter=outerjoin)
+ elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables:
# TODO: this check is not strong enough for different paths to the same endpoint which
# does not use secondary tables
- raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use explicit `Alias` objects." % prop.key)
+ raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key)
mapper = prop.mapper
- return (clause, mapper)
-
- def _join_by(self, args, params, start=None):
- """Return a ``ClauseElement`` representing the ``WHERE``
- clause that would normally be sent to ``select_whereclause()``
- by ``select_by()``.
-
- The criterion is constructed in the same way as the
- ``select_by()`` method.
- """
- import properties
-
- clause = None
- for arg in args:
- if clause is None:
- clause = arg
- else:
- clause &= arg
-
- for key, value in params.iteritems():
- (keys, prop) = self._locate_prop(key, start=start)
- if isinstance(prop, properties.PropertyLoader):
- c = self._with_lazy_criterion(value, prop, True) & self.join_via(keys[:-1])
- else:
- c = prop.compare(value) & self.join_via(keys)
- if clause is None:
- clause = c
- else:
- clause &= c
- return clause
-
- def _locate_prop(self, key, start=None):
- import properties
- keys = []
- seen = util.Set()
- def search_for_prop(mapper_):
- if mapper_ in seen:
- return None
- seen.add(mapper_)
- if mapper_.props.has_key(key):
- prop = mapper_.get_property(key, resolve_synonyms=True)
- if isinstance(prop, properties.PropertyLoader):
- keys.insert(0, prop.key)
- return prop
- else:
- for prop in mapper_.iterate_properties:
- if not isinstance(prop, properties.PropertyLoader):
- continue
- x = search_for_prop(prop.mapper)
- if x:
- keys.insert(0, prop.key)
- return x
- else:
- return None
- p = search_for_prop(start or self.mapper)
- if p is None:
- raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
- return [keys, p]
+ if create_aliases:
+ return (clause, mapper, alias)
+ else:
+ return (clause, mapper, None)
def _generative_col_aggregate(self, col, func):
"""apply the given aggregate function to the query and return the newly
resulting ``Query``.
"""
- if self._col is not None or self._func is not None:
+ if self._column_aggregate is not None:
raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
q = self._clone()
- q._col = col
- q._func = func
+ q._column_aggregate = (col, func)
return q
def apply_min(self, col):
@@ -721,13 +400,13 @@ class Query(object):
For performance, only use subselect if `order_by` attribute is set.
"""
- ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj}
+ ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj}
if self._order_by is not False:
s1 = sql.select([col], self._criterion, **ops).alias('u')
- return sql.select([func(s1.corresponding_column(col))]).scalar()
+ return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar()
else:
- return sql.select([func(col)], self._criterion, **ops).scalar()
+ return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar()
def min(self, col):
"""Execute the SQL ``min()`` function against the given column."""
@@ -756,7 +435,7 @@ class Query(object):
if q._order_by is False:
q._order_by = util.to_list(criterion)
else:
- q._order_by.extend(util.to_list(criterion))
+ q._order_by = q._order_by + util.to_list(criterion)
return q
def group_by(self, criterion):
@@ -766,52 +445,62 @@ class Query(object):
if q._group_by is False:
q._group_by = util.to_list(criterion)
else:
- q._group_by.extend(util.to_list(criterion))
+ q._group_by = q._group_by + util.to_list(criterion)
return q
-
- def join(self, prop):
+
+ def join(self, prop, id=None, aliased=False, from_joinpoint=False):
"""create a join of this ``Query`` object's criterion
to a relationship and return the newly resulting ``Query``.
-
- 'prop' may be a string property name in which it is located
- in the same manner as keyword arguments in ``select_by``, or
- it may be a list of strings in which case the property is located
- by direct traversal of each keyname (i.e. like join_via()).
+
+ 'prop' may be a string property name or a list of string
+ property names.
"""
-
- q = self._clone()
- (clause, mapper) = self._join_to(prop, outerjoin=False)
- q._from_obj = [clause]
- q._joinpoint = mapper
- return q
- def outerjoin(self, prop):
+ return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint)
+
+ def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False):
"""create a left outer join of this ``Query`` object's criterion
to a relationship and return the newly resulting ``Query``.
- 'prop' may be a string property name in which it is located
- in the same manner as keyword arguments in ``select_by``, or
- it may be a list of strings in which case the property is located
- by direct traversal of each keyname (i.e. like join_via()).
+ 'prop' may be a string property name or a list of string
+ property names.
"""
+
+ return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint)
+
+ def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
+ (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
q = self._clone()
- (clause, mapper) = self._join_to(prop, outerjoin=True)
q._from_obj = [clause]
q._joinpoint = mapper
+ q._aliases = aliases
+
+ a = aliases
+ while a is not None:
+ q._alias_ids.setdefault(a.mapper, []).append(a)
+ q._alias_ids.setdefault(a.table, []).append(a)
+ q._alias_ids.setdefault(a.alias, []).append(a)
+ a = a.parentclauses
+
+ if id:
+ q._alias_ids[id] = aliases
return q
def reset_joinpoint(self):
"""return a new Query reset the 'joinpoint' of this Query reset
back to the starting mapper. Subsequent generative calls will
be constructed from the new joinpoint.
-
- This is an interim method which will not be needed with new behavior
- to be released in 0.4."""
-
+
+ Note that each call to join() or outerjoin() also starts from
+ the root.
+ """
+
q = self._clone()
q._joinpoint = q.mapper
+ q._aliases = None
return q
+
def select_from(self, from_obj):
"""Set the `from_obj` parameter of the query and return the newly
resulting ``Query``.
@@ -823,20 +512,6 @@ class Query(object):
new._from_obj = list(new._from_obj) + util.to_list(from_obj)
return new
- def __getattr__(self, key):
- if (key.startswith('select_by_')):
- key = key[10:]
- def foo(arg):
- return self.select_by(**{key:arg})
- return foo
- elif (key.startswith('get_by_')):
- key = key[7:]
- def foo(arg):
- return self.get_by(**{key:arg})
- return foo
- else:
- raise AttributeError(key)
-
def __getitem__(self, item):
if isinstance(item, slice):
start = item.start
@@ -884,99 +559,65 @@ class Query(object):
new._distinct = True
return new
- def list(self):
+ def all(self):
"""Return the results represented by this ``Query`` as a list.
This results in an execution of the underlying query.
-
- this method is deprecated in 0.4. use all() instead.
"""
-
return list(self)
-
- def one(self):
- """Return the first result of this ``Query``, raising an exception if more than one row exists.
-
- This results in an execution of the underlying query.
- this method is for forwards-compatibility with 0.4.
- """
-
- if self._col is None or self._func is None:
- ret = list(self[0:2])
-
- if len(ret) == 1:
- return ret[0]
- elif len(ret) == 0:
- raise exceptions.InvalidRequestError('No rows returned for one()')
- else:
- raise exceptions.InvalidRequestError('Multiple rows returned for one()')
- else:
- return self._col_aggregate(self._col, self._func)
-
+
+ def from_statement(self, statement):
+ if isinstance(statement, basestring):
+ statement = sql.text(statement)
+ q = self._clone()
+ q._statement = statement
+ return q
+
def first(self):
"""Return the first result of this ``Query``.
This results in an execution of the underlying query.
-
- this method is for forwards-compatibility with 0.4.
"""
- if self._col is None or self._func is None:
- ret = list(self[0:1])
- if len(ret) > 0:
- return ret[0]
- else:
- return None
+ if self._column_aggregate is not None:
+ return self._col_aggregate(*self._column_aggregate)
+
+ ret = list(self[0:1])
+ if len(ret) > 0:
+ return ret[0]
else:
- return self._col_aggregate(self._col, self._func)
+ return None
- def all(self):
- """Return the results represented by this ``Query`` as a list.
+ def one(self):
+ """Return the first result of this ``Query``, raising an exception if more than one row exists.
This results in an execution of the underlying query.
"""
- return self.list()
-
- def from_statement(self, statement):
- """execute a full select() statement, or literal textual string as a SELECT statement.
-
- this method is for forwards compatibility with 0.4.
- """
- if isinstance(statement, basestring):
- statement = sql.text(statement)
- q = self._clone()
- q._statement = statement
- return q
- def scalar(self):
- """Return the first result of this ``Query``.
+ if self._column_aggregate is not None:
+ return self._col_aggregate(*self._column_aggregate)
- This results in an execution of the underlying query.
-
- this method will be deprecated in 0.4; first() is added for
- forwards-compatibility.
- """
+ ret = list(self[0:2])
- return self.first()
+ if len(ret) == 1:
+ return ret[0]
+ elif len(ret) == 0:
+ raise exceptions.InvalidRequestError('No rows returned for one()')
+ else:
+ raise exceptions.InvalidRequestError('Multiple rows returned for one()')
def __iter__(self):
- return iter(self.select_whereclause())
-
- def execute(self, clauseelement, params=None, *args, **kwargs):
- """Execute the given ClauseElement-based statement against
- this Query's session/mapper, return the resulting list of
- instances.
-
- this method is deprecated in 0.4. Use from_statement() instead.
- """
-
- p = self._params
- if params is not None:
- p.update(params)
- result = self.session.execute(self.mapper, clauseelement, params=p)
+ statement = self.compile()
+ statement.use_labels = True
+ if self.session.autoflush:
+ self.session.flush()
+ return self._execute_and_instances(statement)
+
+ def _execute_and_instances(self, statement):
+ result = self.session.execute(statement, params=self._params, mapper=self.mapper)
try:
- return self.instances(result, **kwargs)
+ return iter(self.instances(result))
finally:
result.close()
@@ -984,50 +625,47 @@ class Query(object):
"""Return a list of mapped instances corresponding to the rows
in a given *cursor* (i.e. ``ResultProxy``).
- \*mappers_or_columns is an optional list containing one or more of
- classes, mappers, strings or sql.ColumnElements which will be
- applied to each row and added horizontally to the result set,
- which becomes a list of tuples. The first element in each tuple
- is the usual result based on the mapper represented by this
- ``Query``. Each additional element in the tuple corresponds to an
- entry in the \*mappers_or_columns list.
-
- For each element in \*mappers_or_columns, if the element is
- a mapper or mapped class, an additional class instance will be
- present in the tuple. If the element is a string or sql.ColumnElement,
- the corresponding result column from each row will be present in the tuple.
-
- Note that when \*mappers_or_columns is present, "uniquing" for the result set
- is *disabled*, so that the resulting tuples contain entities as they actually
- correspond. this indicates that multiple results may be present if this
- option is used.
+ The \*mappers_or_columns and \**kwargs arguments are deprecated.
+ To add instances or columns to the results, use add_entity()
+ and add_column().
"""
self.__log_debug("instances()")
session = self.session
- context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
+ kwargs.setdefault('populate_existing', self._populate_existing)
+ kwargs.setdefault('version_check', self._version_check)
+
+ context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs)
process = []
mappers_or_columns = tuple(self._entities) + mappers_or_columns
if mappers_or_columns:
- for m in mappers_or_columns:
+ for tup in mappers_or_columns:
+ if isinstance(tup, tuple):
+ (m, alias, alias_id) = tup
+ clauses = self._get_entity_clauses(tup)
+ else:
+ clauses = alias = alias_id = None
+ m = tup
if isinstance(m, type):
m = mapper.class_mapper(m)
if isinstance(m, mapper.Mapper):
def x(m):
+ row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row)
appender = []
def proc(context, row):
- if not m._instance(context, row, appender):
+ if not m._instance(context, row_adapter(row), appender):
appender.append(None)
process.append((proc, appender))
x(m)
- elif isinstance(m, sql.ColumnElement) or isinstance(m, basestring):
+ elif isinstance(m, (sql.ColumnElement, basestring)):
def y(m):
+ row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row)
res = []
def proc(context, row):
- res.append(row[m])
+ res.append(row_adapter(row)[m])
process.append((proc, res))
y(m)
result = []
@@ -1039,9 +677,12 @@ class Query(object):
for proc in process:
proc[0](context, row)
+ for instance in context.identity_map.values():
+ context.attributes.get(('populating_mapper', instance), object_mapper(instance))._post_instance(context, instance)
+
# store new stuff in the identity map
- for value in context.identity_map.values():
- session._register_persistent(value)
+ for instance in context.identity_map.values():
+ session._register_persistent(instance)
if mappers_or_columns:
return list(util.OrderedSet(zip(*([result] + [o[1] for o in process]))))
@@ -1050,8 +691,8 @@ class Query(object):
def _get(self, key, ident=None, reload=False, lockmode=None):
- lockmode = lockmode or self.lockmode
- if not reload and not self.always_refresh and lockmode is None:
+ lockmode = lockmode or self._lockmode
+ if not reload and not self.mapper.always_refresh and lockmode is None:
try:
return self.session._get(key)
except KeyError:
@@ -1062,21 +703,22 @@ class Query(object):
else:
ident = util.to_list(ident)
params = {}
- try:
- for i, primary_key in enumerate(self.primary_key_columns):
+
+ for i, primary_key in enumerate(self.primary_key_columns):
+ try:
params[primary_key._label] = ident[i]
- except IndexError:
- raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
+ except IndexError:
+ raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
try:
- statement = self.compile(self._get_clause, lockmode=lockmode)
- return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0]
+ q = self
+ if lockmode is not None:
+ q = q.with_lockmode(lockmode)
+ q = q.filter(self.select_mapper._get_clause)
+ q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
+ return q.first()
except IndexError:
return None
- def _select_statement(self, statement, params=None, **kwargs):
- statement.use_labels = True
- return self.execute(statement, params=params, **kwargs)
-
def _should_nest(self, querycontext):
"""Return True if the given statement options indicate that we
should *nest* the generated query as a subquery inside of a
@@ -1094,21 +736,56 @@ class Query(object):
return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
- def compile(self, whereclause = None, **kwargs):
- """Given a WHERE criterion, produce a ClauseElement-based
- statement suitable for usage in the execute() method.
+ def count(self, whereclause=None, params=None, **kwargs):
+ """Apply this query's criterion to a SELECT COUNT statement.
+
+ the whereclause, params and \**kwargs arguments are deprecated. use filter()
+ and other generative methods to establish modifiers.
+ """
+
+ q = self
+ if whereclause is not None:
+ q = q.filter(whereclause)
+ if params is not None:
+ q = q.params(**params)
+ q = q._legacy_select_kwargs(**kwargs)
+ return q._count()
- the arguments to this function are deprecated and are removed in version 0.4.
+ def _count(self):
+ """Apply this query's criterion to a SELECT COUNT statement.
+
+ this is the purely generative version which will become
+ the public method in version 0.5.
"""
+ whereclause = self._criterion
+
+ context = QueryContext(self)
+ from_obj = context.from_obj
+
+ alltables = []
+ for l in [sql_util.TableFinder(x) for x in from_obj]:
+ alltables += l
+
+ if self.table not in alltables:
+ from_obj.append(self.table)
+ if self._nestable(**context.select_args()):
+ s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count()
+ else:
+ primary_key = self.primary_key_columns
+ s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args())
+ return self.session.scalar(s, params=self._params, mapper=self.mapper)
+
+ def compile(self):
+ """compiles and returns a SQL statement based on the criterion and conditions within this Query."""
+
if self._statement:
self._statement.use_labels = True
return self._statement
+
+ whereclause = self._criterion
- if self._criterion:
- whereclause = sql.and_(self._criterion, whereclause)
-
- if whereclause is not None and self.is_polymorphic:
+ if whereclause is not None and (self.mapper is not self.select_mapper):
# adapt the given WHERECLAUSE to adjust instances of this query's mapped
# table to be that of our select_table,
# which may be the "polymorphic" selectable used by our mapper.
@@ -1124,16 +801,10 @@ class Query(object):
# get/create query context. get the ultimate compile arguments
# from there
- context = kwargs.pop('query_context', None)
- if context is None:
- context = QueryContext(self, kwargs)
+ context = QueryContext(self)
order_by = context.order_by
- group_by = context.group_by
from_obj = context.from_obj
lockmode = context.lockmode
- distinct = context.distinct
- limit = context.limit
- offset = context.offset
if order_by is False:
order_by = self.mapper.order_by
if order_by is False:
@@ -1161,31 +832,33 @@ class Query(object):
# if theres an order by, add those columns to the column list
# of the "rowcount" query we're going to make
if order_by:
- order_by = util.to_list(order_by) or []
+ order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []]
cf = sql_util.ColumnFinder()
for o in order_by:
cf.traverse(o)
else:
cf = []
- s2 = sql.select(self.table.primary_key + list(cf), whereclause, use_labels=True, from_obj=from_obj, **context.select_args())
+ s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args())
if order_by:
- s2.order_by(*util.to_list(order_by))
+ s2 = s2.order_by(*util.to_list(order_by))
s3 = s2.alias('tbl_row_count')
- crit = s3.primary_key==self.table.primary_key
+ crit = s3.primary_key==self.primary_key_columns
statement = sql.select([], crit, use_labels=True, for_update=for_update)
# now for the order by, convert the columns to their corresponding columns
# in the "rowcount" query, and tack that new order by onto the "rowcount" query
if order_by:
- statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
+ statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
else:
statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
if order_by:
- statement.order_by(*util.to_list(order_by))
+ statement.append_order_by(*util.to_list(order_by))
+
# for a DISTINCT query, you need the columns explicitly specified in order
# to use it in "order_by". ensure they are in the column criterion (particularly oid).
# TODO: this should be done at the SQL level not the mapper level
- if kwargs.get('distinct', False) and order_by:
+ # TODO: need test coverage for this
+ if context.distinct and order_by:
[statement.append_column(c) for c in util.to_list(order_by)]
context.statement = statement
@@ -1197,20 +870,268 @@ class Query(object):
value.setup(context)
# additional entities/columns, add those to selection criterion
- for m in self._entities:
- if isinstance(m, type):
- m = mapper.class_mapper(m)
+ for tup in self._entities:
+ (m, alias, alias_id) = tup
+ clauses = self._get_entity_clauses(tup)
if isinstance(m, mapper.Mapper):
for value in m.iterate_properties:
- value.setup(context)
+ value.setup(context, parentclauses=clauses)
elif isinstance(m, sql.ColumnElement):
+ if clauses is not None:
+ m = clauses.adapt_clause(m)
statement.append_column(m)
return statement
+ def _get_entity_clauses(self, m):
+ """for tuples added via add_entity() or add_column(), attempt to locate
+ an AliasedClauses object which should be used to formulate the query as well
+ as to process result rows."""
+ (m, alias, alias_id) = m
+ if alias is not None:
+ return alias
+ if alias_id is not None:
+ try:
+ return self._alias_ids[alias_id]
+ except KeyError:
+ raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id)
+ if isinstance(m, type):
+ m = mapper.class_mapper(m)
+ if isinstance(m, mapper.Mapper):
+ l = self._alias_ids.get(m)
+ if l:
+ if len(l) > 1:
+ raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(m))
+ else:
+ return l[0]
+ else:
+ return None
+ elif isinstance(m, sql.ColumnElement):
+ aliases = []
+ for table in sql_util.TableFinder(m, check_columns=True):
+ for a in self._alias_ids.get(table, []):
+ aliases.append(a)
+ if len(aliases) > 1:
+ raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column()" % str(m))
+ elif len(aliases) == 1:
+ return aliases[0]
+ else:
+ return None
+
def __log_debug(self, msg):
self.logger.debug(msg)
+ def __str__(self):
+ return str(self.compile())
+
+ # DEPRECATED LAND !
+
+ def list(self):
+ """DEPRECATED. use all()"""
+
+ return list(self)
+
+ def scalar(self):
+ """DEPRECATED. use first()"""
+
+ return self.first()
+
+ def _legacy_filter_by(self, *args, **kwargs):
+ return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint))
+
+ def count_by(self, *args, **params):
+ """DEPRECATED. use query.filter_by(\**params).count()"""
+
+ return self.count(self.join_by(*args, **params))
+
+
+ def select_whereclause(self, whereclause=None, params=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).all()"""
+
+ q = self.filter(whereclause)._legacy_select_kwargs(**kwargs)
+ if params is not None:
+ q = q.params(**params)
+ return list(q)
+
+ def _legacy_select_kwargs(self, **kwargs):
+ q = self
+ if "order_by" in kwargs and kwargs['order_by']:
+ q = q.order_by(kwargs['order_by'])
+ if "group_by" in kwargs:
+ q = q.group_by(kwargs['group_by'])
+ if "from_obj" in kwargs:
+ q = q.select_from(kwargs['from_obj'])
+ if "lockmode" in kwargs:
+ q = q.with_lockmode(kwargs['lockmode'])
+ if "distinct" in kwargs:
+ q = q.distinct()
+ if "limit" in kwargs:
+ q = q.limit(kwargs['limit'])
+ if "offset" in kwargs:
+ q = q.offset(kwargs['offset'])
+ return q
+
+
+ def get_by(self, *args, **params):
+ """DEPRECATED. use query.filter_by(\**params).first()"""
+
+ ret = self._extension.get_by(self, *args, **params)
+ if ret is not mapper.EXT_PASS:
+ return ret
+
+ return self._legacy_filter_by(*args, **params).first()
+
+ def select_by(self, *args, **params):
+ """DEPRECATED. use use query.filter_by(\**params).all()."""
+
+ ret = self._extension.select_by(self, *args, **params)
+ if ret is not mapper.EXT_PASS:
+ return ret
+
+ return self._legacy_filter_by(*args, **params).list()
+
+ def join_by(self, *args, **params):
+ """DEPRECATED. use join() to construct joins based on attribute names."""
+
+ return self._legacy_join_by(args, params, start=self._joinpoint)
+
+ def _build_select(self, arg=None, params=None, **kwargs):
+ if isinstance(arg, sql.FromClause) and arg.supports_execution():
+ return self.from_statement(arg)
+ else:
+ return self.filter(arg)._legacy_select_kwargs(**kwargs)
+
+ def selectfirst(self, arg=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).first()"""
+
+ return self._build_select(arg, **kwargs).first()
+
+ def selectone(self, arg=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).one()"""
+
+ return self._build_select(arg, **kwargs).one()
+
+ def select(self, arg=None, **kwargs):
+ """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
+
+ ret = self._extension.select(self, arg=arg, **kwargs)
+ if ret is not mapper.EXT_PASS:
+ return ret
+ return self._build_select(arg, **kwargs).all()
+
+ def execute(self, clauseelement, params=None, *args, **kwargs):
+ """DEPRECATED. use query.from_statement().all()"""
+
+ return self._select_statement(clauseelement, params, **kwargs)
+
+ def select_statement(self, statement, **params):
+ """DEPRECATED. Use query.from_statement(statement)"""
+
+ return self._select_statement(statement, params)
+
+ def select_text(self, text, **params):
+ """DEPRECATED. Use query.from_statement(statement)"""
+
+ return self._select_statement(text, params)
+
+ def _select_statement(self, statement, params=None, **kwargs):
+ q = self.from_statement(statement)
+ if params is not None:
+ q = q.params(**params)
+ q._select_context_options(**kwargs)
+ return list(q)
+
+ def _select_context_options(self, populate_existing=None, version_check=None):
+ if populate_existing is not None:
+ self._populate_existing = populate_existing
+ if version_check is not None:
+ self._version_check = version_check
+ return self
+
+ def join_to(self, key):
+ """DEPRECATED. use join() to create joins based on property names."""
+
+ [keys, p] = self._locate_prop(key)
+ return self.join_via(keys)
+
+ def join_via(self, keys):
+ """DEPRECATED. use join() to create joins based on property names."""
+
+ mapper = self._joinpoint
+ clause = None
+ for key in keys:
+ prop = mapper.get_property(key, resolve_synonyms=True)
+ if clause is None:
+ clause = prop.get_join(mapper)
+ else:
+ clause &= prop.get_join(mapper)
+ mapper = prop.mapper
+
+ return clause
+
+ def _legacy_join_by(self, args, params, start=None):
+ import properties
+
+ clause = None
+ for arg in args:
+ if clause is None:
+ clause = arg
+ else:
+ clause &= arg
+
+ for key, value in params.iteritems():
+ (keys, prop) = self._locate_prop(key, start=start)
+ if isinstance(prop, properties.PropertyLoader):
+ c = prop.compare(operator.eq, value) & self.join_via(keys[:-1])
+ else:
+ c = prop.compare(operator.eq, value) & self.join_via(keys)
+ if clause is None:
+ clause = c
+ else:
+ clause &= c
+ return clause
+
+ def _locate_prop(self, key, start=None):
+ import properties
+ keys = []
+ seen = util.Set()
+ def search_for_prop(mapper_):
+ if mapper_ in seen:
+ return None
+ seen.add(mapper_)
+
+ prop = mapper_.get_property(key, resolve_synonyms=True, raiseerr=False)
+ if prop is not None:
+ if isinstance(prop, properties.PropertyLoader):
+ keys.insert(0, prop.key)
+ return prop
+ else:
+ for prop in mapper_.iterate_properties:
+ if not isinstance(prop, properties.PropertyLoader):
+ continue
+ x = search_for_prop(prop.mapper)
+ if x:
+ keys.insert(0, prop.key)
+ return x
+ else:
+ return None
+ p = search_for_prop(start or self.mapper)
+ if p is None:
+ raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
+ return [keys, p]
+
+ def selectfirst_by(self, *args, **params):
+ """DEPRECATED. Use query.filter_by(\**kwargs).first()"""
+
+ return self._legacy_filter_by(*args, **params).first()
+
+ def selectone_by(self, *args, **params):
+ """DEPRECATED. Use query.filter_by(\**kwargs).one()"""
+
+ return self._legacy_filter_by(*args, **params).one()
+
+
+
Query.logger = logging.class_logger(Query)
class QueryContext(OperationContext):
@@ -1219,25 +1140,25 @@ class QueryContext(OperationContext):
in a query construction.
"""
- def __init__(self, query, kwargs):
+ def __init__(self, query):
self.query = query
- self.order_by = kwargs.pop('order_by', query._order_by)
- self.group_by = kwargs.pop('group_by', query._group_by)
- self.from_obj = kwargs.pop('from_obj', query._from_obj)
- self.lockmode = kwargs.pop('lockmode', query.lockmode)
- self.distinct = kwargs.pop('distinct', query._distinct)
- self.limit = kwargs.pop('limit', query._limit)
- self.offset = kwargs.pop('offset', query._offset)
+ self.order_by = query._order_by
+ self.group_by = query._group_by
+ self.from_obj = query._from_obj
+ self.lockmode = query._lockmode
+ self.distinct = query._distinct
+ self.limit = query._limit
+ self.offset = query._offset
self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
self.statement = None
- super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs)
+ super(QueryContext, self).__init__(query.mapper, query._with_options)
def select_args(self):
"""Return a dictionary of attributes from this
``QueryContext`` that can be applied to a ``sql.Select``
statement.
"""
- return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by}
+ return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None}
def accept_option(self, opt):
"""Accept a ``MapperOption`` which will process (modify) the
@@ -1265,8 +1186,10 @@ class SelectionContext(OperationContext):
yet been added as persistent to the Session.
attributes
- A dictionary to store arbitrary data; eager loaders use it to
- store additional result lists.
+ A dictionary to store arbitrary data; mappers, strategies, and
+ options all store various state information here in order
+ to communicate with each other and to themselves.
+
populate_existing
Indicates if its OK to overwrite the attributes of instances
@@ -1284,6 +1207,7 @@ class SelectionContext(OperationContext):
self.session = session
self.extension = extension
self.identity_map = {}
+ self.stack = LoaderStack()
super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs)
def accept_option(self, opt):
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 4e7453d84..6b5c4a072 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -4,12 +4,12 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import weakref
+
from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query
+from sqlalchemy.orm import unitofwork, query, util as mapperutil
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
-import weakref
-import sqlalchemy
class SessionTransaction(object):
"""Represents a Session-level Transaction.
@@ -21,70 +21,95 @@ class SessionTransaction(object):
The SessionTransaction object is **not** threadsafe.
"""
- def __init__(self, session, parent=None, autoflush=True):
+ def __init__(self, session, parent=None, autoflush=True, nested=False):
self.session = session
- self.connections = {}
- self.parent = parent
+ self.__connections = {}
+ self.__parent = parent
self.autoflush = autoflush
+ self.nested = nested
- def connection(self, mapper_or_class, entity_name=None):
+ def connection(self, mapper_or_class, entity_name=None, **kwargs):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- engine = self.session.get_bind(mapper_or_class)
+ engine = self.session.get_bind(mapper_or_class, **kwargs)
return self.get_or_add(engine)
- def _begin(self):
- return SessionTransaction(self.session, self)
+ def _begin(self, **kwargs):
+ return SessionTransaction(self.session, self, **kwargs)
def add(self, bind):
- if self.parent is not None:
- return self.parent.add(bind)
-
- if self.connections.has_key(bind.engine):
- raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
+ if self.__parent is not None:
+ return self.__parent.add(bind)
+ if self.__connections.has_key(bind.engine):
+ raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
return self.get_or_add(bind)
+ def _connection_dict(self):
+ if self.__parent is not None and not self.nested:
+ return self.__parent._connection_dict()
+ else:
+ return self.__connections
+
def get_or_add(self, bind):
- if self.parent is not None:
- return self.parent.get_or_add(bind)
+ if self.__parent is not None:
+ if not self.nested:
+ return self.__parent.get_or_add(bind)
+
+ if self.__connections.has_key(bind):
+ return self.__connections[bind][0]
+
+ if bind in self.__parent._connection_dict():
+ (conn, trans, autoclose) = self.__parent.__connections[bind]
+ self.__connections[conn] = self.__connections[bind.engine] = (conn, conn.begin_nested(), autoclose)
+ return conn
+ elif self.__connections.has_key(bind):
+ return self.__connections[bind][0]
- if self.connections.has_key(bind):
- return self.connections[bind][0]
-
if not isinstance(bind, engine.Connection):
e = bind
c = bind.contextual_connect()
else:
e = bind.engine
c = bind
- if e in self.connections:
+ if e in self.__connections:
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
-
- self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind)
- return self.connections[bind][0]
+ if self.nested:
+ trans = c.begin_nested()
+ elif self.session.twophase:
+ trans = c.begin_twophase()
+ else:
+ trans = c.begin()
+ self.__connections[c] = self.__connections[e] = (c, trans, c is not bind)
+ return self.__connections[c][0]
def commit(self):
- if self.parent is not None:
- return
+ if self.__parent is not None and not self.nested:
+ return self.__parent
if self.autoflush:
self.session.flush()
- for t in util.Set(self.connections.values()):
+
+ if self.session.twophase:
+ for t in util.Set(self.__connections.values()):
+ t[1].prepare()
+
+ for t in util.Set(self.__connections.values()):
t[1].commit()
self.close()
+ return self.__parent
def rollback(self):
- if self.parent is not None:
- self.parent.rollback()
- return
- for k, t in self.connections.iteritems():
+ if self.__parent is not None and not self.nested:
+ return self.__parent.rollback()
+ for t in util.Set(self.__connections.values()):
t[1].rollback()
self.close()
-
+ return self.__parent
+
def close(self):
- if self.parent is not None:
+ if self.__parent is not None:
return
- for t in self.connections.values():
+ for t in util.Set(self.__connections.values()):
if t[2]:
t[0].close()
self.session.transaction = None
@@ -108,23 +133,24 @@ class Session(object):
of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module.
"""
- def __init__(self, bind=None, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False):
- if import_session is not None:
- self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map)
- else:
- self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
+ def __init__(self, bind=None, autoflush=False, transactional=False, twophase=False, echo_uow=False, weak_identity_map=False):
+ self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
- self.bind = bind or bind_to
- self.binds = {}
+ self.bind = bind
+ self.__binds = {}
self.echo_uow = echo_uow
self.weak_identity_map = weak_identity_map
self.transaction = None
- if hash_key is None:
- self.hash_key = id(self)
- else:
- self.hash_key = hash_key
+ self.hash_key = id(self)
+ self.autoflush = autoflush
+ self.transactional = transactional or autoflush
+ self.twophase = twophase
+ self._query_cls = query.Query
+ self._mapper_flush_opts = {}
+ if self.transactional:
+ self.begin()
_sessions[self.hash_key] = self
-
+
def _get_echo_uow(self):
return self.uow.echo
@@ -132,37 +158,39 @@ class Session(object):
self.uow.echo = value
echo_uow = property(_get_echo_uow,_set_echo_uow)
- bind_to = property(lambda self:self.bind)
-
- def create_transaction(self, **kwargs):
- """Return a new ``SessionTransaction`` corresponding to an
- existing or new transaction.
-
- If the transaction is new, the returned ``SessionTransaction``
- will have commit control over the underlying transaction, else
- will have rollback control only.
- """
+ def begin(self, **kwargs):
+ """Begin a transaction on this Session."""
if self.transaction is not None:
- return self.transaction._begin()
+ self.transaction = self.transaction._begin(**kwargs)
else:
self.transaction = SessionTransaction(self, **kwargs)
- return self.transaction
-
- def connect(self, mapper=None, **kwargs):
- """Return a unique connection corresponding to the given mapper.
-
- This connection will not be part of any pre-existing
- transactional context.
- """
-
- return self.get_bind(mapper).connect(**kwargs)
-
- def connection(self, mapper, **kwargs):
- """Return a ``Connection`` corresponding to the given mapper.
+ return self.transaction
+
+ create_transaction = begin
- Used by the ``execute()`` method which performs select
- operations for ``Mapper`` and ``Query``.
+ def begin_nested(self):
+ return self.begin(nested=True)
+
+ def rollback(self):
+ if self.transaction is None:
+ raise exceptions.InvalidRequestError("No transaction is begun.")
+ else:
+ self.transaction = self.transaction.rollback()
+ if self.transaction is None and self.transactional:
+ self.begin()
+
+ def commit(self):
+ if self.transaction is None:
+ raise exceptions.InvalidRequestError("No transaction is begun.")
+ else:
+ self.transaction = self.transaction.commit()
+ if self.transaction is None and self.transactional:
+ self.begin()
+
+ def connection(self, mapper=None, **kwargs):
+ """Return a ``Connection`` corresponding to this session's
+ transactional context, if any.
If this ``Session`` is transactional, the connection will be in
the context of this session's transaction. Otherwise, the
@@ -173,6 +201,9 @@ class Session(object):
The given `**kwargs` will be sent to the engine's
``contextual_connect()`` method, if no transaction is in
progress.
+
+ the "mapper" argument is a class or mapper to which a bound engine
+ will be located; use this when the Session itself is unbound.
"""
if self.transaction is not None:
@@ -180,7 +211,7 @@ class Session(object):
else:
return self.get_bind(mapper).contextual_connect(**kwargs)
- def execute(self, mapper, clause, params, **kwargs):
+ def execute(self, clause, params=None, mapper=None, **kwargs):
"""Using the given mapper to identify the appropriate ``Engine``
or ``Connection`` to be used for statement execution, execute the
given ``ClauseElement`` using the provided parameter dictionary.
@@ -191,12 +222,12 @@ class Session(object):
then the ``ResultProxy`` 's ``close()`` method will release the
resources of the underlying ``Connection``, otherwise its a no-op.
"""
- return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs)
+ return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs)
- def scalar(self, mapper, clause, params, **kwargs):
+ def scalar(self, clause, params=None, mapper=None, **kwargs):
"""Like execute() but return a scalar result."""
- return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs)
+ return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs)
def close(self):
"""Close this Session."""
@@ -224,14 +255,17 @@ class Session(object):
return _class_mapper(class_, entity_name = entity_name)
- def bind_mapper(self, mapper, bind):
- """Bind the given `mapper` to the given ``Engine`` or ``Connection``.
+ def bind_mapper(self, mapper, bind, entity_name=None):
+ """Bind the given `mapper` or `class` to the given ``Engine`` or ``Connection``.
All subsequent operations involving this ``Mapper`` will use the
given `bind`.
"""
+
+ if isinstance(mapper, type):
+ mapper = _class_mapper(mapper, entity_name=entity_name)
- self.binds[mapper] = bind
+ self.__binds[mapper] = bind
def bind_table(self, table, bind):
"""Bind the given `table` to the given ``Engine`` or ``Connection``.
@@ -240,7 +274,7 @@ class Session(object):
given `bind`.
"""
- self.binds[table] = bind
+ self.__binds[table] = bind
def get_bind(self, mapper):
"""Return the ``Engine`` or ``Connection`` which is used to execute
@@ -270,15 +304,18 @@ class Session(object):
"""
if mapper is None:
- return self.bind
- elif self.binds.has_key(mapper):
- return self.binds[mapper]
- elif self.binds.has_key(mapper.mapped_table):
- return self.binds[mapper.mapped_table]
+ if self.bind is not None:
+ return self.bind
+ else:
+ raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
+ elif self.__binds.has_key(mapper):
+ return self.__binds[mapper]
+ elif self.__binds.has_key(mapper.mapped_table):
+ return self.__binds[mapper.mapped_table]
elif self.bind is not None:
return self.bind
else:
- e = mapper.mapped_table.engine
+ e = mapper.mapped_table.bind
if e is None:
raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
return e
@@ -291,9 +328,9 @@ class Session(object):
entity_name = kwargs.pop('entity_name', None)
if isinstance(mapper_or_class, type):
- q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
+ q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
else:
- q = query.Query(mapper_or_class, self, **kwargs)
+ q = self._query_cls(mapper_or_class, self, **kwargs)
for ent in addtl_entities:
q = q.add_entity(ent)
@@ -499,7 +536,7 @@ class Session(object):
merged = self.get(mapper.class_, key[1])
if merged is None:
raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object))
- for prop in mapper.props.values():
+ for prop in mapper.iterate_properties:
prop.merge(self, object, merged, _recursive)
if key is None:
self.save(merged, entity_name=mapper.entity_name)
@@ -611,12 +648,12 @@ class Session(object):
def _attach(self, obj):
"""Attach the given object to this ``Session``."""
- if getattr(obj, '_sa_session_id', None) != self.hash_key:
- old = getattr(obj, '_sa_session_id', None)
- if old is not None and _sessions.has_key(old):
+ old_id = getattr(obj, '_sa_session_id', None)
+ if old_id != self.hash_key:
+ if old_id is not None and _sessions.has_key(old_id):
raise exceptions.InvalidRequestError("Object '%s' is already attached "
"to session '%s' (this is '%s')" %
- (repr(obj), old, id(self)))
+ (repr(obj), old_id, id(self)))
# auto-removal from the old session is disabled. but if we decide to
# turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
@@ -695,6 +732,7 @@ def object_session(obj):
return _sessions.get(hashkey)
return None
+# Lazy initialization to avoid circular imports
unitofwork.object_session = object_session
from sqlalchemy.orm import mapper
mapper.attribute_manager = attribute_manager
diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py
new file mode 100644
index 000000000..cc13f8c1f
--- /dev/null
+++ b/lib/sqlalchemy/orm/shard.py
@@ -0,0 +1,112 @@
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm import Query
+
+class ShardedSession(Session):
+ def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs):
+ """construct a ShardedSession.
+
+ shard_chooser
+ a callable which, passed a Mapper and a mapped instance, returns a
+ shard ID. this id may be based off of the attributes present within the
+ object, or on some round-robin scheme. If the scheme is based on a
+ selection, it should set whatever state on the instance to mark it in
+ the future as participating in that shard.
+
+ id_chooser
+ a callable, passed a tuple of identity values, which should return
+ a list of shard ids where the ID might reside. The databases will
+ be queried in the order of this listing.
+
+ query_chooser
+ for a given Query, returns the list of shard_ids where the query
+ should be issued. Results from all shards returned will be
+ combined together into a single listing.
+
+ """
+ super(ShardedSession, self).__init__(**kwargs)
+ self.shard_chooser = shard_chooser
+ self.id_chooser = id_chooser
+ self.query_chooser = query_chooser
+ self.__binds = {}
+ self._mapper_flush_opts = {'connection_callable':self.connection}
+ self._query_cls = ShardedQuery
+
+ def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance)
+
+ if self.transaction is not None:
+ return self.transaction.connection(mapper, shard_id=shard_id)
+ else:
+ return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs)
+
+ def get_bind(self, mapper, shard_id=None, instance=None):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance)
+ return self.__binds[shard_id]
+
+ def bind_shard(self, shard_id, bind):
+ self.__binds[shard_id] = bind
+
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self._shard_id = None
+
+ def _clone(self):
+ q = ShardedQuery.__new__(ShardedQuery)
+ q.__dict__ = self.__dict__.copy()
+ return q
+
+ def set_shard(self, shard_id):
+ """return a new query, limited to a single shard ID.
+
+ all subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+ """
+
+ q = self._clone()
+ q._shard_id = shard_id
+ return q
+
+ def _execute_and_instances(self, statement):
+ if self._shard_id is not None:
+ result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(statement, **self._params)
+ try:
+ return iter(self.instances(result))
+ finally:
+ result.close()
+ else:
+ partial = []
+ for shard_id in self.query_chooser(self):
+ result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(statement, **self._params)
+ try:
+ partial = partial + list(self.instances(result))
+ finally:
+ result.close()
+ # if some kind of in memory 'sorting' were done, this is where it would happen
+ return iter(partial)
+
+ def get(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).get(ident)
+ else:
+ for shard_id in self.id_chooser(ident):
+ o = self.set_shard(shard_id).get(ident, **kwargs)
+ if o is not None:
+ return o
+ else:
+ return None
+
+ def load(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).load(ident)
+ else:
+ for shard_id in self.id_chooser(ident):
+ o = self.set_shard(shard_id).load(ident, raiseerr=False, **kwargs)
+ if o is not None:
+ return o
+ else:
+ raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 462954f6b..babd6e4c0 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -6,12 +6,11 @@
"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
-from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
-from sqlalchemy.orm import mapper, query
-from sqlalchemy.orm.interfaces import *
+from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy.orm import mapper, attributes
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import random
class ColumnLoader(LoaderStrategy):
@@ -19,8 +18,9 @@ class ColumnLoader(LoaderStrategy):
super(ColumnLoader, self).init()
self.columns = self.parent_property.columns
self._should_log_debug = logging.is_debug_enabled(self.logger)
+ self.is_composite = hasattr(self.parent_property, 'composite_class')
- def setup_query(self, context, eagertable=None, parentclauses=None, **kwargs):
+ def setup_query(self, context, parentclauses=None, **kwargs):
for c in self.columns:
if parentclauses is not None:
context.statement.append_column(parentclauses.aliased_column(c))
@@ -28,16 +28,93 @@ class ColumnLoader(LoaderStrategy):
context.statement.append_column(c)
def init_class_attribute(self):
+ if self.is_composite:
+ self._init_composite_attribute()
+ else:
+ self._init_scalar_attribute()
+
+ def _init_composite_attribute(self):
+ self.logger.info("register managed composite attribute %s on class %s" % (self.key, self.parent.class_.__name__))
+ def copy(obj):
+ return self.parent_property.composite_class(*obj.__colset__())
+ def compare(a, b):
+ for col, aprop, bprop in zip(self.columns, a.__colset__(), b.__colset__()):
+ if not col.type.compare_values(aprop, bprop):
+ return False
+ else:
+ return True
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator)
+
+ def _init_scalar_attribute(self):
self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
coltype = self.columns[0].type
- sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable())
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
+
+ def create_row_processor(self, selectcontext, mapper, row):
+ if self.is_composite:
+ for c in self.columns:
+ if c not in row:
+ break
+ else:
+ def execute(instance, row, isnew, ispostselect=None, **flags):
+ if isnew or ispostselect:
+ if self._should_log_debug:
+ self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
+ instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns])
+ self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key))
+ return (execute, None)
+
+ elif self.columns[0] in row:
+ def execute(instance, row, isnew, ispostselect=None, **flags):
+ if isnew or ispostselect:
+ if self._should_log_debug:
+ self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
+ instance.__dict__[self.key] = row[self.columns[0]]
+ self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
+ return (execute, None)
+
+ (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None))
+ if hosted_mapper is None:
+ return (None, None)
+
+ if hosted_mapper.polymorphic_fetch == 'deferred':
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_loader(instance, mapper, needs_tables))
+ self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
+ return (execute, None)
+ else:
+ self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
+ return (None, None)
+
+ def _get_deferred_loader(self, instance, mapper, needs_tables):
+ def load():
+ group = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
if self._should_log_debug:
- self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
- instance.__dict__[self.key] = row[self.columns[0]]
-
+ self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None'))
+
+ session = sessionlib.object_session(instance)
+ if session is None:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+ cond, param_names = mapper._deferred_inheritance_condition(needs_tables)
+ statement = sql.select(needs_tables, cond, use_labels=True)
+ params = {}
+ for c in param_names:
+ params[c.name] = mapper.get_attr_by_column(instance, c)
+
+ result = session.execute(statement, params, mapper=mapper)
+ try:
+ row = result.fetchone()
+ for prop in group:
+ sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
+ return attributes.ATTR_WAS_SET
+ finally:
+ result.close()
+
+ return load
+
ColumnLoader.logger = logging.class_logger(ColumnLoader)
class DeferredColumnLoader(LoaderStrategy):
@@ -47,74 +124,86 @@ class DeferredColumnLoader(LoaderStrategy):
This is per-column lazy loading.
"""
+ def create_row_processor(self, selectcontext, mapper, row):
+ if self.group is not None and selectcontext.attributes.get(('undefer', self.group), False):
+ return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row)
+ elif not self.is_default or len(selectcontext.options):
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
+ sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance))
+ return (execute, None)
+ else:
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
+ sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+ return (execute, None)
+
def init(self):
super(DeferredColumnLoader, self).init()
+ if hasattr(self.parent_property, 'composite_class'):
+ raise NotImplementedError("Deferred loading for composite types not implemented yet")
self.columns = self.parent_property.columns
self.group = self.parent_property.group
self._should_log_debug = logging.is_debug_enabled(self.logger)
def init_class_attribute(self):
self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
- sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=lambda i:self.setup_loader(i), copy_function=lambda x: self.columns[0].type.copy_value(x), compare_function=lambda x,y:self.columns[0].type.compare_values(x,y), mutable_scalars=self.columns[0].type.is_mutable())
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
def setup_query(self, context, **kwargs):
- pass
+ if self.group is not None and context.attributes.get(('undefer', self.group), False):
+ self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
- if not self.is_default or len(selectcontext.options):
- sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self.setup_loader(instance, selectcontext.options))
- else:
- sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
-
- def setup_loader(self, instance, options=None):
- if not mapper.has_mapper(instance):
+ def setup_loader(self, instance):
+ localparent = mapper.object_mapper(instance, raiseerror=False)
+ if localparent is None:
return None
- else:
- prop = mapper.object_mapper(instance).props[self.key]
- if prop is not self.parent_property:
- return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
- def lazyload():
- if self._should_log_debug:
- self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), str(self.group)))
+
+ prop = localparent.get_property(self.key)
+ if prop is not self.parent_property:
+ return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
+ def lazyload():
if not mapper.has_identity(instance):
return None
- try:
- pk = self.parent.pks_by_table[self.columns[0].table]
- except KeyError:
- pk = self.columns[0].table.primary_key
-
- clause = sql.and_()
- for primary_key in pk:
- attr = self.parent.get_attr_by_column(instance, primary_key)
- if not attr:
- return None
- clause.clauses.append(primary_key == attr)
+ if self.group is not None:
+ group = [p for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
+ else:
+ group = None
+
+ if self._should_log_debug:
+ self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None'))
session = sessionlib.object_session(instance)
if session is None:
raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
- localparent = mapper.object_mapper(instance)
- if self.group is not None:
- groupcols = [p for p in localparent.props.values() if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
- result = session.execute(localparent, sql.select([g.columns[0] for g in groupcols], clause, use_labels=True), None)
+
+ clause = localparent._get_clause
+ ident = instance._instance_key[1]
+ params = {}
+ for i, primary_key in enumerate(localparent.primary_key):
+ params[primary_key._label] = ident[i]
+ if group is not None:
+ statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True)
+ else:
+ statement = sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True)
+
+ if group is not None:
+ result = session.execute(statement, params, mapper=localparent)
try:
row = result.fetchone()
- for prop in groupcols:
- if prop is self:
- continue
- # set a scalar object instance directly on the object,
- # bypassing SmartProperty event handlers.
- sessionlib.attribute_manager.init_instance_attribute(instance, prop.key, uselist=False)
- instance.__dict__[prop.key] = row[prop.columns[0]]
- return row[self.columns[0]]
+ for prop in group:
+ sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
+ return attributes.ATTR_WAS_SET
finally:
result.close()
else:
- return session.scalar(localparent, sql.select([self.columns[0]], clause, use_labels=True),None)
+ return session.scalar(sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True),params, mapper=localparent)
return lazyload
@@ -131,6 +220,15 @@ class DeferredOption(StrategizedOption):
else:
return ColumnLoader
+class UndeferGroupOption(MapperOption):
+ def __init__(self, group):
+ self.group = group
+ def process_query_context(self, context):
+ context.attributes[('undefer', self.group)] = True
+
+ def process_selection_context(self, context):
+ context.attributes[('undefer', self.group)] = True
+
class AbstractRelationLoader(LoaderStrategy):
def init(self):
super(AbstractRelationLoader, self).init()
@@ -139,22 +237,26 @@ class AbstractRelationLoader(LoaderStrategy):
self._should_log_debug = logging.is_debug_enabled(self.logger)
def _init_instance_attribute(self, instance, callable_=None):
- return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_)
+ return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_)
def _register_attribute(self, class_, callable_=None):
self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
- sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_)
+ sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator)
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):
self._register_attribute(self.parent.class_)
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
- if not self.is_default or len(selectcontext.options):
- if self._should_log_debug:
- self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key))
- self._init_instance_attribute(instance)
+ def create_row_processor(self, selectcontext, mapper, row):
+ if not self.is_default or len(selectcontext.options):
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key))
+ self._init_instance_attribute(instance)
+ return (execute, None)
+ else:
+ return (None, None)
NoLoader.logger = logging.class_logger(NoLoader)
@@ -167,7 +269,8 @@ class LazyLoader(AbstractRelationLoader):
# determine if our "lazywhere" clause is the same as the mapper's
# get() clause. then we can just use mapper.get()
- self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere)
+ #from sqlalchemy.orm import query
+ self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere)
if self.use_get:
self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
@@ -178,7 +281,7 @@ class LazyLoader(AbstractRelationLoader):
if not mapper.has_mapper(instance):
return None
else:
- prop = mapper.object_mapper(instance).props[self.key]
+ prop = mapper.object_mapper(instance).get_property(self.key)
if prop is not self.parent_property:
return prop._get_strategy(LazyLoader).setup_loader(instance)
def lazyload():
@@ -211,20 +314,27 @@ class LazyLoader(AbstractRelationLoader):
# if we have a simple straight-primary key load, use mapper.get()
# to possibly save a DB round trip
+ q = session.query(self.mapper)
if self.use_get:
ident = []
- for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]:
+ # TODO: when options are added to allow switching between union-based and non-union
+ # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper",
+ # probably via the query's own "mapper" property, and also use one of two "lazy" clauses,
+ # one against the "union" the other not
+ for primary_key in self.select_mapper.primary_key:
bind = self.lazyreverse[primary_key]
ident.append(params[bind.key])
- return session.query(self.mapper).get(ident)
+ return q.get(ident)
elif self.order_by is not False:
- order_by = self.order_by
+ q = q.order_by(self.order_by)
elif self.secondary is not None and self.secondary.default_order_by() is not None:
- order_by = self.secondary.default_order_by()
- else:
- order_by = False
- result = session.query(self.mapper, with_options=options).select_whereclause(self.lazywhere, order_by=order_by, params=params)
+ q = q.order_by(self.secondary.default_order_by())
+ if options:
+ q = q.options(*options)
+ q = q.filter(self.lazywhere).params(**params)
+
+ result = q.all()
if self.uselist:
return result
else:
@@ -232,25 +342,37 @@ class LazyLoader(AbstractRelationLoader):
return result[0]
else:
return None
+
+ if self.uselist:
+ return q.all()
+ else:
+ return q.first()
+
return lazyload
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- if isnew:
- # new object instance being loaded from a result row
- if not self.is_default or len(selectcontext.options):
- self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
- # which will override the clareset_instance_attributess-level behavior
- self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options))
- else:
- self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- # we are the primary manager for this attribute on this class - reset its per-instance attribute state,
- # so that the class-level lazy loader is executed when next referenced on this instance.
- # this usually is not needed unless the constructor of the object referenced the attribute before we got
- # to load data into it.
- sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
-
- def _create_lazy_clause(cls, prop, reverse_direction=False):
+ def create_row_processor(self, selectcontext, mapper, row):
+ if not self.is_default or len(selectcontext.options):
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
+ # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
+ # which will override the clareset_instance_attributess-level behavior
+ self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options))
+ return (execute, None)
+ else:
+ def execute(instance, row, isnew, **flags):
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
+ # we are the primary manager for this attribute on this class - reset its per-instance attribute state,
+ # so that the class-level lazy loader is executed when next referenced on this instance.
+ # this usually is not needed unless the constructor of the object referenced the attribute before we got
+ # to load data into it.
+ sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+ return (execute, None)
+
+ def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='):
(primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
binds = {}
@@ -272,19 +394,16 @@ class LazyLoader(AbstractRelationLoader):
FindColumnInColumnClause().traverse(expr)
return len(columns) and columns[0] or None
- def bind_label():
- # TODO: make this generation deterministic
- return "lazy_" + hex(random.randint(0, 65535))[2:]
-
def visit_binary(binary):
leftcol = find_column_in_expr(binary.left)
rightcol = find_column_in_expr(binary.right)
if leftcol is None or rightcol is None:
return
+
if should_bind(leftcol, rightcol):
col = leftcol
binary.left = binds.setdefault(leftcol,
- sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type, unique=True))
+ sql.bindparam(None, None, shortname=leftcol.name, type_=binary.right.type, unique=True))
reverse[rightcol] = binds[col]
# the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
@@ -292,21 +411,19 @@ class LazyLoader(AbstractRelationLoader):
if leftcol is not rightcol and should_bind(rightcol, leftcol):
col = rightcol
binary.right = binds.setdefault(rightcol,
- sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True))
+ sql.bindparam(None, None, shortname=rightcol.name, type_=binary.left.type, unique=True))
reverse[leftcol] = binds[col]
- lazywhere = primaryjoin.copy_container()
+ lazywhere = primaryjoin
li = mapperutil.BinaryVisitor(visit_binary)
if not secondaryjoin or not reverse_direction:
- li.traverse(lazywhere)
+ lazywhere = li.traverse(lazywhere, clone=True)
if secondaryjoin is not None:
- secondaryjoin = secondaryjoin.copy_container()
if reverse_direction:
- li.traverse(secondaryjoin)
+ secondaryjoin = li.traverse(secondaryjoin, clone=True)
lazywhere = sql.and_(lazywhere, secondaryjoin)
-
return (lazywhere, binds, reverse)
_create_lazy_clause = classmethod(_create_lazy_clause)
@@ -318,154 +435,42 @@ class EagerLoader(AbstractRelationLoader):
def init(self):
super(EagerLoader, self).init()
- if self.parent.isa(self.mapper):
- raise exceptions.ArgumentError(
- "Error creating eager relationship '%s' on parent class '%s' "
- "to child class '%s': Cant use eager loading on a self "
- "referential relationship." %
- (self.key, repr(self.parent.class_), repr(self.mapper.class_)))
if self.is_default:
self.parent._eager_loaders.add(self.parent_property)
self.clauses = {}
- self.clauses_by_lead_mapper = {}
-
- class AliasedClauses(object):
- """Defines a set of join conditions and table aliases which
- are aliased on a randomly-generated alias name, corresponding
- to the connection of an optional parent AliasedClauses object
- and a target mapper.
-
- EagerLoader has a distinct AliasedClauses object per parent
- AliasedClauses object, so that all paths from one mapper to
- another across a chain of eagerloaders generates a distinct
- chain of joins. The AliasedClauses objects are generated and
- cached on an as-needed basis.
-
- E.g.::
-
- mapper A -->
- (EagerLoader 'items') -->
- mapper B -->
- (EagerLoader 'keywords') -->
- mapper C
-
- will generate::
-
- EagerLoader 'items' --> {
- None : AliasedClauses(items, None, alias_suffix='AB34') # mappera JOIN mapperb_AB34
- }
-
- EagerLoader 'keywords' --> [
- None : AliasedClauses(keywords, None, alias_suffix='43EF') # mapperb JOIN mapperc_43EF
- AliasedClauses(items, None, alias_suffix='AB34') :
- AliasedClauses(keywords, items, alias_suffix='8F44') # mapperb_AB34 JOIN mapperc_8F44
- ]
- """
-
- def __init__(self, eagerloader, parentclauses=None):
- self.id = (parentclauses is not None and (parentclauses.id + "/") or '') + str(eagerloader.parent_property)
- self.parent = eagerloader
- self.target = eagerloader.select_table
- self.eagertarget = eagerloader.select_table.alias(self._aliashash("/target"))
- self.extra_cols = {}
-
- if eagerloader.secondary:
- self.eagersecondary = eagerloader.secondary.alias(self._aliashash("/secondary"))
- if parentclauses is not None:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget).\
- chain(sql_util.ClauseAdapter(self.eagersecondary)).\
- chain(sql_util.ClauseAdapter(parentclauses.eagertarget))
- else:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget).\
- chain(sql_util.ClauseAdapter(self.eagersecondary))
- self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container()
- aliasizer.traverse(self.eagersecondaryjoin)
- self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
- aliasizer.traverse(self.eagerprimary)
- else:
- self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
- if parentclauses is not None:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget)
- aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side))
- else:
- aliasizer = sql_util.ClauseAdapter(self.eagertarget)
- aliasizer.traverse(self.eagerprimary)
-
- if eagerloader.order_by:
- self.eager_order_by = sql_util.ClauseAdapter(self.eagertarget).copy_and_process(util.to_list(eagerloader.order_by))
- else:
- self.eager_order_by = None
-
- self._row_decorator = self._create_decorator_row()
-
- def aliased_column(self, column):
- """return the aliased version of the given column, creating a new label for it if not already
- present in this AliasedClauses eagertable."""
-
- conv = self.eagertarget.corresponding_column(column, raiseerr=False)
- if conv:
- return conv
-
- if column in self.extra_cols:
- return self.extra_cols[column]
-
- aliased_column = column.copy_container()
- sql_util.ClauseAdapter(self.eagertarget).traverse(aliased_column)
- alias = self._aliashash(column.name)
- aliased_column = aliased_column.label(alias)
- self._row_decorator.map[column] = alias
- self.extra_cols[column] = aliased_column
- return aliased_column
-
- def _aliashash(self, extra):
- """return a deterministic 4 digit hash value for this AliasedClause's id + extra."""
- # use the first 4 digits of an MD5 hash
- return "anon_" + util.hash(self.id + extra)[0:4]
-
- def _create_decorator_row(self):
- class EagerRowAdapter(object):
- def __init__(self, row):
- self.row = row
- def has_key(self, key):
- return map.has_key(key) or self.row.has_key(key)
- def __getitem__(self, key):
- if map.has_key(key):
- key = map[key]
- return self.row[key]
- def keys(self):
- return map.keys()
- map = {}
- for c in self.eagertarget.c:
- parent = self.target.corresponding_column(c)
- map[parent] = c
- map[parent._label] = c
- map[parent.name] = c
- EagerRowAdapter.map = map
- return EagerRowAdapter
-
- def _decorate_row(self, row):
- # adapts a row at row iteration time to transparently
- # convert plain columns into the aliased columns that were actually
- # added to the column clause of the SELECT.
- return self._row_decorator(row)
+ self.join_depth = self.parent_property.join_depth
def init_class_attribute(self):
self.parent_property._get_strategy(LazyLoader).init_class_attribute()
- def setup_query(self, context, eagertable=None, parentclauses=None, parentmapper=None, **kwargs):
+ def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs):
"""Add a left outer join to the statement thats being constructed."""
+ # build a path as we setup the query. the format of this path
+ # matches that of interfaces.LoaderStack, and will be used in the
+ # row-loading phase to match up AliasedClause objects with the current
+ # LoaderStack position.
+ if parentclauses:
+ path = parentclauses.path + (self.parent.base_mapper(), self.key)
+ else:
+ path = (self.parent.base_mapper(), self.key)
+
+
+ if self.join_depth:
+ if len(path) / 2 > self.join_depth:
+ return
+ else:
+ if self.mapper in path:
+ return
+
+ #print "CREATING EAGER PATH FOR", "->".join([str(s) for s in path])
+
if parentmapper is None:
localparent = context.mapper
else:
localparent = parentmapper
- if self.mapper in context.recursion_stack:
- return
- else:
- context.recursion_stack.add(self.parent)
-
statement = context.statement
if hasattr(statement, '_outerjoin'):
@@ -487,55 +492,57 @@ class EagerLoader(AbstractRelationLoader):
break
else:
raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table))
-
+
try:
- clauses = self.clauses[parentclauses]
+ clauses = self.clauses[path]
except KeyError:
- clauses = EagerLoader.AliasedClauses(self, parentclauses)
- self.clauses[parentclauses] = clauses
-
- if context.mapper not in self.clauses_by_lead_mapper:
- self.clauses_by_lead_mapper[context.mapper] = clauses
-
+ clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.polymorphic_primaryjoin, self.parent_property.polymorphic_secondaryjoin, parentclauses)
+ self.clauses[path] = clauses
+
if self.secondaryjoin is not None:
- statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin)
+ statement._outerjoin = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin)
if self.order_by is False and self.secondary.default_order_by() is not None:
- statement.order_by(*clauses.eagersecondary.default_order_by())
+ statement.append_order_by(*clauses.secondary.default_order_by())
else:
- statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary)
- if self.order_by is False and clauses.eagertarget.default_order_by() is not None:
- statement.order_by(*clauses.eagertarget.default_order_by())
+ statement._outerjoin = towrap.outerjoin(clauses.alias, clauses.primaryjoin)
+ if self.order_by is False and clauses.alias.default_order_by() is not None:
+ statement.append_order_by(*clauses.alias.default_order_by())
- if clauses.eager_order_by:
- statement.order_by(*util.to_list(clauses.eager_order_by))
-
+ if clauses.order_by:
+ statement.append_order_by(*util.to_list(clauses.order_by))
+
statement.append_from(statement._outerjoin)
- for value in self.select_mapper.props.values():
- value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper)
- def _create_row_processor(self, selectcontext, row):
- """Create a *row processing* function that will apply eager
+ for value in self.select_mapper.iterate_properties:
+ value.setup(context, parentclauses=clauses, parentmapper=self.select_mapper)
+
+ def _create_row_decorator(self, selectcontext, row, path):
+ """Create a *row decorating* function that will apply eager
aliasing to the row.
Also check that an identity key can be retrieved from the row,
else return None.
"""
+ #print "creating row decorator for path ", "->".join([str(s) for s in path])
+
# check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option)
- if selectcontext.attributes.has_key((EagerLoader, self.parent_property)):
+ if selectcontext.attributes.has_key(("eager_row_processor", self.parent_property)):
# custom row decoration function, placed in the selectcontext by the
# contains_eager() mapper option
- decorator = selectcontext.attributes[(EagerLoader, self.parent_property)]
+ decorator = selectcontext.attributes[("eager_row_processor", self.parent_property)]
if decorator is None:
decorator = lambda row: row
else:
try:
# decorate the row according to the stored AliasedClauses for this eager load
- clauses = self.clauses_by_lead_mapper[selectcontext.mapper]
- decorator = clauses._row_decorator
+ clauses = self.clauses[path]
+ decorator = clauses.row_decorator
except KeyError, k:
# no stored AliasedClauses: eager loading was not set up in the query and
# AliasedClauses never got initialized
+ if self._should_log_debug:
+ self.logger.debug("Could not locate aliased clauses for key: " + str(path))
return None
try:
@@ -550,81 +557,80 @@ class EagerLoader(AbstractRelationLoader):
self.logger.debug("could not locate identity key from row '%s'; missing column '%s'" % (repr(decorated_row), str(k)))
return None
- def process_row(self, selectcontext, instance, row, identitykey, isnew):
- """Receive a row.
+ def create_row_processor(self, selectcontext, mapper, row):
+ selectcontext.stack.push_property(self.key)
+ path = selectcontext.stack.snapshot()
- Tell our mapper to look for a new object instance in the row,
- and attach it to a list on the parent instance.
- """
-
- if self in selectcontext.recursion_stack:
- return
-
- try:
- # check for row processor
- row_processor = selectcontext.attributes[id(self)]
- except KeyError:
- # create a row processor function and cache it in the context
- row_processor = self._create_row_processor(selectcontext, row)
- selectcontext.attributes[id(self)] = row_processor
-
- if row_processor is not None:
- decorated_row = row_processor(row)
- else:
- # row_processor was None: degrade to a lazy loader
- if self._should_log_debug:
- self.logger.debug("degrade to lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- self.parent_property._get_strategy(LazyLoader).process_row(selectcontext, instance, row, identitykey, isnew)
- return
-
- # TODO: recursion check a speed hit...? try to get a "termination point" into the AliasedClauses
- # or EagerRowAdapter ?
- selectcontext.recursion_stack.add(self)
- try:
- if not self.uselist:
- if self._should_log_debug:
- self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
- if isnew:
- # set a scalar object instance directly on the parent object,
- # bypassing SmartProperty event handlers.
- instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None)
+ row_decorator = self._create_row_decorator(selectcontext, row, path)
+ if row_decorator is not None:
+ def execute(instance, row, isnew, **flags):
+ decorated_row = row_decorator(row)
+
+ selectcontext.stack.push_property(self.key)
+
+ if not self.uselist:
+ if self._should_log_debug:
+ self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
+ if isnew:
+ # set a scalar object instance directly on the
+ # parent object, bypassing InstrumentedAttribute
+ # event handlers.
+ #
+ # FIXME: instead of...
+ sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None))
+ # bypass and set directly:
+ #instance.__dict__[self.key] = ...
+ else:
+ # call _instance on the row, even though the object has been created,
+ # so that we further descend into properties
+ self.mapper._instance(selectcontext, decorated_row, None)
else:
- # call _instance on the row, even though the object has been created,
- # so that we further descend into properties
- self.mapper._instance(selectcontext, decorated_row, None)
- else:
- if isnew:
+ if isnew:
+ if self._should_log_debug:
+ self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
+
+ collection = sessionlib.attribute_manager.init_collection(instance, self.key)
+ appender = util.UniqueAppender(collection, 'append_without_event')
+
+ # store it in the "scratch" area, which is local to this load operation.
+ selectcontext.attributes[(instance, self.key)] = appender
+ result_list = selectcontext.attributes[(instance, self.key)]
if self._should_log_debug:
- self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
- # call the SmartProperty's initialize() method to create a new, blank list
- l = getattr(instance.__class__, self.key).initialize(instance)
-
- # create an appender object which will add set-like semantics to the list
- appender = util.UniqueAppender(l.data)
-
- # store it in the "scratch" area, which is local to this load operation.
- selectcontext.attributes[(instance, self.key)] = appender
- result_list = selectcontext.attributes[(instance, self.key)]
- if self._should_log_debug:
- self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
- self.select_mapper._instance(selectcontext, decorated_row, result_list)
- finally:
- selectcontext.recursion_stack.remove(self)
+ self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
+
+ self.select_mapper._instance(selectcontext, decorated_row, result_list)
+ selectcontext.stack.pop()
+ selectcontext.stack.pop()
+ return (execute, None)
+ else:
+ self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
+ selectcontext.stack.pop()
+ return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
+
+
+ def __str__(self):
+ return str(self.parent) + "." + self.key
+
EagerLoader.logger = logging.class_logger(EagerLoader)
class EagerLazyOption(StrategizedOption):
- def __init__(self, key, lazy=True):
+ def __init__(self, key, lazy=True, chained=False):
super(EagerLazyOption, self).__init__(key)
self.lazy = lazy
-
- def process_query_property(self, context, prop):
+ self.chained = chained
+
+ def is_chained(self):
+ return not self.lazy and self.chained
+
+ def process_query_property(self, context, properties):
if self.lazy:
- if prop in context.eager_loaders:
- context.eager_loaders.remove(prop)
+ if properties[-1] in context.eager_loaders:
+ context.eager_loaders.remove(properties[-1])
else:
- context.eager_loaders.add(prop)
- super(EagerLazyOption, self).process_query_property(context, prop)
+ for prop in properties:
+ context.eager_loaders.add(prop)
+ super(EagerLazyOption, self).process_query_property(context, properties)
def get_strategy_class(self):
if self.lazy:
@@ -636,24 +642,39 @@ class EagerLazyOption(StrategizedOption):
EagerLazyOption.logger = logging.class_logger(EagerLazyOption)
+# TODO: enable FetchMode option. currently
+# this class does nothing. will require Query
+# to swich between using its "polymorphic" selectable
+# and its regular selectable in order to make decisions
+# (therefore might require that FetchModeOperation is performed
+# only as the first operation on a Query.)
+class FetchModeOption(PropertyOption):
+ def __init__(self, key, type):
+ super(FetchModeOption, self).__init__(key)
+ if type not in ('join', 'select'):
+ raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'")
+ self.type = type
+
+ def process_selection_property(self, context, properties):
+ context.attributes[('fetchmode', properties[-1])] = self.type
+
class RowDecorateOption(PropertyOption):
def __init__(self, key, decorator=None, alias=None):
super(RowDecorateOption, self).__init__(key)
self.decorator = decorator
self.alias = alias
- def process_selection_property(self, context, property):
+ def process_selection_property(self, context, properties):
if self.alias is not None and self.decorator is None:
if isinstance(self.alias, basestring):
- self.alias = property.target.alias(self.alias)
+ self.alias = properties[-1].target.alias(self.alias)
def decorate(row):
d = {}
- for c in property.target.columns:
+ for c in properties[-1].target.columns:
d[c] = row[self.alias.corresponding_column(c)]
return d
self.decorator = decorate
- context.attributes[(EagerLoader, property)] = self.decorator
+ context.attributes[("eager_row_processor", properties[-1])] = self.decorator
RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
-
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index 8c70f8cf8..cf48202b0 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -4,17 +4,16 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-
-from sqlalchemy import sql, schema, exceptions
-from sqlalchemy import logging
-from sqlalchemy.orm import util as mapperutil
-
"""Contains the ClauseSynchronizer class, which is used to map
attributes between two objects in a manner corresponding to a SQL
clause that compares column values.
"""
+from sqlalchemy import sql, schema, exceptions
+from sqlalchemy import logging
+from sqlalchemy.orm import util as mapperutil
+import operator
+
ONETOMANY = 0
MANYTOONE = 1
MANYTOMANY = 2
@@ -44,7 +43,7 @@ class ClauseSynchronizer(object):
def compile_binary(binary):
"""Assemble a SyncRule given a single binary condition."""
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
source_column = None
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index c6b0b2689..f59042810 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -19,15 +19,17 @@ new, dirty, or deleted and provides the capability to flush all those
changes at once.
"""
-from sqlalchemy import util, logging, topological
-from sqlalchemy.orm import attributes
+from sqlalchemy import util, logging, topological, exceptions
+from sqlalchemy.orm import attributes, interfaces
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper, class_mapper
-from sqlalchemy.exceptions import *
+from sqlalchemy.orm.mapper import object_mapper
import StringIO
import weakref
-class UOWEventHandler(attributes.AttributeExtension):
+# Load lazily
+object_session = None
+
+class UOWEventHandler(interfaces.AttributeExtension):
"""An event handler added to all class attributes which handles
session operations.
"""
@@ -37,52 +39,46 @@ class UOWEventHandler(attributes.AttributeExtension):
self.class_ = class_
self.cascade = cascade
- def append(self, event, obj, item):
+ def append(self, obj, item, initiator):
# process "save_update" cascade rules for when an instance is appended to the list of another instance
sess = object_session(obj)
if sess is not None:
if self.cascade is not None and self.cascade.save_update and item not in sess:
mapper = object_mapper(obj)
- prop = mapper.props[self.key]
+ prop = mapper.get_property(self.key)
ename = prop.mapper.entity_name
sess.save_or_update(item, entity_name=ename)
- def delete(self, event, obj, item):
+ def remove(self, obj, item, initiator):
# currently no cascade rules for removing an item from a list
# (i.e. it stays in the Session)
pass
- def set(self, event, obj, newvalue, oldvalue):
+ def set(self, obj, newvalue, oldvalue, initiator):
# process "save_update" cascade rules for when an instance is attached to another instance
sess = object_session(obj)
if sess is not None:
if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess:
mapper = object_mapper(obj)
- prop = mapper.props[self.key]
+ prop = mapper.get_property(self.key)
ename = prop.mapper.entity_name
sess.save_or_update(newvalue, entity_name=ename)
-class UOWProperty(attributes.InstrumentedAttribute):
- """Override ``InstrumentedAttribute`` to provide an extra
- ``AttributeExtension`` to all managed attributes as well as the
- `property` property.
- """
-
- def __init__(self, manager, class_, key, uselist, callable_, typecallable, cascade=None, extension=None, **kwargs):
- extension = util.to_list(extension or [])
- extension.insert(0, UOWEventHandler(key, class_, cascade=cascade))
- super(UOWProperty, self).__init__(manager, key, uselist, callable_, typecallable, extension=extension,**kwargs)
- self.class_ = class_
-
- property = property(lambda s:class_mapper(s.class_).props[s.key], doc="returns the MapperProperty object associated with this property")
class UOWAttributeManager(attributes.AttributeManager):
"""Override ``AttributeManager`` to provide the ``UOWProperty``
instance for all ``InstrumentedAttributes``.
"""
- def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs):
- return UOWProperty(self, class_, key, uselist, callable_, typecallable, **kwargs)
+ def create_prop(self, class_, key, uselist, callable_, typecallable,
+ cascade=None, extension=None, **kwargs):
+ extension = util.to_list(extension or [])
+ extension.insert(0, UOWEventHandler(key, class_, cascade=cascade))
+
+ return super(UOWAttributeManager, self).create_prop(
+ class_, key, uselist, callable_, typecallable,
+ extension=extension, **kwargs)
+
class UnitOfWork(object):
"""Main UOW object which stores lists of dirty/new/deleted objects.
@@ -122,7 +118,7 @@ class UnitOfWork(object):
def _validate_obj(self, obj):
if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \
(not hasattr(obj, '_instance_key') and obj not in self.new):
- raise InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj))
+ raise exceptions.InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj))
def _is_valid(self, obj):
if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \
@@ -138,7 +134,7 @@ class UnitOfWork(object):
self.new.remove(obj)
if not hasattr(obj, '_instance_key'):
mapper = object_mapper(obj)
- obj._instance_key = mapper.instance_key(obj)
+ obj._instance_key = mapper.identity_key_from_instance(obj)
if hasattr(obj, '_sa_insert_order'):
delattr(obj, '_sa_insert_order')
self.identity_map[obj._instance_key] = obj
@@ -148,7 +144,7 @@ class UnitOfWork(object):
"""register the given object as 'new' (i.e. unsaved) within this unit of work."""
if hasattr(obj, '_instance_key'):
- raise InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj))
+ raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj))
if obj not in self.new:
self.new.add(obj)
obj._sa_insert_order = len(self.new)
@@ -204,14 +200,14 @@ class UnitOfWork(object):
for obj in self.deleted.intersection(objset).difference(processed):
flush_context.register_object(obj, isdelete=True)
- trans = session.create_transaction(autoflush=False)
- flush_context.transaction = trans
+ session.create_transaction(autoflush=False)
+ flush_context.transaction = session.transaction
try:
flush_context.execute()
except:
- trans.rollback()
+ session.rollback()
raise
- trans.commit()
+ session.commit()
flush_context.post_exec()
@@ -228,6 +224,7 @@ class UOWTransaction(object):
def __init__(self, uow, session):
self.uow = uow
self.session = session
+ self.mapper_flush_opts = session._mapper_flush_opts
# stores tuples of mapper/dependent mapper pairs,
# representing a partial ordering fed into topological sort
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 3b3b9b7ed..d248c0dd0 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -4,7 +4,8 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions
+from sqlalchemy import sql, util, exceptions, sql_util
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_PASS
all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire", "none"])
@@ -89,8 +90,6 @@ class TranslatingDict(dict):
def __translate_col(self, col):
ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
-# if col is not ourcol and ourcol is not None:
-# print "TD TRANSLATING ", col, "TO", ourcol
if ourcol is None:
return col
else:
@@ -111,6 +110,56 @@ class TranslatingDict(dict):
def setdefault(self, col, value):
return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
+class ExtensionCarrier(MapperExtension):
+ def __init__(self, _elements=None):
+ self.__elements = _elements or []
+
+ def copy(self):
+ return ExtensionCarrier(list(self.__elements))
+
+ def __iter__(self):
+ return iter(self.__elements)
+
+ def insert(self, extension):
+ """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
+
+ self.__elements.insert(0, extension)
+
+ def append(self, extension):
+ """Append a MapperExtension at the end of this ExtensionCarrier's list."""
+
+ self.__elements.append(extension)
+
+ def _create_do(funcname):
+ def _do(self, *args, **kwargs):
+ for elem in self.__elements:
+ ret = getattr(elem, funcname)(*args, **kwargs)
+ if ret is not EXT_PASS:
+ return ret
+ else:
+ return EXT_PASS
+ return _do
+
+ init_instance = _create_do('init_instance')
+ init_failed = _create_do('init_failed')
+ dispose_class = _create_do('dispose_class')
+ get_session = _create_do('get_session')
+ load = _create_do('load')
+ get = _create_do('get')
+ get_by = _create_do('get_by')
+ select_by = _create_do('select_by')
+ select = _create_do('select')
+ translate_row = _create_do('translate_row')
+ create_instance = _create_do('create_instance')
+ append_result = _create_do('append_result')
+ populate_instance = _create_do('populate_instance')
+ before_insert = _create_do('before_insert')
+ before_update = _create_do('before_update')
+ after_update = _create_do('after_update')
+ after_insert = _create_do('after_insert')
+ before_delete = _create_do('before_delete')
+ after_delete = _create_do('after_delete')
+
class BinaryVisitor(sql.ClauseVisitor):
def __init__(self, func):
self.func = func
@@ -118,6 +167,138 @@ class BinaryVisitor(sql.ClauseVisitor):
def visit_binary(self, binary):
self.func(binary)
+class AliasedClauses(object):
+ """Creates aliases of a mapped tables for usage in ORM queries.
+ """
+
+ def __init__(self, mapped_table, alias=None):
+ if alias:
+ self.alias = alias
+ else:
+ self.alias = mapped_table.alias()
+ self.mapped_table = mapped_table
+ self.extra_cols = {}
+ self.row_decorator = self._create_row_adapter()
+
+ def aliased_column(self, column):
+ """return the aliased version of the given column, creating a new label for it if not already
+ present in this AliasedClauses."""
+
+ conv = self.alias.corresponding_column(column, raiseerr=False)
+ if conv:
+ return conv
+
+ if column in self.extra_cols:
+ return self.extra_cols[column]
+
+ aliased_column = column
+ # for column-level subqueries, swap out its selectable with our
+ # eager version as appropriate, and manually build the
+ # "correlation" list of the subquery.
+ class ModifySubquery(sql.ClauseVisitor):
+ def visit_select(s, select):
+ select._should_correlate = False
+ select.append_correlation(self.alias)
+ aliased_column = sql_util.ClauseAdapter(self.alias).chain(ModifySubquery()).traverse(aliased_column, clone=True)
+ aliased_column = aliased_column.label(None)
+ self.row_decorator.map[column] = aliased_column
+ # TODO: this is a little hacky
+ for attr in ('name', '_label'):
+ if hasattr(column, attr):
+ self.row_decorator.map[getattr(column, attr)] = aliased_column
+ self.extra_cols[column] = aliased_column
+ return aliased_column
+
+ def adapt_clause(self, clause):
+ return self.aliased_column(clause)
+# return sql_util.ClauseAdapter(self.alias).traverse(clause, clone=True)
+
+ def _create_row_adapter(self):
+ """Return a callable which,
+ when passed a RowProxy, will return a new dict-like object
+ that translates Column objects to that of this object's Alias before calling upon the row.
+
+ This allows a regular Table to be used to target columns in a row that was in reality generated from an alias
+ of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form
+ of the table.
+ """
+ class AliasedRowAdapter(object):
+ def __init__(self, row):
+ self.row = row
+ def __contains__(self, key):
+ return key in map or key in self.row
+ def has_key(self, key):
+ return key in self
+ def __getitem__(self, key):
+ if key in map:
+ key = map[key]
+ return self.row[key]
+ def keys(self):
+ return map.keys()
+ map = {}
+ for c in self.alias.c:
+ parent = self.mapped_table.corresponding_column(c)
+ map[parent] = c
+ map[parent._label] = c
+ map[parent.name] = c
+ for c in self.extra_cols:
+ map[c] = self.extra_cols[c]
+ # TODO: this is a little hacky
+ for attr in ('name', '_label'):
+ if hasattr(c, attr):
+ map[getattr(c, attr)] = self.extra_cols[c]
+
+ AliasedRowAdapter.map = map
+ return AliasedRowAdapter
+
+
+class PropertyAliasedClauses(AliasedClauses):
+ """extends AliasedClauses to add support for primary/secondary joins on a relation()."""
+
+ def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None):
+ super(PropertyAliasedClauses, self).__init__(prop.select_table)
+
+ self.parentclauses = parentclauses
+ if parentclauses is not None:
+ self.path = parentclauses.path + (prop.parent, prop.key)
+ else:
+ self.path = (prop.parent, prop.key)
+
+ self.prop = prop
+
+ if prop.secondary:
+ self.secondary = prop.secondary.alias()
+ if parentclauses is not None:
+ aliasizer = sql_util.ClauseAdapter(self.alias).\
+ chain(sql_util.ClauseAdapter(self.secondary)).\
+ chain(sql_util.ClauseAdapter(parentclauses.alias))
+ else:
+ aliasizer = sql_util.ClauseAdapter(self.alias).\
+ chain(sql_util.ClauseAdapter(self.secondary))
+ self.secondaryjoin = aliasizer.traverse(secondaryjoin, clone=True)
+ self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True)
+ else:
+ if parentclauses is not None:
+ aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
+ aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side))
+ else:
+ aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
+ self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True)
+ self.secondary = None
+ self.secondaryjoin = None
+
+ if prop.order_by:
+ self.order_by = sql_util.ClauseAdapter(self.alias).copy_and_process(util.to_list(prop.order_by))
+ else:
+ self.order_by = None
+
+ mapper = property(lambda self:self.prop.mapper)
+ table = property(lambda self:self.prop.select_table)
+
+ def __str__(self):
+ return "->".join([str(s) for s in self.path])
+
+
def instance_str(instance):
"""Return a string describing an instance."""
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index 8670464a0..f86e14ab1 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -13,7 +13,7 @@ automatically, based on module type and connect arguments, simply by
calling regular DBAPI connect() methods.
"""
-import weakref, string, time, sys, traceback
+import weakref, time
try:
import cPickle as pickle
except:
@@ -190,6 +190,7 @@ class _ConnectionRecord(object):
def __init__(self, pool):
self.__pool = pool
self.connection = self.__connect()
+ self.properties = {}
def close(self):
if self.connection is not None:
@@ -207,10 +208,12 @@ class _ConnectionRecord(object):
def get_connection(self):
if self.connection is None:
self.connection = self.__connect()
+ self.properties.clear()
elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle):
self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
self.__close()
self.connection = self.__connect()
+ self.properties.clear()
return self.connection
def __close(self):
@@ -257,6 +260,21 @@ class _ConnectionFairy(object):
_logger = property(lambda self: self._pool.logger)
is_valid = property(lambda self:self.connection is not None)
+
+ def _get_properties(self):
+ """A property collection unique to this DBAPI connection."""
+
+ try:
+ return self._connection_record.properties
+ except AttributeError:
+ if self.connection is None:
+ raise exceptions.InvalidRequestError("This connection is closed")
+ try:
+ return self._detatched_properties
+ except AttributeError:
+ self._detatched_properties = value = {}
+ return value
+ properties = property(_get_properties)
def invalidate(self, e=None):
"""Mark this connection as invalidated.
@@ -301,6 +319,8 @@ class _ConnectionFairy(object):
if self._connection_record is not None:
self._connection_record.connection = None
self._pool.do_return_conn(self._connection_record)
+ self._detatched_properties = \
+ self._connection_record.properties.copy()
self._connection_record = None
def close_open_cursors(self):
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index f6ad52adc..3faa3b89c 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -17,17 +17,19 @@ objects as well as the visitor interface, so that the schema package
*plugs in* to the SQL package.
"""
-from sqlalchemy import sql, types, exceptions, util, databases
+from sqlalchemy import sql, types, exceptions,util, databases
import sqlalchemy
-import copy, re, string
+import re, string, inspect
__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint',
'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'DefaultGenerator', 'Constraint',
- 'MetaData', 'ThreadLocalMetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
+ 'MetaData', 'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
class SchemaItem(object):
"""Base class for items that define a database schema."""
+ __metaclass__ = sql._FigureVisitName
+
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
@@ -69,15 +71,7 @@ class SchemaItem(object):
m = self._derived_metadata()
return m and m.bind or None
- def get_engine(self):
- """Return the engine or raise an error if no engine.
-
- Deprecated. use the "bind" attribute.
- """
-
- return self._get_engine(raiseerr=True)
-
- def _set_casing_strategy(self, name, kwargs, keyname='case_sensitive'):
+ def _set_casing_strategy(self, kwargs, keyname='case_sensitive'):
"""Set the "case_sensitive" argument sent via keywords to the item's constructor.
For the purposes of Table's 'schema' property, the name of the
@@ -85,7 +79,7 @@ class SchemaItem(object):
"""
setattr(self, '_%s_setting' % keyname, kwargs.pop(keyname, None))
- def _determine_case_sensitive(self, name, keyname='case_sensitive'):
+ def _determine_case_sensitive(self, keyname='case_sensitive'):
"""Determine the `case_sensitive` value for this item.
For the purposes of Table's `schema` property, the name of the
@@ -111,16 +105,22 @@ class SchemaItem(object):
return True
def _get_case_sensitive(self):
+ """late-compile the 'case-sensitive' setting when first accessed.
+
+ typically the SchemaItem will be assembled into its final structure
+ of other SchemaItems at this point, whereby it can attain this setting
+ from its containing SchemaItem if not defined locally.
+ """
+
try:
return self.__case_sensitive
except AttributeError:
- self.__case_sensitive = self._determine_case_sensitive(self.name)
+ self.__case_sensitive = self._determine_case_sensitive()
return self.__case_sensitive
case_sensitive = property(_get_case_sensitive)
- engine = property(lambda s:s._get_engine())
metadata = property(lambda s:s._derived_metadata())
- bind = property(lambda s:s.engine)
+ bind = property(lambda s:s._get_engine())
def _get_table_key(name, schema):
if schema is None:
@@ -128,30 +128,16 @@ def _get_table_key(name, schema):
else:
return schema + "." + name
-class _TableSingleton(type):
+class _TableSingleton(sql._FigureVisitName):
"""A metaclass used by the ``Table`` object to provide singleton behavior."""
def __call__(self, name, metadata, *args, **kwargs):
- if isinstance(metadata, sql.Executor):
- # backwards compatibility - get a BoundSchema associated with the engine
- engine = metadata
- if not hasattr(engine, '_legacy_metadata'):
- engine._legacy_metadata = MetaData(engine)
- metadata = engine._legacy_metadata
- elif metadata is not None and not isinstance(metadata, MetaData):
- # they left MetaData out, so assume its another SchemaItem, add it to *args
- args = list(args)
- args.insert(0, metadata)
- metadata = None
-
- if metadata is None:
- metadata = default_metadata
-
schema = kwargs.get('schema', None)
autoload = kwargs.pop('autoload', False)
autoload_with = kwargs.pop('autoload_with', False)
mustexist = kwargs.pop('mustexist', False)
useexisting = kwargs.pop('useexisting', False)
+ include_columns = kwargs.pop('include_columns', None)
key = _get_table_key(name, schema)
try:
table = metadata.tables[key]
@@ -170,9 +156,9 @@ class _TableSingleton(type):
if autoload:
try:
if autoload_with:
- autoload_with.reflecttable(table)
+ autoload_with.reflecttable(table, include_columns=include_columns)
else:
- metadata._get_engine(raiseerr=True).reflecttable(table)
+ metadata._get_engine(raiseerr=True).reflecttable(table, include_columns=include_columns)
except exceptions.NoSuchTableError:
del metadata.tables[key]
raise
@@ -187,7 +173,7 @@ class Table(SchemaItem, sql.TableClause):
This subclasses ``sql.TableClause`` to provide a table that is
associated with an instance of ``MetaData``, which in turn
- may be associated with an instance of ``SQLEngine``.
+ may be associated with an instance of ``Engine``.
Whereas ``TableClause`` represents a table as its used in an SQL
expression, ``Table`` represents a table as it exists in a
@@ -232,16 +218,28 @@ class Table(SchemaItem, sql.TableClause):
options include:
schema
- Defaults to None: the *schema name* for this table, which is
+ The *schema name* for this table, which is
required if the table resides in a schema other than the
default selected schema for the engine's database
- connection.
+ connection. Defaults to ``None``.
autoload
Defaults to False: the Columns for this table should be
reflected from the database. Usually there will be no
Column objects in the constructor if this property is set.
+ autoload_with
+ if autoload==True, this is an optional Engine or Connection
+ instance to be used for the table reflection. If ``None``,
+ the underlying MetaData's bound connectable will be used.
+
+ include_columns
+ A list of strings indicating a subset of columns to be
+ loaded via the ``autoload`` operation; table columns who
+ aren't present in this list will not be represented on the resulting
+ ``Table`` object. Defaults to ``None`` which indicates all
+ columns should be reflected.
+
mustexist
Defaults to False: indicates that this Table must already
have been defined elsewhere in the application, else an
@@ -293,8 +291,8 @@ class Table(SchemaItem, sql.TableClause):
self.fullname = self.name
self.owner = kwargs.pop('owner', None)
- self._set_casing_strategy(name, kwargs)
- self._set_casing_strategy(self.schema or '', kwargs, keyname='case_sensitive_schema')
+ self._set_casing_strategy(kwargs)
+ self._set_casing_strategy(kwargs, keyname='case_sensitive_schema')
if len([k for k in kwargs if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]):
raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
@@ -302,6 +300,8 @@ class Table(SchemaItem, sql.TableClause):
# store extra kwargs, which should only contain db-specific options
self.kwargs = kwargs
+ key = property(lambda self:_get_table_key(self.name, self.schema))
+
def _export_columns(self, columns=None):
# override FromClause's collection initialization logic; TableClause and Table
# implement it differently
@@ -311,7 +311,7 @@ class Table(SchemaItem, sql.TableClause):
try:
return getattr(self, '_case_sensitive_schema')
except AttributeError:
- setattr(self, '_case_sensitive_schema', self._determine_case_sensitive(self.schema or '', keyname='case_sensitive_schema'))
+ setattr(self, '_case_sensitive_schema', self._determine_case_sensitive(keyname='case_sensitive_schema'))
return getattr(self, '_case_sensitive_schema')
case_sensitive_schema = property(_get_case_sensitive_schema)
@@ -361,36 +361,28 @@ class Table(SchemaItem, sql.TableClause):
else:
return []
- def exists(self, bind=None, connectable=None):
+ def exists(self, bind=None):
"""Return True if this table exists."""
- if connectable is not None:
- bind = connectable
-
if bind is None:
bind = self._get_engine(raiseerr=True)
def do(conn):
- e = conn.engine
- return e.dialect.has_table(conn, self.name, schema=self.schema)
+ return conn.dialect.has_table(conn, self.name, schema=self.schema)
return bind.run_callable(do)
- def create(self, bind=None, checkfirst=False, connectable=None):
+ def create(self, bind=None, checkfirst=False):
"""Issue a ``CREATE`` statement for this table.
See also ``metadata.create_all()``."""
- if connectable is not None:
- bind = connectable
self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self])
- def drop(self, bind=None, checkfirst=False, connectable=None):
+ def drop(self, bind=None, checkfirst=False):
"""Issue a ``DROP`` statement for this table.
See also ``metadata.drop_all()``."""
- if connectable is not None:
- bind = connectable
self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self])
def tometadata(self, metadata, schema=None):
@@ -417,7 +409,7 @@ class Column(SchemaItem, sql._ColumnClause):
``TableClause``/``Table``.
"""
- def __init__(self, name, type, *args, **kwargs):
+ def __init__(self, name, type_, *args, **kwargs):
"""Construct a new ``Column`` object.
Arguments are:
@@ -426,7 +418,7 @@ class Column(SchemaItem, sql._ColumnClause):
The name of this column. This should be the identical name
as it appears, or will appear, in the database.
- type
+ type\_
The ``TypeEngine`` for this column. This can be any
subclass of ``types.AbstractType``, including the
database-agnostic types defined in the types module,
@@ -516,7 +508,7 @@ class Column(SchemaItem, sql._ColumnClause):
identifier contains mixed case.
"""
- super(Column, self).__init__(name, None, type)
+ super(Column, self).__init__(name, None, type_)
self.args = args
self.key = kwargs.pop('key', name)
self._primary_key = kwargs.pop('primary_key', False)
@@ -526,7 +518,7 @@ class Column(SchemaItem, sql._ColumnClause):
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
self.quote = kwargs.pop('quote', False)
- self._set_casing_strategy(name, kwargs)
+ self._set_casing_strategy(kwargs)
self.onupdate = kwargs.pop('onupdate', None)
self.autoincrement = kwargs.pop('autoincrement', True)
self.constraints = util.Set()
@@ -631,12 +623,13 @@ class Column(SchemaItem, sql._ColumnClause):
This is a copy of this ``Column`` referenced by a different parent
(such as an alias or select statement).
"""
+
fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
c.table = selectable
c.orig_set = self.orig_set
- c._distance = self._distance + 1
c.__originating_column = self.__originating_column
+ c._distance = self._distance + 1
if not c._is_oid:
selectable.columns.add(c)
if self.primary_key:
@@ -749,18 +742,14 @@ class ForeignKey(SchemaItem):
raise exceptions.ArgumentError("Could not create ForeignKey '%s' on table '%s': table '%s' has no column named '%s'" % (self._colspec, parenttable.name, table.name, str(e)))
else:
self._column = self._colspec
+
# propigate TypeEngine to parent if it didnt have one
- if self.parent.type is types.NULLTYPE:
+ if isinstance(self.parent.type, types.NullType):
self.parent.type = self._column.type
return self._column
column = property(lambda s: s._init_column())
- def accept_visitor(self, visitor):
- """Call the `visit_foreign_key` method on the given visitor."""
-
- visitor.visit_foreign_key(self)
-
def _get_parent(self):
return self.parent
@@ -802,7 +791,7 @@ class DefaultGenerator(SchemaItem):
def execute(self, bind=None, **kwargs):
if bind is None:
bind = self._get_engine(raiseerr=True)
- return bind.execute_default(self, **kwargs)
+ return bind._execute_default(self, **kwargs)
def __repr__(self):
return "DefaultGenerator()"
@@ -814,9 +803,6 @@ class PassiveDefault(DefaultGenerator):
super(PassiveDefault, self).__init__(**kwargs)
self.arg = arg
- def accept_visitor(self, visitor):
- return visitor.visit_passive_default(self)
-
def __repr__(self):
return "PassiveDefault(%s)" % repr(self.arg)
@@ -829,15 +815,26 @@ class ColumnDefault(DefaultGenerator):
def __init__(self, arg, **kwargs):
super(ColumnDefault, self).__init__(**kwargs)
- self.arg = arg
-
- def accept_visitor(self, visitor):
- """Call the visit_column_default method on the given visitor."""
+ if callable(arg):
+ if not inspect.isfunction(arg):
+ self.arg = lambda ctx: arg()
+ else:
+ argspec = inspect.getargspec(arg)
+ if len(argspec[0]) == 0:
+ self.arg = lambda ctx: arg()
+ elif len(argspec[0]) != 1:
+ raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments")
+ else:
+ self.arg = arg
+ else:
+ self.arg = arg
+ def _visit_name(self):
if self.for_update:
- return visitor.visit_column_onupdate(self)
+ return "column_onupdate"
else:
- return visitor.visit_column_default(self)
+ return "column_default"
+ __visit_name__ = property(_visit_name)
def __repr__(self):
return "ColumnDefault(%s)" % repr(self.arg)
@@ -852,7 +849,7 @@ class Sequence(DefaultGenerator):
self.increment = increment
self.optional=optional
self.quote = quote
- self._set_casing_strategy(name, kwargs)
+ self._set_casing_strategy(kwargs)
def __repr__(self):
return "Sequence(%s)" % string.join(
@@ -864,20 +861,16 @@ class Sequence(DefaultGenerator):
super(Sequence, self)._set_parent(column)
column.sequence = self
- def create(self, bind=None):
+ def create(self, bind=None, checkfirst=True):
if bind is None:
bind = self._get_engine(raiseerr=True)
- bind.create(self)
+ bind.create(self, checkfirst=checkfirst)
- def drop(self, bind=None):
+ def drop(self, bind=None, checkfirst=True):
if bind is None:
bind = self._get_engine(raiseerr=True)
- bind.drop(self)
-
- def accept_visitor(self, visitor):
- """Call the visit_seauence method on the given visitor."""
+ bind.drop(self, checkfirst=checkfirst)
- return visitor.visit_sequence(self)
class Constraint(SchemaItem):
"""Represent a table-level ``Constraint`` such as a composite primary key, foreign key, or unique constraint.
@@ -891,7 +884,7 @@ class Constraint(SchemaItem):
self.columns = sql.ColumnCollection()
def __contains__(self, x):
- return x in self.columns
+ return self.columns.contains_column(x)
def keys(self):
return self.columns.keys()
@@ -916,11 +909,12 @@ class CheckConstraint(Constraint):
super(CheckConstraint, self).__init__(name)
self.sqltext = sqltext
- def accept_visitor(self, visitor):
+ def _visit_name(self):
if isinstance(self.parent, Table):
- visitor.visit_check_constraint(self)
+ return "check_constraint"
else:
- visitor.visit_column_check_constraint(self)
+ return "column_check_constraint"
+ __visit_name__ = property(_visit_name)
def _set_parent(self, parent):
self.parent = parent
@@ -949,9 +943,6 @@ class ForeignKeyConstraint(Constraint):
for (c, r) in zip(self.__colnames, self.__refcolnames):
self.append_element(c,r)
- def accept_visitor(self, visitor):
- visitor.visit_foreign_key_constraint(self)
-
def append_element(self, col, refcol):
fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
fk._set_parent(self.table.c[col])
@@ -975,9 +966,6 @@ class PrimaryKeyConstraint(Constraint):
for c in self.__colnames:
self.append_column(table.c[c])
- def accept_visitor(self, visitor):
- visitor.visit_primary_key_constraint(self)
-
def add(self, col):
self.append_column(col)
@@ -1009,9 +997,6 @@ class UniqueConstraint(Constraint):
def append_column(self, col):
self.columns.add(col)
- def accept_visitor(self, visitor):
- visitor.visit_unique_constraint(self)
-
def copy(self):
return UniqueConstraint(name=self.name, *self.__colnames)
@@ -1075,22 +1060,19 @@ class Index(SchemaItem):
% (self.name, column))
self.columns.append(column)
- def create(self, connectable=None):
- if connectable is not None:
- connectable.create(self)
+ def create(self, bind=None):
+ if bind is not None:
+ bind.create(self)
else:
self._get_engine(raiseerr=True).create(self)
return self
- def drop(self, connectable=None):
- if connectable is not None:
- connectable.drop(self)
+ def drop(self, bind=None):
+ if bind is not None:
+ bind.drop(self)
else:
self._get_engine(raiseerr=True).drop(self)
- def accept_visitor(self, visitor):
- visitor.visit_index(self)
-
def __str__(self):
return repr(self)
@@ -1103,69 +1085,34 @@ class Index(SchemaItem):
class MetaData(SchemaItem):
"""Represent a collection of Tables and their associated schema constructs."""
- def __init__(self, engine_or_url=None, url=None, bind=None, engine=None, **kwargs):
+ __visit_name__ = 'metadata'
+
+ def __init__(self, bind=None, **kwargs):
"""create a new MetaData object.
bind
an Engine, or a string or URL instance which will be passed
- to create_engine(), along with \**kwargs - this MetaData will
- be bound to the resulting engine.
+ to create_engine(), this MetaData will be bound to the resulting
+ engine.
- engine_or_url
- deprecated; a synonym for 'bind'
-
- url
- deprecated. a string or URL instance which will be passed to
- create_engine(), along with \**kwargs - this MetaData will be
- bound to the resulting engine.
-
- engine
- deprecated. an Engine instance to which this MetaData will
- be bound.
-
case_sensitive
- popped from \**kwargs, indicates default case sensitive
- setting for all contained objects. defaults to True.
+ popped from \**kwargs, indicates default case sensitive setting for
+ all contained objects. defaults to True.
- name
- deprecated, optional name for this MetaData instance.
-
- """
-
- # transition from <= 0.3.8 signature:
- # MetaData(name=None, url=None, engine=None)
- # to 0.4 signature:
- # MetaData(engine_or_url=None)
- name = kwargs.get('name', None)
- if engine_or_url is None:
- engine_or_url = url or bind or engine
- elif 'name' in kwargs:
- engine_or_url = engine_or_url or bind or engine or url
- else:
- import sqlalchemy.engine as engine
- import sqlalchemy.engine.url as url
- if (not isinstance(engine_or_url, url.URL) and
- not isinstance(engine_or_url, engine.Connectable)):
- try:
- url.make_url(engine_or_url)
- except exceptions.ArgumentError:
- # nope, must have been a name as 1st positional
- name, engine_or_url = engine_or_url, (url or engine or bind)
- kwargs.pop('name', None)
+ """
self.tables = {}
- self.name = name
- self._bind = None
- self._set_casing_strategy(name, kwargs)
- if engine_or_url:
- self.connect(engine_or_url, **kwargs)
+ self._set_casing_strategy(kwargs)
+ self.bind = bind
+
+ def __repr__(self):
+ return 'MetaData(%r)' % self.bind
def __getstate__(self):
- return {'tables':self.tables, 'name':self.name, 'casesensitive':self._case_sensitive_setting}
-
+ return {'tables':self.tables, 'casesensitive':self._case_sensitive_setting}
+
def __setstate__(self, state):
self.tables = state['tables']
- self.name = state['name']
self._case_sensitive_setting = state['casesensitive']
self._bind = None
@@ -1173,7 +1120,7 @@ class MetaData(SchemaItem):
"""return True if this MetaData is bound to an Engine."""
return self._bind is not None
- def connect(self, bind=None, **kwargs):
+ def connect(self, bind, **kwargs):
"""bind this MetaData to an Engine.
DEPRECATED. use metadata.bind = <engine> or metadata.bind = <url>.
@@ -1184,13 +1131,8 @@ class MetaData(SchemaItem):
produce the engine which to connect to. otherwise connects
directly to the given Engine.
- engine_or_url
- deprecated. synonymous with "bind"
-
"""
- if bind is None:
- bind = kwargs.pop('engine_or_url', None)
from sqlalchemy.engine.url import URL
if isinstance(bind, (basestring, URL)):
self._bind = sqlalchemy.create_engine(bind, **kwargs)
@@ -1202,6 +1144,10 @@ class MetaData(SchemaItem):
def clear(self):
self.tables.clear()
+ def remove(self, table):
+ # TODO: scan all other tables and remove FK _column
+ del self.tables[table.key]
+
def table_iterator(self, reverse=True, tables=None):
import sqlalchemy.sql_util
if tables is None:
@@ -1214,7 +1160,7 @@ class MetaData(SchemaItem):
def _get_parent(self):
return None
- def create_all(self, bind=None, tables=None, checkfirst=True, connectable=None):
+ def create_all(self, bind=None, tables=None, checkfirst=True):
"""Create all tables stored in this metadata.
This will conditionally create tables depending on if they do
@@ -1224,21 +1170,16 @@ class MetaData(SchemaItem):
A ``Connectable`` used to access the database; if None, uses
the existing bind on this ``MetaData``, if any.
- connectable
- deprecated. synonymous with "bind"
-
tables
Optional list of tables, which is a subset of the total
tables in the ``MetaData`` (others are ignored).
"""
- if connectable is not None:
- bind = connectable
if bind is None:
bind = self._get_engine(raiseerr=True)
bind.create(self, checkfirst=checkfirst, tables=tables)
- def drop_all(self, bind=None, tables=None, checkfirst=True, connectable=None):
+ def drop_all(self, bind=None, tables=None, checkfirst=True):
"""Drop all tables stored in this metadata.
This will conditionally drop tables depending on if they
@@ -1248,23 +1189,15 @@ class MetaData(SchemaItem):
A ``Connectable`` used to access the database; if None, uses
the existing bind on this ``MetaData``, if any.
- connectable
- deprecated. synonymous with "bind"
-
tables
Optional list of tables, which is a subset of the total
tables in the ``MetaData`` (others are ignored).
"""
- if connectable is not None:
- bind = connectable
if bind is None:
bind = self._get_engine(raiseerr=True)
bind.drop(self, checkfirst=checkfirst, tables=tables)
- def accept_visitor(self, visitor):
- visitor.visit_metadata(self)
-
def _derived_metadata(self):
return self
@@ -1276,27 +1209,22 @@ class MetaData(SchemaItem):
return None
return self._bind
-
-class BoundMetaData(MetaData):
- """Deprecated. Use ``MetaData``."""
-
- def __init__(self, engine_or_url, name=None, **kwargs):
- super(BoundMetaData, self).__init__(engine_or_url=engine_or_url,
- name=name, **kwargs)
-
-
class ThreadLocalMetaData(MetaData):
- """A ``MetaData`` that binds to multiple ``Engine`` implementations on a thread-local basis."""
+ """Build upon ``MetaData`` to provide the capability to bind to
+multiple ``Engine`` implementations on a dynamically alterable,
+thread-local basis.
+ """
+
+ __visit_name__ = 'metadata'
- def __init__(self, name=None, **kwargs):
+ def __init__(self, **kwargs):
self.context = util.ThreadLocal()
self.__engines = {}
- super(ThreadLocalMetaData, self).__init__(engine_or_url=None,
- name=name, **kwargs)
+ super(ThreadLocalMetaData, self).__init__(**kwargs)
def connect(self, engine_or_url, **kwargs):
from sqlalchemy.engine.url import URL
- if isinstance(engine_or_url, basestring) or isinstance(engine_or_url, URL):
+ if isinstance(engine_or_url, (basestring, URL)):
try:
self.context._engine = self.__engines[engine_or_url]
except KeyError:
@@ -1304,6 +1232,8 @@ class ThreadLocalMetaData(MetaData):
self.__engines[engine_or_url] = e
self.context._engine = e
else:
+ # TODO: this is squirrely. we shouldnt have to hold onto engines
+ # in a case like this
if not self.__engines.has_key(engine_or_url):
self.__engines[engine_or_url] = engine_or_url
self.context._engine = engine_or_url
@@ -1325,77 +1255,10 @@ class ThreadLocalMetaData(MetaData):
raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
else:
return None
- engine=property(_get_engine)
bind = property(_get_engine, connect)
-def DynamicMetaData(name=None, threadlocal=True, **kw):
- """Deprecated. Use ``MetaData`` or ``ThreadLocalMetaData``."""
-
- if threadlocal:
- return ThreadLocalMetaData(name=name, **kw)
- else:
- return MetaData(name=name, **kw)
-
class SchemaVisitor(sql.ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
__traverse_options__ = {'schema_visitor':True}
-
- def visit_schema(self, schema):
- """Visit a generic ``SchemaItem``."""
- pass
-
- def visit_table(self, table):
- """Visit a ``Table``."""
- pass
-
- def visit_column(self, column):
- """Visit a ``Column``."""
- pass
-
- def visit_foreign_key(self, join):
- """Visit a ``ForeignKey``."""
- pass
-
- def visit_index(self, index):
- """Visit an ``Index``."""
- pass
-
- def visit_passive_default(self, default):
- """Visit a passive default."""
- pass
-
- def visit_column_default(self, default):
- """Visit a ``ColumnDefault``."""
- pass
-
- def visit_column_onupdate(self, onupdate):
- """Visit a ``ColumnDefault`` with the `for_update` flag set."""
- pass
-
- def visit_sequence(self, sequence):
- """Visit a ``Sequence``."""
- pass
-
- def visit_primary_key_constraint(self, constraint):
- """Visit a ``PrimaryKeyConstraint``."""
- pass
-
- def visit_foreign_key_constraint(self, constraint):
- """Visit a ``ForeignKeyConstraint``."""
- pass
-
- def visit_unique_constraint(self, constraint):
- """Visit a ``UniqueConstraint``."""
- pass
-
- def visit_check_constraint(self, constraint):
- """Visit a ``CheckConstraint``."""
- pass
-
- def visit_column_check_constraint(self, constraint):
- """Visit a ``CheckConstraint`` on a ``Column``."""
- pass
-
-default_metadata = ThreadLocalMetaData(name='default')
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 8b454947e..01588e92d 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -24,54 +24,21 @@ are less guaranteed to stay the same in future releases.
"""
-from sqlalchemy import util, exceptions, logging
+from sqlalchemy import util, exceptions
from sqlalchemy import types as sqltypes
-import string, re, random, sets
+import re, operator
-
-__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
+__all__ = ['Alias', 'ClauseElement', 'ClauseParameters',
'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
- 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join',
- 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc',
- 'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
+ 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
+ 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
+ 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
'insert', 'intersect', 'intersect_all', 'join', 'literal',
- 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
+ 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select',
'subquery', 'table', 'text', 'union', 'union_all', 'update',]
-# precedence ordering for common operators. if an operator is not present in this list,
-# it will be parenthesized when grouped against other operators
-PRECEDENCE = {
- 'FROM':15,
- '*':7,
- '/':7,
- '%':7,
- '+':6,
- '-':6,
- 'ILIKE':5,
- 'NOT ILIKE':5,
- 'LIKE':5,
- 'NOT LIKE':5,
- 'IN':5,
- 'NOT IN':5,
- 'IS':5,
- 'IS NOT':5,
- '=':5,
- '!=':5,
- '>':5,
- '<':5,
- '>=':5,
- '<=':5,
- 'BETWEEN':5,
- 'NOT':4,
- 'AND':3,
- 'OR':2,
- ',':-1,
- 'AS':-1,
- 'EXISTS':0,
- '_smallest': -1000,
- '_largest': 1000
-}
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
def desc(column):
"""Return a descending ``ORDER BY`` clause element.
@@ -141,7 +108,7 @@ def join(left, right, onclause=None, **kwargs):
return Join(left, right, onclause, **kwargs)
-def select(columns=None, whereclause = None, from_obj = [], **kwargs):
+def select(columns=None, whereclause=None, from_obj=[], **kwargs):
"""Returns a ``SELECT`` clause element.
Similar functionality is also available via the ``select()`` method on any
@@ -224,9 +191,6 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
automatically bind to whatever ``Connectable`` instances can be located
within its contained ``ClauseElement`` members.
- engine=None
- deprecated. a synonym for "bind".
-
limit=None
a numerical value which usually compiles to a ``LIMIT`` expression
in the resulting select. Databases that don't support ``LIMIT``
@@ -238,12 +202,8 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
will attempt to provide similar functionality.
scalar=False
- when ``True``, indicates that the resulting ``Select`` object
- is to be used in the "columns" clause of another select statement,
- where the evaluated value of the column is the scalar result of
- this statement. Normally, placing any ``Selectable`` within the
- columns clause of a ``select()`` call will expand the member
- columns of the ``Selectable`` individually.
+ deprecated. use select(...).as_scalar() to create a "scalar column"
+ proxy for an existing Select object.
correlate=True
indicates that this ``Select`` object should have its contained
@@ -254,8 +214,12 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
rendered in the ``FROM`` clause of this select statement.
"""
-
- return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs)
+ scalar = kwargs.pop('scalar', False)
+ s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
+ if scalar:
+ return s.as_scalar()
+ else:
+ return s
def subquery(alias, *args, **kwargs):
"""Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select].
@@ -271,7 +235,7 @@ def subquery(alias, *args, **kwargs):
return Select(*args, **kwargs).alias(alias)
def insert(table, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Insert] clause element.
+ """Return an [sqlalchemy.sql#Insert] clause element.
Similar functionality is available via the ``insert()``
method on [sqlalchemy.schema#Table].
@@ -304,10 +268,10 @@ def insert(table, values = None, **kwargs):
against the ``INSERT`` statement.
"""
- return _Insert(table, values, **kwargs)
+ return Insert(table, values, **kwargs)
def update(table, whereclause = None, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Update] clause element.
+ """Return an [sqlalchemy.sql#Update] clause element.
Similar functionality is available via the ``update()``
method on [sqlalchemy.schema#Table].
@@ -344,10 +308,10 @@ def update(table, whereclause = None, values = None, **kwargs):
against the ``UPDATE`` statement.
"""
- return _Update(table, whereclause, values, **kwargs)
+ return Update(table, whereclause, values, **kwargs)
def delete(table, whereclause = None, **kwargs):
- """Return a [sqlalchemy.sql#_Delete] clause element.
+ """Return a [sqlalchemy.sql#Delete] clause element.
Similar functionality is available via the ``delete()``
method on [sqlalchemy.schema#Table].
@@ -361,7 +325,7 @@ def delete(table, whereclause = None, **kwargs):
"""
- return _Delete(table, whereclause, **kwargs)
+ return Delete(table, whereclause, **kwargs)
def and_(*clauses):
"""Join a list of clauses together using the ``AND`` operator.
@@ -371,7 +335,7 @@ def and_(*clauses):
"""
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='AND', *clauses)
+ return ClauseList(operator=operator.and_, *clauses)
def or_(*clauses):
"""Join a list of clauses together using the ``OR`` operator.
@@ -382,7 +346,7 @@ def or_(*clauses):
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='OR', *clauses)
+ return ClauseList(operator=operator.or_, *clauses)
def not_(clause):
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
@@ -391,7 +355,7 @@ def not_(clause):
subclasses to produce the same result.
"""
- return clause._negate()
+ return operator.inv(clause)
def distinct(expr):
"""return a ``DISTINCT`` clause."""
@@ -407,12 +371,9 @@ def between(ctest, cleft, cright):
provides similar functionality.
"""
- return _BinaryExpression(ctest, ClauseList(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type), operator='AND', group=False), 'BETWEEN')
+ ctest = _literal_as_binds(ctest)
+ return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op)
-def between_(*args, **kwargs):
- """synonym for [sqlalchemy.sql#between()] (deprecated)."""
-
- return between(*args, **kwargs)
def case(whens, value=None, else_=None):
"""Produce a ``CASE`` statement.
@@ -435,7 +396,7 @@ def case(whens, value=None, else_=None):
type = list(whenlist[-1])[-1].type
else:
type = None
- cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END'])
+ cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END'])
return cc
def cast(clause, totype, **kwargs):
@@ -457,7 +418,7 @@ def cast(clause, totype, **kwargs):
def extract(field, expr):
"""Return the clause ``extract(field FROM expr)``."""
- expr = _BinaryExpression(text(field), expr, "FROM")
+ expr = _BinaryExpression(text(field), expr, Operators.from_)
return func.extract(expr)
def exists(*args, **kwargs):
@@ -587,7 +548,7 @@ def alias(selectable, alias=None):
return Alias(selectable, alias=alias)
-def literal(value, type=None):
+def literal(value, type_=None):
"""Return a literal clause, bound to a bind parameter.
Literal clauses are created automatically when non-
@@ -603,13 +564,13 @@ def literal(value, type=None):
the underlying DBAPI, or is translatable via the given type
argument.
- type
+ type\_
an optional [sqlalchemy.types#TypeEngine] which will provide
bind-parameter translation for this literal.
"""
- return _BindParamClause('literal', value, type=type, unique=True)
+ return _BindParamClause('literal', value, type_=type_, unique=True)
def label(name, obj):
"""Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement].
@@ -630,7 +591,7 @@ def label(name, obj):
return _Label(name, obj)
-def column(text, type=None):
+def column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
@@ -644,15 +605,15 @@ def column(text, type=None):
constructs that are not to be quoted, use the [sqlalchemy.sql#literal_column()]
function.
- type
+ type\_
an optional [sqlalchemy.types#TypeEngine] object which will provide
result-set translation for this column.
"""
- return _ColumnClause(text, type=type)
+ return _ColumnClause(text, type_=type_)
-def literal_column(text, type=None):
+def literal_column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
@@ -674,7 +635,7 @@ def literal_column(text, type=None):
"""
- return _ColumnClause(text, type=type, is_literal=True)
+ return _ColumnClause(text, type_=type_, is_literal=True)
def table(name, *columns):
"""Return a [sqlalchemy.sql#Table] object.
@@ -685,7 +646,7 @@ def table(name, *columns):
return TableClause(name, *columns)
-def bindparam(key, value=None, type=None, shortname=None, unique=False):
+def bindparam(key, value=None, type_=None, shortname=None, unique=False):
"""Create a bind parameter clause with the given key.
value
@@ -707,11 +668,22 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False):
"""
if isinstance(key, _ColumnClause):
- return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique)
+ return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique)
else:
- return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique)
+ return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique)
-def text(text, bind=None, engine=None, *args, **kwargs):
+def outparam(key, type_=None):
+ """create an 'OUT' parameter for usage in functions (stored procedures), for databases
+ whith support them.
+
+ The ``outparam`` can be used like a regular function parameter. The "output" value will
+ be available from the [sqlalchemy.engine#ResultProxy] object via its ``out_parameters``
+ attribute, which returns a dictionary containing the values.
+ """
+
+ return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True)
+
+def text(text, bind=None, *args, **kwargs):
"""Create literal text to be inserted into a query.
When constructing a query from a ``select()``, ``update()``,
@@ -729,9 +701,6 @@ def text(text, bind=None, engine=None, *args, **kwargs):
bind
An optional connection or engine to be used for this text query.
- engine
- deprecated. a synonym for 'bind'.
-
bindparams
A list of ``bindparam()`` instances which can be used to define
the types and/or initial values for the bind parameters within
@@ -748,7 +717,7 @@ def text(text, bind=None, engine=None, *args, **kwargs):
"""
- return _TextClause(text, engine=engine, bind=bind, *args, **kwargs)
+ return _TextClause(text, bind=bind, *args, **kwargs)
def null():
"""Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement."""
@@ -786,30 +755,44 @@ def _compound_select(keyword, *selects, **kwargs):
def _is_literal(element):
return not isinstance(element, ClauseElement)
-def _literals_as_text(element):
- if _is_literal(element):
+def _literal_as_text(element):
+ if isinstance(element, Operators):
+ return element.expression_element()
+ elif _is_literal(element):
return _TextClause(unicode(element))
else:
return element
-def _literals_as_binds(element, name='literal', type=None):
- if _is_literal(element):
+def _literal_as_column(element):
+ if isinstance(element, Operators):
+ return element.clause_element()
+ elif _is_literal(element):
+ return literal_column(str(element))
+ else:
+ return element
+
+def _literal_as_binds(element, name='literal', type_=None):
+ if isinstance(element, Operators):
+ return element.expression_element()
+ elif _is_literal(element):
if element is None:
return null()
else:
- return _BindParamClause(name, element, shortname=name, type=type, unique=True)
+ return _BindParamClause(name, element, shortname=name, type_=type_, unique=True)
else:
return element
+
+def _selectable(element):
+ if hasattr(element, '__selectable__'):
+ return element.__selectable__()
+ elif isinstance(element, Selectable):
+ return element
+ else:
+ raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
def is_column(col):
return isinstance(col, ColumnElement)
-class AbstractDialect(object):
- """Represent the behavior of a particular database.
-
- Used by ``Compiled`` objects."""
- pass
-
class ClauseParameters(object):
"""Represent a dictionary/iterator of bind parameter key names/values.
@@ -822,52 +805,51 @@ class ClauseParameters(object):
def __init__(self, dialect, positional=None):
super(ClauseParameters, self).__init__()
self.dialect = dialect
- self.binds = {}
- self.binds_to_names = {}
- self.binds_to_values = {}
+ self.__binds = {}
self.positional = positional or []
+ def get_parameter(self, key):
+ return self.__binds[key]
+
def set_parameter(self, bindparam, value, name):
- self.binds[bindparam.key] = bindparam
- self.binds[name] = bindparam
- self.binds_to_names[bindparam] = name
- self.binds_to_values[bindparam] = value
+ self.__binds[name] = [bindparam, name, value]
def get_original(self, key):
- """Return the given parameter as it was originally placed in
- this ``ClauseParameters`` object, without any ``Type``
- conversion."""
- return self.binds_to_values[self.binds[key]]
+ return self.__binds[key][2]
+
+ def get_type(self, key):
+ return self.__binds[key][0].type
def get_processed(self, key):
- bind = self.binds[key]
- value = self.binds_to_values[bind]
+ (bind, name, value) = self.__binds[key]
return bind.typeprocess(value, self.dialect)
def keys(self):
- return self.binds_to_names.values()
+ return self.__binds.keys()
+
+ def __iter__(self):
+ return iter(self.keys())
def __getitem__(self, key):
return self.get_processed(key)
def __contains__(self, key):
- return key in self.binds
+ return key in self.__binds
def set_value(self, key, value):
- bind = self.binds[key]
- self.binds_to_values[bind] = value
+ self.__binds[key][2] = value
def get_original_dict(self):
- return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()])
+ return dict([(name, value) for (b, name, value) in self.__binds.values()])
def get_raw_list(self):
return [self.get_processed(key) for key in self.positional]
- def get_raw_dict(self):
- d = {}
- for k in self.binds_to_names.values():
- d[k] = self.get_processed(k)
- return d
+ def get_raw_dict(self, encode_keys=False):
+ if encode_keys:
+ return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()])
+ else:
+ return dict([(key, self.get_processed(key)) for key in self.keys()])
def __repr__(self):
return self.__class__.__name__ + ":" + repr(self.get_original_dict())
@@ -876,8 +858,8 @@ class ClauseVisitor(object):
"""A class that knows how to traverse and visit
``ClauseElements``.
- Each ``ClauseElement``'s accept_visitor() method will call a
- corresponding visit_XXXX() method here. Traversal of a
+ Calls visit_XXX() methods dynamically generated for each particualr
+ ``ClauseElement`` subclass encountered. Traversal of a
hierarchy of ``ClauseElements`` is achieved via the
``traverse()`` method, which is passed the lead
``ClauseElement``.
@@ -889,22 +871,44 @@ class ClauseVisitor(object):
these options can indicate modifications to the set of
elements returned, such as to not return column collections
(column_collections=False) or to return Schema-level items
- (schema_visitor=True)."""
+ (schema_visitor=True).
+
+ ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+ operation, which will produce a copy of a given ``ClauseElement``
+ structure while at the same time allowing ``ClauseVisitor`` subclasses
+ to modify the new structure in-place.
+
+ """
__traverse_options__ = {}
- def traverse(self, obj, stop_on=None):
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
+
+ def traverse_single(self, obj, **kwargs):
+ meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kwargs)
+
+ def traverse(self, obj, stop_on=None, clone=False):
+ if clone:
+ obj = obj._clone()
+
+ v = self
+ visitors = []
+ while v is not None:
+ visitors.append(v)
+ v = getattr(v, '_next', None)
+
+ def _trav(obj):
+ if stop_on is not None and obj in stop_on:
+ return
+ if clone:
+ obj._copy_internals()
+ for c in obj.get_children(**self.__traverse_options__):
+ _trav(c)
+
+ for v in visitors:
+ meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ meth(obj)
+ _trav(obj)
return obj
def chain(self, visitor):
@@ -916,78 +920,6 @@ class ClauseVisitor(object):
tail = tail._next
tail._next = visitor
return self
-
- def visit_column(self, column):
- pass
- def visit_table(self, table):
- pass
- def visit_fromclause(self, fromclause):
- pass
- def visit_bindparam(self, bindparam):
- pass
- def visit_textclause(self, textclause):
- pass
- def visit_compound(self, compound):
- pass
- def visit_compound_select(self, compound):
- pass
- def visit_binary(self, binary):
- pass
- def visit_unary(self, unary):
- pass
- def visit_alias(self, alias):
- pass
- def visit_select(self, select):
- pass
- def visit_join(self, join):
- pass
- def visit_null(self, null):
- pass
- def visit_clauselist(self, list):
- pass
- def visit_calculatedclause(self, calcclause):
- pass
- def visit_grouping(self, gr):
- pass
- def visit_function(self, func):
- pass
- def visit_cast(self, cast):
- pass
- def visit_label(self, label):
- pass
- def visit_typeclause(self, typeclause):
- pass
-
-class LoggingClauseVisitor(ClauseVisitor):
- """extends ClauseVisitor to include debug logging of all traversal.
-
- To install this visitor, set logging.DEBUG for
- 'sqlalchemy.sql.ClauseVisitor' **before** you import the
- sqlalchemy.sql module.
- """
-
- def traverse(self, obj, stop_on=None):
- stack = [(obj, "")]
- traversal = []
- while len(stack) > 0:
- (t, indent) = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, (t, indent))
- for c in t.get_children(**self.__traverse_options__):
- stack.append((c, indent + " "))
-
- for (target, indent) in traversal:
- self.logger.debug(indent + repr(target))
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
- return obj
-
-LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor)
-
-if logging.is_debug_enabled(LoggingClauseVisitor.logger):
- ClauseVisitor=LoggingClauseVisitor
class NoColumnVisitor(ClauseVisitor):
"""a ClauseVisitor that will not traverse the exported Column
@@ -1000,113 +932,35 @@ class NoColumnVisitor(ClauseVisitor):
"""
__traverse_options__ = {'column_collections':False}
-
-class Executor(object):
- """Interface representing a "thing that can produce Compiled objects
- and execute them"."""
-
- def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
- """Execute a Compiled object."""
-
- raise NotImplementedError()
-
- def compiler(self, statement, parameters, **kwargs):
- """Return a Compiled object for the given statement and parameters."""
-
- raise NotImplementedError()
-
-class Compiled(ClauseVisitor):
- """Represent a compiled SQL expression.
-
- The ``__str__`` method of the ``Compiled`` object should produce
- the actual text of the statement. ``Compiled`` objects are
- specific to their underlying database dialect, and also may
- or may not be specific to the columns referenced within a
- particular set of bind parameters. In no case should the
- ``Compiled`` object be dependent on the actual values of those
- bind parameters, even though it may reference those values as
- defaults.
- """
-
- def __init__(self, dialect, statement, parameters, bind=None, engine=None):
- """Construct a new ``Compiled`` object.
-
- statement
- ``ClauseElement`` to be compiled.
-
- parameters
- Optional dictionary indicating a set of bind parameters
- specified with this ``Compiled`` object. These parameters
- are the *default* values corresponding to the
- ``ClauseElement``'s ``_BindParamClauses`` when the
- ``Compiled`` is executed. In the case of an ``INSERT`` or
- ``UPDATE`` statement, these parameters will also result in
- the creation of new ``_BindParamClause`` objects for each
- key and will also affect the generated column list in an
- ``INSERT`` statement and the ``SET`` clauses of an
- ``UPDATE`` statement. The keys of the parameter dictionary
- can either be the string names of columns or
- ``_ColumnClause`` objects.
-
- bind
- optional engine or connection which will be bound to the
- compiled object.
-
- engine
- deprecated, a synonym for 'bind'
- """
- self.dialect = dialect
- self.statement = statement
- self.parameters = parameters
- self.bind = bind or engine
- self.can_execute = statement.supports_execution()
-
- def compile(self):
- self.traverse(self.statement)
- self.after_compile()
-
- def __str__(self):
- """Return the string text of the generated SQL statement."""
-
- raise NotImplementedError()
- def get_params(self, **params):
- """Deprecated. use construct_params(). (supports unicode names)
- """
- return self.construct_params(params)
-
- def construct_params(self, params):
- """Return the bind params for this compiled object.
-
- Will start with the default parameters specified when this
- ``Compiled`` object was first constructed, and will override
- those values with those sent via `**params`, which are
- key/value pairs. Each key should match one of the
- ``_BindParamClause`` objects compiled into this object; either
- the `key` or `shortname` property of the ``_BindParamClause``.
- """
- raise NotImplementedError()
+class _FigureVisitName(type):
+ def __init__(cls, clsname, bases, dict):
+ if not '__visit_name__' in cls.__dict__:
+ m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+ x = m.group(1)
+ x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+ cls.__visit_name__ = x.lower()
+ super(_FigureVisitName, cls).__init__(clsname, bases, dict)
- def execute(self, *multiparams, **params):
- """Execute this compiled object."""
-
- e = self.bind
- if e is None:
- raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
- return e.execute_compiled(self, *multiparams, **params)
-
- def scalar(self, *multiparams, **params):
- """Execute this compiled object and return the result's scalar value."""
-
- return self.execute(*multiparams, **params).scalar()
-
class ClauseElement(object):
"""Base class for elements of a programmatically constructed SQL
expression.
"""
+ __metaclass__ = _FigureVisitName
+
+ def _clone(self):
+ """create a shallow copy of this ClauseElement.
+
+ This method may be used by a generative API.
+ Its also used as part of the "deep" copy afforded
+ by a traversal that combines the _copy_internals()
+ method."""
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = self.__dict__.copy()
+ return c
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
"""Return objects represented in this ``ClauseElement`` that
should be added to the ``FROM`` list of a query, when this
``ClauseElement`` is placed in the column clause of a
@@ -1115,7 +969,7 @@ class ClauseElement(object):
raise NotImplementedError(repr(self))
- def _hide_froms(self):
+ def _hide_froms(self, **modifiers):
"""Return a list of ``FROM`` clause elements which this
``ClauseElement`` replaces.
"""
@@ -1131,13 +985,14 @@ class ClauseElement(object):
return self is other
- def accept_visitor(self, visitor):
- """Accept a ``ClauseVisitor`` and call the appropriate
- ``visit_xxx`` method.
- """
-
- raise NotImplementedError(repr(self))
-
+ def _copy_internals(self):
+ """reassign internal elements to be clones of themselves.
+
+ called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy."""
+
+ pass
+
def get_children(self, **kwargs):
"""return immediate child elements of this ``ClauseElement``.
@@ -1160,18 +1015,6 @@ class ClauseElement(object):
return False
- def copy_container(self):
- """Return a copy of this ``ClauseElement``, if this
- ``ClauseElement`` contains other ``ClauseElements``.
-
- If this ``ClauseElement`` is not a container, it should return
- self. This is used to create copies of expression trees that
- still reference the same *leaf nodes*. The new structure can
- then be restructured without affecting the original.
- """
-
- return self
-
def _find_engine(self):
"""Default strategy for locating an engine within the clause element.
@@ -1195,7 +1038,6 @@ class ClauseElement(object):
return None
bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""")
- engine = bind
def execute(self, *multiparams, **params):
"""Compile and execute this ``ClauseElement``."""
@@ -1213,7 +1055,7 @@ class ClauseElement(object):
return self.execute(*multiparams, **params).scalar()
- def compile(self, bind=None, engine=None, parameters=None, compiler=None, dialect=None):
+ def compile(self, bind=None, parameters=None, compiler=None, dialect=None):
"""Compile this SQL expression.
Uses the given ``Compiler``, or the given ``AbstractDialect``
@@ -1236,7 +1078,7 @@ class ClauseElement(object):
``SET`` and ``VALUES`` clause of those statements.
"""
- if (isinstance(parameters, list) or isinstance(parameters, tuple)):
+ if isinstance(parameters, (list, tuple)):
parameters = parameters[0]
if compiler is None:
@@ -1244,8 +1086,6 @@ class ClauseElement(object):
compiler = dialect.compiler(self, parameters)
elif bind is not None:
compiler = bind.compiler(self, parameters)
- elif engine is not None:
- compiler = engine.compiler(self, parameters)
elif self.bind is not None:
compiler = self.bind.compiler(self, parameters)
@@ -1268,49 +1108,257 @@ class ClauseElement(object):
return self._negate()
def _negate(self):
- return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
+ if hasattr(self, 'negation_clause'):
+ return self.negation_clause
+ else:
+ return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None)
+
-class _CompareMixin(object):
- """Defines comparison operations for ``ClauseElement`` instances.
+class Operators(object):
+ def from_():
+ raise NotImplementedError()
+ from_ = staticmethod(from_)
- This is a mixin class that adds the capability to produce ``ClauseElement``
- instances based on regular Python operators.
- These operations are achieved using Python's operator overload methods
- (i.e. ``__eq__()``, ``__ne__()``, etc.
+ def as_():
+ raise NotImplementedError()
+ as_ = staticmethod(as_)
- Overridden operators include all comparison operators (i.e. '==', '!=', '<'),
- math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate
- to ``AND`` and ``OR`` respectively.
+ def exists():
+ raise NotImplementedError()
+ exists = staticmethod(exists)
- Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``,
- ``DISTINCT``, etc.
+ def is_():
+ raise NotImplementedError()
+ is_ = staticmethod(is_)
- """
+ def isnot():
+ raise NotImplementedError()
+ isnot = staticmethod(isnot)
+
+ def __and__(self, other):
+ return self.operate(operator.and_, other)
+
+ def __or__(self, other):
+ return self.operate(operator.or_, other)
+
+ def __invert__(self):
+ return self.operate(operator.inv)
+
+ def clause_element(self):
+ raise NotImplementedError()
+
+ def operate(self, op, *other, **kwargs):
+ raise NotImplementedError()
+
+ def reverse_operate(self, op, *other, **kwargs):
+ raise NotImplementedError()
+
+class ColumnOperators(Operators):
+ """defines comparison and math operations"""
+
+ def like_op(a, b):
+ return a.like(b)
+ like_op = staticmethod(like_op)
+
+ def notlike_op(a, b):
+ raise NotImplementedError()
+ notlike_op = staticmethod(notlike_op)
+ def ilike_op(a, b):
+ return a.ilike(b)
+ ilike_op = staticmethod(ilike_op)
+
+ def notilike_op(a, b):
+ raise NotImplementedError()
+ notilike_op = staticmethod(notilike_op)
+
+ def between_op(a, b):
+ return a.between(b)
+ between_op = staticmethod(between_op)
+
+ def in_op(a, b):
+ return a.in_(*b)
+ in_op = staticmethod(in_op)
+
+ def notin_op(a, b):
+ raise NotImplementedError()
+ notin_op = staticmethod(notin_op)
+
+ def startswith_op(a, b):
+ return a.startswith(b)
+ startswith_op = staticmethod(startswith_op)
+
+ def endswith_op(a, b):
+ return a.endswith(b)
+ endswith_op = staticmethod(endswith_op)
+
+ def comma_op(a, b):
+ raise NotImplementedError()
+ comma_op = staticmethod(comma_op)
+
+ def concat_op(a, b):
+ return a.concat(b)
+ concat_op = staticmethod(concat_op)
+
def __lt__(self, other):
- return self._compare('<', other)
+ return self.operate(operator.lt, other)
def __le__(self, other):
- return self._compare('<=', other)
+ return self.operate(operator.le, other)
def __eq__(self, other):
- return self._compare('=', other)
+ return self.operate(operator.eq, other)
def __ne__(self, other):
- return self._compare('!=', other)
+ return self.operate(operator.ne, other)
def __gt__(self, other):
- return self._compare('>', other)
+ return self.operate(operator.gt, other)
def __ge__(self, other):
- return self._compare('>=', other)
+ return self.operate(operator.ge, other)
+ def concat(self, other):
+ return self.operate(ColumnOperators.concat_op, other)
+
def like(self, other):
- """produce a ``LIKE`` clause."""
- return self._compare('LIKE', other)
+ return self.operate(ColumnOperators.like_op, other)
+
+ def in_(self, *other):
+ return self.operate(ColumnOperators.in_op, other)
+
+ def startswith(self, other):
+ return self.operate(ColumnOperators.startswith_op, other)
+
+ def endswith(self, other):
+ return self.operate(ColumnOperators.endswith_op, other)
+
+ def __radd__(self, other):
+ return self.reverse_operate(operator.add, other)
+
+ def __rsub__(self, other):
+ return self.reverse_operate(operator.sub, other)
+
+ def __rmul__(self, other):
+ return self.reverse_operate(operator.mul, other)
+
+ def __rdiv__(self, other):
+ return self.reverse_operate(operator.div, other)
+
+ def between(self, cleft, cright):
+ return self.operate(Operators.between_op, (cleft, cright))
+
+ def __add__(self, other):
+ return self.operate(operator.add, other)
+
+ def __sub__(self, other):
+ return self.operate(operator.sub, other)
+
+ def __mul__(self, other):
+ return self.operate(operator.mul, other)
+
+ def __div__(self, other):
+ return self.operate(operator.div, other)
+
+ def __mod__(self, other):
+ return self.operate(operator.mod, other)
+
+ def __truediv__(self, other):
+ return self.operate(operator.truediv, other)
+
+# precedence ordering for common operators. if an operator is not present in this list,
+# it will be parenthesized when grouped against other operators
+_smallest = object()
+_largest = object()
+
+PRECEDENCE = {
+ Operators.from_:15,
+ operator.mul:7,
+ operator.div:7,
+ operator.mod:7,
+ operator.add:6,
+ operator.sub:6,
+ ColumnOperators.concat_op:6,
+ ColumnOperators.ilike_op:5,
+ ColumnOperators.notilike_op:5,
+ ColumnOperators.like_op:5,
+ ColumnOperators.notlike_op:5,
+ ColumnOperators.in_op:5,
+ ColumnOperators.notin_op:5,
+ Operators.is_:5,
+ Operators.isnot:5,
+ operator.eq:5,
+ operator.ne:5,
+ operator.gt:5,
+ operator.lt:5,
+ operator.ge:5,
+ operator.le:5,
+ ColumnOperators.between_op:5,
+ operator.inv:4,
+ operator.and_:3,
+ operator.or_:2,
+ ColumnOperators.comma_op:-1,
+ Operators.as_:-1,
+ Operators.exists:0,
+ _smallest: -1000,
+ _largest: 1000
+}
+
+class _CompareMixin(ColumnOperators):
+ """Defines comparison and math operations for ``ClauseElement`` instances."""
+
+ def __compare(self, op, obj, negate=None):
+ if obj is None or isinstance(obj, _Null):
+ if op == operator.eq:
+ return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot)
+ elif op == operator.ne:
+ return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_)
+ else:
+ raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+ else:
+ obj = self._check_literal(obj)
+
+
+ return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate)
+
+ def __operate(self, op, obj):
+ obj = self._check_literal(obj)
+
+ type_ = self._compare_type(obj)
+
+ # TODO: generalize operator overloading like this out into the types module
+ if op == operator.add and isinstance(type_, (sqltypes.Concatenable)):
+ op = ColumnOperators.concat_op
+
+ return _BinaryExpression(self.expression_element(), obj, op, type_=type_)
+
+ operators = {
+ operator.add : (__operate,),
+ operator.mul : (__operate,),
+ operator.sub : (__operate,),
+ operator.div : (__operate,),
+ operator.mod : (__operate,),
+ operator.truediv : (__operate,),
+ operator.lt : (__compare, operator.ge),
+ operator.le : (__compare, operator.gt),
+ operator.ne : (__compare, operator.eq),
+ operator.gt : (__compare, operator.le),
+ operator.ge : (__compare, operator.lt),
+ operator.eq : (__compare, operator.ne),
+ ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op),
+ }
+
+ def operate(self, op, other):
+ o = _CompareMixin.operators[op]
+ return o[0](self, op, other, *o[1:])
+
+ def reverse_operate(self, op, other):
+ return self._bind_param(other).operate(op, self)
def in_(self, *other):
- """produce an ``IN`` clause."""
+ return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other)
+
+ def _in_impl(self, op, negate_op, *other):
if len(other) == 0:
return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1')))
elif len(other) == 1:
@@ -1318,8 +1366,8 @@ class _CompareMixin(object):
if _is_literal(o) or isinstance( o, _CompareMixin):
return self.__eq__( o) #single item -> ==
else:
- assert hasattr( o, '_selectable') #better check?
- return self._compare( 'IN', o, negate='NOT IN') #single selectable
+ assert isinstance(o, Selectable)
+ return self.__compare( op, o, negate=negate_op) #single selectable
args = []
for o in other:
@@ -1329,29 +1377,22 @@ class _CompareMixin(object):
else:
o = self._bind_param(o)
args.append(o)
- return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
+ return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op)
def startswith(self, other):
"""produce the clause ``LIKE '<other>%'``"""
- perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String)
- return self._compare('LIKE', other + perc)
+
+ perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String)
+ return self.__compare(ColumnOperators.like_op, other + perc)
def endswith(self, other):
"""produce the clause ``LIKE '%<other>'``"""
+
if isinstance(other,(str,unicode)): po = '%' + other
else:
- po = literal('%', type= sqltypes.String) + other
- po.type = sqltypes.to_instance( sqltypes.String) #force!
- return self._compare('LIKE', po)
-
- def __radd__(self, other):
- return self._bind_param(other)._operate('+', self)
- def __rsub__(self, other):
- return self._bind_param(other)._operate('-', self)
- def __rmul__(self, other):
- return self._bind_param(other)._operate('*', self)
- def __rdiv__(self, other):
- return self._bind_param(other)._operate('/', self)
+ po = literal('%', type_=sqltypes.String) + other
+ po.type = sqltypes.to_instance(sqltypes.String) #force!
+ return self.__compare(ColumnOperators.like_op, po)
def label(self, name):
"""produce a column label, i.e. ``<columnname> AS <name>``"""
@@ -1363,7 +1404,8 @@ class _CompareMixin(object):
def between(self, cleft, cright):
"""produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
- return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN')
+
+ return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op)
def op(self, operator):
"""produce a generic operator function.
@@ -1382,59 +1424,25 @@ class _CompareMixin(object):
passed to the generated function.
"""
- return lambda other: self._operate(operator, other)
-
- # and here come the math operators:
-
- def __add__(self, other):
- return self._operate('+', other)
-
- def __sub__(self, other):
- return self._operate('-', other)
-
- def __mul__(self, other):
- return self._operate('*', other)
-
- def __div__(self, other):
- return self._operate('/', other)
-
- def __mod__(self, other):
- return self._operate('%', other)
-
- def __truediv__(self, other):
- return self._operate('/', other)
+ return lambda other: self.__operate(operator, other)
def _bind_param(self, obj):
- return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
+ return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True)
def _check_literal(self, other):
- if _is_literal(other):
+ if isinstance(other, Operators):
+ return other.expression_element()
+ elif _is_literal(other):
return self._bind_param(other)
else:
return other
+
+ def clause_element(self):
+ """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``."""
+ return self
- def _compare(self, operator, obj, negate=None):
- if obj is None or isinstance(obj, _Null):
- if operator == '=':
- return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT')
- elif operator == '!=':
- return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS')
- else:
- raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
- else:
- obj = self._check_literal(obj)
-
- return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate)
-
- def _operate(self, operator, obj):
- if _is_literal(obj):
- obj = self._bind_param(obj)
- return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
-
- def _compare_self(self):
- """Allow ``ColumnImpl`` to return its ``Column`` object for
- usage in ``ClauseElements``, all others to just return self.
- """
+ def expression_element(self):
+ """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
return self
@@ -1460,23 +1468,10 @@ class Selectable(ClauseElement):
columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""")
- def _selectable(self):
- return self
-
- def accept_visitor(self, visitor):
- raise NotImplementedError(repr(self))
-
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
- def _group_parenthesized(self):
- """Indicate if this ``Selectable`` requires parenthesis when
- grouped into a compound statement.
- """
- return True
-
-
class ColumnElement(Selectable, _CompareMixin):
"""Represent an element that is useable within the
"column clause" portion of a ``SELECT`` statement.
@@ -1616,8 +1611,10 @@ class ColumnCollection(util.OrderedProperties):
l.append(c==local)
return and_(*l)
- def __contains__(self, col):
- return self.contains_column(col)
+ def __contains__(self, other):
+ if not isinstance(other, basestring):
+ raise exceptions.ArgumentError("__contains__ requires a string argument")
+ return self.has_key(other)
def contains_column(self, col):
# have to use a Set here, because it will compare the identity
@@ -1649,19 +1646,18 @@ class FromClause(Selectable):
clause of a ``SELECT`` statement.
"""
+ __visit_name__ = 'fromclause'
+
def __init__(self, name=None):
self.name = name
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
# this could also be [self], at the moment it doesnt matter to the Select object
return []
def default_order_by(self):
return [self.oid_column]
- def accept_visitor(self, visitor):
- visitor.visit_fromclause(self)
-
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
@@ -1703,6 +1699,13 @@ class FromClause(Selectable):
FindCols().traverse(self)
return ret
+ def is_derived_from(self, fromclause):
+ """return True if this FromClause is 'derived' from the given FromClause.
+
+ An example would be an Alias of a Table is derived from that Table."""
+
+ return False
+
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
"""Given a ``ColumnElement``, return the exported
``ColumnElement`` object from this ``Selectable`` which
@@ -1730,9 +1733,10 @@ class FromClause(Selectable):
it merely shares a common anscestor with one of
the exported columns of this ``FromClause``.
"""
- if column in self.c:
+
+ if self.c.contains_column(column):
return column
-
+
if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
if not raiseerr:
return None
@@ -1761,6 +1765,15 @@ class FromClause(Selectable):
self._export_columns()
return getattr(self, name)
+ def _clone_from_clause(self):
+ # delete all the "generated" collections of columns for a newly cloned FromClause,
+ # so that they will be re-derived from the item.
+ # this is because FromClause subclasses, when cloned, need to reestablish new "proxied"
+ # columns that are linked to the new item
+ for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'):
+ if hasattr(self, attr):
+ delattr(self, attr)
+
columns = property(lambda s:s._get_exported_attribute('_columns'))
c = property(lambda s:s._get_exported_attribute('_columns'))
primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
@@ -1791,8 +1804,9 @@ class FromClause(Selectable):
self._primary_key = ColumnSet()
self._foreign_keys = util.Set()
self._orig_cols = {}
+
if columns is None:
- columns = self._adjusted_exportable_columns()
+ columns = self._flatten_exportable_columns()
for co in columns:
cp = self._proxy_column(co)
for ci in cp.orig_set:
@@ -1806,13 +1820,14 @@ class FromClause(Selectable):
for ci in self.oid_column.orig_set:
self._orig_cols[ci] = self.oid_column
- def _adjusted_exportable_columns(self):
+ def _flatten_exportable_columns(self):
"""return the list of ColumnElements represented within this FromClause's _exportable_columns"""
export = self._exportable_columns()
for column in export:
- try:
- s = column._selectable()
- except AttributeError:
+ # TODO: is this conditional needed ?
+ if isinstance(column, Selectable):
+ s = column
+ else:
continue
for co in s.columns:
yield co
@@ -1829,7 +1844,9 @@ class _BindParamClause(ClauseElement, _CompareMixin):
Public constructor is the ``bindparam()`` function.
"""
- def __init__(self, key, value, shortname=None, type=None, unique=False):
+ __visit_name__ = 'bindparam'
+
+ def __init__(self, key, value, shortname=None, type_=None, unique=False, isoutparam=False):
"""Construct a _BindParamClause.
key
@@ -1852,7 +1869,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
execution may match either the key or the shortname of the
corresponding ``_BindParamClause`` objects.
- type
+ type\_
A ``TypeEngine`` object that will be used to pre-process the
value corresponding to this ``_BindParamClause`` at
execution time.
@@ -1862,23 +1879,34 @@ class _BindParamClause(ClauseElement, _CompareMixin):
modified if another ``_BindParamClause`` of the same
name already has been located within the containing
``ClauseElement``.
+
+ isoutparam
+ if True, the parameter should be treated like a stored procedure "OUT"
+ parameter.
"""
- self.key = key
+ self.key = key or "{ANON %d param}" % id(self)
self.value = value
self.shortname = shortname or key
self.unique = unique
- self.type = sqltypes.to_instance(type)
-
- def accept_visitor(self, visitor):
- visitor.visit_bindparam(self)
-
- def _get_from_objects(self):
+ self.isoutparam = isoutparam
+ type_ = sqltypes.to_instance(type_)
+ if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map:
+ self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)])
+ else:
+ self.type = type_
+
+ # TODO: move to types module, obviously
+ type_map = {
+ str : sqltypes.String,
+ unicode : sqltypes.Unicode,
+ int : sqltypes.Integer,
+ float : sqltypes.Numeric
+ }
+
+ def _get_from_objects(self, **modifiers):
return []
- def copy_container(self):
- return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique)
-
def typeprocess(self, value, dialect):
return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
@@ -1893,7 +1921,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
def __repr__(self):
- return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+ return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
class _TypeClause(ClauseElement):
"""Handle a type keyword in a SQL statement.
@@ -1901,13 +1929,12 @@ class _TypeClause(ClauseElement):
Used by the ``Case`` statement.
"""
+ __visit_name__ = 'typeclause'
+
def __init__(self, type):
self.type = type
- def accept_visitor(self, visitor):
- visitor.visit_typeclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class _TextClause(ClauseElement):
@@ -1916,8 +1943,10 @@ class _TextClause(ClauseElement):
Public constructor is the ``text()`` function.
"""
- def __init__(self, text = "", bind=None, engine=None, bindparams=None, typemap=None):
- self._bind = bind or engine
+ __visit_name__ = 'textclause'
+
+ def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
+ self._bind = bind
self.bindparams = {}
self.typemap = typemap
if typemap is not None:
@@ -1930,7 +1959,7 @@ class _TextClause(ClauseElement):
# scan the string and search for bind parameter names, add them
# to the list of bindparams
- self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text)
+ self.text = BIND_PARAMS.sub(repl, text)
if bindparams is not None:
for b in bindparams:
self.bindparams[b.key] = b
@@ -1944,13 +1973,13 @@ class _TextClause(ClauseElement):
columns = property(lambda s:[])
+ def _copy_internals(self):
+ self.bindparams = [b._clone() for b in self.bindparams]
+
def get_children(self, **kwargs):
return self.bindparams.values()
- def accept_visitor(self, visitor):
- visitor.visit_textclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
def supports_execution(self):
@@ -1965,10 +1994,7 @@ class _Null(ColumnElement):
def __init__(self):
self.type = sqltypes.NULLTYPE
- def accept_visitor(self, visitor):
- visitor.visit_null(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class ClauseList(ClauseElement):
@@ -1976,14 +2002,16 @@ class ClauseList(ClauseElement):
By default, is comma-separated, such as a column listing.
"""
-
+ __visit_name__ = 'clauselist'
+
def __init__(self, *clauses, **kwargs):
self.clauses = []
- self.operator = kwargs.pop('operator', ',')
+ self.operator = kwargs.pop('operator', ColumnOperators.comma_op)
self.group = kwargs.pop('group', True)
self.group_contents = kwargs.pop('group_contents', True)
for c in clauses:
- if c is None: continue
+ if c is None:
+ continue
self.append(c)
def __iter__(self):
@@ -1991,32 +2019,28 @@ class ClauseList(ClauseElement):
def __len__(self):
return len(self.clauses)
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return ClauseList(operator=self.operator, *clauses)
-
def append(self, clause):
# TODO: not sure if i like the 'group_contents' flag. need to define the difference between
# a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ?
if self.group_contents:
- self.clauses.append(_literals_as_text(clause).self_group(against=self.operator))
+ self.clauses.append(_literal_as_text(clause).self_group(against=self.operator))
else:
- self.clauses.append(_literals_as_text(clause))
+ self.clauses.append(_literal_as_text(clause))
+
+ def _copy_internals(self):
+ self.clauses = [clause._clone() for clause in self.clauses]
def get_children(self, **kwargs):
return self.clauses
- def accept_visitor(self, visitor):
- visitor.visit_clauselist(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
f = []
for c in self.clauses:
- f += c._get_from_objects()
+ f += c._get_from_objects(**modifiers)
return f
def self_group(self, against=None):
- if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
@@ -2043,40 +2067,45 @@ class _CalculatedClause(ColumnElement):
Extends ``ColumnElement`` to provide column-level comparison
operators.
"""
-
+ __visit_name__ = 'calculatedclause'
+
def __init__(self, name, *clauses, **kwargs):
self.name = name
- self.type = sqltypes.to_instance(kwargs.get('type', None))
- self._bind = kwargs.get('bind', kwargs.get('engine', None))
+ self.type = sqltypes.to_instance(kwargs.get('type_', None))
+ self._bind = kwargs.get('bind', None)
self.group = kwargs.pop('group', True)
- self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
+ clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
if self.group:
- self.clause_expr = self.clauses.self_group()
+ self.clause_expr = clauses.self_group()
else:
- self.clause_expr = self.clauses
+ self.clause_expr = clauses
key = property(lambda self:self.name or "_calc_")
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _CalculatedClause(type=self.type, bind=self._bind, *clauses)
-
+ def _copy_internals(self):
+ self.clause_expr = self.clause_expr._clone()
+
+ def clauses(self):
+ if isinstance(self.clause_expr, _Grouping):
+ return self.clause_expr.elem
+ else:
+ return self.clause_expr
+ clauses = property(clauses)
+
def get_children(self, **kwargs):
return self.clause_expr,
- def accept_visitor(self, visitor):
- visitor.visit_calculatedclause(self)
- def _get_from_objects(self):
- return self.clauses._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clauses._get_from_objects(**modifiers)
def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, type=self.type, unique=True)
+ return _BindParamClause(self.name, obj, type_=self.type, unique=True)
def select(self):
return select([self])
def scalar(self):
- return select([self]).scalar()
+ return select([self]).execute().scalar()
def execute(self):
return select([self]).execute()
@@ -2092,28 +2121,26 @@ class _Function(_CalculatedClause, FromClause):
"""
def __init__(self, name, *clauses, **kwargs):
- self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
- kwargs['operator'] = ','
- self._engine = kwargs.get('engine', None)
+ kwargs['operator'] = ColumnOperators.comma_op
_CalculatedClause.__init__(self, name, **kwargs)
for c in clauses:
self.append(c)
key = property(lambda self:self.name)
+ def _copy_internals(self):
+ _CalculatedClause._copy_internals(self)
+ self._clone_from_clause()
- def append(self, clause):
- self.clauses.append(_literals_as_binds(clause, self.name))
-
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _Function(self.name, type=self.type, packagenames=self.packagenames, bind=self._bind, *clauses)
+ def get_children(self, **kwargs):
+ return _CalculatedClause.get_children(self, **kwargs)
- def accept_visitor(self, visitor):
- visitor.visit_function(self)
+ def append(self, clause):
+ self.clauses.append(_literal_as_binds(clause, self.name))
class _Cast(ColumnElement):
+
def __init__(self, clause, totype, **kwargs):
if not hasattr(clause, 'label'):
clause = literal(clause)
@@ -2122,17 +2149,19 @@ class _Cast(ColumnElement):
self.typeclause = _TypeClause(self.type)
self._distance = 0
+ def _copy_internals(self):
+ self.clause = self.clause._clone()
+ self.typeclause = self.typeclause._clone()
+
def get_children(self, **kwargs):
return self.clause, self.typeclause
- def accept_visitor(self, visitor):
- visitor.visit_cast(self)
- def _get_from_objects(self):
- return self.clause._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clause._get_from_objects(**modifiers)
def _make_proxy(self, selectable, name=None):
if name is not None:
- co = _ColumnClause(name, selectable, type=self.type)
+ co = _ColumnClause(name, selectable, type_=self.type)
co._distance = self._distance + 1
co.orig_set = self.orig_set
selectable.columns[name]= co
@@ -2142,26 +2171,23 @@ class _Cast(ColumnElement):
class _UnaryExpression(ColumnElement):
- def __init__(self, element, operator=None, modifier=None, type=None, negate=None):
+ def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
self.operator = operator
self.modifier = modifier
- self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier)
- self.type = sqltypes.to_instance(type)
+ self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate)
+ def _get_from_objects(self, **modifiers):
+ return self.element._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.element._get_from_objects()
+ def _copy_internals(self):
+ self.element = self.element._clone()
def get_children(self, **kwargs):
return self.element,
- def accept_visitor(self, visitor):
- visitor.visit_unary(self)
-
def compare(self, other):
"""Compare this ``_UnaryExpression`` against the given ``ClauseElement``."""
@@ -2170,14 +2196,15 @@ class _UnaryExpression(ColumnElement):
self.modifier == other.modifier and
self.element.compare(other.element)
)
+
def _negate(self):
if self.negate is not None:
- return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
+ return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type)
else:
return super(_UnaryExpression, self)._negate()
def self_group(self, against):
- if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
@@ -2186,25 +2213,23 @@ class _UnaryExpression(ColumnElement):
class _BinaryExpression(ColumnElement):
"""Represent an expression that is ``LEFT <operator> RIGHT``."""
- def __init__(self, left, right, operator, type=None, negate=None):
- self.left = _literals_as_text(left).self_group(against=operator)
- self.right = _literals_as_text(right).self_group(against=operator)
+ def __init__(self, left, right, operator, type_=None, negate=None):
+ self.left = _literal_as_text(left).self_group(against=operator)
+ self.right = _literal_as_text(right).self_group(against=operator)
self.operator = operator
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
+ def _get_from_objects(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _copy_internals(self):
+ self.left = self.left._clone()
+ self.right = self.right._clone()
def get_children(self, **kwargs):
return self.left, self.right
- def accept_visitor(self, visitor):
- visitor.visit_binary(self)
-
def compare(self, other):
"""Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
@@ -2213,7 +2238,7 @@ class _BinaryExpression(ColumnElement):
(
self.left.compare(other.left) and self.right.compare(other.right)
or (
- self.operator in ['=', '!=', '+', '*'] and
+ self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and
self.left.compare(other.right) and self.right.compare(other.left)
)
)
@@ -2221,25 +2246,27 @@ class _BinaryExpression(ColumnElement):
def self_group(self, against=None):
# use small/large defaults for comparison so that unknown operators are always parenthesized
- if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])):
+ if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])):
return _Grouping(self)
else:
return self
def _negate(self):
if self.negate is not None:
- return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type)
+ return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type)
else:
return super(_BinaryExpression, self)._negate()
class _Exists(_UnaryExpression):
+ __visit_name__ = _UnaryExpression.__visit_name__
+
def __init__(self, *args, **kwargs):
kwargs['correlate'] = True
s = select(*args, **kwargs).self_group()
- _UnaryExpression.__init__(self, s, operator="EXISTS")
+ _UnaryExpression.__init__(self, s, operator=Operators.exists)
- def _hide_froms(self):
- return self._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self._get_from_objects(**modifiers)
class Join(FromClause):
"""represent a ``JOIN`` construct between two ``FromClause``
@@ -2251,8 +2278,8 @@ class Join(FromClause):
"""
def __init__(self, left, right, onclause=None, isouter = False):
- self.left = left._selectable()
- self.right = right._selectable()
+ self.left = _selectable(left)
+ self.right = _selectable(right).self_group()
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
else:
@@ -2265,8 +2292,8 @@ class Join(FromClause):
encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace'))
def _init_primary_key(self):
- pkcol = util.Set([c for c in self._adjusted_exportable_columns() if c.primary_key])
-
+ pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key])
+
equivs = {}
def add_equiv(a, b):
for x, y in ((a, b), (b, a)):
@@ -2277,7 +2304,7 @@ class Join(FromClause):
class BinaryVisitor(ClauseVisitor):
def visit_binary(self, binary):
- if binary.operator == '=':
+ if binary.operator == operator.eq:
add_equiv(binary.left, binary.right)
BinaryVisitor().traverse(self.onclause)
@@ -2294,9 +2321,12 @@ class Join(FromClause):
omit.add(p)
p = c
- self.__primary_key = ColumnSet([c for c in self._adjusted_exportable_columns() if c.primary_key and c not in omit])
+ self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit])
primary_key = property(lambda s:s.__primary_key)
+
+ def self_group(self, against=None):
+ return _Grouping(self)
def _locate_oid_column(self):
return self.left.oid_column
@@ -2310,6 +2340,17 @@ class Join(FromClause):
self._foreign_keys.add(f)
return column
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self.left = self.left._clone()
+ self.right = self.right._clone()
+ self.onclause = self.onclause._clone()
+ self.__folded_equivalents = None
+ self._init_primary_key()
+
+ def get_children(self, **kwargs):
+ return self.left, self.right, self.onclause
+
def _match_primaries(self, primary, secondary):
crit = []
constraints = util.Set()
@@ -2338,9 +2379,6 @@ class Join(FromClause):
else:
return and_(*crit)
- def _group_parenthesized(self):
- return True
-
def _get_folded_equivalents(self, equivs=None):
if self.__folded_equivalents is not None:
return self.__folded_equivalents
@@ -2348,7 +2386,7 @@ class Join(FromClause):
equivs = util.Set()
class LocateEquivs(NoColumnVisitor):
def visit_binary(self, binary):
- if binary.operator == '=' and binary.left.name == binary.right.name:
+ if binary.operator == operator.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
LocateEquivs().traverse(self.onclause)
@@ -2401,13 +2439,7 @@ class Join(FromClause):
return select(collist, whereclause, from_obj=[self], **kwargs)
- def get_children(self, **kwargs):
- return self.left, self.right, self.onclause
-
- def accept_visitor(self, visitor):
- visitor.visit_join(self)
-
- engine = property(lambda s:s.left.engine or s.right.engine)
+ bind = property(lambda s:s.left.bind or s.right.bind)
def alias(self, name=None):
"""Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it.
@@ -2417,11 +2449,11 @@ class Join(FromClause):
return self.select(use_labels=True, correlate=False).alias(name)
- def _hide_froms(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
class Alias(FromClause):
"""represent an alias, as typically applied to any
@@ -2443,15 +2475,22 @@ class Alias(FromClause):
if alias is None:
if self.original.named_with_column():
alias = getattr(self.original, 'name', None)
- if alias is None:
- alias = 'anon'
- elif len(alias) > 15:
- alias = alias[0:15]
- alias = alias + "_" + hex(random.randint(0, 65535))[2:]
+ alias = '{ANON %d %s}' % (id(self), alias or 'anon')
self.name = alias
self.encodedname = alias.encode('ascii', 'backslashreplace')
self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
+ def is_derived_from(self, fromclause):
+ x = self.selectable
+ while True:
+ if x is fromclause:
+ return True
+ if isinstance(x, Alias):
+ x = x.selectable
+ else:
+ break
+ return False
+
def supports_execution(self):
return self.original.supports_execution()
@@ -2468,46 +2507,57 @@ class Alias(FromClause):
#return self.selectable._exportable_columns()
return self.selectable.columns
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self.selectable = self.selectable._clone()
+ baseselectable = self.selectable
+ while isinstance(baseselectable, Alias):
+ baseselectable = baseselectable.selectable
+ self.original = baseselectable
+
def get_children(self, **kwargs):
for c in self.c:
yield c
yield self.selectable
- def accept_visitor(self, visitor):
- visitor.visit_alias(self)
-
def _get_from_objects(self):
return [self]
- def _group_parenthesized(self):
- return False
-
bind = property(lambda s: s.selectable.bind)
- engine = bind
-class _Grouping(ColumnElement):
+class _ColumnElementAdapter(ColumnElement):
+ """adapts a ClauseElement which may or may not be a
+ ColumnElement subclass itself into an object which
+ acts like a ColumnElement.
+ """
+
def __init__(self, elem):
self.elem = elem
self.type = getattr(elem, 'type', None)
-
+ self.orig_set = getattr(elem, 'orig_set', util.Set())
+
key = property(lambda s: s.elem.key)
_label = property(lambda s: s.elem._label)
- orig_set = property(lambda s:s.elem.orig_set)
-
- def copy_container(self):
- return _Grouping(self.elem.copy_container())
-
- def accept_visitor(self, visitor):
- visitor.visit_grouping(self)
+ columns = c = property(lambda s:s.elem.columns)
+
+ def _copy_internals(self):
+ self.elem = self.elem._clone()
+
def get_children(self, **kwargs):
return self.elem,
- def _hide_froms(self):
- return self.elem._hide_froms()
- def _get_from_objects(self):
- return self.elem._get_from_objects()
+
+ def _hide_froms(self, **modifiers):
+ return self.elem._hide_froms(**modifiers)
+
+ def _get_from_objects(self, **modifiers):
+ return self.elem._get_from_objects(**modifiers)
+
def __getattr__(self, attr):
return getattr(self.elem, attr)
-
+
+class _Grouping(_ColumnElementAdapter):
+ pass
+
class _Label(ColumnElement):
"""represent a label, as typically applied to any column-level element
using the ``AS`` sql keyword.
@@ -2518,32 +2568,33 @@ class _Label(ColumnElement):
"""
- def __init__(self, name, obj, type=None):
- self.name = name
+ def __init__(self, name, obj, type_=None):
while isinstance(obj, _Label):
obj = obj.obj
- self.obj = obj.self_group(against='AS')
+ self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
+
+ self.obj = obj.self_group(against=Operators.as_)
self.case_sensitive = getattr(obj, "case_sensitive", True)
- self.type = sqltypes.to_instance(type or getattr(obj, 'type', None))
+ self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
key = property(lambda s: s.name)
_label = property(lambda s: s.name)
orig_set = property(lambda s:s.obj.orig_set)
- def _compare_self(self):
+ def expression_element(self):
return self.obj
-
+
+ def _copy_internals(self):
+ self.obj = self.obj._clone()
+
def get_children(self, **kwargs):
return self.obj,
- def accept_visitor(self, visitor):
- visitor.visit_label(self)
+ def _get_from_objects(self, **modifiers):
+ return self.obj._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.obj._get_from_objects()
-
- def _hide_froms(self):
- return self.obj._hide_froms()
+ def _hide_froms(self, **modifiers):
+ return self.obj._hide_froms(**modifiers)
def _make_proxy(self, selectable, name = None):
if isinstance(self.obj, Selectable):
@@ -2551,8 +2602,6 @@ class _Label(ColumnElement):
else:
return column(self.name)._make_proxy(selectable=selectable)
-legal_characters = util.Set(string.ascii_letters + string.digits + '_')
-
class _ColumnClause(ColumnElement):
"""Represents a generic column expression from any textual string.
This includes columns associated with tables, aliases and select
@@ -2584,17 +2633,21 @@ class _ColumnClause(ColumnElement):
"""
- def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
+ def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False):
self.key = self.name = text
self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name
self.table = selectable
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self._is_oid = _is_oid
self._distance = 0
self.__label = None
self.case_sensitive = case_sensitive
self.is_literal = is_literal
-
+
+ def _clone(self):
+ # ColumnClause is immutable
+ return self
+
def _get_label(self):
"""Generate a 'label' for this column.
@@ -2617,7 +2670,6 @@ class _ColumnClause(ColumnElement):
counter += 1
else:
self.__label = self.name
- self.__label = "".join([x for x in self.__label if x in legal_characters])
return self.__label
is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name)
@@ -2632,23 +2684,20 @@ class _ColumnClause(ColumnElement):
else:
return super(_ColumnClause, self).label(name)
- def accept_visitor(self, visitor):
- visitor.visit_column(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
if self.table is not None:
return [self.table]
else:
return []
def _bind_param(self, obj):
- return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True)
+ return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True)
def _make_proxy(self, selectable, name = None):
# propigate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
- c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal)
+ c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
c.orig_set = self.orig_set
c._distance = self._distance + 1
if not self._is_oid:
@@ -2658,9 +2707,6 @@ class _ColumnClause(ColumnElement):
def _compare_type(self, obj):
return self.type
- def _group_parenthesized(self):
- return False
-
class TableClause(FromClause):
"""represents a "table" construct.
@@ -2677,6 +2723,10 @@ class TableClause(FromClause):
self._oid_column = _ColumnClause('oid', self, _is_oid=True)
self._export_columns(columns)
+ def _clone(self):
+ # TableClause is immutable
+ return self
+
def named_with_column(self):
return True
@@ -2709,15 +2759,9 @@ class TableClause(FromClause):
else:
return []
- def accept_visitor(self, visitor):
- visitor.visit_table(self)
-
def _exportable_columns(self):
raise NotImplementedError()
- def _group_parenthesized(self):
- return False
-
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
@@ -2746,68 +2790,120 @@ class TableClause(FromClause):
def delete(self, whereclause = None):
return delete(self, whereclause)
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return [self]
+
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
+ def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None):
+ self.use_labels = use_labels
+ self.for_update = for_update
+ self._limit = limit
+ self._offset = offset
+ self._bind = bind
+
+ self.append_order_by(*util.to_list(order_by, []))
+ self.append_group_by(*util.to_list(group_by, []))
+
+ def as_scalar(self):
+ return _ScalarSelect(self)
+
+ def label(self, name):
+ return self.as_scalar().label(name)
+
def supports_execution(self):
return True
+ def _generate(self):
+ s = self._clone()
+ s._clone_from_clause()
+ return s
+
+ def limit(self, limit):
+ s = self._generate()
+ s._limit = limit
+ return s
+
+ def offset(self, offset):
+ s = self._generate()
+ s._offset = offset
+ return s
+
def order_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.order_by_clause = ClauseList()
- elif getattr(self, 'order_by_clause', None):
- self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses)))
- else:
- self.order_by_clause = ClauseList(*clauses)
+ s = self._generate()
+ s.append_order_by(*clauses)
+ return s
def group_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.group_by_clause = ClauseList()
- elif getattr(self, 'group_by_clause', None):
- self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses)))
+ s = self._generate()
+ s.append_group_by(*clauses)
+ return s
+
+ def append_order_by(self, *clauses):
+ if clauses == [None]:
+ self._order_by_clause = ClauseList()
else:
- self.group_by_clause = ClauseList(*clauses)
+ if getattr(self, '_order_by_clause', None):
+ clauses = list(self._order_by_clause) + list(clauses)
+ self._order_by_clause = ClauseList(*clauses)
+ def append_group_by(self, *clauses):
+ if clauses == [None]:
+ self._group_by_clause = ClauseList()
+ else:
+ if getattr(self, '_group_by_clause', None):
+ clauses = list(self._group_by_clause) + list(clauses)
+ self._group_by_clause = ClauseList(*clauses)
+
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
- def _get_from_objects(self):
- if self.is_where or self.is_scalar:
+ def _get_from_objects(self, is_where=False, **modifiers):
+ if is_where:
return []
else:
return [self]
+class _ScalarSelect(_Grouping):
+ __visit_name__ = 'grouping'
+
+ def __init__(self, elem):
+ super(_ScalarSelect, self).__init__(elem)
+ self.type = list(elem.inner_columns)[0].type
+
+ columns = property(lambda self:[self])
+
+ def self_group(self, **kwargs):
+ return self
+
+ def _make_proxy(self, selectable, name):
+ return list(self.inner_columns)[0]._make_proxy(selectable, name)
+
+ def _get_from_objects(self, **modifiers):
+ return []
+
class CompoundSelect(_SelectBaseMixin, FromClause):
def __init__(self, keyword, *selects, **kwargs):
- _SelectBaseMixin.__init__(self)
+ self._should_correlate = kwargs.pop('correlate', False)
self.keyword = keyword
- self.use_labels = kwargs.pop('use_labels', False)
- self.should_correlate = kwargs.pop('correlate', False)
- self.for_update = kwargs.pop('for_update', False)
- self.nowait = kwargs.pop('nowait', False)
- self.limit = kwargs.pop('limit', None)
- self.offset = kwargs.pop('offset', None)
- self.is_compound = True
- self.is_where = False
- self.is_scalar = False
- self.is_subquery = False
-
- # unions group from left to right, so don't group first select
- self.selects = [n and select.self_group(self) or select for n,select in enumerate(selects)]
+ self.selects = []
# some DBs do not like ORDER BY in the inner queries of a UNION, etc.
- for s in selects:
- s.order_by(None)
+ for n, s in enumerate(selects):
+ if len(s._order_by_clause):
+ s = s.order_by(None)
+ # unions group from left to right, so don't group first select
+ if n:
+ self.selects.append(s.self_group(self))
+ else:
+ self.selects.append(s)
- self.group_by(*kwargs.pop('group_by', [None]))
- self.order_by(*kwargs.pop('order_by', [None]))
- if len(kwargs):
- raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys()))
self._col_map = {}
+ _SelectBaseMixin.__init__(self, **kwargs)
+
name = property(lambda s:s.keyword + " statement")
def self_group(self, against=None):
@@ -2835,12 +2931,18 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
col.orig_set = colset
return col
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self._col_map = {}
+ self.selects = [s._clone() for s in self.selects]
+ for attr in ('_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+
def get_children(self, column_collections=True, **kwargs):
return (column_collections and list(self.c) or []) + \
- [self.order_by_clause, self.group_by_clause] + list(self.selects)
- def accept_visitor(self, visitor):
- visitor.visit_compound_select(self)
-
+ [self._order_by_clause, self._group_by_clause] + list(self.selects)
+
def _find_engine(self):
for s in self.selects:
e = s._find_engine()
@@ -2855,160 +2957,287 @@ class Select(_SelectBaseMixin, FromClause):
"""
- def __init__(self, columns=None, whereclause=None, from_obj=[],
- order_by=None, group_by=None, having=None,
- use_labels=False, distinct=False, for_update=False,
- engine=None, bind=None, limit=None, offset=None, scalar=False,
- correlate=True):
+ def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs):
"""construct a Select object.
The public constructor for Select is the [sqlalchemy.sql#select()] function;
see that function for argument descriptions.
"""
- _SelectBaseMixin.__init__(self)
- self.__froms = util.OrderedSet()
- self.__hide_froms = util.Set([self])
- self.use_labels = use_labels
- self.whereclause = None
- self.having = None
- self._bind = bind or engine
- self.limit = limit
- self.offset = offset
- self.for_update = for_update
- self.is_compound = False
-
- # indicates that this select statement should not expand its columns
- # into the column clause of an enclosing select, and should instead
- # act like a single scalar column
- self.is_scalar = scalar
- if scalar:
- # allow corresponding_column to return None
- self.orig_set = util.Set()
-
- # indicates if this select statement, as a subquery, should automatically correlate
- # its FROM clause to that of an enclosing select, update, or delete statement.
- # note that the "correlate" method can be used to explicitly add a value to be correlated.
- self.should_correlate = correlate
-
- # indicates if this select statement is a subquery inside another query
- self.is_subquery = False
-
- # indicates if this select statement is in the from clause of another query
- self.is_selected_from = False
-
- # indicates if this select statement is a subquery as a criterion
- # inside of a WHERE clause
- self.is_where = False
+
+ self._should_correlate = correlate
+ self._distinct = distinct
- self.distinct = distinct
self._raw_columns = []
- self.__correlated = {}
- self.__correlator = Select._CorrelatedVisitor(self, False)
- self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
- self.__fromvisitor = Select._FromVisitor(self)
-
-
- self.order_by_clause = self.group_by_clause = None
+ self.__correlate = util.Set()
+ self._froms = util.OrderedSet()
+ self._whereclause = None
+ self._having = None
+ self._prefixes = []
if columns is not None:
for c in columns:
self.append_column(c)
- if order_by:
- order_by = util.to_list(order_by)
- if group_by:
- group_by = util.to_list(group_by)
- self.order_by(*(order_by or [None]))
- self.group_by(*(group_by or [None]))
- for c in self.order_by_clause:
- self.__correlator.traverse(c)
- for c in self.group_by_clause:
- self.__correlator.traverse(c)
-
- for f in from_obj:
- self.append_from(f)
-
- # whereclauses must be appended after the columns/FROM, since it affects
- # the correlation of subqueries. see test/sql/select.py SelectTest.testwheresubquery
+ if from_obj is not None:
+ for f in from_obj:
+ self.append_from(f)
+
if whereclause is not None:
self.append_whereclause(whereclause)
+
if having is not None:
self.append_having(having)
+ _SelectBaseMixin.__init__(self, **kwargs)
- class _CorrelatedVisitor(NoColumnVisitor):
- """Visit a clause, locate any ``Select`` clauses, and tell
- them that they should correlate their ``FROM`` list to that of
- their parent.
+ def _get_display_froms(self, correlation_state=None):
+ """return the full list of 'from' clauses to be displayed.
+
+ takes into account an optional 'correlation_state'
+ dictionary which contains information about this Select's
+ correlation to an enclosing select, which may cause some 'from'
+ clauses to not display in this Select's FROM clause.
+ this dictionary is generated during compile time by the
+ _calculate_correlations() method.
+
"""
+ froms = util.OrderedSet()
+ hide_froms = util.Set()
+
+ for col in self._raw_columns:
+ for f in col._hide_froms():
+ hide_froms.add(f)
+ for f in col._get_from_objects():
+ froms.add(f)
+
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
+
+ for elem in froms:
+ for f in elem._hide_froms():
+ hide_froms.add(f)
+
+ froms = froms.difference(hide_froms)
+
+ if len(froms) > 1:
+ corr = self.__correlate
+ if correlation_state is not None:
+ corr = correlation_state[self].get('correlate', util.Set()).union(corr)
+ f = froms.difference(corr)
+ if len(f) == 0:
+ raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
+ return f
+ else:
+ return froms
+
+ froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
+
+ def locate_all_froms(self):
+ froms = util.Set()
+ for col in self._raw_columns:
+ for f in col._get_from_objects():
+ froms.add(f)
+
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
+ return froms
+
+ def _calculate_correlations(self, correlation_state):
+ """generate a 'correlation_state' dictionary used by the _get_display_froms() method.
+
+ The dictionary is passed in initially empty, or already
+ containing the state information added by an enclosing
+ Select construct. The method will traverse through all
+ embedded Select statements and add information about their
+ position and "from" objects to the dictionary. Those Select
+ statements will later consult the 'correlation_state' dictionary
+ when their list of 'FROM' clauses are generated using their
+ _get_display_froms() method.
+ """
+
+ if self not in correlation_state:
+ correlation_state[self] = {}
- def __init__(self, select, is_where):
- NoColumnVisitor.__init__(self)
- self.select = select
- self.is_where = is_where
-
- def visit_compound_select(self, cs):
- self.visit_select(cs)
-
- def visit_column(self, c):
- pass
-
- def visit_table(self, c):
- pass
-
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_where = self.is_where
- select.is_subquery = True
- if not select.should_correlate:
- return
- [select.correlate(x) for x in self.select._Select__froms]
+ display_froms = self._get_display_froms(correlation_state)
+
+ class CorrelatedVisitor(NoColumnVisitor):
+ def __init__(self, is_where=False, is_column=False, is_from=False):
+ self.is_where = is_where
+ self.is_column = is_column
+ self.is_from = is_from
+
+ def visit_compound_select(self, cs):
+ self.visit_select(cs)
- class _FromVisitor(NoColumnVisitor):
- def __init__(self, select):
- NoColumnVisitor.__init__(self)
- self.select = select
+ def visit_select(s, select):
+ if select not in correlation_state:
+ correlation_state[select] = {}
+
+ if select is self:
+ return
+
+ select_state = correlation_state[select]
+ if s.is_from:
+ select_state['is_selected_from'] = True
+ if s.is_where:
+ select_state['is_where'] = True
+ select_state['is_subquery'] = True
+
+ if select._should_correlate:
+ corr = select_state.setdefault('correlate', util.Set())
+ # not crazy about this part. need to be clearer on what elements in the
+ # subquery correspond to elements in the enclosing query.
+ for f in display_froms:
+ corr.add(f)
+ for f2 in f._get_from_objects():
+ corr.add(f2)
+
+ col_vis = CorrelatedVisitor(is_column=True)
+ where_vis = CorrelatedVisitor(is_where=True)
+ from_vis = CorrelatedVisitor(is_from=True)
+
+ for col in self._raw_columns:
+ col_vis.traverse(col)
+ for f in col._get_from_objects():
+ if f is not self:
+ from_vis.traverse(f)
+
+ for col in list(self._order_by_clause) + list(self._group_by_clause):
+ col_vis.traverse(col)
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_selected_from = True
- select.is_subquery = True
+ if self._whereclause is not None:
+ where_vis.traverse(self._whereclause)
+ for f in self._whereclause._get_from_objects(is_where=True):
+ if f is not self:
+ from_vis.traverse(f)
+
+ for elem in self._froms:
+ from_vis.traverse(elem)
+
+ def _get_inner_columns(self):
+ for c in self._raw_columns:
+ if isinstance(c, Selectable):
+ for co in c.columns:
+ yield co
+ else:
+ yield c
+
+ inner_columns = property(_get_inner_columns)
+
+ def _copy_internals(self):
+ self._clone_from_clause()
+ self._raw_columns = [c._clone() for c in self._raw_columns]
+ self._recorrelate_froms([(f, f._clone()) for f in self._froms])
+ for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+ def get_children(self, column_collections=True, **kwargs):
+ return (column_collections and list(self.columns) or []) + \
+ list(self._froms) + \
+ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+
+ def _recorrelate_froms(self, froms):
+ newcorrelate = util.Set()
+ newfroms = util.Set()
+ oldfroms = util.Set(self._froms)
+ for old, new in froms:
+ if old in self.__correlate:
+ newcorrelate.add(new)
+ self.__correlate.remove(old)
+ if old in oldfroms:
+ newfroms.add(new)
+ oldfroms.remove(old)
+ self.__correlate = self.__correlate.union(newcorrelate)
+ self._froms = [f for f in oldfroms.union(newfroms)]
+
+ def column(self, column):
+ s = self._generate()
+ s.append_column(column)
+ return s
+
+ def where(self, whereclause):
+ s = self._generate()
+ s.append_whereclause(whereclause)
+ return s
+
+ def having(self, having):
+ s = self._generate()
+ s.append_having(having)
+ return s
+
+ def distinct(self):
+ s = self._generate()
+ s.distinct = True
+ return s
+
+ def prefix_with(self, clause):
+ s = self._generate()
+ s.append_prefix(clause)
+ return s
+
+ def select_from(self, fromclause):
+ s = self._generate()
+ s.append_from(fromclause)
+ return s
+
+ def __dont_correlate(self):
+ s = self._generate()
+ s._should_correlate = False
+ return s
+
+ def correlate(self, fromclause):
+ s = self._generate()
+ s._should_correlate=False
+ if fromclause is None:
+ s.__correlate = util.Set()
+ else:
+ s.append_correlation(fromclause)
+ return s
+
+ def append_correlation(self, fromclause):
+ self.__correlate.add(fromclause)
+
def append_column(self, column):
- if _is_literal(column):
- column = literal_column(str(column))
+ column = _literal_as_column(column)
- if isinstance(column, Select) and column.is_scalar:
- column = column.self_group(against=',')
+ if isinstance(column, _ScalarSelect):
+ column = column.self_group(against=ColumnOperators.comma_op)
self._raw_columns.append(column)
-
- if self.is_scalar and not hasattr(self, 'type'):
- self.type = column.type
+
+ def append_prefix(self, clause):
+ clause = _literal_as_text(clause)
+ self._prefixes.append(clause)
- # if the column is a Select statement itself,
- # accept visitor
- self.__correlator.traverse(column)
-
- # visit the FROM objects of the column looking for more Selects
- for f in column._get_from_objects():
- if f is not self:
- self.__correlator.traverse(f)
- self._process_froms(column, False)
-
- def _make_proxy(self, selectable, name):
- if self.is_scalar:
- return self._raw_columns[0]._make_proxy(selectable, name)
+ def append_whereclause(self, whereclause):
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
-
- def label(self, name):
- if not self.is_scalar:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
+ self._whereclause = _literal_as_text(whereclause)
+
+ def append_having(self, having):
+ if self._having is not None:
+ self._having = and_(self._having, _literal_as_text(having))
else:
- return label(name, self)
+ self._having = _literal_as_text(having)
+
+ def append_from(self, fromclause):
+ if _is_literal(fromclause):
+ fromclause = FromClause(fromclause)
+ self._froms.add(fromclause)
def _exportable_columns(self):
return [c for c in self._raw_columns if isinstance(c, Selectable)]
@@ -3019,53 +3248,13 @@ class Select(_SelectBaseMixin, FromClause):
else:
return column._make_proxy(self)
- def _process_froms(self, elem, asfrom):
- for f in elem._get_from_objects():
- self.__fromvisitor.traverse(f)
- self.__froms.add(f)
- if asfrom:
- self.__froms.add(elem)
- for f in elem._hide_froms():
- self.__hide_froms.add(f)
-
def self_group(self, against=None):
if isinstance(against, CompoundSelect):
return self
return _Grouping(self)
-
- def append_whereclause(self, whereclause):
- self._append_condition('whereclause', whereclause)
-
- def append_having(self, having):
- self._append_condition('having', having)
-
- def _append_condition(self, attribute, condition):
- if isinstance(condition, basestring):
- condition = _TextClause(condition)
- self.__wherecorrelator.traverse(condition)
- self._process_froms(condition, False)
- if getattr(self, attribute) is not None:
- setattr(self, attribute, and_(getattr(self, attribute), condition))
- else:
- setattr(self, attribute, condition)
-
- def correlate(self, from_obj):
- """Given a ``FROM`` object, correlate this ``SELECT`` statement to it.
-
- This basically means the given from object will not come out
- in this select statement's ``FROM`` clause when printed.
- """
-
- self.__correlated[from_obj] = from_obj
-
- def append_from(self, fromclause):
- if isinstance(fromclause, basestring):
- fromclause = FromClause(fromclause)
- self.__correlator.traverse(fromclause)
- self._process_froms(fromclause, True)
def _locate_oid_column(self):
- for f in self.__froms:
+ for f in self.locate_all_froms():
if f is self:
# we might be in our own _froms list if a column with us as the parent is attached,
# which includes textual columns.
@@ -3076,25 +3265,6 @@ class Select(_SelectBaseMixin, FromClause):
else:
return None
- def _calc_froms(self):
- f = self.__froms.difference(self.__hide_froms)
- if (len(f) > 1):
- return f.difference(self.__correlated)
- else:
- return f
-
- froms = property(_calc_froms,
- doc="""A collection containing all elements
- of the ``FROM`` clause.""")
-
- def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.columns) or []) + \
- list(self.froms) + \
- [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None]
-
- def accept_visitor(self, visitor):
- visitor.visit_select(self)
-
def union(self, other, **kwargs):
return union(self, other, **kwargs)
@@ -3108,7 +3278,7 @@ class Select(_SelectBaseMixin, FromClause):
if self._bind is not None:
return self._bind
- for f in self.__froms:
+ for f in self._froms:
if f is self:
continue
e = f.bind
@@ -3133,20 +3303,24 @@ class _UpdateBase(ClauseElement):
def supports_execution(self):
return True
- class _SelectCorrelator(NoColumnVisitor):
- def __init__(self, table):
- NoColumnVisitor.__init__(self)
- self.table = table
-
- def visit_select(self, select):
- if select.should_correlate:
- select.correlate(self.table)
-
- def _process_whereclause(self, whereclause):
- if whereclause is not None:
- _UpdateBase._SelectCorrelator(self.table).traverse(whereclause)
- return whereclause
+ def _calculate_correlations(self, correlate_state):
+ class SelectCorrelator(NoColumnVisitor):
+ def visit_select(s, select):
+ if select._should_correlate:
+ select_state = correlate_state.setdefault(select, {})
+ corr = select_state.setdefault('correlate', util.Set())
+ corr.add(self.table)
+
+ vis = SelectCorrelator()
+ if self._whereclause is not None:
+ vis.traverse(self._whereclause)
+
+ if getattr(self, 'parameters', None) is not None:
+ for key, value in self.parameters.items():
+ if isinstance(value, ClauseElement):
+ vis.traverse(value)
+
def _process_colparams(self, parameters):
"""Receive the *values* of an ``INSERT`` or ``UPDATE``
statement and construct appropriate bind parameters.
@@ -3155,7 +3329,7 @@ class _UpdateBase(ClauseElement):
if parameters is None:
return None
- if isinstance(parameters, list) or isinstance(parameters, tuple):
+ if isinstance(parameters, (list, tuple)):
pp = {}
i = 0
for c in self.table.c:
@@ -3163,11 +3337,10 @@ class _UpdateBase(ClauseElement):
i +=1
parameters = pp
- correlator = _UpdateBase._SelectCorrelator(self.table)
for key in parameters.keys():
value = parameters[key]
if isinstance(value, ClauseElement):
- correlator.traverse(value)
+ parameters[key] = value.self_group()
elif _is_literal(value):
if _is_literal(key):
col = self.table.c[key]
@@ -3182,7 +3355,7 @@ class _UpdateBase(ClauseElement):
def _find_engine(self):
return self.table.bind
-class _Insert(_UpdateBase):
+class Insert(_UpdateBase):
def __init__(self, table, values=None):
self.table = table
self.select = None
@@ -3193,32 +3366,41 @@ class _Insert(_UpdateBase):
return self.select,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_insert(self)
-class _Update(_UpdateBase):
+class Update(_UpdateBase):
def __init__(self, table, whereclause, values=None):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
self.parameters = self._process_colparams(values)
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_update(self)
-class _Delete(_UpdateBase):
+class Delete(_UpdateBase):
def __init__(self, table, whereclause):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_delete(self)
+
+class _IdentifiedClause(ClauseElement):
+ def __init__(self, ident):
+ self.ident = ident
+ def supports_execution(self):
+ return True
+
+class SavepointClause(_IdentifiedClause):
+ pass
+
+class RollbackToSavepointClause(_IdentifiedClause):
+ pass
+
+class ReleaseSavepointClause(_IdentifiedClause):
+ pass
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py
index 9235b9c4e..d91fbe4b5 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql_util.py
@@ -53,7 +53,7 @@ class TableCollection(object):
for table in self.tables:
vis.traverse(table)
sorter = topological.QueueDependencySorter( tuples, self.tables )
- head = sorter.sort()
+ head = sorter.sort()
sequence = []
def to_sequence( node, seq=sequence):
seq.append( node.item )
@@ -67,12 +67,12 @@ class TableCollection(object):
class TableFinder(TableCollection, sql.NoColumnVisitor):
"""locate all Tables within a clause."""
- def __init__(self, table, check_columns=False, include_aliases=False):
+ def __init__(self, clause, check_columns=False, include_aliases=False):
TableCollection.__init__(self)
self.check_columns = check_columns
self.include_aliases = include_aliases
- if table is not None:
- self.traverse(table)
+ for clause in util.to_list(clause):
+ self.traverse(clause)
def visit_alias(self, alias):
if self.include_aliases:
@@ -83,7 +83,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor):
def visit_column(self, column):
if self.check_columns:
- self.traverse(column.table)
+ self.tables.append(column.table)
class ColumnFinder(sql.ClauseVisitor):
def __init__(self):
@@ -125,7 +125,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
process the new list.
"""
- list_ = [o.copy_container() for o in list_]
+ list_ = list(list_)
self.process_list(list_)
return list_
@@ -137,7 +137,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
if elem is not None:
list_[i] = elem
else:
- self.traverse(list_[i])
+ list_[i] = self.traverse(list_[i], clone=True)
def visit_grouping(self, grouping):
elem = self.convert_element(grouping.elem)
@@ -162,8 +162,24 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
elem = self.convert_element(binary.right)
if elem is not None:
binary.right = elem
-
- # TODO: visit_select().
+
+ def visit_select(self, select):
+ fr = util.OrderedSet()
+ for elem in select._froms:
+ n = self.convert_element(elem)
+ if n is not None:
+ fr.add((elem, n))
+ select._recorrelate_froms(fr)
+
+ col = []
+ for elem in select._raw_columns:
+ print "RAW COLUMN", elem
+ n = self.convert_element(elem)
+ if n is None:
+ col.append(elem)
+ else:
+ col.append(n)
+ select._raw_columns = col
class ClauseAdapter(AbstractClauseProcessor):
"""Given a clause (like as in a WHERE criterion), locate columns
@@ -200,6 +216,9 @@ class ClauseAdapter(AbstractClauseProcessor):
self.equivalents = equivalents
def convert_element(self, col):
+ if isinstance(col, sql.FromClause):
+ if self.selectable.is_derived_from(col):
+ return self.selectable
if not isinstance(col, sql.ColumnElement):
return None
if self.include is not None:
@@ -214,4 +233,9 @@ class ClauseAdapter(AbstractClauseProcessor):
newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False)
if newcol:
return newcol
+ #if newcol is None:
+ # self.traverse(col)
+ # return col
return newcol
+
+
diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py
index bad71293e..56c8cb46e 100644
--- a/lib/sqlalchemy/topological.py
+++ b/lib/sqlalchemy/topological.py
@@ -42,7 +42,6 @@ nature - very tricky to reproduce and track down, particularly before
I realized this characteristic of the algorithm.
"""
-import string, StringIO
from sqlalchemy import util
from sqlalchemy.exceptions import CircularDependencyError
@@ -68,7 +67,7 @@ class _Node(object):
str(self.item) + \
(self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \
"\n" + \
- string.join([n.safestr(indent + 1) for n in self.children], '')
+ ''.join([n.safestr(indent + 1) for n in self.children])
def __repr__(self):
return "%s" % (str(self.item))
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 3cceedae6..ec1459852 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -7,34 +7,29 @@
__all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'FLOAT', 'DECIMAL',
'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN', 'String', 'Integer', 'SmallInteger','Smallinteger',
- 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE',
+ 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE', 'NullType',
'SMALLINT', 'DATE', 'TIME','Interval'
]
-from sqlalchemy import util, exceptions
-import inspect, weakref
+import inspect
import datetime as dt
+from decimal import Decimal
try:
import cPickle as pickle
except:
import pickle
-_impl_cache = weakref.WeakKeyDictionary()
+from sqlalchemy import exceptions
class AbstractType(object):
- def _get_impl_dict(self):
- try:
- return _impl_cache[self]
- except KeyError:
- return _impl_cache.setdefault(self, {})
-
- impl_dict = property(_get_impl_dict)
-
+ def __init__(self, *args, **kwargs):
+ pass
+
def copy_value(self, value):
return value
def compare_values(self, x, y):
- return x is y
+ return x == y
def is_mutable(self):
return False
@@ -51,15 +46,20 @@ class AbstractType(object):
return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]]))
class TypeEngine(AbstractType):
- def __init__(self, *args, **params):
- pass
-
def dialect_impl(self, dialect):
try:
- return self.impl_dict[dialect]
+ return self._impl_dict[dialect]
+ except AttributeError:
+ self._impl_dict = {}
+ return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
except KeyError:
- return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self))
-
+ return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d['_impl_dict'] = {}
+ return d
+
def get_col_spec(self):
raise NotImplementedError()
@@ -88,15 +88,19 @@ class TypeDecorator(AbstractType):
def dialect_impl(self, dialect):
try:
- return self.impl_dict[dialect]
- except:
- typedesc = self.load_dialect_impl(dialect)
- tt = self.copy()
- if not isinstance(tt, self.__class__):
- raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
- tt.impl = typedesc
- self.impl_dict[dialect] = tt
- return tt
+ return self._impl_dict[dialect]
+ except AttributeError:
+ self._impl_dict = {}
+ except KeyError:
+ pass
+
+ typedesc = self.load_dialect_impl(dialect)
+ tt = self.copy()
+ if not isinstance(tt, self.__class__):
+ raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
+ tt.impl = typedesc
+ self._impl_dict[dialect] = tt
+ return tt
def load_dialect_impl(self, dialect):
"""loads the dialect-specific implementation of this type.
@@ -179,7 +183,7 @@ def adapt_type(typeobj, colspecs):
return typeobj
return typeobj.adapt(impltype)
-class NullTypeEngine(TypeEngine):
+class NullType(TypeEngine):
def get_col_spec(self):
raise NotImplementedError()
@@ -188,8 +192,13 @@ class NullTypeEngine(TypeEngine):
def convert_result_value(self, value, dialect):
return value
+NullTypeEngine = NullType
-class String(TypeEngine):
+class Concatenable(object):
+ """marks a type as supporting 'concatenation'"""
+ pass
+
+class String(TypeEngine, Concatenable):
def __init__(self, length=None, convert_unicode=False):
self.length = length
self.convert_unicode = convert_unicode
@@ -219,9 +228,6 @@ class String(TypeEngine):
def get_dbapi_type(self, dbapi):
return dbapi.STRING
- def compare_values(self, x, y):
- return x == y
-
class Unicode(String):
def __init__(self, length=None, **kwargs):
kwargs['convert_unicode'] = True
@@ -241,22 +247,36 @@ class SmallInteger(Integer):
Smallinteger = SmallInteger
class Numeric(TypeEngine):
- def __init__(self, precision = 10, length = 2):
+ def __init__(self, precision = 10, length = 2, asdecimal=True):
self.precision = precision
self.length = length
+ self.asdecimal = asdecimal
def adapt(self, impltype):
- return impltype(precision=self.precision, length=self.length)
+ return impltype(precision=self.precision, length=self.length, asdecimal=self.asdecimal)
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
+ def convert_bind_param(self, value, dialect):
+ if value is not None:
+ return float(value)
+ else:
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if value is not None and self.asdecimal:
+ return Decimal(str(value))
+ else:
+ return value
+
class Float(Numeric):
- def __init__(self, precision = 10):
+ def __init__(self, precision = 10, asdecimal=False, **kwargs):
self.precision = precision
-
+ self.asdecimal = asdecimal
+
def adapt(self, impltype):
- return impltype(precision=self.precision)
+ return impltype(precision=self.precision, asdecimal=self.asdecimal)
class DateTime(TypeEngine):
"""Implement a type for ``datetime.datetime()`` objects."""
@@ -416,4 +436,4 @@ class NCHAR(Unicode):pass
class BLOB(Binary): pass
class BOOLEAN(Boolean): pass
-NULLTYPE = NullTypeEngine()
+NULLTYPE = NullType()
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index b47822d61..e711de3a3 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -14,7 +14,6 @@ from sqlalchemy import exceptions
import md5
import sys
import warnings
-
import __builtin__
try:
@@ -33,10 +32,35 @@ except:
i -= 1
raise StopIteration()
-def to_list(x):
+if sys.version_info >= (2, 5):
+ class PopulateDict(dict):
+ """a dict which populates missing values via a creation function.
+
+ note the creation function takes a key, unlike collections.defaultdict.
+ """
+
+ def __init__(self, creator):
+ self.creator = creator
+ def __missing__(self, key):
+ self[key] = val = self.creator(key)
+ return val
+else:
+ class PopulateDict(dict):
+ """a dict which populates missing values via a creation function."""
+
+ def __init__(self, creator):
+ self.creator = creator
+ def __getitem__(self, key):
+ try:
+ return dict.__getitem__(self, key)
+ except KeyError:
+ self[key] = value = self.creator(key)
+ return value
+
+def to_list(x, default=None):
if x is None:
- return None
- if not isinstance(x, list) and not isinstance(x, tuple):
+ return default
+ if not isinstance(x, (list, tuple)):
return [x]
else:
return x
@@ -113,19 +137,25 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True):
else:
kw[key] = type_(kw[key])
-def duck_type_collection(col, default=None):
+def duck_type_collection(specimen, default=None):
"""Given an instance or class, guess if it is or is acting as one of
the basic collection types: list, set and dict. If the __emulates__
property is present, return that preferentially.
"""
- if hasattr(col, '__emulates__'):
- return getattr(col, '__emulates__')
- elif hasattr(col, 'append'):
+ if hasattr(specimen, '__emulates__'):
+ return specimen.__emulates__
+
+ isa = isinstance(specimen, type) and issubclass or isinstance
+ if isa(specimen, list): return list
+ if isa(specimen, Set): return Set
+ if isa(specimen, dict): return dict
+
+ if hasattr(specimen, 'append'):
return list
- elif hasattr(col, 'add'):
+ elif hasattr(specimen, 'add'):
return Set
- elif hasattr(col, 'set'):
+ elif hasattr(specimen, 'set'):
return dict
else:
return default
@@ -138,11 +168,11 @@ def assert_arg_type(arg, argtype, name):
raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
else:
raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
-
-def warn_exception(func):
+
+def warn_exception(func, *args, **kwargs):
"""executes the given function, catches all exceptions and converts to a warning."""
try:
- return func()
+ return func(*args, **kwargs)
except:
warnings.warn(RuntimeWarning("%s('%s') ignored" % sys.exc_info()[0:2]))
@@ -246,12 +276,12 @@ class OrderedProperties(object):
class OrderedDict(dict):
"""A Dictionary that returns keys/values/items in the order they were added."""
- def __init__(self, d=None, **kwargs):
+ def __init__(self, ____sequence=None, **kwargs):
self._list = []
- if d is None:
+ if ____sequence is None:
self.update(**kwargs)
else:
- self.update(d, **kwargs)
+ self.update(____sequence, **kwargs)
def clear(self):
self._list = []
@@ -347,7 +377,13 @@ class DictDecorator(dict):
return dict.__getitem__(self, key)
except KeyError:
return self.decorate[key]
-
+
+ def __contains__(self, key):
+ return dict.__contains__(self, key) or key in self.decorate
+
+ def has_key(self, key):
+ return key in self
+
def __repr__(self):
return dict.__repr__(self) + repr(self.decorate)
@@ -442,19 +478,28 @@ class OrderedSet(Set):
__isub__ = difference_update
class UniqueAppender(object):
- def __init__(self, data):
+ """appends items to a collection such that only unique items
+ are added."""
+
+ def __init__(self, data, via=None):
self.data = data
- if hasattr(data, 'append'):
+ self._unique = Set()
+ if via:
+ self._data_appender = getattr(data, via)
+ elif hasattr(data, 'append'):
self._data_appender = data.append
elif hasattr(data, 'add'):
+ # TODO: we think its a set here. bypass unneeded uniquing logic ?
self._data_appender = data.add
- self.set = Set()
-
+
def append(self, item):
- if item not in self.set:
- self.set.add(item)
+ if item not in self._unique:
self._data_appender(item)
-
+ self._unique.add(item)
+
+ def __iter__(self):
+ return iter(self.data)
+
class ScopedRegistry(object):
"""A Registry that can store one or multiple instances of a single
class on a per-thread scoped basis, or on a customized scope.