diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
| commit | ed4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch) | |
| tree | c1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy | |
| parent | 3a8e235af64e36b3b711df1f069d32359fe6c967 (diff) | |
| download | sqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz | |
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy')
47 files changed, 7159 insertions, 5413 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index ad2615131..6e95fd7e1 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -7,10 +7,8 @@ from sqlalchemy.types import * from sqlalchemy.sql import * from sqlalchemy.schema import * -from sqlalchemy.orm import * from sqlalchemy.engine import create_engine -from sqlalchemy.schema import default_metadata def __figure_version(): try: @@ -25,8 +23,6 @@ def __figure_version(): return '(not installed)' except: return '(not installed)' - + __version__ = __figure_version() - -def global_connect(*args, **kwargs): - default_metadata.connect(*args, **kwargs) + diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 9994d5288..22227d56a 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -10,9 +10,10 @@ Contains default implementations for the abstract objects in the sql module. """ -from sqlalchemy import schema, sql, engine, util, sql_util, exceptions +import string, re, sets, operator + +from sqlalchemy import schema, sql, engine, util, exceptions from sqlalchemy.engine import default -import string, re, sets, weakref, random ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', @@ -40,6 +41,41 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') +BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) +BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE) + +OPERATORS = { + operator.and_ : 'AND', + operator.or_ : 'OR', + operator.inv : 'NOT', + operator.add : '+', + operator.mul : '*', + operator.sub : '-', + operator.div : '/', + operator.mod : '%', + operator.truediv : '/', + operator.lt : '<', + operator.le : '<=', + operator.ne : '!=', + operator.gt : '>', + operator.ge : '>=', + operator.eq : '=', + sql.ColumnOperators.concat_op : '||', + sql.ColumnOperators.like_op : 'LIKE', + sql.ColumnOperators.notlike_op : 'NOT LIKE', + sql.ColumnOperators.ilike_op : 'ILIKE', + sql.ColumnOperators.notilike_op : 'NOT ILIKE', + sql.ColumnOperators.between_op : 'BETWEEN', + sql.ColumnOperators.in_op : 'IN', + sql.ColumnOperators.notin_op : 'NOT IN', + sql.ColumnOperators.comma_op : ', ', + sql.Operators.from_ : 'FROM', + sql.Operators.as_ : 'AS', + sql.Operators.exists : 'EXISTS', + sql.Operators.is_ : 'IS', + sql.Operators.isnot : 'IS NOT' +} + class ANSIDialect(default.DefaultDialect): def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) @@ -66,14 +102,16 @@ class ANSIDialect(default.DefaultDialect): """ return ANSIIdentifierPreparer(self) -class ANSICompiler(sql.Compiled): +class ANSICompiler(engine.Compiled, sql.ClauseVisitor): """Default implementation of Compiled. Compiles ClauseElements into ANSI-compliant SQL strings. """ - __traverse_options__ = {'column_collections':False} + __traverse_options__ = {'column_collections':False, 'entry':True} + operators = OPERATORS + def __init__(self, dialect, statement, parameters=None, **kwargs): """Construct a new ``ANSICompiler`` object. @@ -92,7 +130,7 @@ class ANSICompiler(sql.Compiled): correspond to the keys present in the parameters. """ - sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs) + super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs) # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False @@ -104,21 +142,6 @@ class ANSICompiler(sql.Compiled): # actually present in the generated SQL self.bind_names = {} - # a dictionary which stores the string representation for every ClauseElement - # processed by this compiler. - self.strings = {} - - # a dictionary which stores the string representation for ClauseElements - # processed by this compiler, which are to be used in the FROM clause - # of a select. items are often placed in "froms" as well as "strings" - # and sometimes with different representations. - self.froms = {} - - # slightly hacky. maps FROM clauses to WHERE clauses, and used in select - # generation to modify the WHERE clause of the select. currently a hack - # used by the oracle module. - self.wheres = {} - # when the compiler visits a SELECT statement, the clause object is appended # to this stack. various visit operations will check this stack to determine # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only @@ -137,12 +160,6 @@ class ANSICompiler(sql.Compiled): # for aliases self.generated_ids = {} - # True if this compiled represents an INSERT - self.isinsert = False - - # True if this compiled represents an UPDATE - self.isupdate = False - # default formatting style for bind parameters self.bindtemplate = ":%s" @@ -158,64 +175,76 @@ class ANSICompiler(sql.Compiled): # an ANSIIdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer - + + # a dictionary containing attributes about all select() + # elements located within the clause, regarding which are subqueries, which are + # selected from, and which elements should be correlated to an enclosing select. + # used mostly to determine the list of FROM elements for each select statement, as well + # as some dialect-specific rules regarding subqueries. + self.correlate_state = {} + # for UPDATE and INSERT statements, a set of columns whos values are being set # from a SQL expression (i.e., not one of the bind parameter values). if present, # default-value logic in the Dialect knows not to fire off column defaults # and also knows postfetching will be needed to get the values represented by these # parameters. self.inline_params = None - + def after_compile(self): # this re will search for params like :param # it has a negative lookbehind for an extra ':' so that it doesnt match # postgres '::text' tokens - match = re.compile(r'(?<!:):([\w_]+)', re.UNICODE) + text = self.string + if ':' not in text: + return + if self.paramstyle=='pyformat': - self.strings[self.statement] = match.sub(lambda m:'%(' + m.group(1) +')s', self.strings[self.statement]) + text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text) elif self.positional: - params = match.finditer(self.strings[self.statement]) + params = BIND_PARAMS.finditer(text) for p in params: self.positiontup.append(p.group(1)) if self.paramstyle=='qmark': - self.strings[self.statement] = match.sub('?', self.strings[self.statement]) + text = BIND_PARAMS.sub('?', text) elif self.paramstyle=='format': - self.strings[self.statement] = match.sub('%s', self.strings[self.statement]) + text = BIND_PARAMS.sub('%s', text) elif self.paramstyle=='numeric': i = [0] def getnum(x): i[0] += 1 return str(i[0]) - self.strings[self.statement] = match.sub(getnum, self.strings[self.statement]) - - def get_from_text(self, obj): - return self.froms.get(obj, None) - - def get_str(self, obj): - return self.strings[obj] - + text = BIND_PARAMS.sub(getnum, text) + # un-escape any \:params + text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text) + self.string = text + + def compile(self): + self.string = self.process(self.statement) + self.after_compile() + + def process(self, obj, **kwargs): + return self.traverse_single(obj, **kwargs) + + def is_subquery(self, select): + return self.correlate_state[select].get('is_subquery', False) + def get_whereclause(self, obj): - return self.wheres.get(obj, None) + """given a FROM clause, return an additional WHERE condition that should be + applied to a SELECT. + + Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN + constructs in non-ansi mode. + """ + + return None def construct_params(self, params): - """Return a structure of bind parameters for this compiled object. - - This includes bind parameters that might be compiled in via - the `values` argument of an ``Insert`` or ``Update`` statement - object, and also the given `**params`. The keys inside of - `**params` can be any key that matches the - ``BindParameterClause`` objects compiled within this object. - - The output is dependent on the paramstyle of the DBAPI being - used; if a named style, the return result will be a dictionary - with keynames matching the compiled statement. If a - positional style, the output will be a list, with an iterator - that will return parameter values in an order corresponding to - the bind positions in the compiled statement. - - For an executemany style of call, this method should be called - for each element in the list of parameter groups that will - ultimately be executed. + """Return a sql.ClauseParameters object. + + Combines the given bind parameter dictionary (string keys to object values) + with the _BindParamClause objects stored within this Compiled object + to produce a ClauseParameters structure, representing the bind arguments + for a single statement execution, or one element of an executemany execution. """ if self.parameters is not None: @@ -225,7 +254,7 @@ class ANSICompiler(sql.Compiled): bindparams.update(params) d = sql.ClauseParameters(self.dialect, self.positiontup) for b in self.binds.values(): - name = self.bind_names.get(b, b.key) + name = self.bind_names[b] d.set_parameter(b, b.value, name) for key, value in bindparams.iteritems(): @@ -233,7 +262,7 @@ class ANSICompiler(sql.Compiled): b = self.binds[key] except KeyError: continue - name = self.bind_names.get(b, b.key) + name = self.bind_names[b] d.set_parameter(b, value, name) return d @@ -246,8 +275,8 @@ class ANSICompiler(sql.Compiled): return "" - def visit_grouping(self, grouping): - self.strings[grouping] = "(" + self.strings[grouping.elem] + ")" + def visit_grouping(self, grouping, **kwargs): + return "(" + self.process(grouping.elem) + ")" def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) @@ -256,9 +285,10 @@ class ANSICompiler(sql.Compiled): self.typemap.setdefault(labelname.lower(), label.obj.type) if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname - self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname) + self.column_labels[label.name] = labelname + return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)]) - def visit_column(self, column): + def visit_column(self, column, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily # want to truncate a column identifier, if its mapped to the name of a # physical column. but thats very hard to identify at this point, and @@ -269,107 +299,110 @@ class ANSICompiler(sql.Compiled): else: name = column.name + if len(self.select_stack): + # if we are within a visit to a Select, set up the "typemap" + # for this column which is used to translate result set values + self.typemap.setdefault(name.lower(), column.type) + self.column_labels.setdefault(column._label, name.lower()) + if column.table is None or not column.table.named_with_column(): - self.strings[column] = self.preparer.format_column(column, name=name) + return self.preparer.format_column(column, name=name) else: if column.table.oid_column is column: n = self.dialect.oid_column_name(column) if n is not None: - self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n) + return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n) elif len(column.table.primary_key) != 0: pk = list(column.table.primary_key)[0] pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) - self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname) + return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name)) else: - self.strings[column] = None + return None else: - self.strings[column] = self.preparer.format_column_with_table(column, column_name=name) + return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name)) - if len(self.select_stack): - # if we are within a visit to a Select, set up the "typemap" - # for this column which is used to translate result set values - self.typemap.setdefault(name.lower(), column.type) - self.column_labels.setdefault(column._label, name.lower()) - def visit_fromclause(self, fromclause): - self.froms[fromclause] = fromclause.name + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name - def visit_index(self, index): - self.strings[index] = index.name + def visit_index(self, index, **kwargs): + return index.name - def visit_typeclause(self, typeclause): - self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec() + def visit_typeclause(self, typeclause, **kwargs): + return typeclause.type.dialect_impl(self.dialect).get_col_spec() - def visit_textclause(self, textclause): - self.strings[textclause] = textclause.text - self.froms[textclause] = textclause.text + def visit_textclause(self, textclause, **kwargs): + for bind in textclause.bindparams.values(): + self.process(bind) if textclause.typemap is not None: self.typemap.update(textclause.typemap) + return textclause.text - def visit_null(self, null): - self.strings[null] = 'NULL' + def visit_null(self, null, **kwargs): + return 'NULL' - def visit_clauselist(self, list): - sep = list.operator - if sep == ',': - sep = ', ' - elif sep is None or sep == " ": + def visit_clauselist(self, clauselist, **kwargs): + sep = clauselist.operator + if sep is None: sep = " " + elif sep == sql.ColumnOperators.comma_op: + sep = ', ' else: - sep = " " + sep + " " - self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep) + sep = " " + self.operator_string(clauselist.operator) + " " + return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - def visit_calculatedclause(self, clause): - self.strings[clause] = self.get_str(clause.clause_expr) + def visit_calculatedclause(self, clause, **kwargs): + return self.process(clause.clause_expr) - def visit_cast(self, cast): + def visit_cast(self, cast, **kwargs): if len(self.select_stack): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) - self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause]) + return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) - def visit_function(self, func): + def visit_function(self, func, **kwargs): if len(self.select_stack): self.typemap.setdefault(func.name, func.type) if not self.apply_function_parens(func): - self.strings[func] = ".".join(func.packagenames + [func.name]) - self.froms[func] = self.strings[func] + return ".".join(func.packagenames + [func.name]) else: - self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr) - self.froms[func] = self.strings[func] + return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) - def visit_compound_select(self, cs): - text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ") - group_by = self.get_str(cs.group_by_clause) + def visit_compound_select(self, cs, asfrom=False, **kwargs): + text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ") + group_by = self.process(cs._group_by_clause) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs) - text += self.visit_select_postclauses(cs) - self.strings[cs] = text - self.froms[cs] = "(" + text + ")" + text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" + + if asfrom: + return "(" + text + ")" + else: + return text - def visit_unary(self, unary): - s = self.get_str(unary.element) + def visit_unary(self, unary, **kwargs): + s = self.process(unary.element) if unary.operator: - s = unary.operator + " " + s + s = self.operator_string(unary.operator) + " " + s if unary.modifier: s = s + " " + unary.modifier - self.strings[unary] = s + return s - def visit_binary(self, binary): - result = self.get_str(binary.left) - if binary.operator is not None: - result += " " + self.binary_operator_string(binary) - result += " " + self.get_str(binary.right) - self.strings[binary] = result - - def binary_operator_string(self, binary): - return binary.operator + def visit_binary(self, binary, **kwargs): + op = self.operator_string(binary.operator) + if callable(op): + return op(self.process(binary.left), self.process(binary.right)) + else: + return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + def operator_string(self, operator): + return self.operators.get(operator, str(operator)) - def visit_bindparam(self, bindparam): + def visit_bindparam(self, bindparam, **kwargs): # apply truncation to the ultimate generated name if bindparam.shortname != bindparam.key: @@ -378,7 +411,6 @@ class ANSICompiler(sql.Compiled): if bindparam.unique: count = 1 key = bindparam.key - # redefine the generated name of the bind param in the case # that we have multiple conflicting bind parameters. while self.binds.setdefault(key, bindparam) is not bindparam: @@ -386,164 +418,167 @@ class ANSICompiler(sql.Compiled): key = bindparam.key + tag count += 1 bindparam.key = key - self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) + return self.bindparam_string(self._truncate_bindparam(bindparam)) else: existing = self.binds.get(bindparam.key) if existing is not None and existing.unique: raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) - self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] bind_name = bindparam.key - if len(bind_name) > self.dialect.max_identifier_length(): - bind_name = self._truncated_identifier("bindparam", bind_name) - # add to bind_names for translation - self.bind_names[bindparam] = bind_name + bind_name = self._truncated_identifier("bindparam", bind_name) + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + return bind_name def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] - if len(name) > self.dialect.max_identifier_length(): + + anonname = self._anonymize(name) + if len(anonname) > self.dialect.max_identifier_length(): counter = self.generated_ids.get(ident_class, 1) truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:] self.generated_ids[ident_class] = counter + 1 else: - truncname = name + truncname = anonname self.generated_ids[(ident_class, name)] = truncname return truncname + + def _anonymize(self, name): + def anon(match): + (ident, derived) = match.group(1,2) + if ('anonymous', ident) in self.generated_ids: + return self.generated_ids[('anonymous', ident)] + else: + anonymous_counter = self.generated_ids.get('anonymous', 1) + newname = derived + "_" + str(anonymous_counter) + self.generated_ids['anonymous'] = anonymous_counter + 1 + self.generated_ids[('anonymous', ident)] = newname + return newname + return re.sub(r'{ANON (-?\d+) (.*)}', anon, name) def bindparam_string(self, name): return self.bindtemplate % name - def visit_alias(self, alias): - self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias) - self.strings[alias] = self.get_str(alias.original) + def visit_alias(self, alias, asfrom=False, **kwargs): + if asfrom: + return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) + else: + return self.process(alias.original, **kwargs) - def visit_select(self, select): - # the actual list of columns to print in the SELECT column list. - inner_columns = util.OrderedDict() + def label_select_column(self, select, column): + """convert a column from a select's "columns" clause. + + given a select() and a column element from its inner_columns collection, return a + Label object if this column should be labeled in the columns clause. Otherwise, + return None and the column will be used as-is. + + The calling method will traverse the returned label to acquire its string + representation. + """ + + # SQLite doesnt like selecting from a subquery where the column + # names look like table.colname. so if column is in a "selected from" + # subquery, label it synoymously with its column name + if \ + self.correlate_state[select].get('is_selected_from', False) and \ + isinstance(column, sql._ColumnClause) and \ + not column.is_literal and \ + column.table is not None and \ + not isinstance(column.table, sql.Select): + return column.label(column.name) + else: + return None + + def visit_select(self, select, asfrom=False, **kwargs): + select._calculate_correlations(self.correlate_state) self.select_stack.append(select) - for c in select._raw_columns: - if hasattr(c, '_selectable'): - s = c._selectable() - else: - self.traverse(c) - inner_columns[self.get_str(c)] = c - continue - for co in s.columns: - if select.use_labels: - labelname = co._label - if labelname is not None: - l = co.label(labelname) - self.traverse(l) - inner_columns[labelname] = l - else: - self.traverse(co) - inner_columns[self.get_str(co)] = co - # TODO: figure this out, a ColumnClause with a select as a parent - # is different from any other kind of parent - elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select): - # SQLite doesnt like selecting from a subquery where the column - # names look like table.colname, so add a label synonomous with - # the column name - l = co.label(co.name) - self.traverse(l) - inner_columns[self.get_str(l.obj)] = l + + # the actual list of columns to print in the SELECT column list. + inner_columns = util.OrderedSet() + + froms = select._get_display_froms(self.correlate_state) + + for co in select.inner_columns: + if select.use_labels: + labelname = co._label + if labelname is not None: + l = co.label(labelname) + inner_columns.add(self.process(l)) else: self.traverse(co) - inner_columns[self.get_str(co)] = co + inner_columns.add(self.process(co)) + else: + l = self.label_select_column(select, co) + if l is not None: + inner_columns.add(self.process(l)) + else: + inner_columns.add(self.process(co)) + self.select_stack.pop(-1) - collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ') + collist = string.join(inner_columns.difference(util.Set([None])), ', ') - text = "SELECT " - text += self.visit_select_precolumns(select) + text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " + text += self.get_select_precolumns(select) text += collist - whereclause = select.whereclause - - froms = [] - for f in select.froms: - - if self.parameters is not None: - # TODO: whack this feature in 0.4 - # look at our own parameters, see if they - # are all present in the form of BindParamClauses. if - # not, then append to the above whereclause column conditions - # matching those keys - for c in f.columns: - if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key): - value = self.parameters[c.key] - else: - continue - clause = c==value - if whereclause is not None: - whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause])) - else: - whereclause = clause - self.traverse(whereclause) - - # special thingy used by oracle to redefine a join + whereclause = select._whereclause + + from_strings = [] + for f in froms: + from_strings.append(self.process(f, asfrom=True)) + w = self.get_whereclause(f) if w is not None: - # TODO: move this more into the oracle module if whereclause is not None: - whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w])) + whereclause = sql.and_(w, whereclause) else: whereclause = w - t = self.get_from_text(f) - if t is not None: - froms.append(t) - if len(froms): text += " \nFROM " - text += string.join(froms, ', ') + text += string.join(from_strings, ', ') else: text += self.default_from() if whereclause is not None: - t = self.get_str(whereclause) + t = self.process(whereclause) if t: text += " \nWHERE " + t - group_by = self.get_str(select.group_by_clause) + group_by = self.process(select._group_by_clause) if group_by: text += " GROUP BY " + group_by - if select.having is not None: - t = self.get_str(select.having) + if select._having is not None: + t = self.process(select._having) if t: text += " \nHAVING " + t text += self.order_by_clause(select) - text += self.visit_select_postclauses(select) + text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) - self.strings[select] = text - self.froms[select] = "(" + text + ")" + if asfrom: + return "(" + text + ")" + else: + return text - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list.""" - - return select.distinct and "DISTINCT " or "" - - def visit_select_postclauses(self, select): - """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses. - - Most DB syntaxes put ``LIMIT``/``OFFSET`` here. - """ - - return (select.limit or select.offset) and self.limit_clause(select) or "" + return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): - order_by = self.get_str(select.order_by_clause) + order_by = self.process(select._order_by_clause) if order_by: return " ORDER BY " + order_by else: @@ -557,175 +592,103 @@ class ANSICompiler(sql.Compiled): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT -1" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def visit_table(self, table): - self.froms[table] = self.preparer.format_table(table) - self.strings[table] = "" - - def visit_join(self, join): - righttext = self.get_from_text(join.right) - if join.right._group_parenthesized(): - righttext = "(" + righttext + ")" - if join.isouter: - self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext + - " ON " + self.get_str(join.onclause)) + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + return self.preparer.format_table(table) else: - self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext + - " ON " + self.get_str(join.onclause)) - self.strings[join] = self.froms[join] - - def visit_insert_column_default(self, column, default, parameters): - """Called when visiting an ``Insert`` statement. - - For each column in the table that contains a ``ColumnDefault`` - object, add a blank *placeholder* parameter so the ``Insert`` - gets compiled with this column's name in its column and - ``VALUES`` clauses. - """ - - parameters.setdefault(column.key, None) - - def visit_update_column_default(self, column, default, parameters): - """Called when visiting an ``Update`` statement. - - For each column in the table that contains a ``ColumnDefault`` - object as an onupdate, add a blank *placeholder* parameter so - the ``Update`` gets compiled with this column's name as one of - its ``SET`` clauses. - """ - - parameters.setdefault(column.key, None) - - def visit_insert_sequence(self, column, sequence, parameters): - """Called when visiting an ``Insert`` statement. - - This may be overridden compilers that support sequences to - place a blank *placeholder* parameter for each column in the - table that contains a Sequence object, so the Insert gets - compiled with this column's name in its column and ``VALUES`` - clauses. - """ - - pass - - def visit_insert_column(self, column, parameters): - """Called when visiting an ``Insert`` statement. - - This may be overridden by compilers who disallow NULL columns - being set in an ``Insert`` where there is a default value on - the column (i.e. postgres), to remove the column for which - there is a NULL insert from the parameter list. - """ + return "" - pass + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + def uses_sequences_for_inserts(self): + return False + def visit_insert(self, insert_stmt): - # scan the table's columns for defaults that have to be pre-set for an INSERT - # add these columns to the parameter list via visit_insert_XXX methods - default_params = {} + + # search for columns who will be required to have an explicit bound value. + # for inserts, this includes Python-side defaults, columns with sequences for dialects + # that support sequences, and primary key columns for dialects that explicitly insert + # pre-generated primary key values + required_cols = util.Set() class DefaultVisitor(schema.SchemaVisitor): - def visit_column(s, c): - self.visit_insert_column(c, default_params) + def visit_column(s, cd): + if c.primary_key and self.uses_sequences_for_inserts(): + required_cols.add(c) def visit_column_default(s, cd): - self.visit_insert_column_default(c, cd, default_params) + required_cols.add(c) def visit_sequence(s, seq): - self.visit_insert_sequence(c, seq, default_params) + if self.uses_sequences_for_inserts(): + required_cols.add(c) vis = DefaultVisitor() for c in insert_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): vis.traverse(c) self.isinsert = True - colparams = self._get_colparams(insert_stmt, default_params) + colparams = self._get_colparams(insert_stmt, required_cols) - self.inline_params = util.Set() - def create_param(col, p): - if isinstance(p, sql._BindParamClause): - self.binds[p.key] = p - if p.shortname is not None: - self.binds[p.shortname] = p - return self.bindparam_string(self._truncate_bindparam(p)) - else: - self.inline_params.add(col) - self.traverse(p) - if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): - return "(" + self.get_str(p) + ")" - else: - return self.get_str(p) - - text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + - " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")") - - self.strings[insert_stmt] = text + return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + + " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") def visit_update(self, update_stmt): - # scan the table's columns for onupdates that have to be pre-set for an UPDATE - # add these columns to the parameter list via visit_update_XXX methods - default_params = {} + update_stmt._calculate_correlations(self.correlate_state) + + # search for columns who will be required to have an explicit bound value. + # for updates, this includes Python-side "onupdate" defaults. + required_cols = util.Set() class OnUpdateVisitor(schema.SchemaVisitor): def visit_column_onupdate(s, cd): - self.visit_update_column_default(c, cd, default_params) + required_cols.add(c) vis = OnUpdateVisitor() for c in update_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): vis.traverse(c) self.isupdate = True - colparams = self._get_colparams(update_stmt, default_params) - - self.inline_params = util.Set() - def create_param(col, p): - if isinstance(p, sql._BindParamClause): - self.binds[p.key] = p - self.binds[p.shortname] = p - return self.bindparam_string(self._truncate_bindparam(p)) - else: - self.traverse(p) - self.inline_params.add(col) - if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): - return "(" + self.get_str(p) + ")" - else: - return self.get_str(p) - - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ') - - if update_stmt.whereclause: - text += " WHERE " + self.get_str(update_stmt.whereclause) + colparams = self._get_colparams(update_stmt, required_cols) - self.strings[update_stmt] = text + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) - def _get_colparams(self, stmt, default_params): - """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples. - - Each tuple will contain the ``Column`` and a ``ClauseElement`` - representing the value to be set (usually a ``_BindParamClause``, - but could also be other SQL expressions.) - - The list of tuples will determine the columns that are - actually rendered into the ``SET``/``VALUES`` clause of the - rendered ``UPDATE``/``INSERT`` statement. It will also - determine how to generate the list/dictionary of bind - parameters at execution time (i.e. ``get_params()``). + return text - This list takes into account the `values` keyword specified - to the statement, the parameters sent to this Compiled - instance, and the default bind parameter values corresponding - to the dialect's behavior for otherwise unspecified primary - key columns. + def _get_colparams(self, stmt, required_cols): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + This method may generate new bind params within this compiled + based on the given set of "required columns", which are required + to have a value set in the statement. """ + def create_bind_param(col, value): + bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.parameters is None and stmt.parameters is None: - return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns] + return [(c, create_bind_param(c, None)) for c in stmt.table.columns] + + def create_clause_param(col, value): + self.traverse(value) + self.inline_params.add(col) + return self.process(value) + + self.inline_params = util.Set() def to_col(key): if not isinstance(key, sql._ColumnClause): @@ -744,29 +707,43 @@ class ANSICompiler(sql.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(to_col(k), v) - for k, v in default_params.iteritems(): - parameters.setdefault(to_col(k), v) + for col in required_cols: + parameters.setdefault(col, None) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: - if parameters.has_key(c): + if c in parameters: value = parameters[c] if sql._is_literal(value): - value = sql.bindparam(c.key, value, type=c.type, unique=True) + value = create_bind_param(c, value) + else: + value = create_clause_param(c, value) values.append((c, value)) + return values def visit_delete(self, delete_stmt): + delete_stmt._calculate_correlations(self.correlate_state) + text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) - if delete_stmt.whereclause: - text += " WHERE " + self.get_str(delete_stmt.whereclause) + if delete_stmt._whereclause: + text += " WHERE " + self.process(delete_stmt._whereclause) - self.strings[delete_stmt] = text + return text + + def visit_savepoint(self, savepoint_stmt): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + + def visit_release_savepoint(self, savepoint_stmt): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + def __str__(self): - return self.get_str(self.statement) + return self.string class ANSISchemaBase(engine.SchemaIterator): def find_alterables(self, tables): @@ -795,7 +772,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: - table.accept_visitor(self) + self.traverse_single(table) if self.dialect.supports_alter(): for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -803,9 +780,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_visitor(self) - #if column.onupdate is not None: - # column.onupdate.accept_visitor(visitor) + self.traverse_single(column.default) self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") @@ -820,20 +795,20 @@ class ANSISchemaGenerator(ANSISchemaBase): if column.primary_key: first_pk = True for constraint in column.constraints: - constraint.accept_visitor(self) + self.traverse_single(constraint) # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if len(table.primary_key): - table.primary_key.accept_visitor(self) + self.traverse_single(table.primary_key) for constraint in [c for c in table.constraints if c is not table.primary_key]: - constraint.accept_visitor(self) + self.traverse_single(constraint) self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): for index in table.indexes: - index.accept_visitor(self) + self.traverse_single(index) def post_create_table(self, table): return '' @@ -870,7 +845,7 @@ class ANSISchemaGenerator(ANSISchemaBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") - self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and self.dialect.supports_alter(): @@ -889,9 +864,9 @@ class ANSISchemaGenerator(ANSISchemaBase): self.append("CONSTRAINT %s " % preparer.format_constraint(constraint)) self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - string.join([preparer.format_column(f.parent) for f in constraint.elements], ', '), + ', '.join([preparer.format_column(f.parent) for f in constraint.elements]), preparer.format_table(list(constraint.elements)[0].column.table), - string.join([preparer.format_column(f.column) for f in constraint.elements], ', ') + ', '.join([preparer.format_column(f.column) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -903,17 +878,17 @@ class ANSISchemaGenerator(ANSISchemaBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint]))) def visit_column(self, column): pass def visit_index(self, index): - preparer = self.preparer - self.append('CREATE ') + preparer = self.preparer + self.append("CREATE ") if index.unique: - self.append('UNIQUE ') - self.append('INDEX %s ON %s (%s)' \ + self.append("UNIQUE ") + self.append("INDEX %s ON %s (%s)" \ % (preparer.format_index(index), preparer.format_table(index.table), string.join([preparer.format_column(c) for c in index.columns], ', '))) @@ -933,7 +908,7 @@ class ANSISchemaDropper(ANSISchemaBase): for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: - table.accept_visitor(self) + self.traverse_single(table) def visit_index(self, index): self.append("\nDROP INDEX " + self.preparer.format_index(index)) @@ -948,7 +923,7 @@ class ANSISchemaDropper(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_visitor(self) + self.traverse_single(column.default) self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() @@ -1048,17 +1023,17 @@ class ANSIIdentifierPreparer(object): def should_quote(self, object): return object.quote or self._requires_quotes(object.name, object.case_sensitive) - def is_natural_case(self, object): - return object.quote or self._requires_quotes(object.name, object.case_sensitive) - def format_sequence(self, sequence): return self.__generic_obj_format(sequence, sequence.name) def format_label(self, label, name=None): return self.__generic_obj_format(label, name or label.name) - def format_alias(self, alias): - return self.__generic_obj_format(alias, alias.name) + def format_alias(self, alias, name=None): + return self.__generic_obj_format(alias, name or alias.name) + + def format_savepoint(self, savepoint): + return self.__generic_obj_format(savepoint, savepoint) def format_constraint(self, constraint): return self.__generic_obj_format(constraint, constraint.name) @@ -1076,25 +1051,25 @@ class ANSIIdentifierPreparer(object): result = self.__generic_obj_format(table, table.schema) + "." + result return result - def format_column(self, column, use_table=False, name=None): + def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name) + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name) else: return self.__generic_obj_format(column, name) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: - return self.format_table(column.table, use_schema=False) + "." + name + return self.format_table(column.table, use_schema=False, name=table_name) + "." + name else: return name - def format_column_with_table(self, column, column_name=None): + def format_column_with_table(self, column, column_name=None, table_name=None): """Prepare a quoted column name with table name.""" - return self.format_column(column, use_table=True, name=column_name) + return self.format_column(column, use_table=True, name=column_name, table_name=table_name) dialect = ANSIDialect diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index a02781c84..07f07644f 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -5,15 +5,11 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, types +import warnings -from sqlalchemy import util +from sqlalchemy import util, sql, schema, ansisql, exceptions import sqlalchemy.engine.default as default -import sqlalchemy.sql as sql -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions _initialized_kb = False @@ -176,7 +172,7 @@ class FBDialect(ansisql.ANSIDialect): else: return False - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): #TODO: map these better column_func = { 14 : lambda r: sqltypes.String(r['FLEN']), # TEXT @@ -254,11 +250,20 @@ class FBDialect(ansisql.ANSIDialect): while row: name = row['FNAME'] - args = [lower_if_possible(name)] + python_name = lower_if_possible(name) + if include_columns and python_name not in include_columns: + continue + args = [python_name] kw = {} # get the data types and lengths - args.append(column_func[row['FTYPE']](row)) + coltype = column_func.get(row['FTYPE'], None) + if coltype is None: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (str(row['FTYPE']), name))) + coltype = sqltypes.NULLTYPE + else: + coltype = coltype(row) + args.append(coltype) # is it a primary key? kw['primary_key'] = name in pkfields @@ -301,39 +306,39 @@ class FBDialect(ansisql.ANSIDialect): class FBCompiler(ansisql.ANSICompiler): """Firebird specific idiosincrasies""" - def visit_alias(self, alias): + def visit_alias(self, alias, asfrom=False, **kwargs): # Override to not use the AS keyword which FB 1.5 does not like - self.froms[alias] = self.get_from_text(alias.original) + " " + self.preparer.format_alias(alias) - self.strings[alias] = self.get_str(alias.original) + if asfrom: + return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias) + else: + return self.process(alias.original, asfrom=True) def visit_function(self, func): if len(func.clauses): - super(FBCompiler, self).visit_function(func) + return super(FBCompiler, self).visit_function(func) else: - self.strings[func] = func.name + return func.name - def visit_insert_column(self, column, parameters): - # all column primary key inserts must be explicitly present - if column.primary_key: - parameters[column.key] = None + def uses_sequences_for_inserts(self): + return True - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list Firebird puts the limit and offset right after the ``SELECT``... """ result = "" - if select.limit: - result += " FIRST %d " % select.limit - if select.offset: - result +=" SKIP %d " % select.offset - if select.distinct: + if select._limit: + result += " FIRST %d " % select._limit + if select._offset: + result +=" SKIP %d " % select._offset + if select._distinct: result += " DISTINCT " return result def limit_clause(self, select): - """Already taken care of in the `visit_select_precolumns` method.""" + """Already taken care of in the `get_select_precolumns` method.""" return "" @@ -364,7 +369,7 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper): class FBDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.connection) + c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection) return self.connection.execute_compiled(c).scalar() def visit_sequence(self, seq): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 81c44dcaa..93f47de15 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -1,4 +1,6 @@ -from sqlalchemy import sql, schema, exceptions, select, MetaData, Table, Column, String, Integer +import sqlalchemy.sql as sql +import sqlalchemy.exceptions as exceptions +from sqlalchemy import select, MetaData, Table, Column, String, Integer from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint ischema = MetaData() @@ -96,8 +98,7 @@ class ISchema(object): return self.cache[name] -def reflecttable(connection, table, ischema_names): - +def reflecttable(connection, table, include_columns, ischema_names): key_constraints = pg_key_constraints if table.schema is not None: @@ -128,7 +129,9 @@ def reflecttable(connection, table, ischema_names): row[columns.c.numeric_scale], row[columns.c.column_default] ) - + if include_columns and name not in include_columns: + continue + args = [] for a in (charlen, numericprec, numericscale): if a is not None: @@ -139,7 +142,7 @@ def reflecttable(connection, table, ischema_names): colargs= [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(Column(name, coltype, nullable=nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 2fb508280..f3a6cf60e 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -5,20 +5,11 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import datetime, warnings -import sys, StringIO, string , random -import datetime -from decimal import Decimal - -import sqlalchemy.util as util -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine +from sqlalchemy import sql, schema, ansisql, exceptions, pool import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions -import sqlalchemy.pool as pool # for offset @@ -128,7 +119,7 @@ class InfoBoolean(sqltypes.Boolean): elif value is None: return None else: - return value and True or False + return value and True or False colspecs = { @@ -262,7 +253,7 @@ class InfoDialect(ansisql.ANSIDialect): cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() ) return bool( cursor.fetchone() is not None ) - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() ) rows = c.fetchall() if not rows : @@ -289,6 +280,10 @@ class InfoDialect(ansisql.ANSIDialect): raise exceptions.NoSuchTableError(table.name) for name , colattr , collength , default , colno in rows: + name = name.lower() + if include_columns and name not in include_columns: + continue + # in 7.31, coltype = 0x000 # ^^-- column type # ^-- 1 not null , 0 null @@ -306,14 +301,16 @@ class InfoDialect(ansisql.ANSIDialect): scale = 0 coltype = InfoNumeric(precision, scale) else: - coltype = ischema_names.get(coltype) + try: + coltype = ischema_names[coltype] + except KeyError: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name))) + coltype = sqltypes.NULLTYPE colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - name = name.lower() - table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs)) # FK @@ -372,20 +369,20 @@ class InfoCompiler(ansisql.ANSICompiler): def default_from(self): return " from systables where tabname = 'systables' " - def visit_select_precolumns( self , select ): - s = select.distinct and "DISTINCT " or "" + def get_select_precolumns( self , select ): + s = select._distinct and "DISTINCT " or "" # only has limit - if select.limit: - off = select.offset or 0 - s += " FIRST %s " % ( select.limit + off ) + if select._limit: + off = select._offset or 0 + s += " FIRST %s " % ( select._limit + off ) else: s += "" return s def visit_select(self, select): - if select.offset: - self.offset = select.offset - self.limit = select.limit or 0 + if select._offset: + self.offset = select._offset + self.limit = select._limit or 0 # the column in order by clause must in select too def __label( c ): @@ -393,13 +390,14 @@ class InfoCompiler(ansisql.ANSICompiler): return c._label.lower() except: return '' - + + # TODO: dont modify the original select, generate a new one a = [ __label(c) for c in select._raw_columns ] for c in select.order_by_clause.clauses: if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid': select.append_column( c ) - ansisql.ANSICompiler.visit_select(self, select) + return ansisql.ANSICompiler.visit_select(self, select) def limit_clause(self, select): return "" @@ -414,23 +412,20 @@ class InfoCompiler(ansisql.ANSICompiler): def visit_function( self , func ): if func.name.lower() == 'current_date': - self.strings[func] = "today" + return "today" elif func.name.lower() == 'current_time': - self.strings[func] = "CURRENT HOUR TO SECOND" + return "CURRENT HOUR TO SECOND" elif func.name.lower() in ( 'current_timestamp' , 'now' ): - self.strings[func] = "CURRENT YEAR TO SECOND" + return "CURRENT YEAR TO SECOND" else: - ansisql.ANSICompiler.visit_function( self , func ) + return ansisql.ANSICompiler.visit_function( self , func ) def visit_clauselist(self, list): try: li = [ c for c in list.clauses if c.name != 'oid' ] except: li = [ c for c in list.clauses ] - if list.parens: - self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in li] if s is not None ], ', ') + ")" - else: - self.strings[list] = string.join([s for s in [self.get_str(c) for c in li] if s is not None], ', ') + return ', '.join([s for s in [self.process(c) for c in li] if s is not None]) class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, first_pk=False): diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index ba1c0fd9d..206291404 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -25,7 +25,7 @@ * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT`` -* ``select.limit`` implemented as ``SELECT TOP n`` +* ``select._limit`` implemented as ``SELECT TOP n`` Known issues / TODO: @@ -39,16 +39,11 @@ Known issues / TODO: """ -import sys, StringIO, string, types, re, datetime, random +import datetime, random, warnings -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +from sqlalchemy import sql, schema, ansisql, exceptions import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions - +from sqlalchemy.engine import default class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): @@ -500,7 +495,7 @@ class MSSQLDialect(ansisql.ANSIDialect): row = c.fetchone() return row is not None - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): import sqlalchemy.databases.information_schema as ischema # Get base columns @@ -532,16 +527,22 @@ class MSSQLDialect(ansisql.ANSIDialect): row[columns.c.numeric_scale], row[columns.c.column_default] ) + if include_columns and name not in include_columns: + continue args = [] for a in (charlen, numericprec, numericscale): if a is not None: args.append(a) - coltype = self.ischema_names[type] + coltype = self.ischema_names.get(type, None) if coltype == MSString and charlen == -1: coltype = MSText() else: - if coltype == MSNVarchar and charlen == -1: + if coltype is None: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name))) + coltype = sqltypes.NULLTYPE + + elif coltype == MSNVarchar and charlen == -1: charlen = None coltype = coltype(*args) colargs= [] @@ -812,12 +813,12 @@ class MSSQLCompiler(ansisql.ANSICompiler): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) self.tablealiases = {} - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ - s = select.distinct and "DISTINCT " or "" - if select.limit: - s += "TOP %s " % (select.limit,) - if select.offset: + s = select._distinct and "DISTINCT " or "" + if select._limit: + s += "TOP %s " % (select._limit,) + if select._offset: raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s @@ -825,49 +826,50 @@ class MSSQLCompiler(ansisql.ANSICompiler): # Limit in mssql is after the select keyword return "" - def visit_table(self, table): + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if not self.tablealiases.has_key(table): + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + # alias schema-qualified tables - if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table): - alias = table.alias() - self.tablealiases[table] = alias - self.traverse(alias) - self.froms[('alias', table)] = self.froms[table] - for c in alias.c: - self.traverse(c) - self.traverse(alias.oid_column) - self.tablealiases[alias] = self.froms[table] - self.froms[table] = self.froms[alias] + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) else: - super(MSSQLCompiler, self).visit_table(table) + return super(MSSQLCompiler, self).visit_table(table, **kwargs) - def visit_alias(self, alias): + def visit_alias(self, alias, **kwargs): # translate for schema-qualified table aliases - if self.froms.has_key(('alias', alias.original)): - self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name - self.strings[alias] = "" - else: - super(MSSQLCompiler, self).visit_alias(alias) + self.tablealiases[alias.original] = alias + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) def visit_column(self, column): - # translate for schema-qualified table aliases - super(MSSQLCompiler, self).visit_column(column) - if column.table is not None and self.tablealiases.has_key(column.table): - self.strings[column] = \ - self.strings[self.tablealiases[column.table].corresponding_column(column)] + if column.table is not None: + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + return self.process(t.corresponding_column(column)) + return super(MSSQLCompiler, self).visit_column(column) def visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=': - binary.left, binary.right = binary.right, binary.left - super(MSSQLCompiler, self).visit_binary(binary) - - def visit_select(self, select): - # label function calls, so they return a name in cursor.description - for i,c in enumerate(select._raw_columns): - if isinstance(c, sql._Function): - select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:]) + if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq: + return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(MSSQLCompiler, self).visit_binary(binary) - super(MSSQLCompiler, self).visit_select(select) + def label_select_column(self, select, column): + if isinstance(column, sql._Function): + return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + else: + return super(MSSQLCompiler, self).label_select_column(select, column) function_rewrites = {'current_date': 'getdate', 'length': 'len', @@ -881,10 +883,10 @@ class MSSQLCompiler(ansisql.ANSICompiler): return '' def order_by_clause(self, select): - order_by = self.get_str(select.order_by_clause) + order_by = self.process(select._order_by_clause) # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not select.is_subquery or select.limit): + if order_by and (not self.is_subquery(select) or select._limit): return " ORDER BY " + order_by else: return "" @@ -916,10 +918,12 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote_identifier(index.name))) + self.preparer.quote_identifier(index.table.name), + self.preparer.quote_identifier(index.name) + )) self.execute() + class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): # TODO: does ms-sql have standalone sequences ? pass @@ -940,4 +944,3 @@ dialect = MSSQLDialect - diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index bac0e5e12..26800e32b 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, datetime, inspect, warnings, weakref +import re, datetime, inspect, warnings, weakref, operator from sqlalchemy import sql, schema, ansisql from sqlalchemy.engine import default @@ -12,13 +12,13 @@ import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import sqlalchemy.util as util from array import array as _array +from decimal import Decimal try: from threading import Lock except ImportError: from dummy_threading import Lock - RESERVED_WORDS = util.Set( ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', @@ -60,7 +60,6 @@ RESERVED_WORDS = util.Set( 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', 'read_only', 'read_write', # 5.1 ]) - _per_connection_mutex = Lock() class _NumericType(object): @@ -137,7 +136,7 @@ class _StringType(object): class MSNumeric(sqltypes.Numeric, _NumericType): """MySQL NUMERIC type""" - def __init__(self, precision = 10, length = 2, **kw): + def __init__(self, precision = 10, length = 2, asdecimal=True, **kw): """Construct a NUMERIC. precision @@ -157,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType): """ _NumericType.__init__(self, **kw) - sqltypes.Numeric.__init__(self, precision, length) - + sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal) + def get_col_spec(self): if self.precision is None: return self._extend("NUMERIC") else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class MSDecimal(MSNumeric): """MySQL DECIMAL type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DECIMAL. precision @@ -187,7 +195,7 @@ class MSDecimal(MSNumeric): underlying database API, which continue to be numeric. """ - super(MSDecimal, self).__init__(precision, length, **kw) + super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is None: @@ -200,7 +208,7 @@ class MSDecimal(MSNumeric): class MSDouble(MSNumeric): """MySQL DOUBLE type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DOUBLE. precision @@ -222,7 +230,7 @@ class MSDouble(MSNumeric): if ((precision is None and length is not None) or (precision is not None and length is None)): raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.") - super(MSDouble, self).__init__(precision, length, **kw) + super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is not None and self.length is not None: @@ -235,7 +243,7 @@ class MSDouble(MSNumeric): class MSFloat(sqltypes.Float, _NumericType): """MySQL FLOAT type""" - def __init__(self, precision=10, length=None, **kw): + def __init__(self, precision=10, length=None, asdecimal=False, **kw): """Construct a FLOAT. precision @@ -257,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType): if length is not None: self.length=length _NumericType.__init__(self, **kw) - sqltypes.Float.__init__(self, precision) + sqltypes.Float.__init__(self, precision, asdecimal=asdecimal) def get_col_spec(self): if hasattr(self, 'length') and self.length is not None: @@ -267,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType): else: return self._extend("FLOAT") + def convert_bind_param(self, value, dialect): + return value + + class MSInteger(sqltypes.Integer, _NumericType): """MySQL INTEGER type""" @@ -955,7 +967,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext): if self.compiled.isinsert: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - + + def is_select(self): + return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None + class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) @@ -1044,6 +1059,27 @@ class MySQLDialect(ansisql.ANSIDialect): except: pass + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("XA RECOVER")) + return [row['data'][0:row['gtrid_length']] for row in resultset] + def is_disconnect(self, e): return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055) @@ -1088,7 +1124,7 @@ class MySQLDialect(ansisql.ANSIDialect): version.append(n) return tuple(version) - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" decode_from = self._detect_charset(connection) @@ -1111,6 +1147,9 @@ class MySQLDialect(ansisql.ANSIDialect): # leave column names as unicode name = name.decode(decode_from) + + if include_columns and name not in include_columns: + continue match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type) col_type = match.group(1) @@ -1118,7 +1157,11 @@ class MySQLDialect(ansisql.ANSIDialect): extra_1 = match.group(3) extra_2 = match.group(4) - coltype = ischema_names.get(col_type, MSString) + try: + coltype = ischema_names[col_type] + except KeyError: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name))) + coltype = sqltypes.NULLTYPE kw = {} if extra_1 is not None: @@ -1156,7 +1199,6 @@ class MySQLDialect(ansisql.ANSIDialect): if not row: raise exceptions.NoSuchTableError(table.fullname) desc = row[1].strip() - row.close() tabletype = '' lastparen = re.search(r'\)[^\)]*\Z', desc) @@ -1223,7 +1265,6 @@ class MySQLDialect(ansisql.ANSIDialect): cs = True else: cs = row[1] in ('0', 'OFF' 'off') - row.close() cache['lower_case_table_names'] = cs self.per_connection[raw_connection] = cache return cache.get('lower_case_table_names') @@ -1266,14 +1307,21 @@ class _MySQLPythonRowProxy(object): class MySQLCompiler(ansisql.ANSICompiler): - def visit_cast(self, cast): - + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y), + operator.mod : '%%' + } + ) + + def visit_cast(self, cast, **kwargs): if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): - return super(MySQLCompiler, self).visit_cast(cast) + return super(MySQLCompiler, self).visit_cast(cast, **kwargs) else: # so just skip the CAST altogether for now. # TODO: put whatever MySQL does for CAST here. - self.strings[cast] = self.strings[cast.clause] + return self.process(cast.clause) def for_update_clause(self, select): if select.for_update == 'read': @@ -1283,20 +1331,15 @@ class MySQLCompiler(ansisql.ANSICompiler): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: - # striaght from the MySQL docs, I kid you not + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + # straight from the MySQL docs, I kid you not text += " \n LIMIT 18446744073709551615" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def binary_operator_string(self, binary): - if binary.operator == '%': - return '%%' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 9d7d6a112..d3aa2e268 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -5,9 +5,9 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, re, warnings +import re, warnings, operator -from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging +from sqlalchemy import util, sql, schema, ansisql, exceptions, logging from sqlalchemy.engine import default, base import sqlalchemy.types as sqltypes @@ -88,8 +88,11 @@ class OracleText(sqltypes.TEXT): def convert_result_value(self, value, dialect): if value is None: return None - else: + elif hasattr(value, 'read'): + # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str return super(OracleText, self).convert_result_value(value.read(), dialect) + else: + return super(OracleText, self).convert_result_value(value, dialect) class OracleRaw(sqltypes.Binary): @@ -178,25 +181,31 @@ class OracleExecutionContext(default.DefaultExecutionContext): super(OracleExecutionContext, self).pre_exec() if self.dialect.auto_setinputsizes: self.set_input_sizes() + if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list): + for key in self.compiled_parameters: + (bindparam, name, value) = self.compiled_parameters.get_parameter(key) + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[name] = self.out_parameters[name] def get_result_proxy(self): + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None: + for k in self.out_parameters: + type = self.compiled_parameters.get_type(k) + self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect) + else: + for k in self.out_parameters: + self.out_parameters[k] = self.out_parameters[k].getvalue() + if self.cursor.description is not None: - if self.dialect.auto_convert_lobs and self.typemap is None: - typemap = {} - binary = False - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - binary = True - typemap[column[0].lower()] = OracleBinary() - self.typemap = typemap - if binary: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: return base.BufferedColumnResultProxy(self) - else: - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - return base.BufferedColumnResultProxy(self) return base.ResultProxy(self) @@ -208,11 +217,26 @@ class OracleDialect(ansisql.ANSIDialect): self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) self.auto_setinputsizes = auto_setinputsizes self.auto_convert_lobs = auto_convert_lobs + if self.dbapi is not None: self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)] else: self.ORACLE_BINARY_TYPES = [] + def dbapi_type_map(self): + if self.dbapi is None or not self.auto_convert_lobs: + return {} + else: + return { + self.dbapi.NUMBER: OracleInteger(), + self.dbapi.CLOB: OracleText(), + self.dbapi.BLOB: OracleBinary(), + self.dbapi.STRING: OracleString(), + self.dbapi.TIMESTAMP: OracleTimestamp(), + self.dbapi.BINARY: OracleRaw(), + datetime.datetime: OracleDate() + } + def dbapi(cls): import cx_Oracle return cx_Oracle @@ -251,7 +275,7 @@ class OracleDialect(ansisql.ANSIDialect): return 30 def oid_column_name(self, column): - if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select): + if not isinstance(column.table, (sql.TableClause, sql.Select)): return None else: return "rowid" @@ -341,7 +365,7 @@ class OracleDialect(ansisql.ANSIDialect): return name, owner, dblink raise - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer if not preparer.should_quote(table): name = table.name.upper() @@ -363,6 +387,13 @@ class OracleDialect(ansisql.ANSIDialect): #print "ROW:" , row (colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) + # if name comes back as all upper, assume its case folded + if (colname.upper() == colname): + colname = colname.lower() + + if include_columns and colname not in include_columns: + continue + # INTEGER if the scale is 0 and precision is null # NUMBER if the scale and precision are both null # NUMBER(9,2) if the precision is 9 and the scale is 2 @@ -382,16 +413,13 @@ class OracleDialect(ansisql.ANSIDialect): try: coltype = ischema_names[coltype] except KeyError: - raise exceptions.AssertionError("Can't get coltype for type '%s' on colname '%s'" % (coltype, colname)) + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname))) + coltype = sqltypes.NULLTYPE colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - # if name comes back as all upper, assume its case folded - if (colname.upper() == colname): - colname = colname.lower() - table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) if not len(table.columns): @@ -458,16 +486,27 @@ class OracleDialect(ansisql.ANSIDialect): OracleDialect.logger = logging.class_logger(OracleDialect) +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + def __init__(self, column): + self.column = column + class OracleCompiler(ansisql.ANSICompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False. """ + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : lambda x, y:"mod(%s, %s)" % (x, y) + } + ) + def __init__(self, *args, **kwargs): super(OracleCompiler, self).__init__(*args, **kwargs) - # we have to modify SELECT objects a little bit, so store state here - self._select_state = {} + self.__wheres = {} def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. @@ -480,49 +519,46 @@ class OracleCompiler(ansisql.ANSICompiler): def apply_function_parens(self, func): return len(func.clauses) > 0 - def visit_join(self, join): + def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return ansisql.ANSICompiler.visit_join(self, join) - - self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right) - where = self.wheres.get(join.left, None) + return ansisql.ANSICompiler.visit_join(self, join, **kwargs) + + (where, parentjoin) = self.__wheres.get(join, (None, None)) + + class VisitOn(sql.ClauseVisitor): + def visit_binary(s, binary): + if binary.operator == operator.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + if where is not None: - self.wheres[join] = sql.and_(where, join.onclause) + self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin) else: - self.wheres[join] = join.onclause -# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause) - self.strings[join] = self.froms[join] - - if join.isouter: - # if outer join, push on the right side table as the current "outertable" - self._outertable = join.right - - # now re-visit the onclause, which will be used as a where clause - # (the first visit occured via the Join object itself right before it called visit_join()) - self.traverse(join.onclause) - - self._outertable = None - - self.wheres[join].accept_visitor(self) + self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join) - def visit_insert_sequence(self, column, sequence, parameters): - """This is the `sequence` equivalent to ``ANSICompiler``'s - `visit_insert_column_default` which ensures that the column is - present in the generated column list. - """ - - parameters.setdefault(column.key, None) + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + + def get_whereclause(self, f): + if f in self.__wheres: + return self.__wheres[f][0] + else: + return None + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def uses_sequences_for_inserts(self): + return True - def visit_alias(self, alias): + def visit_alias(self, alias, asfrom=False, **kwargs): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name - self.strings[alias] = self.get_str(alias.original) - - def visit_column(self, column): - ansisql.ANSICompiler.visit_column(self, column) - if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable: - self.strings[column] = self.strings[column] + "(+)" + + if asfrom: + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name + else: + return self.process(alias.original, **kwargs) def visit_insert(self, insert): """``INSERT`` s are required to have the primary keys be explicitly present. @@ -539,76 +575,35 @@ class OracleCompiler(ansisql.ANSICompiler): def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass - if getattr(select, '_oracle_visit', False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_compound_select(self, select) - return - - if select.limit is not None or select.offset is not None: - select._oracle_visit = True - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] - if not orderby: - orderby = select.oid_column - self.traverse(orderby) - orderby = self.strings[orderby] - class SelectVisitor(sql.NoColumnVisitor): - def visit_select(self, select): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - SelectVisitor().traverse(select) - limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) - else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) - self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] - else: - ansisql.ANSICompiler.visit_compound_select(self, select) - - def visit_select(self, select): + def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. """ - # TODO: put a real copy-container on Select and copy, or somehow make this - # not modify the Select statement - if self._select_state.get((select, 'visit'), False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_select(self, select) - return - - if select.limit is not None or select.offset is not None: - self._select_state[(select, 'visit')] = True + if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None): # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] + orderby = self.process(select._order_by_clause) if not orderby: orderby = select.oid_column self.traverse(orderby) - orderby = self.strings[orderby] - if not hasattr(select, '_oracle_visit'): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - select._oracle_visit = True + orderby = self.process(orderby) + + oldselect = select + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) + select._oracle_visit = True + limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) + if select._offset is not None: + limitselect.append_whereclause("ora_rn>%d" % select._offset) + if select._limit is not None: + limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset)) else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) - self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] + limitselect.append_whereclause("ora_rn<=%d" % select._limit) + return self.process(limitselect) else: - ansisql.ANSICompiler.visit_select(self, select) + return ansisql.ANSICompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" @@ -619,12 +614,6 @@ class OracleCompiler(ansisql.ANSICompiler): else: return super(OracleCompiler, self).for_update_clause(select) - def visit_binary(self, binary): - if binary.operator == '%': - self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right))) - else: - return ansisql.ANSICompiler.visit_binary(self, binary) - class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): @@ -639,22 +628,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not self.dialect.has_sequence(self.connection, sequence.name): + if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class OracleSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if self.dialect.has_sequence(self.connection, sequence.name): + if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["DUAL"]).compile(engine=self.connection) - return self.connection.execute_compiled(c).scalar() + c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection) + return self.connection.execute(c).scalar() def visit_sequence(self, seq): - return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar() + return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar() dialect = OracleDialect diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d3726fc1f..b192c4778 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,12 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, string, types, re, random, warnings +import re, random, warnings, operator -from sqlalchemy import util, sql, schema, ansisql, exceptions +from sqlalchemy import sql, schema, ansisql, exceptions from sqlalchemy.engine import base, default import sqlalchemy.types as sqltypes from sqlalchemy.databases import information_schema as ischema +from decimal import Decimal try: import mx.DateTime.DateTime as mxDateTime @@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class PGFloat(sqltypes.Float): def get_col_spec(self): if not self.precision: @@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float): else: return "FLOAT(%(precision)s)" % {'precision': self.precision} + class PGInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -47,74 +58,15 @@ class PGBigInteger(PGInteger): def get_col_spec(self): return "BIGINT" -class PG2DateTime(sqltypes.DateTime): +class PGDateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" -class PG1DateTime(sqltypes.DateTime): - def convert_bind_param(self, value, dialect): - if value is not None: - if isinstance(value, datetime.datetime): - seconds = float(str(value.second) + "." - + str(value.microsecond)) - mx_datetime = mxDateTime(value.year, value.month, value.day, - value.hour, value.minute, - seconds) - return dialect.dbapi.TimestampFromMx(mx_datetime) - return dialect.dbapi.TimestampFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - if value is None: - return None - second_parts = str(value.second).split(".") - seconds = int(second_parts[0]) - microseconds = int(float(second_parts[1])) - return datetime.datetime(value.year, value.month, value.day, - value.hour, value.minute, seconds, - microseconds) - - def get_col_spec(self): - return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PG2Date(sqltypes.Date): - def get_col_spec(self): - return "DATE" - -class PG1Date(sqltypes.Date): - def convert_bind_param(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - # this one doesnt seem to work with the "emulation" mode - if value is not None: - return dialect.dbapi.DateFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - return value - +class PGDate(sqltypes.Date): def get_col_spec(self): return "DATE" -class PG2Time(sqltypes.Time): - def get_col_spec(self): - return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PG1Time(sqltypes.Time): - def convert_bind_param(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - # this one doesnt seem to work with the "emulation" mode - if value is not None: - return psycopg.TimeFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - return value - +class PGTime(sqltypes.Time): def get_col_spec(self): return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" @@ -142,28 +94,55 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" -pg2_colspecs = { +class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): + def __init__(self, item_type): + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + + def dialect_impl(self, dialect): + impl = self.__class__.__new__(self.__class__) + impl.__dict__.update(self.__dict__) + impl.item_type = self.item_type.dialect_impl(dialect) + return impl + def convert_bind_param(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_bind_param(item, dialect) + return [convert_item(item) for item in value] + def convert_result_value(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_result_value(item, dialect) + # Could specialcase when item_type.convert_result_value is the default identity func + return [convert_item(item) for item in value] + def get_col_spec(self): + return self.item_type.get_col_spec() + '[]' + +colspecs = { sqltypes.Integer : PGInteger, sqltypes.Smallinteger : PGSmallInteger, sqltypes.Numeric : PGNumeric, sqltypes.Float : PGFloat, - sqltypes.DateTime : PG2DateTime, - sqltypes.Date : PG2Date, - sqltypes.Time : PG2Time, + sqltypes.DateTime : PGDateTime, + sqltypes.Date : PGDate, + sqltypes.Time : PGTime, sqltypes.String : PGString, sqltypes.Binary : PGBinary, sqltypes.Boolean : PGBoolean, sqltypes.TEXT : PGText, sqltypes.CHAR: PGChar, } -pg1_colspecs = pg2_colspecs.copy() -pg1_colspecs.update({ - sqltypes.DateTime : PG1DateTime, - sqltypes.Date : PG1Date, - sqltypes.Time : PG1Time - }) - -pg2_ischema_names = { + +ischema_names = { 'integer' : PGInteger, 'bigint' : PGBigInteger, 'smallint' : PGSmallInteger, @@ -175,24 +154,17 @@ pg2_ischema_names = { 'real' : PGFloat, 'inet': PGInet, 'double precision' : PGFloat, - 'timestamp' : PG2DateTime, - 'timestamp with time zone' : PG2DateTime, - 'timestamp without time zone' : PG2DateTime, - 'time with time zone' : PG2Time, - 'time without time zone' : PG2Time, - 'date' : PG2Date, - 'time': PG2Time, + 'timestamp' : PGDateTime, + 'timestamp with time zone' : PGDateTime, + 'timestamp without time zone' : PGDateTime, + 'time with time zone' : PGTime, + 'time without time zone' : PGTime, + 'date' : PGDate, + 'time': PGTime, 'bytea' : PGBinary, 'boolean' : PGBoolean, 'interval':PGInterval, } -pg1_ischema_names = pg2_ischema_names.copy() -pg1_ischema_names.update({ - 'timestamp with time zone' : PG1DateTime, - 'timestamp without time zone' : PG1DateTime, - 'date' : PG1Date, - 'time' : PG1Time - }) def descriptor(): return {'name':'postgres', @@ -206,11 +178,11 @@ def descriptor(): class PGExecutionContext(default.DefaultExecutionContext): - def is_select(self): - return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I) - + def _is_server_side(self): + return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I) + def create_cursor(self): - if self.dialect.server_side_cursors and self.is_select(): + if self._is_server_side(): # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html ident = "c" + hex(random.randint(0, 65535))[2:] @@ -219,7 +191,7 @@ class PGExecutionContext(default.DefaultExecutionContext): return self.connection.connection.cursor() def get_result_proxy(self): - if self.dialect.server_side_cursors and self.is_select(): + if self._is_server_side(): return base.BufferedRowResultProxy(self) else: return base.ResultProxy(self) @@ -242,31 +214,18 @@ class PGDialect(ansisql.ANSIDialect): ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors - if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'): - self.version = 2 - else: - self.version = 1 self.use_information_schema = use_information_schema self.paramstyle = 'pyformat' def dbapi(cls): - try: - import psycopg2 as psycopg - except ImportError, e: - try: - import psycopg - except ImportError, e2: - raise e + import psycopg2 as psycopg return psycopg dbapi = classmethod(dbapi) def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) if opts.has_key('port'): - if self.version == 2: - opts['port'] = int(opts['port']) - else: - opts['port'] = str(opts['port']) + opts['port'] = int(opts['port']) opts.update(url.query) return ([], opts) @@ -278,10 +237,7 @@ class PGDialect(ansisql.ANSIDialect): return 63 def type_descriptor(self, typeobj): - if self.version == 2: - return sqltypes.adapt_type(typeobj, pg2_colspecs) - else: - return sqltypes.adapt_type(typeobj, pg1_colspecs) + return sqltypes.adapt_type(typeobj, colspecs) def compiler(self, statement, bindparams, **kwargs): return PGCompiler(self, statement, bindparams, **kwargs) @@ -292,8 +248,36 @@ class PGDialect(ansisql.ANSIDialect): def schemadropper(self, *args, **kwargs): return PGSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, connection, **kwargs): - return PGDefaultRunner(connection, **kwargs) + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions + # Must find out a way how to make the dbapi not open a transaction. + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) + return [row[0] for row in resultset] + + def defaultrunner(self, context, **kwargs): + return PGDefaultRunner(context, **kwargs) def preparer(self): return PGIdentifierPreparer(self) @@ -351,14 +335,9 @@ class PGDialect(ansisql.ANSIDialect): else: return False - def reflecttable(self, connection, table): - if self.version == 2: - ischema_names = pg2_ischema_names - else: - ischema_names = pg1_ischema_names - + def reflecttable(self, connection, table, include_columns): if self.use_information_schema: - ischema.reflecttable(connection, table, ischema_names) + ischema.reflecttable(connection, table, include_columns, ischema_names) else: preparer = self.identifier_preparer if table.schema is not None: @@ -387,7 +366,7 @@ class PGDialect(ansisql.ANSIDialect): ORDER BY a.attnum """ % schema_where_clause - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) + s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) c = connection.execute(s, table_name=table.name, schema=table.schema) rows = c.fetchall() @@ -398,9 +377,13 @@ class PGDialect(ansisql.ANSIDialect): domains = self._load_domains(connection) for name, format_type, default, notnull, attnum, table_oid in rows: + if include_columns and name not in include_columns: + continue + ## strip (30) from character varying(30) - attype = re.search('([^\(]+)', format_type).group(1) + attype = re.search('([^\([]+)', format_type).group(1) nullable = not notnull + is_array = format_type.endswith('[]') try: charlen = re.search('\(([\d,]+)\)', format_type).group(1) @@ -453,6 +436,8 @@ class PGDialect(ansisql.ANSIDialect): if coltype: coltype = coltype(*args, **kwargs) + if is_array: + coltype = PGArray(coltype) else: warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name))) coltype = sqltypes.NULLTYPE @@ -517,7 +502,6 @@ class PGDialect(ansisql.ANSIDialect): table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname)) def _load_domains(self, connection): - ## Load data types for domains: SQL_DOMAINS = """ SELECT t.typname as "name", @@ -554,49 +538,46 @@ class PGDialect(ansisql.ANSIDialect): - class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): - # all column primary key inserts must be explicitly present - if column.primary_key: - parameters[column.key] = None + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : '%%' + } + ) - def visit_insert_sequence(self, column, sequence, parameters): - """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures - that the column is present in the generated column list""" - parameters.setdefault(column.key, None) + def uses_sequences_for_inserts(self): + return True def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT ALL" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def visit_select_precolumns(self, select): - if select.distinct: - if type(select.distinct) == bool: + def get_select_precolumns(self, select): + if select._distinct: + if type(select._distinct) == bool: return "DISTINCT " - if type(select.distinct) == list: + if type(select._distinct) == list: dist_set = "DISTINCT ON (" - for col in select.distinct: + for col in select._distinct: dist_set += self.strings[col] + ", " dist_set = dist_set[:-2] + ") " return dist_set - return "DISTINCT ON (" + str(select.distinct) + ") " + return "DISTINCT ON (" + str(select._distinct) + ") " else: return "" - def binary_operator_string(self, binary): - if isinstance(binary.type, sqltypes.String) and binary.operator == '+': - return '||' - elif binary.operator == '%': - return '%%' + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) + return super(PGCompiler, self).for_update_clause(select) class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): @@ -617,13 +598,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)): + if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class PGSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)): + if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() @@ -632,7 +613,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): - return self.connection.execute_text("select %s" % column.default.arg).scalar() + return self.connection.execute("select %s" % column.default.arg).scalar() elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -641,7 +622,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.connection.execute_text(exc).scalar() + return self.connection.execute(exc).scalar() return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 816b1b76a..725ea23e2 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -5,9 +5,9 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, types, re +import re -from sqlalchemy import sql, engine, schema, ansisql, exceptions, pool, PassiveDefault +from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault import sqlalchemy.engine.default as default import sqlalchemy.types as sqltypes import datetime,time, warnings @@ -126,6 +126,7 @@ colspecs = { pragma_names = { 'INTEGER' : SLInteger, + 'INT' : SLInteger, 'SMALLINT' : SLSmallInteger, 'VARCHAR' : SLString, 'CHAR' : SLChar, @@ -150,8 +151,9 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): if self.compiled.isinsert: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - - super(SQLiteExecutionContext, self).post_exec() + + def is_select(self): + return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None class SQLiteDialect(ansisql.ANSIDialect): @@ -233,7 +235,7 @@ class SQLiteDialect(ansisql.ANSIDialect): return (row is not None) - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {}) found_table = False while True: @@ -244,6 +246,8 @@ class SQLiteDialect(ansisql.ANSIDialect): found_table = True (name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5]) name = re.sub(r'^\"|\"$', '', name) + if include_columns and name not in include_columns: + continue match = re.match(r'(\w+)(\(.*?\))?', type) if match: coltype = match.group(1) @@ -253,7 +257,12 @@ class SQLiteDialect(ansisql.ANSIDialect): args = '' #print "coltype: " + repr(coltype) + " args: " + repr(args) - coltype = pragma_names.get(coltype, SLString) + try: + coltype = pragma_names[coltype] + except KeyError: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name))) + coltype = sqltypes.NULLTYPE + if args is not None: args = re.findall(r'(\d+)', args) #print "args! " +repr(args) @@ -318,21 +327,21 @@ class SQLiteDialect(ansisql.ANSIDialect): class SQLiteCompiler(ansisql.ANSICompiler): def visit_cast(self, cast): if self.dialect.supports_cast: - super(SQLiteCompiler, self).visit_cast(cast) + return super(SQLiteCompiler, self).visit_cast(cast) else: if len(self.select_stack): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) - self.strings[cast] = self.strings[cast.clause] + return self.process(cast.clause) def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT -1" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) else: text += " OFFSET 0" return text @@ -341,12 +350,6 @@ class SQLiteCompiler(ansisql.ANSICompiler): # sqlite has no "FOR UPDATE" AFAICT return '' - def binary_operator_string(self, binary): - if isinstance(binary.type, sqltypes.String) and binary.operator == '+': - return '||' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) - class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 5a4865de0..50d03ea91 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -48,7 +48,6 @@ The package is represented among several individual modules, including: from sqlalchemy import databases from sqlalchemy.engine.base import * from sqlalchemy.engine import strategies -import re def engine_descriptors(): """Provide a listing of all the database implementations supported. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d0ca36515..fc4433a47 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -9,33 +9,10 @@ higher-level statement-construction, connection-management, execution and result contexts.""" from sqlalchemy import exceptions, sql, schema, util, types, logging -import StringIO, sys, re +import StringIO, sys, re, random -class ConnectionProvider(object): - """Define an interface that returns raw Connection objects (or compatible).""" - - def get_connection(self): - """Return a Connection or compatible object from a DBAPI which also contains a close() method. - - It is not defined what context this connection belongs to. It - may be newly connected, returned from a pool, part of some - other kind of context such as thread-local, or can be a fixed - member of this object. - """ - - raise NotImplementedError() - - def dispose(self): - """Release all resources corresponding to this ConnectionProvider. - - This includes any underlying connection pools. - """ - - raise NotImplementedError() - - -class Dialect(sql.AbstractDialect): +class Dialect(object): """Define the behavior of a specific database/DBAPI. Any aspect of metadata definition, SQL query generation, execution, @@ -70,11 +47,14 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def convert_compiled_params(self, parameters): - """Build DBAPI execute arguments from a [sqlalchemy.sql#ClauseParameters] instance. - - Returns an array or dictionary suitable to pass directly to this ``Dialect`` instance's DBAPI's - execute method. + def dbapi_type_map(self): + """return a mapping of DBAPI type objects present in this Dialect's DBAPI + mapped to TypeEngine implementations used by the dialect. + + This is used to apply types to result sets based on the DBAPI types + present in cursor.description; it only takes effect for result sets against + textual statements where no explicit typemap was present. Constructed SQL statements + always have type information explicitly embedded. """ raise NotImplementedError() @@ -149,11 +129,11 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def defaultrunner(self, connection, **kwargs): + def defaultrunner(self, execution_context): """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults. - connection - a [sqlalchemy.engine#Connection] to use for statement execution + execution_context + a [sqlalchemy.engine#ExecutionContext] to use for statement execution """ @@ -168,11 +148,12 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns=None): """Load table description from the database. Given a [sqlalchemy.engine#Connection] and a [sqlalchemy.schema#Table] object, reflect its - columns and properties from the database. + columns and properties from the database. If include_columns (a list or set) is specified, limit the autoload + to the given column names. """ raise NotImplementedError() @@ -222,6 +203,46 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() + def do_savepoint(self, connection, name): + """Create a savepoint with the given name on a SQLAlchemy connection.""" + + raise NotImplementedError() + + def do_rollback_to_savepoint(self, connection, name): + """Rollback a SQL Alchemy connection to the named savepoint.""" + + raise NotImplementedError() + + def do_release_savepoint(self, connection, name): + """Release the named savepoint on a SQL Alchemy connection.""" + + raise NotImplementedError() + + def do_begin_twophase(self, connection, xid): + """Begin a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_prepare_twophase(self, connection, xid): + """Prepare a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + """Rollback a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + """Commit a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_recover_twophase(self, connection): + """Recover list of uncommited prepared two phase transaction identifiers on the given connection.""" + + raise NotImplementedError() + def do_executemany(self, cursor, statement, parameters): """Provide an implementation of *cursor.executemany(statement, parameters)*.""" @@ -266,19 +287,18 @@ class ExecutionContext(object): compiled if passed to constructor, sql.Compiled object being executed - compiled_parameters - if passed to constructor, sql.ClauseParameters object - statement string version of the statement to be executed. Is either passed to the constructor, or must be created from the sql.Compiled object by the time pre_exec() has completed. parameters - "raw" parameters suitable for direct execution by the - dialect. Either passed to the constructor, or must be - created from the sql.ClauseParameters object by the time - pre_exec() has completed. + bind parameters passed to the execute() method. for + compiled statements, this is a dictionary or list + of dictionaries. for textual statements, it should + be in a format suitable for the dialect's paramstyle + (i.e. dict or list of dicts for non positional, + list or list of lists/tuples for positional). The Dialect should provide an ExecutionContext via the @@ -288,24 +308,28 @@ class ExecutionContext(object): """ def create_cursor(self): - """Return a new cursor generated this ExecutionContext's connection.""" + """Return a new cursor generated from this ExecutionContext's connection. + + Some dialects may wish to change the behavior of connection.cursor(), + such as postgres which may return a PG "server side" cursor. + """ raise NotImplementedError() - def pre_exec(self): + def pre_execution(self): """Called before an execution of a compiled statement. - If compiled and compiled_parameters were passed to this + If a compiled statement was passed to this ExecutionContext, the `statement` and `parameters` datamembers must be initialized after this statement is complete. """ raise NotImplementedError() - def post_exec(self): + def post_execution(self): """Called after the execution of a compiled statement. - If compiled was passed to this ExecutionContext, + If a compiled statement was passed to this ExecutionContext, the `last_insert_ids`, `last_inserted_params`, etc. datamembers should be available after this method completes. @@ -313,8 +337,11 @@ class ExecutionContext(object): raise NotImplementedError() - def get_result_proxy(self): - """return a ResultProxy corresponding to this ExecutionContext.""" + def result(self): + """return a result object corresponding to this ExecutionContext. + + Returns a ResultProxy.""" + raise NotImplementedError() def get_rowcount(self): @@ -361,8 +388,88 @@ class ExecutionContext(object): raise NotImplementedError() +class Compiled(object): + """Represent a compiled SQL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + def __init__(self, dialect, statement, parameters, bind=None): + """Construct a new ``Compiled`` object. + + statement + ``ClauseElement`` to be compiled. + + parameters + Optional dictionary indicating a set of bind parameters + specified with this ``Compiled`` object. These parameters + are the *default* values corresponding to the + ``ClauseElement``'s ``_BindParamClauses`` when the + ``Compiled`` is executed. In the case of an ``INSERT`` or + ``UPDATE`` statement, these parameters will also result in + the creation of new ``_BindParamClause`` objects for each + key and will also affect the generated column list in an + ``INSERT`` statement and the ``SET`` clauses of an + ``UPDATE`` statement. The keys of the parameter dictionary + can either be the string names of columns or + ``_ColumnClause`` objects. + + bind + Optional Engine or Connection to compile this statement against. + """ + self.dialect = dialect + self.statement = statement + self.parameters = parameters + self.bind = bind + self.can_execute = statement.supports_execution() + + def compile(self): + """Produce the internal string representation of this element.""" + + raise NotImplementedError() + + def __str__(self): + """Return the string text of the generated SQL statement.""" + + raise NotImplementedError() + + def get_params(self, **params): + """Deprecated. use construct_params(). (supports unicode names) + """ + + return self.construct_params(params) + + def construct_params(self, params): + """Return the bind params for this compiled object. + + params is a dict of string/object pairs whos + values will override bind values compiled in + to the statement. + """ + raise NotImplementedError() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.bind + if e is None: + raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.") + return e._execute_compiled(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's scalar value.""" + + return self.execute(*multiparams, **params).scalar() + -class Connectable(sql.Executor): +class Connectable(object): """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine.""" def contextual_connect(self): @@ -401,6 +508,7 @@ class Connection(Connectable): self.__connection = connection or engine.raw_connection() self.__transaction = None self.__close_with_result = close_with_result + self.__savepoint_seq = 0 def _get_connection(self): try: @@ -408,13 +516,18 @@ class Connection(Connectable): except AttributeError: raise exceptions.InvalidRequestError("This Connection is closed") + def _branch(self): + """return a new Connection which references this Connection's + engine and connection; but does not have close_with_result enabled.""" + + return Connection(self.__engine, self.__connection) + engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.") dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.") connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.") should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") - - def _create_transaction(self, parent): - return Transaction(self, parent) + properties = property(lambda s: s._get_connection().properties, + doc="A set of per-DBAPI connection properties.") def connect(self): """connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly.""" @@ -448,12 +561,34 @@ class Connection(Connectable): self.__connection.detach() - def begin(self): + def begin(self, nested=False): if self.__transaction is None: - self.__transaction = self._create_transaction(None) - return self.__transaction + self.__transaction = RootTransaction(self) + elif nested: + self.__transaction = NestedTransaction(self, self.__transaction) else: - return self._create_transaction(self.__transaction) + return Transaction(self, self.__transaction) + return self.__transaction + + def begin_nested(self): + return self.begin(nested=True) + + def begin_twophase(self, xid=None): + if self.__transaction is not None: + raise exceptions.InvalidRequestError("Cannot start a two phase transaction when a transaction is already started.") + if xid is None: + xid = "_sa_%032x" % random.randint(0,2**128) + self.__transaction = TwoPhaseTransaction(self, xid) + return self.__transaction + + def recover_twophase(self): + return self.__engine.dialect.do_recover_twophase(self) + + def rollback_prepared(self, xid, recover=False): + self.__engine.dialect.do_rollback_twophase(self, xid, recover=recover) + + def commit_prepared(self, xid, recover=False): + self.__engine.dialect.do_commit_twophase(self, xid, recover=recover) def in_transaction(self): return self.__transaction is not None @@ -485,6 +620,45 @@ class Connection(Connectable): raise exceptions.SQLError(None, None, e) self.__transaction = None + def _savepoint_impl(self, name=None): + if name is None: + self.__savepoint_seq += 1 + name = '__sa_savepoint_%s' % self.__savepoint_seq + if self.__connection.is_valid: + self.__engine.dialect.do_savepoint(self, name) + return name + + def _rollback_to_savepoint_impl(self, name, context): + if self.__connection.is_valid: + self.__engine.dialect.do_rollback_to_savepoint(self, name) + self.__transaction = context + + def _release_savepoint_impl(self, name, context): + if self.__connection.is_valid: + self.__engine.dialect.do_release_savepoint(self, name) + self.__transaction = context + + def _begin_twophase_impl(self, xid): + if self.__connection.is_valid: + self.__engine.dialect.do_begin_twophase(self, xid) + + def _prepare_twophase_impl(self, xid): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_prepare_twophase(self, xid) + + def _rollback_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_rollback_twophase(self, xid, is_prepared) + self.__transaction = None + + def _commit_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_commit_twophase(self, xid, is_prepared) + self.__transaction = None + def _autocommit(self, statement): """When no Transaction is present, this is called after executions to provide "autocommit" behavior.""" # TODO: have the dialect determine if autocommit can be set on the connection directly without this @@ -495,7 +669,7 @@ class Connection(Connectable): def _autorollback(self): if not self.in_transaction(): self._rollback_impl() - + def close(self): try: c = self.__connection @@ -514,74 +688,66 @@ class Connection(Connectable): def execute(self, object, *multiparams, **params): for c in type(object).__mro__: if c in Connection.executors: - return Connection.executors[c](self, object, *multiparams, **params) + return Connection.executors[c](self, object, multiparams, params) else: raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) - def execute_default(self, default, **kwargs): - return default.accept_visitor(self.__engine.dialect.defaultrunner(self)) + def _execute_default(self, default, multiparams=None, params=None): + return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) + + def _execute_text(self, statement, multiparams, params): + parameters = self.__distill_params(multiparams, params) + context = self.__create_execution_context(statement=statement, parameters=parameters) + self.__execute_raw(context) + return context.result() - def execute_text(self, statement, *multiparams, **params): - if len(multiparams) == 0: + def __distill_params(self, multiparams, params): + if multiparams is None or len(multiparams) == 0: parameters = params or None - elif len(multiparams) == 1 and (isinstance(multiparams[0], list) or isinstance(multiparams[0], tuple) or isinstance(multiparams[0], dict)): + elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)): parameters = multiparams[0] else: parameters = list(multiparams) - context = self._create_execution_context(statement=statement, parameters=parameters) - self._execute_raw(context) - return context.get_result_proxy() - - def _params_to_listofdicts(self, *multiparams, **params): - if len(multiparams) == 0: - return [params] - elif len(multiparams) == 1: - if multiparams[0] == None: - return [{}] - elif isinstance (multiparams[0], list) or isinstance (multiparams[0], tuple): - return multiparams[0] - else: - return [multiparams[0]] - else: - return multiparams - - def execute_function(self, func, *multiparams, **params): - return self.execute_clauseelement(func.select(), *multiparams, **params) + return parameters + + def _execute_function(self, func, multiparams, params): + return self._execute_clauseelement(func.select(), multiparams, params) - def execute_clauseelement(self, elem, *multiparams, **params): - executemany = len(multiparams) > 0 + def _execute_clauseelement(self, elem, multiparams=None, params=None): + executemany = multiparams is not None and len(multiparams) > 0 if executemany: param = multiparams[0] else: param = params - return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params) + return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param), multiparams, params) - def execute_compiled(self, compiled, *multiparams, **params): + def _execute_compiled(self, compiled, multiparams=None, params=None): """Execute a sql.Compiled object.""" if not compiled.can_execute: raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) - parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)] - if len(parameters) == 1: - parameters = parameters[0] - context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters) - context.pre_exec() - self._execute_raw(context) - context.post_exec() - return context.get_result_proxy() - - def _create_execution_context(self, **kwargs): + + params = self.__distill_params(multiparams, params) + context = self.__create_execution_context(compiled=compiled, parameters=params) + + context.pre_execution() + self.__execute_raw(context) + context.post_execution() + return context.result() + + def __create_execution_context(self, **kwargs): return self.__engine.dialect.create_execution_context(connection=self, **kwargs) - def _execute_raw(self, context): - self.__engine.logger.info(context.statement) - self.__engine.logger.info(repr(context.parameters)) - if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], tuple) or isinstance(context.parameters[0], dict)): - self._executemany(context) + def __execute_raw(self, context): + if logging.is_info_enabled(self.__engine.logger): + self.__engine.logger.info(context.statement) + self.__engine.logger.info(repr(context.parameters)) + if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)): + self.__executemany(context) else: - self._execute(context) + self.__execute(context) self._autocommit(context.statement) - def _execute(self, context): + def __execute(self, context): if context.parameters is None: if context.dialect.positional: context.parameters = () @@ -592,19 +758,19 @@ class Connection(Connectable): except Exception, e: if self.dialect.is_disconnect(e): self.__connection.invalidate(e=e) - self.engine.connection_provider.dispose() + self.engine.dispose() self._autorollback() if self.__close_with_result: self.close() raise exceptions.SQLError(context.statement, context.parameters, e) - def _executemany(self, context): + def __executemany(self, context): try: context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context) except Exception, e: if self.dialect.is_disconnect(e): self.__connection.invalidate(e=e) - self.engine.connection_provider.dispose() + self.engine.dispose() self._autorollback() if self.__close_with_result: self.close() @@ -612,11 +778,11 @@ class Connection(Connectable): # poor man's multimethod/generic function thingy executors = { - sql._Function : execute_function, - sql.ClauseElement : execute_clauseelement, - sql.ClauseVisitor : execute_compiled, - schema.SchemaItem:execute_default, - str.__mro__[-2] : execute_text + sql._Function : _execute_function, + sql.ClauseElement : _execute_clauseelement, + sql.ClauseVisitor : _execute_compiled, + schema.SchemaItem:_execute_default, + str.__mro__[-2] : _execute_text } def create(self, entity, **kwargs): @@ -629,10 +795,10 @@ class Connection(Connectable): return self.__engine.drop(entity, connection=self, **kwargs) - def reflecttable(self, table, **kwargs): + def reflecttable(self, table, include_columns=None): """Reflect the columns in the given string table name from the database.""" - return self.__engine.reflecttable(table, connection=self, **kwargs) + return self.__engine.reflecttable(table, self, include_columns) def default_schema_name(self): return self.__engine.dialect.get_default_schema_name(self) @@ -647,39 +813,90 @@ class Transaction(object): """ def __init__(self, connection, parent): - self.__connection = connection - self.__parent = parent or self - self.__is_active = True - if self.__parent is self: - self.__connection._begin_impl() + self._connection = connection + self._parent = parent or self + self._is_active = True - connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction") - is_active = property(lambda s:s.__is_active) + connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction") + is_active = property(lambda s:s._is_active) def rollback(self): - if not self.__parent.__is_active: + if not self._parent._is_active: return - if self.__parent is self: - self.__connection._rollback_impl() - self.__is_active = False - else: - self.__parent.rollback() + self._is_active = False + self._do_rollback() + + def _do_rollback(self): + self._parent.rollback() def commit(self): - if not self.__parent.__is_active: + if not self._parent._is_active: raise exceptions.InvalidRequestError("This transaction is inactive") - if self.__parent is self: - self.__connection._commit_impl() - self.__is_active = False + self._is_active = False + self._do_commit() + + def _do_commit(self): + pass + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is None and self._is_active: + self.commit() + else: + self.rollback() + +class RootTransaction(Transaction): + def __init__(self, connection): + super(RootTransaction, self).__init__(connection, None) + self._connection._begin_impl() + + def _do_rollback(self): + self._connection._rollback_impl() + + def _do_commit(self): + self._connection._commit_impl() + +class NestedTransaction(Transaction): + def __init__(self, connection, parent): + super(NestedTransaction, self).__init__(connection, parent) + self._savepoint = self._connection._savepoint_impl() + + def _do_rollback(self): + self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent) + + def _do_commit(self): + self._connection._release_savepoint_impl(self._savepoint, self._parent) + +class TwoPhaseTransaction(Transaction): + def __init__(self, connection, xid): + super(TwoPhaseTransaction, self).__init__(connection, None) + self._is_prepared = False + self.xid = xid + self._connection._begin_twophase_impl(self.xid) + + def prepare(self): + if not self._parent._is_active: + raise exceptions.InvalidRequestError("This transaction is inactive") + self._connection._prepare_twophase_impl(self.xid) + self._is_prepared = True + + def _do_rollback(self): + self._connection._rollback_twophase_impl(self.xid, self._is_prepared) + + def commit(self): + self._connection._commit_twophase_impl(self.xid, self._is_prepared) class Engine(Connectable): """ - Connects a ConnectionProvider, a Dialect and a CompilerFactory together to + Connects a Pool, a Dialect and a CompilerFactory together to provide a default implementation of SchemaEngine. """ - def __init__(self, connection_provider, dialect, echo=None): - self.connection_provider = connection_provider + def __init__(self, pool, dialect, url, echo=None): + self.pool = pool + self.url = url self._dialect=dialect self.echo = echo self.logger = logging.instance_logger(self) @@ -688,10 +905,13 @@ class Engine(Connectable): engine = property(lambda s:s) dialect = property(lambda s:s._dialect, doc="the [sqlalchemy.engine#Dialect] in use by this engine.") echo = logging.echo_property() - url = property(lambda s:s.connection_provider.url, doc="The [sqlalchemy.engine.url#URL] object representing this ``Engine`` object's datasource.") + + def __repr__(self): + return 'Engine(%s)' % str(self.url) def dispose(self): - self.connection_provider.dispose() + self.pool.dispose() + self.pool = self.pool.recreate() def create(self, entity, connection=None, **kwargs): """Create a table or index within this engine's database connection given a schema.Table object.""" @@ -703,22 +923,22 @@ class Engine(Connectable): self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs) - def execute_default(self, default, **kwargs): + def _execute_default(self, default): connection = self.contextual_connect() try: - return connection.execute_default(default, **kwargs) + return connection._execute_default(default) finally: connection.close() def _func(self): - return sql._FunctionGenerator(engine=self) + return sql._FunctionGenerator(bind=self) func = property(_func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" - return sql.text(text, engine=self, *args, **kwargs) + return sql.text(text, bind=self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: @@ -726,7 +946,7 @@ class Engine(Connectable): else: conn = connection try: - element.accept_visitor(visitorcallable(conn, **kwargs)) + visitorcallable(conn, **kwargs).traverse(element) finally: if connection is None: conn.close() @@ -775,12 +995,12 @@ class Engine(Connectable): def scalar(self, statement, *multiparams, **params): return self.execute(statement, *multiparams, **params).scalar() - def execute_compiled(self, compiled, *multiparams, **params): + def _execute_compiled(self, compiled, multiparams, params): connection = self.contextual_connect(close_with_result=True) - return connection.execute_compiled(compiled, *multiparams, **params) + return connection._execute_compiled(compiled, multiparams, params) def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, engine=self, **kwargs) + return self.dialect.compiler(statement, parameters, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -795,7 +1015,7 @@ class Engine(Connectable): return Connection(self, close_with_result=close_with_result, **kwargs) - def reflecttable(self, table, connection=None): + def reflecttable(self, table, connection=None, include_columns=None): """Given a Table object, reflects its columns and properties from the database.""" if connection is None: @@ -803,7 +1023,7 @@ class Engine(Connectable): else: conn = connection try: - self.dialect.reflecttable(conn, table) + self.dialect.reflecttable(conn, table, include_columns) finally: if connection is None: conn.close() @@ -814,7 +1034,7 @@ class Engine(Connectable): def raw_connection(self): """Return a DBAPI connection.""" - return self.connection_provider.get_connection() + return self.pool.connect() def log(self, msg): """Log a message using this SQLEngine's logger stream.""" @@ -858,28 +1078,42 @@ class ResultProxy(object): self.closed = False self.cursor = context.cursor self.__echo = logging.is_debug_enabled(context.engine.logger) - self._init_metadata() - - rowcount = property(lambda s:s.context.get_rowcount()) - connection = property(lambda s:s.context.connection) + if context.is_select(): + self._init_metadata() + self._rowcount = None + else: + self._rowcount = context.get_rowcount() + self.close() + + connection = property(lambda self:self.context.connection) + def _get_rowcount(self): + if self._rowcount is not None: + return self._rowcount + else: + return self.context.get_rowcount() + rowcount = property(_get_rowcount) lastrowid = property(lambda s:s.cursor.lastrowid) + out_parameters = property(lambda s:s.context.out_parameters) def _init_metadata(self): if hasattr(self, '_ResultProxy__props'): return - self.__key_cache = {} self.__props = {} + self._key_cache = self._create_key_cache() self.__keys = [] metadata = self.cursor.description if metadata is not None: + typemap = self.dialect.dbapi_type_map() + for i, item in enumerate(metadata): # sqlite possibly prepending table name to colnames so strip - colname = item[0].split('.')[-1] + colname = self.dialect.decode_result_columnname(item[0].split('.')[-1]) if self.context.typemap is not None: - type = self.context.typemap.get(colname.lower(), types.NULLTYPE) + type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE)) else: - type = types.NULLTYPE + type = typemap.get(item[1], types.NULLTYPE) + rec = (type, type.dialect_impl(self.dialect), i) if rec[0] is None: @@ -889,6 +1123,33 @@ class ResultProxy(object): self.__keys.append(colname) self.__props[i] = rec + if self.__echo: + self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata]))) + + def _create_key_cache(self): + # local copies to avoid circular ref against 'self' + props = self.__props + context = self.context + def lookup_key(key): + """Given a key, which could be a ColumnElement, string, etc., + matches it to the appropriate key we got from the result set's + metadata; then cache it locally for quick re-access.""" + + if isinstance(key, int) and key in props: + rec = props[key] + elif isinstance(key, basestring) and key.lower() in props: + rec = props[key.lower()] + elif isinstance(key, sql.ColumnElement): + label = context.column_labels.get(key._label, key.name).lower() + if label in props: + rec = props[label] + + if not "rec" in locals(): + raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) + + return rec + return util.PopulateDict(lookup_key) + def close(self): """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution. @@ -904,38 +1165,12 @@ class ResultProxy(object): self.cursor.close() if self.connection.should_close_with_result: self.connection.close() - - def _convert_key(self, key): - """Convert and cache a key. - - Given a key, which could be a ColumnElement, string, etc., - matches it to the appropriate key we got from the result set's - metadata; then cache it locally for quick re-access. - """ - - if key in self.__key_cache: - return self.__key_cache[key] - else: - if isinstance(key, int) and key in self.__props: - rec = self.__props[key] - elif isinstance(key, basestring) and key.lower() in self.__props: - rec = self.__props[key.lower()] - elif isinstance(key, sql.ColumnElement): - label = self.context.column_labels.get(key._label, key.name).lower() - if label in self.__props: - rec = self.__props[label] - - if not "rec" in locals(): - raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) - - self.__key_cache[key] = rec - return rec keys = property(lambda s:s.__keys) def _has_key(self, row, key): try: - self._convert_key(key) + self._key_cache[key] return True except KeyError: return False @@ -989,7 +1224,7 @@ class ResultProxy(object): return self.context.supports_sane_rowcount() def _get_col(self, row, key): - rec = self._convert_key(key) + rec = self._key_cache[key] return rec[1].convert_result_value(row[rec[2]], self.dialect) def _fetchone_impl(self): @@ -1101,7 +1336,7 @@ class BufferedColumnResultProxy(ResultProxy): """ def _get_col(self, row, key): - rec = self._convert_key(key) + rec = self._key_cache[key] return row[rec[2]] def _process_row(self, row): @@ -1152,6 +1387,9 @@ class RowProxy(object): self.__parent.close() + def __contains__(self, key): + return self.__parent._has_key(self.__row, key) + def __iter__(self): for i in range(0, len(self.__row)): yield self.__parent._get_col(self.__row, i) @@ -1168,7 +1406,11 @@ class RowProxy(object): return self.__parent._has_key(self.__row, key) def __getitem__(self, key): - return self.__parent._get_col(self.__row, key) + if isinstance(key, slice): + indices = key.indices(len(self)) + return tuple([self.__parent._get_col(self.__row, i) for i in range(*indices)]) + else: + return self.__parent._get_col(self.__row, key) def __getattr__(self, name): try: @@ -1226,19 +1468,22 @@ class DefaultRunner(schema.SchemaVisitor): DefaultRunner to allow database-specific behavior. """ - def __init__(self, connection): - self.connection = connection - self.dialect = connection.dialect + def __init__(self, context): + self.context = context + # branch the connection so it doesnt close after result + self.connection = context.connection._branch() + dialect = property(lambda self:self.context.dialect) + def get_column_default(self, column): if column.default is not None: - return column.default.accept_visitor(self) + return self.traverse_single(column.default) else: return None def get_column_onupdate(self, column): if column.onupdate is not None: - return column.onupdate.accept_visitor(self) + return self.traverse_single(column.onupdate) else: return None @@ -1260,14 +1505,14 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg]).compile(engine=self.connection) - return self.connection.execute_compiled(c).scalar() + c = sql.select([default.arg]).compile(bind=self.connection) + return self.connection._execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): if isinstance(onupdate.arg, sql.ClauseElement): return self.exec_default_sql(onupdate) elif callable(onupdate.arg): - return onupdate.arg() + return onupdate.arg(self.context) else: return onupdate.arg @@ -1275,6 +1520,6 @@ class DefaultRunner(schema.SchemaVisitor): if isinstance(default.arg, sql.ClauseElement): return self.exec_default_sql(default) elif callable(default.arg): - return default.arg() + return default.arg(self.context) else: return default.arg diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 95f6566e3..962e2ab60 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -4,25 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Provide default implementations of per-dialect sqlalchemy.engine classes""" -from sqlalchemy import schema, exceptions, util, sql, types -import StringIO, sys, re +from sqlalchemy import schema, exceptions, sql, types +import sys, re from sqlalchemy.engine import base -"""Provide default implementations of the engine interfaces""" -class PoolConnectionProvider(base.ConnectionProvider): - def __init__(self, url, pool): - self.url = url - self._pool = pool - - def get_connection(self): - return self._pool.connect() - - def dispose(self): - self._pool.dispose() - self._pool = self._pool.recreate() - class DefaultDialect(base.Dialect): """Default implementation of Dialect""" @@ -33,7 +21,18 @@ class DefaultDialect(base.Dialect): self._ischema = None self.dbapi = dbapi self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) - + + def decode_result_columnname(self, name): + """decode a name found in cursor.description to a unicode object.""" + + return name.decode(self.encoding) + + def dbapi_type_map(self): + # most DBAPIs have problems with this (such as, psycocpg2 types + # are unhashable). So far Oracle can return it. + + return {} + def create_execution_context(self, **kwargs): return DefaultExecutionContext(self, **kwargs) @@ -88,6 +87,15 @@ class DefaultDialect(base.Dialect): #print "ENGINE COMMIT ON ", connection.connection connection.commit() + + def do_savepoint(self, connection, name): + connection.execute(sql.SavepointClause(name)) + + def do_rollback_to_savepoint(self, connection, name): + connection.execute(sql.RollbackToSavepointClause(name)) + + def do_release_savepoint(self, connection, name): + connection.execute(sql.ReleaseSavepointClause(name)) def do_executemany(self, cursor, statement, parameters, **kwargs): cursor.executemany(statement, parameters) @@ -95,8 +103,8 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters) - def defaultrunner(self, connection): - return base.DefaultRunner(connection) + def defaultrunner(self, context): + return base.DefaultRunner(context) def is_disconnect(self, e): return False @@ -107,23 +115,6 @@ class DefaultDialect(base.Dialect): paramstyle = property(lambda s:s._paramstyle, _set_paramstyle) - def convert_compiled_params(self, parameters): - executemany = parameters is not None and isinstance(parameters, list) - # the bind params are a CompiledParams object. but all the DBAPI's hate - # that object (or similar). so convert it to a clean - # dictionary/list/tuple of dictionary/tuple of list - if parameters is not None: - if self.positional: - if executemany: - parameters = [p.get_raw_list() for p in parameters] - else: - parameters = parameters.get_raw_list() - else: - if executemany: - parameters = [p.get_raw_dict() for p in parameters] - else: - parameters = parameters.get_raw_dict() - return parameters def _figure_paramstyle(self, paramstyle=None, default='named'): if paramstyle is not None: @@ -152,29 +143,38 @@ class DefaultDialect(base.Dialect): ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None): + def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): self.dialect = dialect self.connection = connection self.compiled = compiled - self.compiled_parameters = compiled_parameters if compiled is not None: self.typemap = compiled.typemap self.column_labels = compiled.column_labels self.statement = unicode(compiled) - else: + if parameters is None: + self.compiled_parameters = compiled.construct_params({}) + elif not isinstance(parameters, (list, tuple)): + self.compiled_parameters = compiled.construct_params(parameters) + else: + self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters] + if len(self.compiled_parameters) == 1: + self.compiled_parameters = self.compiled_parameters[0] + elif statement is not None: self.typemap = self.column_labels = None - self.parameters = self._encode_param_keys(parameters) + self.parameters = self.__encode_param_keys(parameters) self.statement = statement - - if not dialect.supports_unicode_statements(): + else: + self.statement = None + + if self.statement is not None and not dialect.supports_unicode_statements(): self.statement = self.statement.encode(self.dialect.encoding) self.cursor = self.create_cursor() engine = property(lambda s:s.connection.engine) - def _encode_param_keys(self, params): + def __encode_param_keys(self, params): """apply string encoding to the keys of dictionary-based bind parameters""" if self.dialect.positional or self.dialect.supports_unicode_statements(): return params @@ -189,16 +189,46 @@ class DefaultExecutionContext(base.ExecutionContext): return [proc(d) for d in params] else: return proc(params) + + def __convert_compiled_params(self, parameters): + executemany = parameters is not None and isinstance(parameters, list) + encode = not self.dialect.supports_unicode_statements() + # the bind params are a CompiledParams object. but all the DBAPI's hate + # that object (or similar). so convert it to a clean + # dictionary/list/tuple of dictionary/tuple of list + if parameters is not None: + if self.dialect.positional: + if executemany: + parameters = [p.get_raw_list() for p in parameters] + else: + parameters = parameters.get_raw_list() + else: + if executemany: + parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters] + else: + parameters = parameters.get_raw_dict(encode_keys=encode) + return parameters def is_select(self): - return re.match(r'SELECT', self.statement.lstrip(), re.I) + """return TRUE if the statement is expected to have result rows.""" + + return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None def create_cursor(self): return self.connection.connection.cursor() - + + def pre_execution(self): + self.pre_exec() + + def post_execution(self): + self.post_exec() + + def result(self): + return self.get_result_proxy() + def pre_exec(self): self._process_defaults() - self.parameters = self._encode_param_keys(self.dialect.convert_compiled_params(self.compiled_parameters)) + self.parameters = self.__convert_compiled_params(self.compiled_parameters) def post_exec(self): pass @@ -241,7 +271,7 @@ class DefaultExecutionContext(base.ExecutionContext): inputsizes = [] for params in plist[0:1]: for key in params.positional: - typeengine = params.binds[key].type + typeengine = params.get_type(key) dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes.append(dbtype) @@ -250,36 +280,23 @@ class DefaultExecutionContext(base.ExecutionContext): inputsizes = {} for params in plist[0:1]: for key in params.keys(): - typeengine = params.binds[key].type + typeengine = params.get_type(key) dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes[key] = dbtype self.cursor.setinputsizes(**inputsizes) def _process_defaults(self): - """``INSERT`` and ``UPDATE`` statements, when compiled, may - have additional columns added to their ``VALUES`` and ``SET`` - lists corresponding to column defaults/onupdates that are - present on the ``Table`` object (i.e. ``ColumnDefault``, - ``Sequence``, ``PassiveDefault``). This method pre-execs - those ``DefaultGenerator`` objects that require pre-execution - and sets their values within the parameter list, and flags this - ExecutionContext about ``PassiveDefault`` objects that may - require post-fetching the row after it is inserted/updated. - - This method relies upon logic within the ``ANSISQLCompiler`` - in its `visit_insert` and `visit_update` methods that add the - appropriate column clauses to the statement when its being - compiled, so that these parameters can be bound to the - statement. - """ + """generate default values for compiled insert/update statements, + and generate last_inserted_ids() collection.""" + # TODO: cleanup if self.compiled.isinsert: if isinstance(self.compiled_parameters, list): plist = self.compiled_parameters else: plist = [self.compiled_parameters] - drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) + drunner = self.dialect.defaultrunner(self) self._lastrow_has_defaults = False for param in plist: last_inserted_ids = [] @@ -319,7 +336,7 @@ class DefaultExecutionContext(base.ExecutionContext): plist = self.compiled_parameters else: plist = [self.compiled_parameters] - drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) + drunner = self.dialect.defaultrunner(self) self._lastrow_has_defaults = False for param in plist: # check the "onupdate" status of each column in the table diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 7d85de9ad..0c59ee8eb 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -4,13 +4,11 @@ By default there are two, one which is the "thread-local" strategy, one which is the "plain" strategy. New strategies can be added via constructing a new EngineStrategy -object which will add itself to the list of available strategies here, -or replace one of the existing name. this can be accomplished via a -mod; see the sqlalchemy/mods package for details. +object which will add itself to the list of available strategies. """ -from sqlalchemy.engine import base, default, threadlocal, url +from sqlalchemy.engine import base, threadlocal, url from sqlalchemy import util, exceptions from sqlalchemy import pool as poollib @@ -92,8 +90,6 @@ class DefaultEngineStrategy(EngineStrategy): else: pool = pool - provider = self.get_pool_provider(u, pool) - # create engine. engineclass = self.get_engine_cls() engine_args = {} @@ -105,14 +101,11 @@ class DefaultEngineStrategy(EngineStrategy): if len(kwargs): raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) - return engineclass(provider, dialect, **engine_args) + return engineclass(pool, dialect, u, **engine_args) def pool_threadlocal(self): raise NotImplementedError() - def get_pool_provider(self, url, pool): - raise NotImplementedError() - def get_engine_cls(self): raise NotImplementedError() @@ -123,9 +116,6 @@ class PlainEngineStrategy(DefaultEngineStrategy): def pool_threadlocal(self): return False - def get_pool_provider(self, url, pool): - return default.PoolConnectionProvider(url, pool) - def get_engine_cls(self): return base.Engine @@ -138,9 +128,6 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy): def pool_threadlocal(self): return True - def get_pool_provider(self, url, pool): - return threadlocal.TLocalConnectionProvider(url, pool) - def get_engine_cls(self): return threadlocal.TLEngine @@ -195,4 +182,4 @@ class MockEngineStrategy(EngineStrategy): def execute(self, object, *multiparams, **params): raise NotImplementedError() -MockEngineStrategy()
\ No newline at end of file +MockEngineStrategy() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 2bbb1ed43..b6ba54ea5 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -1,8 +1,7 @@ -from sqlalchemy import schema, exceptions, util, sql, types -import StringIO, sys, re -from sqlalchemy.engine import base, default +from sqlalchemy import util +from sqlalchemy.engine import base -"""Provide a thread-local transactional wrapper around the basic ComposedSQLEngine. +"""Provide a thread-local transactional wrapper around the root Engine class. Multiple calls to engine.connect() will return the same connection for the same thread. also provides begin/commit methods on the engine @@ -70,11 +69,8 @@ class TLConnection(base.Connection): self.__opencount += 1 return self - def _create_transaction(self, parent): - return TLTransaction(self, parent) - def _begin(self): - return base.Connection.begin(self) + return TLTransaction(self) def in_transaction(self): return self.session.in_transaction() @@ -91,7 +87,7 @@ class TLConnection(base.Connection): self.__opencount = 0 base.Connection.close(self) -class TLTransaction(base.Transaction): +class TLTransaction(base.RootTransaction): def _commit_impl(self): base.Transaction.commit(self) @@ -112,7 +108,7 @@ class TLEngine(base.Engine): """ def __init__(self, *args, **kwargs): - """The TLEngine relies upon the ConnectionProvider having + """The TLEngine relies upon the Pool having "threadlocal" behavior, so that once a connection is checked out for the current thread, you get that same connection repeatedly. @@ -124,7 +120,7 @@ class TLEngine(base.Engine): def raw_connection(self): """Return a DBAPI connection.""" - return self.connection_provider.get_connection() + return self.pool.connect() def connect(self, **kwargs): """Return a Connection that is not thread-locally scoped. @@ -133,7 +129,7 @@ class TLEngine(base.Engine): ComposedSQLEngine. """ - return base.Connection(self, self.connection_provider.unique_connection()) + return base.Connection(self, self.pool.unique_connection()) def _session(self): if not hasattr(self.context, 'session'): @@ -156,6 +152,3 @@ class TLEngine(base.Engine): def rollback(self): self.session.rollback() -class TLocalConnectionProvider(default.PoolConnectionProvider): - def unique_connection(self): - return self._pool.unique_connection() diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index c5ad90ee9..1da76d7b2 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -1,10 +1,8 @@ -import re -import cgi -import sys -import urllib +"""Provide the URL object as well as the make_url parsing function.""" + +import re, cgi, sys, urllib from sqlalchemy import exceptions -"""Provide the URL object as well as the make_url parsing function.""" class URL(object): """Represent the components of a URL used to connect to a database. diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py index 2fcf44f61..fa32a5fc3 100644 --- a/lib/sqlalchemy/ext/activemapper.py +++ b/lib/sqlalchemy/ext/activemapper.py @@ -1,11 +1,10 @@ -from sqlalchemy import create_session, relation, mapper, \ - join, ThreadLocalMetaData, class_mapper, \ - util, Integer -from sqlalchemy import and_, or_ +from sqlalchemy import ThreadLocalMetaData, util, Integer from sqlalchemy import Table, Column, ForeignKey +from sqlalchemy.orm import class_mapper, relation, create_session + from sqlalchemy.ext.sessioncontext import SessionContext from sqlalchemy.ext.assignmapper import assign_mapper -from sqlalchemy import backref as create_backref +from sqlalchemy.orm import backref as create_backref import sqlalchemy import inspect @@ -14,7 +13,7 @@ import sys # # the "proxy" to the database engine... this can be swapped out at runtime # -metadata = ThreadLocalMetaData("activemapper") +metadata = ThreadLocalMetaData() try: objectstore = sqlalchemy.objectstore diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py index 4708afd8d..238041702 100644 --- a/lib/sqlalchemy/ext/assignmapper.py +++ b/lib/sqlalchemy/ext/assignmapper.py @@ -1,51 +1,50 @@ -from sqlalchemy import mapper, util, Query, exceptions +from sqlalchemy import util, exceptions import types - -def monkeypatch_query_method(ctx, class_, name): - def do(self, *args, **kwargs): - query = Query(class_, session=ctx.current) - return getattr(query, name)(*args, **kwargs) - try: - do.__name__ = name - except: - pass - setattr(class_, name, classmethod(do)) - -def monkeypatch_objectstore_method(ctx, class_, name): +from sqlalchemy.orm import mapper + +def _monkeypatch_session_method(name, ctx, class_): def do(self, *args, **kwargs): session = ctx.current - if name == "flush": - # flush expects a list of objects - self = [self] return getattr(session, name)(self, *args, **kwargs) try: do.__name__ = name except: pass - setattr(class_, name, do) - + if not hasattr(class_, name): + setattr(class_, name, do) + def assign_mapper(ctx, class_, *args, **kwargs): + extension = kwargs.pop('extension', None) + if extension is not None: + extension = util.to_list(extension) + extension.append(ctx.mapper_extension) + else: + extension = ctx.mapper_extension + validate = kwargs.pop('validate', False) + if not isinstance(getattr(class_, '__init__'), types.MethodType): def __init__(self, **kwargs): for key, value in kwargs.items(): if validate: - if not key in self.mapper.props: + if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False): raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) setattr(self, key, value) class_.__init__ = __init__ - extension = kwargs.pop('extension', None) - if extension is not None: - extension = util.to_list(extension) - extension.append(ctx.mapper_extension) - else: - extension = ctx.mapper_extension + + class query(object): + def __getattr__(self, key): + return getattr(ctx.current.query(class_), key) + def __call__(self): + return ctx.current.query(class_) + + if not hasattr(class_, 'query'): + class_.query = query() + + for name in ['refresh', 'expire', 'delete', 'expunge', 'update']: + _monkeypatch_session_method(name, ctx, class_) + m = mapper(class_, extension=extension, *args, **kwargs) class_.mapper = m - class_.query = classmethod(lambda cls: Query(class_, session=ctx.current)) - for name in ['get', 'filter', 'filter_by', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']: - monkeypatch_query_method(ctx, class_, name) - for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']: - monkeypatch_objectstore_method(ctx, class_, name) return m diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index cdb814702..2dd807222 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -6,11 +6,10 @@ transparent proxied access to the endpoint of an association object. See the example ``examples/association/proxied_association.py``. """ -from sqlalchemy.orm.attributes import InstrumentedList +import weakref, itertools import sqlalchemy.exceptions as exceptions import sqlalchemy.orm as orm import sqlalchemy.util as util -import weakref def association_proxy(targetcollection, attr, **kw): """Convenience function for use in mapped classes. Implements a Python @@ -109,7 +108,7 @@ class AssociationProxy(object): self.collection_class = None def _get_property(self): - return orm.class_mapper(self.owning_class).props[self.target_collection] + return orm.class_mapper(self.owning_class).get_property(self.target_collection) def _target_class(self): return self._get_property().mapper.class_ @@ -168,15 +167,7 @@ class AssociationProxy(object): def _new(self, lazy_collection): creator = self.creator and self.creator or self.target_class - - # Prefer class typing here to spot dicts with the required append() - # method. - collection = lazy_collection() - if isinstance(collection.data, dict): - self.collection_class = dict - else: - self.collection_class = util.duck_type_collection(collection.data) - del collection + self.collection_class = util.duck_type_collection(lazy_collection()) if self.proxy_factory: return self.proxy_factory(lazy_collection, creator, self.value_attr) @@ -269,7 +260,33 @@ class _AssociationList(object): return self._get(self.col[index]) def __setitem__(self, index, value): - self._set(self.col[index], value) + if not isinstance(index, slice): + self._set(self.col[index], value) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + rng = range(index.start or 0, stop, step) + if step == 1: + for i in rng: + del self[index.start] + i = index.start + for item in value: + self.insert(i, item) + i += 1 + else: + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), + len(rng))) + for i, item in zip(rng, value): + self._set(self.col[i], item) def __delitem__(self, index): del self.col[index] @@ -291,9 +308,13 @@ class _AssociationList(object): del self.col[start:end] def __iter__(self): - """Iterate over proxied values. For the actual domain objects, - iterate over .col instead or just use the underlying collection - directly from its property on the parent.""" + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + for member in self.col: yield self._get(member) raise StopIteration @@ -304,6 +325,10 @@ class _AssociationList(object): item = self._create(value, **kw) self.col.append(item) + def count(self, value): + return sum([1 for _ in + itertools.ifilter(lambda v: v == value, iter(self))]) + def extend(self, values): for v in values: self.append(v) @@ -311,6 +336,26 @@ class _AssociationList(object): def insert(self, index, value): self.col[index:index] = [self._create(value)] + def pop(self, index=-1): + return self.getter(self.col.pop(index)) + + def remove(self, value): + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self): + """Not supported, use reversed(mylist)""" + + raise NotImplementedError + + def sort(self): + """Not supported, use sorted(mylist)""" + + raise NotImplementedError + def clear(self): del self.col[0:len(self.col)] @@ -545,9 +590,7 @@ class _AssociationSet(object): def add(self, value): if value not in self: - # must shove this through InstrumentedList.append() which will - # eventually call the collection_class .add() - self.col.append(self._create(value)) + self.col.add(self._create(value)) # for discard and remove, choosing a more expensive check strategy rather # than call self.creator() @@ -567,12 +610,7 @@ class _AssociationSet(object): def pop(self): if not self.col: raise KeyError('pop from an empty set') - # grumble, pop() is borked on InstrumentedList (#548) - if isinstance(self.col, InstrumentedList): - member = list(self.col)[0] - self.col.remove(member) - else: - member = self.col.pop() + member = self.col.pop() return self._get(member) def update(self, other): diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py deleted file mode 100644 index b81702fc4..000000000 --- a/lib/sqlalchemy/ext/proxy.py +++ /dev/null @@ -1,113 +0,0 @@ -try: - from threading import local -except ImportError: - from sqlalchemy.util import ThreadLocal as local - -from sqlalchemy import sql -from sqlalchemy.engine import create_engine, Engine - -__all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine'] - -class BaseProxyEngine(sql.Executor): - """Basis for all proxy engines.""" - - def get_engine(self): - raise NotImplementedError - - def set_engine(self, engine): - raise NotImplementedError - - engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e)) - - def execute_compiled(self, *args, **kwargs): - """Override superclass behaviour. - - This method is required to be present as it overrides the - `execute_compiled` present in ``sql.Engine``. - """ - - return self.get_engine().execute_compiled(*args, **kwargs) - - def compiler(self, *args, **kwargs): - """Override superclass behaviour. - - This method is required to be present as it overrides the - `compiler` method present in ``sql.Engine``. - """ - - return self.get_engine().compiler(*args, **kwargs) - - def __getattr__(self, attr): - """Provide proxying for methods that are not otherwise present on this ``BaseProxyEngine``. - - Note that methods which are present on the base class - ``sql.Engine`` will **not** be proxied through this, and must - be explicit on this class. - """ - - # call get_engine() to give subclasses a chance to change - # connection establishment behavior - e = self.get_engine() - if e is not None: - return getattr(e, attr) - raise AttributeError("No connection established in ProxyEngine: " - " no access to %s" % attr) - -class AutoConnectEngine(BaseProxyEngine): - """An SQLEngine proxy that automatically connects when necessary.""" - - def __init__(self, dburi, **kwargs): - BaseProxyEngine.__init__(self) - self.dburi = dburi - self.kwargs = kwargs - self._engine = None - - def get_engine(self): - if self._engine is None: - if callable(self.dburi): - dburi = self.dburi() - else: - dburi = self.dburi - self._engine = create_engine(dburi, **self.kwargs) - return self._engine - - -class ProxyEngine(BaseProxyEngine): - """Engine proxy for lazy and late initialization. - - This engine will delegate access to a real engine set with connect(). - """ - - def __init__(self, **kwargs): - BaseProxyEngine.__init__(self) - # create the local storage for uri->engine map and current engine - self.storage = local() - self.kwargs = kwargs - - def connect(self, *args, **kwargs): - """Establish connection to a real engine.""" - - kwargs.update(self.kwargs) - if not kwargs: - key = repr(args) - else: - key = "%s, %s" % (repr(args), repr(sorted(kwargs.items()))) - try: - map = self.storage.connection - except AttributeError: - self.storage.connection = {} - self.storage.engine = None - map = self.storage.connection - try: - self.storage.engine = map[key] - except KeyError: - map[key] = create_engine(*args, **kwargs) - self.storage.engine = map[key] - - def get_engine(self): - if not hasattr(self.storage, 'engine') or self.storage.engine is None: - raise AttributeError("No connection established") - return self.storage.engine - - def set_engine(self, engine): - self.storage.engine = engine diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py index 68538f3cb..1920b6f92 100644 --- a/lib/sqlalchemy/ext/selectresults.py +++ b/lib/sqlalchemy/ext/selectresults.py @@ -1,212 +1,28 @@ +"""SelectResults has been rolled into Query. This class is now just a placeholder.""" + import sqlalchemy.sql as sql import sqlalchemy.orm as orm class SelectResultsExt(orm.MapperExtension): """a MapperExtension that provides SelectResults functionality for the results of query.select_by() and query.select()""" + def select_by(self, query, *args, **params): - return SelectResults(query, query.join_by(*args, **params)) + q = query + for a in args: + q = q.filter(a) + return q.filter_by(**params) + def select(self, query, arg=None, **kwargs): if isinstance(arg, sql.FromClause) and arg.supports_execution(): return orm.EXT_PASS else: - return SelectResults(query, arg, ops=kwargs) - -class SelectResults(object): - """Build a query one component at a time via separate method - calls, each call transforming the previous ``SelectResults`` - instance into a new ``SelectResults`` instance with further - limiting criterion added. When interpreted in an iterator context - (such as via calling ``list(selectresults)``), executes the query. - """ - - def __init__(self, query, clause=None, ops={}, joinpoint=None): - """Construct a new ``SelectResults`` using the given ``Query`` - object and optional ``WHERE`` clause. `ops` is an optional - dictionary of bind parameter values. - """ - - self._query = query - self._clause = clause - self._ops = {} - self._ops.update(ops) - self._joinpoint = joinpoint or (self._query.table, self._query.mapper) - - def options(self,*args, **kwargs): - """Apply mapper options to the underlying query. - - See also ``Query.options``. - """ - - new = self.clone() - new._query = new._query.options(*args, **kwargs) - return new - - def count(self): - """Execute the SQL ``count()`` function against the ``SelectResults`` criterion.""" - - return self._query.count(self._clause, **self._ops) - - def _col_aggregate(self, col, func): - """Execute ``func()`` function against the given column. - - For performance, only use subselect if `order_by` attribute is set. - """ - - if self._ops.get('order_by'): - s1 = sql.select([col], self._clause, **self._ops).alias('u') - return sql.select([func(s1.corresponding_column(col))]).scalar() - else: - return sql.select([func(col)], self._clause, **self._ops).scalar() - - def min(self, col): - """Execute the SQL ``min()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.min) - - def max(self, col): - """Execute the SQL ``max()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.max) - - def sum(self, col): - """Execute the SQL ``sum()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.sum) - - def avg(self, col): - """Execute the SQL ``avg()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.avg) - - def clone(self): - """Create a copy of this ``SelectResults``.""" - - return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint) - - def filter(self, clause): - """Apply an additional ``WHERE`` clause against the query.""" - - new = self.clone() - new._clause = sql.and_(self._clause, clause) - return new - - def select(self, clause): - return self.filter(clause) - - def select_by(self, *args, **kwargs): - return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1])) - - def order_by(self, order_by): - """Apply an ``ORDER BY`` to the query.""" - - new = self.clone() - new._ops['order_by'] = order_by - return new - - def limit(self, limit): - """Apply a ``LIMIT`` to the query.""" - - return self[:limit] - - def offset(self, offset): - """Apply an ``OFFSET`` to the query.""" - - return self[offset:] - - def distinct(self): - """Apply a ``DISTINCT`` to the query.""" - - new = self.clone() - new._ops['distinct'] = True - return new - - def list(self): - """Return the results represented by this ``SelectResults`` as a list. - - This results in an execution of the underlying query. - """ - - return list(self) - - def select_from(self, from_obj): - """Set the `from_obj` parameter of the query. - - `from_obj` is a list of one or more tables. - """ - - new = self.clone() - new._ops['from_obj'] = from_obj - return new - - def join_to(self, prop): - """Join the table of this ``SelectResults`` to the table located against the given property name. - - Subsequent calls to join_to or outerjoin_to will join against - the rightmost table located from the previous `join_to` or - `outerjoin_to` call, searching for the property starting with - the rightmost mapper last located. - """ - - new = self.clone() - (clause, mapper) = self._join_to(prop, outerjoin=False) - new._ops['from_obj'] = [clause] - new._joinpoint = (clause, mapper) - return new - - def outerjoin_to(self, prop): - """Outer join the table of this ``SelectResults`` to the - table located against the given property name. - - Subsequent calls to join_to or outerjoin_to will join against - the rightmost table located from the previous ``join_to`` or - ``outerjoin_to`` call, searching for the property starting with - the rightmost mapper last located. - """ - - new = self.clone() - (clause, mapper) = self._join_to(prop, outerjoin=True) - new._ops['from_obj'] = [clause] - new._joinpoint = (clause, mapper) - return new - - def _join_to(self, prop, outerjoin=False): - [keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1]) - clause = self._joinpoint[0] - mapper = self._joinpoint[1] - for key in keys: - prop = mapper.props[key] - if outerjoin: - clause = clause.outerjoin(prop.select_table, prop.get_join(mapper)) - else: - clause = clause.join(prop.select_table, prop.get_join(mapper)) - mapper = prop.mapper - return (clause, mapper) - - def compile(self): - return self._query.compile(self._clause, **self._ops) - - def __getitem__(self, item): - if isinstance(item, slice): - start = item.start - stop = item.stop - if (isinstance(start, int) and start < 0) or \ - (isinstance(stop, int) and stop < 0): - return list(self)[item] - else: - res = self.clone() - if start is not None and stop is not None: - res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start)) - elif start is None and stop is not None: - res._ops.update(dict(limit=stop)) - elif start is not None and stop is None: - res._ops.update(dict(offset=self._ops.get('offset', 0)+start)) - if item.step is not None: - return list(res)[None:None:item.step] - else: - return res - else: - return list(self[item:item+1])[0] - - def __iter__(self): - return iter(self._query.select_whereclause(self._clause, **self._ops)) + if arg is not None: + query = query.filter(arg) + return query._legacy_select_kwargs(**kwargs) + +def SelectResults(query, clause=None, ops={}): + if clause is not None: + query = query.filter(clause) + query = query.options(orm.extension(SelectResultsExt())) + return query._legacy_select_kwargs(**ops) diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py index 2f81e55d2..fcbf29c3f 100644 --- a/lib/sqlalchemy/ext/sessioncontext.py +++ b/lib/sqlalchemy/ext/sessioncontext.py @@ -1,5 +1,5 @@ from sqlalchemy.util import ScopedRegistry -from sqlalchemy.orm.mapper import MapperExtension +from sqlalchemy.orm import create_session, object_session, MapperExtension, EXT_PASS __all__ = ['SessionContext', 'SessionContextExt'] @@ -15,16 +15,18 @@ class SessionContext(object): engine = create_engine(...) def session_factory(): - return Session(bind_to=engine) + return Session(bind=engine) context = SessionContext(session_factory) s = context.current # get thread-local session - context.current = Session(bind_to=other_engine) # set current session + context.current = Session(bind=other_engine) # set current session del context.current # discard the thread-local session (a new one will # be created on the next call to context.current) """ - def __init__(self, session_factory, scopefunc=None): + def __init__(self, session_factory=None, scopefunc=None): + if session_factory is None: + session_factory = create_session self.registry = ScopedRegistry(session_factory, scopefunc) super(SessionContext, self).__init__() @@ -60,3 +62,21 @@ class SessionContextExt(MapperExtension): def get_session(self): return self.context.current + + def init_instance(self, mapper, class_, instance, args, kwargs): + session = kwargs.pop('_sa_session', self.context.current) + session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None)) + return EXT_PASS + + def init_failed(self, mapper, class_, instance, args, kwargs): + object_session(instance).expunge(instance) + return EXT_PASS + + def dispose_class(self, mapper, class_): + if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): + if class_.__init__._oldinit is not None: + class_.__init__ = class_.__init__._oldinit + else: + delattr(class_, '__init__') + + diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 04e5b49f7..756b5e1e7 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -310,8 +310,8 @@ Boring tests here. Nothing of real expository value. """ from sqlalchemy import * +from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext -from sqlalchemy.ext.assignmapper import assign_mapper from sqlalchemy.exceptions import * @@ -392,7 +392,7 @@ class SelectableClassType(type): def update(cls, whereclause=None, values=None, **kwargs): _ddl_error(cls) - def _selectable(cls): + def __selectable__(cls): return cls._table def __getattr__(cls, attr): @@ -434,9 +434,7 @@ def _selectable_name(selectable): return x def class_for_table(selectable, **mapper_kwargs): - if not hasattr(selectable, '_selectable') \ - or selectable._selectable() != selectable: - raise ArgumentError('class_for_table requires a selectable as its argument') + selectable = sql._selectable(selectable) mapname = 'Mapped' + _selectable_name(selectable) if isinstance(selectable, Table): klass = TableClassType(mapname, (object,), {}) @@ -520,7 +518,7 @@ class SqlSoup: def with_labels(self, item): # TODO give meaningful aliases - return self.map(item._selectable().select(use_labels=True).alias('foo')) + return self.map(sql._selectable(item).select(use_labels=True).alias('foo')) def join(self, *args, **kwargs): j = join(*args, **kwargs) @@ -539,6 +537,9 @@ class SqlSoup: t = None self._cache[attr] = t return t + + def __repr__(self): + return 'SqlSoup(%r)' % self._metadata if __name__ == '__main__': import doctest diff --git a/lib/sqlalchemy/mods/legacy_session.py b/lib/sqlalchemy/mods/legacy_session.py deleted file mode 100644 index e21a5634b..000000000 --- a/lib/sqlalchemy/mods/legacy_session.py +++ /dev/null @@ -1,176 +0,0 @@ -"""A plugin that emulates 0.1 Session behavior.""" - -import sqlalchemy.orm.objectstore as objectstore -import sqlalchemy.orm.unitofwork as unitofwork -import sqlalchemy.util as util -import sqlalchemy - -import sqlalchemy.mods.threadlocal - -class LegacySession(objectstore.Session): - def __init__(self, nest_on=None, hash_key=None, **kwargs): - super(LegacySession, self).__init__(**kwargs) - self.parent_uow = None - self.begin_count = 0 - self.nest_on = util.to_list(nest_on) - self.__pushed_count = 0 - - def was_pushed(self): - if self.nest_on is None: - return - self.__pushed_count += 1 - if self.__pushed_count == 1: - for n in self.nest_on: - n.push_session() - - def was_popped(self): - if self.nest_on is None or self.__pushed_count == 0: - return - self.__pushed_count -= 1 - if self.__pushed_count == 0: - for n in self.nest_on: - n.pop_session() - - class SessionTrans(object): - """Returned by ``Session.begin()``, denotes a - transactionalized UnitOfWork instance. Call ``commit()` - on this to commit the transaction. - """ - - def __init__(self, parent, uow, isactive): - self.__parent = parent - self.__isactive = isactive - self.__uow = uow - - isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.") - parent = property(lambda s:s.__parent, doc="The parent Session of this SessionTrans object.") - uow = property(lambda s:s.__uow, doc="The parent UnitOfWork corresponding to this transaction.") - - def begin(self): - """Call ``begin()`` on the underlying ``Session`` object, - returning a new no-op ``SessionTrans`` object. - """ - - if self.parent.uow is not self.uow: - raise InvalidRequestError("This SessionTrans is no longer valid") - return self.parent.begin() - - def commit(self): - """Commit the transaction noted by this ``SessionTrans`` object.""" - - self.__parent._trans_commit(self) - self.__isactive = False - - def rollback(self): - """Roll back the current UnitOfWork transaction, in the - case that ``begin()`` has been called. - - The changes logged since the begin() call are discarded. - """ - - self.__parent._trans_rollback(self) - self.__isactive = False - - def begin(self): - """Begin a new UnitOfWork transaction and return a - transaction-holding object. - - ``commit()`` or ``rollback()`` should be called on the returned object. - - ``commit()`` on the ``Session`` will do nothing while a - transaction is pending, and further calls to ``begin()`` will - return no-op transactional objects. - """ - - if self.parent_uow is not None: - return LegacySession.SessionTrans(self, self.uow, False) - self.parent_uow = self.uow - self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map) - return LegacySession.SessionTrans(self, self.uow, True) - - def commit(self, *objects): - """Commit the current UnitOfWork transaction. - - Called with no arguments, this is only used for *implicit* - transactions when there was no ``begin()``. - - If individual objects are submitted, then only those objects - are committed, and the begin/commit cycle is not affected. - """ - - # if an object list is given, commit just those but dont - # change begin/commit status - if len(objects): - self._commit_uow(*objects) - self.uow.flush(self, *objects) - return - if self.parent_uow is None: - self._commit_uow() - - def _trans_commit(self, trans): - if trans.uow is self.uow and trans.isactive: - try: - self._commit_uow() - finally: - self.uow = self.parent_uow - self.parent_uow = None - - def _trans_rollback(self, trans): - if trans.uow is self.uow: - self.uow = self.parent_uow - self.parent_uow = None - - def _commit_uow(self, *obj): - self.was_pushed() - try: - self.uow.flush(self, *obj) - finally: - self.was_popped() - -def begin(): - """Deprecated. Use ``s = Session(new_imap=False)``.""" - - return objectstore.get_session().begin() - -def commit(*obj): - """Deprecated. Use ``flush(*obj)``.""" - - objectstore.get_session().flush(*obj) - -def uow(): - return objectstore.get_session() - -def push_session(sess): - old = get_session() - if getattr(sess, '_previous', None) is not None: - raise InvalidRequestError("Given Session is already pushed onto some thread's stack") - sess._previous = old - session_registry.set(sess) - sess.was_pushed() - -def pop_session(): - sess = get_session() - old = sess._previous - sess._previous = None - session_registry.set(old) - sess.was_popped() - return old - -def using_session(sess, func): - push_session(sess) - try: - return func() - finally: - pop_session() - -def install_plugin(): - objectstore.Session = LegacySession - objectstore.session_registry = util.ScopedRegistry(objectstore.Session) - objectstore.begin = begin - objectstore.commit = commit - objectstore.uow = uow - objectstore.push_session = push_session - objectstore.pop_session = pop_session - objectstore.using_session = using_session - -install_plugin() diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py index ac8de9b06..25bfa2840 100644 --- a/lib/sqlalchemy/mods/selectresults.py +++ b/lib/sqlalchemy/mods/selectresults.py @@ -1,4 +1,4 @@ -from sqlalchemy.ext.selectresults import * +from sqlalchemy.ext.selectresults import SelectResultsExt from sqlalchemy.orm.mapper import global_extensions def install_plugin(): diff --git a/lib/sqlalchemy/mods/threadlocal.py b/lib/sqlalchemy/mods/threadlocal.py deleted file mode 100644 index c8043bc62..000000000 --- a/lib/sqlalchemy/mods/threadlocal.py +++ /dev/null @@ -1,53 +0,0 @@ -"""This plugin installs thread-local behavior at the ``Engine`` and ``Session`` level. - -The default ``Engine`` strategy will be *threadlocal*, producing -``TLocalEngine`` instances for create_engine by default. - -With this engine, ``connect()`` method will return the same connection -on the same thread, if it is already checked out from the pool. This -greatly helps functions that call multiple statements to be able to -easily use just one connection without explicit ``close`` statements -on result handles. - -On the ``Session`` side, module-level methods will be installed within -the objectstore module, such as ``flush()``, ``delete()``, etc. which -call this method on the thread-local session. - -Note: this mod creates a global, thread-local session context named -``sqlalchemy.objectstore``. All mappers created while this mod is -installed will reference this global context when creating new mapped -object instances. -""" - -from sqlalchemy import util, engine, mapper -from sqlalchemy.ext.sessioncontext import SessionContext -import sqlalchemy.ext.assignmapper as assignmapper -from sqlalchemy.orm.mapper import global_extensions -from sqlalchemy.orm.session import Session -import sqlalchemy -import sys, types - -__all__ = ['Objectstore', 'assign_mapper'] - -class Objectstore(object): - def __init__(self, *args, **kwargs): - self.context = SessionContext(*args, **kwargs) - def __getattr__(self, name): - return getattr(self.context.current, name) - session = property(lambda s:s.context.current) - -def assign_mapper(class_, *args, **kwargs): - assignmapper.assign_mapper(objectstore.context, class_, *args, **kwargs) - -objectstore = Objectstore(Session) -def install_plugin(): - sqlalchemy.objectstore = objectstore - global_extensions.append(objectstore.context.mapper_extension) - engine.default_strategy = 'threadlocal' - sqlalchemy.assign_mapper = assign_mapper - -def uninstall_plugin(): - engine.default_strategy = 'plain' - global_extensions.remove(objectstore.context.mapper_extension) - -install_plugin() diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 7ef2da897..1982a94f7 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -11,58 +11,229 @@ packages and tying operations to class properties and constructors. from sqlalchemy import exceptions from sqlalchemy import util as sautil -from sqlalchemy.orm.mapper import * +from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, mapper_registry +from sqlalchemy.orm.interfaces import SynonymProperty, MapperExtension, EXT_PASS, ExtensionOption, PropComparator +from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty, CompositeProperty, BackRef from sqlalchemy.orm import mapper as mapperlib +from sqlalchemy.orm import collections, strategies from sqlalchemy.orm.query import Query from sqlalchemy.orm.util import polymorphic_union -from sqlalchemy.orm import properties, strategies, interfaces from sqlalchemy.orm.session import Session as create_session from sqlalchemy.orm.session import object_session, attribute_manager -__all__ = ['relation', 'column_property', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', 'extension', - 'mapper', 'clear_mappers', 'compile_mappers', 'clear_mapper', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query', - 'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS', 'object_session' - ] +__all__ = ['relation', 'column_property', 'composite', 'backref', 'eagerload', + 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', + 'undefer_group', 'extension', 'mapper', 'clear_mappers', + 'compile_mappers', 'class_mapper', 'object_mapper', + 'MapperExtension', 'Query', 'polymorphic_union', 'create_session', + 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS', + 'object_session', 'PropComparator' + ] -def relation(*args, **kwargs): +def relation(argument, secondary=None, **kwargs): """Provide a relationship of a primary Mapper to a secondary Mapper. - This corresponds to a parent-child or associative table relationship. + This corresponds to a parent-child or associative table relationship. + The constructed class is an instance of [sqlalchemy.orm.properties#PropertyLoader]. + + argument + a class or Mapper instance, representing the target of the relation. + + secondary + for a many-to-many relationship, specifies the intermediary table. The + ``secondary`` keyword argument should generally only be used for a table + that is not otherwise expressed in any class mapping. In particular, + using the Association Object Pattern is + generally mutually exclusive against using the ``secondary`` keyword + argument. + + \**kwargs follow: + + association + Deprecated; as of version 0.3.0 the association keyword is synonomous + with applying the "all, delete-orphan" cascade to a "one-to-many" + relationship. SA can now automatically reconcile a "delete" and + "insert" operation of two objects with the same "identity" in a flush() + operation into a single "update" statement, which is the pattern that + "association" used to indicate. + + backref + indicates the name of a property to be placed on the related mapper's + class that will handle this relationship in the other direction, + including synchronizing the object attributes on both sides of the + relation. Can also point to a ``backref()`` construct for more + configurability. + + cascade + a string list of cascade rules which determines how persistence + operations should be "cascaded" from parent to child. + + collection_class + a class or function that returns a new list-holding object. will be + used in place of a plain list for storing elements. + + foreign_keys + a list of columns which are to be used as "foreign key" columns. + this parameter should be used in conjunction with explicit + ``primaryjoin`` and ``secondaryjoin`` (if needed) arguments, and the + columns within the ``foreign_keys`` list should be present within + those join conditions. Normally, ``relation()`` will inspect the + columns within the join conditions to determine which columns are + the "foreign key" columns, based on information in the ``Table`` + metadata. Use this argument when no ForeignKey's are present in the + join condition, or to override the table-defined foreign keys. + + foreignkey + deprecated. use the ``foreign_keys`` argument for foreign key + specification, or ``remote_side`` for "directional" logic. + + lazy=True + specifies how the related items should be loaded. a value of True + indicates they should be loaded lazily when the property is first + accessed. A value of False indicates they should be loaded by joining + against the parent object query, so parent and child are loaded in one + round trip (i.e. eagerly). A value of None indicates the related items + are not loaded by the mapper in any case; the application will manually + insert items into the list in some other way. In all cases, items added + or removed to the parent object's collection (or scalar attribute) will + cause the appropriate updates and deletes upon flush(), i.e. this + option only affects load operations, not save operations. + + order_by + indicates the ordering that should be applied when loading these items. + + passive_deletes=False + Indicates if lazy-loaders should not be executed during the ``flush()`` + process, which normally occurs in order to locate all existing child + items when a parent item is to be deleted. Setting this flag to True is + appropriate when ``ON DELETE CASCADE`` rules have been set up on the + actual tables so that the database may handle cascading deletes + automatically. This strategy is useful particularly for handling the + deletion of objects that have very large (and/or deep) child-object + collections. + + post_update + this indicates that the relationship should be handled by a second + UPDATE statement after an INSERT or before a DELETE. Currently, it also + will issue an UPDATE after the instance was UPDATEd as well, although + this technically should be improved. This flag is used to handle saving + bi-directional dependencies between two individual rows (i.e. each row + references the other), where it would otherwise be impossible to INSERT + or DELETE both rows fully since one row exists before the other. Use + this flag when a particular mapping arrangement will incur two rows + that are dependent on each other, such as a table that has a + one-to-many relationship to a set of child rows, and also has a column + that references a single child row within that list (i.e. both tables + contain a foreign key to each other). If a ``flush()`` operation returns + an error that a "cyclical dependency" was detected, this is a cue that + you might want to use ``post_update`` to "break" the cycle. + + primaryjoin + a ClauseElement that will be used as the primary join of this child + object against the parent object, or in a many-to-many relationship the + join of the primary object to the association table. By default, this + value is computed based on the foreign key relationships of the parent + and child tables (or association table). + + private=False + deprecated. setting ``private=True`` is the equivalent of setting + ``cascade="all, delete-orphan"``, and indicates the lifecycle of child + objects should be contained within that of the parent. + + remote_side + used for self-referential relationships, indicates the column or list + of columns that form the "remote side" of the relationship. + + secondaryjoin + a ClauseElement that will be used as the join of an association table + to the child object. By default, this value is computed based on the + foreign key relationships of the association and child tables. + + uselist=(True|False) + a boolean that indicates if this property should be loaded as a list or + a scalar. In most cases, this value is determined automatically by + ``relation()``, based on the type and direction of the relationship - one + to many forms a list, many to one forms a scalar, many to many is a + list. If a scalar is desired where normally a list would be present, + such as a bi-directional one-to-one relationship, set uselist to False. + + viewonly=False + when set to True, the relation is used only for loading objects within + the relationship, and has no effect on the unit-of-work flush process. + Relations with viewonly can specify any kind of join conditions to + provide additional views of related objects onto a parent object. Note + that the functionality of a viewonly relationship has its limits - + complicated join conditions may not compile into eager or lazy loaders + properly. If this is the case, use an alternative method. + """ - if len(args) > 1 and isinstance(args[0], type): - raise exceptions.ArgumentError("relation(class, table, **kwargs) is deprecated. Please use relation(class, **kwargs) or relation(mapper, **kwargs).") - return _relation_loader(*args, **kwargs) + return PropertyLoader(argument, secondary=secondary, **kwargs) + +# return _relation_loader(argument, secondary=secondary, **kwargs) + +#def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs): def column_property(*args, **kwargs): """Provide a column-level property for use with a Mapper. + + Column-based properties can normally be applied to the mapper's + ``properties`` dictionary using the ``schema.Column`` element directly. + Use this function when the given column is not directly present within + the mapper's selectable; examples include SQL expressions, functions, + and scalar SELECT queries. + + Columns that arent present in the mapper's selectable won't be persisted + by the mapper and are effectively "read-only" attributes. + + \*cols + list of Column objects to be mapped. - Normally, custom column-level properties that represent columns - directly or indirectly present within the mapped selectable - can just be added to the ``properties`` dictionary directly, - in which case this function's usage is not necessary. - - In the case of a ``ColumnElement`` directly present within the - ``properties`` dictionary, the given column is converted to be the exact column - located within the mapped selectable, in the case that the mapped selectable - is not the exact parent selectable of the given column, but shares a common - base table relationship with that column. + group + a group name for this property when marked as deferred. + + deferred + when True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + [sqlalchemy.orm#deferred()]. + + """ - Use this function when the column expression being added does not - correspond to any single column within the mapped selectable, - such as a labeled function or scalar-returning subquery, to force the element - to become a mapped property regardless of it not being present within the - mapped selectable. + return ColumnProperty(*args, **kwargs) + +def composite(class_, *cols, **kwargs): + """Return a composite column-based property for use with a Mapper. + + This is very much like a column-based property except the given class + is used to construct values composed of one or more columns. The class must + implement a constructor with positional arguments matching the order of + columns given, as well as a __colset__() method which returns its attributes + in column order. - Note that persistence of instances is driven from the collection of columns - within the mapped selectable, so column properties attached to a Mapper which have - no direct correspondence to the mapped selectable will effectively be non-persisted - attributes. + class\_ + the "composite type" class. + + \*cols + list of Column objects to be mapped. + + group + a group name for this property when marked as deferred. + + deferred + when True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + [sqlalchemy.orm#deferred()]. + + comparator + an optional instance of [sqlalchemy.orm#PropComparator] which + provides SQL expression generation functions for this composite + type. """ - return properties.ColumnProperty(*args, **kwargs) -def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs): - return properties.PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, lazy=lazy, **kwargs) + return CompositeProperty(class_, *cols, **kwargs) + def backref(name, **kwargs): """Create a BackRef object with explicit arguments, which are the same arguments one @@ -72,7 +243,7 @@ def backref(name, **kwargs): place of a string argument. """ - return properties.BackRef(name, **kwargs) + return BackRef(name, **kwargs) def deferred(*columns, **kwargs): """Return a ``DeferredColumnProperty``, which indicates this @@ -82,15 +253,141 @@ def deferred(*columns, **kwargs): Used with the `properties` dictionary sent to ``mapper()``. """ - return properties.ColumnProperty(deferred=True, *columns, **kwargs) - -def mapper(class_, table=None, *args, **params): - """Return a new ``Mapper`` object. - - See the ``Mapper`` class for a description of arguments. + return ColumnProperty(deferred=True, *columns, **kwargs) + +def mapper(class_, local_table=None, *args, **params): + """Return a new [sqlalchemy.orm#Mapper] object. + + class\_ + The class to be mapped. + + local_table + The table to which the class is mapped, or None if this + mapper inherits from another mapper using concrete table + inheritance. + + entity_name + A name to be associated with the `class`, to allow alternate + mappings for a single class. + + always_refresh + If True, all query operations for this mapped class will + overwrite all data within object instances that already + exist within the session, erasing any in-memory changes with + whatever information was loaded from the database. Usage + of this flag is highly discouraged; as an alternative, + see the method `populate_existing()` on [sqlalchemy.orm.query#Query]. + + allow_column_override + If True, allows the usage of a ``relation()`` which has the + same name as a column in the mapped table. The table column + will no longer be mapped. + + allow_null_pks + Indicates that composite primary keys where one or more (but + not all) columns contain NULL is a valid primary key. + Primary keys which contain NULL values usually indicate that + a result row does not contain an entity and should be + skipped. + + batch + Indicates that save operations of multiple entities can be + batched together for efficiency. setting to False indicates + that an instance will be fully saved before saving the next + instance, which includes inserting/updating all table rows + corresponding to the entity as well as calling all + ``MapperExtension`` methods corresponding to the save + operation. + + column_prefix + A string which will be prepended to the `key` name of all + Columns when creating column-based properties from the given + Table. Does not affect explicitly specified column-based + properties + + concrete + If True, indicates this mapper should use concrete table + inheritance with its parent mapper. + + extension + A [sqlalchemy.orm#MapperExtension] instance or list of + ``MapperExtension`` instances which will be applied to all + operations by this ``Mapper``. + + inherits + Another ``Mapper`` for which this ``Mapper`` will have an + inheritance relationship with. + + inherit_condition + For joined table inheritance, a SQL expression (constructed + ``ClauseElement``) which will define how the two tables are + joined; defaults to a natural join between the two tables. + + order_by + A single ``Column`` or list of ``Columns`` for which + selection operations should use as the default ordering for + entities. Defaults to the OID/ROWID of the table if any, or + the first primary key column of the table. + + non_primary + Construct a ``Mapper`` that will define only the selection + of instances, not their persistence. Any number of non_primary + mappers may be created for a particular class. + + polymorphic_on + Used with mappers in an inheritance relationship, a ``Column`` + which will identify the class/mapper combination to be used + with a particular row. requires the polymorphic_identity + value to be set for all mappers in the inheritance + hierarchy. + + _polymorphic_map + Used internally to propigate the full map of polymorphic + identifiers to surrogate mappers. + + polymorphic_identity + A value which will be stored in the Column denoted by + polymorphic_on, corresponding to the *class identity* of + this mapper. + + polymorphic_fetch + specifies how subclasses mapped through joined-table + inheritance will be fetched. options are 'union', + 'select', and 'deferred'. if the select_table argument + is present, defaults to 'union', otherwise defaults to + 'select'. + + properties + A dictionary mapping the string names of object attributes + to ``MapperProperty`` instances, which define the + persistence behavior of that attribute. Note that the + columns in the mapped table are automatically converted into + ``ColumnProperty`` instances based on the `key` property of + each ``Column`` (although they can be overridden using this + dictionary). + + primary_key + A list of ``Column`` objects which define the *primary key* + to be used against this mapper's selectable unit. This is + normally simply the primary key of the `local_table`, but + can be overridden here. + + select_table + A [sqlalchemy.schema#Table] or any [sqlalchemy.sql#Selectable] + which will be used to select instances of this mapper's class. + usually used to provide polymorphic loading among several + classes in an inheritance hierarchy. + + version_id_col + A ``Column`` which must have an integer type that will be + used to keep a running *version id* of mapped entities in + the database. this is used during save operations to ensure + that no other thread or process has updated the instance + during the lifetime of the entity, else a + ``ConcurrentModificationError`` exception is thrown. """ - return Mapper(class_, table, *args, **params) + return Mapper(class_, local_table, *args, **params) def synonym(name, proxy=False): """Set up `name` as a synonym to another ``MapperProperty``. @@ -98,7 +395,7 @@ def synonym(name, proxy=False): Used with the `properties` dictionary sent to ``mapper()``. """ - return interfaces.SynonymProperty(name, proxy=proxy) + return SynonymProperty(name, proxy=proxy) def compile_mappers(): """Compile all mappers that have been defined. @@ -120,32 +417,13 @@ def clear_mappers(): mapperlib._COMPILE_MUTEX.acquire() try: for mapper in mapper_registry.values(): - attribute_manager.reset_class_managed(mapper.class_) - if hasattr(mapper.class_, 'c'): - del mapper.class_.c + mapper.dispose() mapper_registry.clear() # TODO: either dont use ArgSingleton, or # find a way to clear only ClassKey instances from it sautil.ArgSingleton.instances.clear() finally: mapperlib._COMPILE_MUTEX.release() - -def clear_mapper(m): - """Remove the given mapper from the storage of mappers. - - When a new mapper is created for the previous mapper's class, it - will be used as that classes' new primary mapper. - """ - - mapperlib._COMPILE_MUTEX.acquire() - try: - del mapper_registry[m.class_key] - attribute_manager.reset_class_managed(m.class_) - if hasattr(m.class_, 'c'): - del m.class_.c - m.class_key.dispose() - finally: - mapperlib._COMPILE_MUTEX.release() def extension(ext): """Return a ``MapperOption`` that will insert the given @@ -166,6 +444,22 @@ def eagerload(name): return strategies.EagerLazyOption(name, lazy=False) +def eagerload_all(name): + """Return a ``MapperOption`` that will convert all + properties along the given dot-separated path into an + eager load. + + e.g:: + query.options(eagerload_all('orders.items.keywords'))... + + will set all of 'orders', 'orders.items', and 'orders.items.keywords' + to load in one eager load. + + Used with ``query.options()``. + """ + + return strategies.EagerLazyOption(name, lazy=False, chained=True) + def lazyload(name): """Return a ``MapperOption`` that will convert the property of the given name into a lazy load. @@ -175,6 +469,9 @@ def lazyload(name): return strategies.EagerLazyOption(name, lazy=True) +def fetchmode(name, type): + return strategies.FetchModeOption(name, type) + def noload(name): """Return a ``MapperOption`` that will convert the property of the given name into a non-load. @@ -250,64 +547,11 @@ def undefer(name): return strategies.DeferredOption(name, defer=False) +def undefer_group(name): + """Return a ``MapperOption`` that will convert the given + group of deferred column properties into a non-deferred (regular column) load. -def cascade_mappers(*classes_or_mappers): - """Attempt to create a series of ``relations()`` between mappers - automatically, via introspecting the foreign key relationships of - the underlying tables. - - Given a list of classes and/or mappers, identify the foreign key - relationships between the given mappers or corresponding class - mappers, and create ``relation()`` objects representing those - relationships, including a backreference. Attempt to find the - *secondary* table in a many-to-many relationship as well. - - The names of the relations will be a lowercase version of the - related class. In the case of one-to-many or many-to-many, the - name will be *pluralized*, which currently is based on the English - language (i.e. an 's' or 'es' added to it). - - NOTE: this method usually works poorly, and its usage is generally - not advised. + Used with ``query.options()``. """ - - table_to_mapper = {} - for item in classes_or_mappers: - if isinstance(item, Mapper): - m = item - else: - klass = item - m = class_mapper(klass) - table_to_mapper[m.mapped_table] = m - - def pluralize(name): - # oh crap, do we need locale stuff now - if name[-1] == 's': - return name + "es" - else: - return name + "s" - - for table,mapper in table_to_mapper.iteritems(): - for fk in table.foreign_keys: - if fk.column.table is table: - continue - secondary = None - try: - m2 = table_to_mapper[fk.column.table] - except KeyError: - if len(fk.column.table.primary_key): - continue - for sfk in fk.column.table.foreign_keys: - if sfk.column.table is table: - continue - m2 = table_to_mapper.get(sfk.column.table) - secondary = fk.column.table - if m2 is None: - continue - if secondary: - propname = pluralize(m2.class_.__name__.lower()) - propname2 = pluralize(mapper.class_.__name__.lower()) - else: - propname = m2.class_.__name__.lower() - propname2 = pluralize(mapper.class_.__name__.lower()) - mapper.add_property(propname, relation(m2, secondary=secondary, backref=propname2)) + return strategies.UndeferGroupOption(name) + diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 9f8a04db8..47ff26085 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -4,37 +4,67 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import util -from sqlalchemy.orm import util as orm_util -from sqlalchemy import logging, exceptions import weakref -class InstrumentedAttribute(object): - """A property object that instruments attribute access on object instances. +from sqlalchemy import util +from sqlalchemy.orm import util as orm_util, interfaces, collections +from sqlalchemy.orm.mapper import class_mapper +from sqlalchemy import logging, exceptions - All methods correspond to a single attribute on a particular - class. - """ - PASSIVE_NORESULT = object() +PASSIVE_NORESULT = object() +ATTR_WAS_SET = object() - def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): +class InstrumentedAttribute(interfaces.PropComparator): + """attribute access for instrumented classes.""" + + def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs): + """Construct an InstrumentedAttribute. + + class_ + the class to be instrumented. + + manager + AttributeManager managing this class + + key + string name of the attribute + + callable_ + optional function which generates a callable based on a parent + instance, which produces the "default" values for a scalar or + collection attribute when it's first accessed, if not present already. + + trackparent + if True, attempt to track if an instance has a parent attached to it + via this attribute + + extension + an AttributeExtension object which will receive + set/delete/append/remove/etc. events + + compare_function + a function that compares two values which are normally assignable to this + attribute + + mutable_scalars + if True, the values which are normally assignable to this attribute can mutate, + and need to be compared against a copy of their original contents in order to + detect changes on the parent instance + + comparator + a sql.Comparator to which class-level compare/math events will be sent + + """ + + self.class_ = class_ self.manager = manager self.key = key - self.uselist = uselist self.callable_ = callable_ - self.typecallable= typecallable self.trackparent = trackparent self.mutable_scalars = mutable_scalars - if copy_function is None: - if uselist: - self.copy = lambda x:[y for y in x] - else: - # scalar values are assumed to be immutable unless a copy function - # is passed - self.copy = lambda x:x - else: - self.copy = lambda x:copy_function(x) + self.comparator = comparator + self.copy = None if compare_function is None: self.is_equal = lambda x,y: x == y else: @@ -42,7 +72,7 @@ class InstrumentedAttribute(object): self.extensions = util.to_list(extension or []) def __set__(self, obj, value): - self.set(None, obj, value) + self.set(obj, value, None) def __delete__(self, obj): self.delete(None, obj) @@ -52,17 +82,18 @@ class InstrumentedAttribute(object): return self return self.get(obj) - def check_mutable_modified(self, obj): - if self.mutable_scalars: - h = self.get_history(obj, passive=True) - if h is not None and h.is_modified(): - obj._state['modified'] = True - return True - else: - return False - else: - return False + def clause_element(self): + return self.comparator.clause_element() + + def expression_element(self): + return self.comparator.expression_element() + + def operate(self, op, other, **kwargs): + return op(self.comparator, other, **kwargs) + def reverse_operate(self, op, other, **kwargs): + return op(other, self.comparator, **kwargs) + def hasparent(self, item, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. @@ -98,8 +129,8 @@ class InstrumentedAttribute(object): # get the current state. this may trigger a lazy load if # passive is False. - current = self.get(obj, passive=passive, raiseerr=False) - if current is InstrumentedAttribute.PASSIVE_NORESULT: + current = self.get(obj, passive=passive) + if current is PASSIVE_NORESULT: return None return AttributeHistory(self, obj, current, passive=passive) @@ -123,6 +154,14 @@ class InstrumentedAttribute(object): else: obj._state[('callable', self)] = callable_ + def _get_callable(self, obj): + if ('callable', self) in obj._state: + return obj._state[('callable', self)] + elif self.callable_ is not None: + return self.callable_(obj) + else: + return None + def reset(self, obj): """Remove any per-instance callable functions corresponding to this ``InstrumentedAttribute``'s attribute from the given @@ -148,43 +187,21 @@ class InstrumentedAttribute(object): except KeyError: pass - def _get_callable(self, obj): - if obj._state.has_key(('callable', self)): - return obj._state[('callable', self)] - elif self.callable_ is not None: - return self.callable_(obj) - else: - return None - - def _blank_list(self): - if self.typecallable is not None: - return self.typecallable() - else: - return [] + def check_mutable_modified(self, obj): + return False def initialize(self, obj): - """Initialize this attribute on the given object instance. + """Initialize this attribute on the given object instance with an empty value.""" - If this is a list-based attribute, a new, blank list will be - created. if a scalar attribute, the value will be initialized - to None. - """ - - if self.uselist: - l = InstrumentedList(self, obj, self._blank_list()) - obj.__dict__[self.key] = l - return l - else: - obj.__dict__[self.key] = None - return None + obj.__dict__[self.key] = None + return None - def get(self, obj, passive=False, raiseerr=True): + def get(self, obj, passive=False): """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and passive is False, the callable will be executed and the - resulting value will be set as the new value for this - attribute. + resulting value will be set as the new value for this attribute. """ try: @@ -193,441 +210,301 @@ class InstrumentedAttribute(object): state = obj._state # if an instance-wide "trigger" was set, call that # and start again - if state.has_key('trigger'): + if 'trigger' in state: trig = state['trigger'] del state['trigger'] trig() - return self.get(obj, passive=passive, raiseerr=raiseerr) - - if self.uselist: - callable_ = self._get_callable(obj) - if callable_ is not None: - if passive: - return InstrumentedAttribute.PASSIVE_NORESULT - self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key)) - values = callable_() - l = InstrumentedList(self, obj, values, init=False) - - # if a callable was executed, then its part of the "committed state" - # if any, so commit the newly loaded data - orig = state.get('original', None) - if orig is not None: - orig.commit_attribute(self, obj, l) - + return self.get(obj, passive=passive) + + callable_ = self._get_callable(obj) + if callable_ is not None: + if passive: + return PASSIVE_NORESULT + self.logger.debug("Executing lazy callable on %s.%s" % + (orm_util.instance_str(obj), self.key)) + value = callable_() + if value is not ATTR_WAS_SET: + return self.set_committed_value(obj, value) else: - # note that we arent raising AttributeErrors, just creating a new - # blank list and setting it. - # this might be a good thing to be changeable by options. - l = InstrumentedList(self, obj, self._blank_list(), init=False) - obj.__dict__[self.key] = l - return l - else: - callable_ = self._get_callable(obj) - if callable_ is not None: - if passive: - return InstrumentedAttribute.PASSIVE_NORESULT - self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key)) - value = callable_() - obj.__dict__[self.key] = value - - # if a callable was executed, then its part of the "committed state" - # if any, so commit the newly loaded data - orig = state.get('original', None) - if orig is not None: - orig.commit_attribute(self, obj) - return value - else: - # note that we arent raising AttributeErrors, just returning None. - # this might be a good thing to be changeable by options. - return None - - def set(self, event, obj, value): - """Set a value on the given object. - - `event` is the ``InstrumentedAttribute`` that initiated the - ``set()` operation and is used to control the depth of a - circular setter operation. - """ - - if event is not self: - state = obj._state - # if an instance-wide "trigger" was set, call that - if state.has_key('trigger'): - trig = state['trigger'] - del state['trigger'] - trig() - if self.uselist: - value = InstrumentedList(self, obj, value) - old = self.get(obj) - obj.__dict__[self.key] = value - state['modified'] = True - if not self.uselist: - if self.trackparent: - if value is not None: - self.sethasparent(value, True) - if old is not None: - self.sethasparent(old, False) - for ext in self.extensions: - ext.set(event or self, obj, value, old) + return obj.__dict__[self.key] else: - # mark all the old elements as detached from the parent - old.list_replaced() + # Return a new, empty value + return self.initialize(obj) - def delete(self, event, obj): - """Delete a value from the given object. + def append(self, obj, value, initiator): + self.set(obj, value, initiator) - `event` is the ``InstrumentedAttribute`` that initiated the - ``delete()`` operation and is used to control the depth of a - circular delete operation. - """ - - if event is not self: - try: - if not self.uselist and (self.trackparent or len(self.extensions)): - old = self.get(obj) - del obj.__dict__[self.key] - except KeyError: - # TODO: raise this? not consistent with get() ? - raise AttributeError(self.key) - obj._state['modified'] = True - if not self.uselist: - if self.trackparent: - if old is not None: - self.sethasparent(old, False) - for ext in self.extensions: - ext.delete(event or self, obj, old) - - def append(self, event, obj, value): - """Append an element to a list based element or sets a scalar - based element to the given value. - - Used by ``GenericBackrefExtension`` to *append* an item - independent of list/scalar semantics. - - `event` is the ``InstrumentedAttribute`` that initiated the - ``append()`` operation and is used to control the depth of a - circular append operation. - """ + def remove(self, obj, value, initiator): + self.set(obj, None, initiator) - if self.uselist: - if event is not self: - self.get(obj).append_with_event(value, event) - else: - self.set(event, obj, value) - - def remove(self, event, obj, value): - """Remove an element from a list based element or sets a - scalar based element to None. - - Used by ``GenericBackrefExtension`` to *remove* an item - independent of list/scalar semantics. + def set(self, obj, value, initiator): + raise NotImplementedError() - `event` is the ``InstrumentedAttribute`` that initiated the - ``remove()`` operation and is used to control the depth of a - circular remove operation. + def set_committed_value(self, obj, value): + """set an attribute value on the given instance and 'commit' it. + + this indicates that the given value is the "persisted" value, + and history will be logged only if a newly set value is not + equal to this value. + + this is typically used by deferred/lazy attribute loaders + to set object attributes after the initial load. """ - if self.uselist: - if event is not self: - self.get(obj).remove_with_event(value, event) - else: - self.set(event, obj, None) + state = obj._state + orig = state.get('original', None) + if orig is not None: + orig.commit_attribute(self, obj, value) + # remove per-instance callable, if any + state.pop(('callable', self), None) + obj.__dict__[self.key] = value + return value - def append_event(self, event, obj, value): - """Called by ``InstrumentedList`` when an item is appended.""" + def set_raw_value(self, obj, value): + obj.__dict__[self.key] = value + return value + def fire_append_event(self, obj, value, initiator): obj._state['modified'] = True if self.trackparent and value is not None: self.sethasparent(value, True) for ext in self.extensions: - ext.append(event or self, obj, value) - - def remove_event(self, event, obj, value): - """Called by ``InstrumentedList`` when an item is removed.""" + ext.append(obj, value, initiator or self) + def fire_remove_event(self, obj, value, initiator): obj._state['modified'] = True if self.trackparent and value is not None: self.sethasparent(value, False) for ext in self.extensions: - ext.delete(event or self, obj, value) + ext.remove(obj, value, initiator or self) + + def fire_replace_event(self, obj, value, previous, initiator): + obj._state['modified'] = True + if self.trackparent: + if value is not None: + self.sethasparent(value, True) + if previous is not None: + self.sethasparent(previous, False) + for ext in self.extensions: + ext.set(obj, value, previous, initiator or self) + + property = property(lambda s: class_mapper(s.class_).get_property(s.key), + doc="the MapperProperty object associated with this attribute") InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) + +class InstrumentedScalarAttribute(InstrumentedAttribute): + """represents a scalar-holding InstrumentedAttribute.""" -class InstrumentedList(object): - """Instrument a list-based attribute. - - All mutator operations (i.e. append, remove, etc.) will fire off - events to the ``InstrumentedAttribute`` that manages the object's - attribute. Those events in turn trigger things like backref - operations and whatever is implemented by - ``do_list_value_changed`` on ``InstrumentedAttribute``. - - Note that this list does a lot less than earlier versions of SA - list-based attributes, which used ``HistoryArraySet``. This list - wrapper does **not** maintain setlike semantics, meaning you can add - as many duplicates as you want (which can break a lot of SQL), and - also does not do anything related to history tracking. - - Please see ticket #213 for information on the future of this - class, where it will be broken out into more collection-specific - subtypes. - """ + def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): + super(InstrumentedScalarAttribute, self).__init__(class_, manager, key, + callable_, trackparent=trackparent, extension=extension, + compare_function=compare_function, **kwargs) + self.mutable_scalars = mutable_scalars - def __init__(self, attr, obj, data, init=True): - self.attr = attr - # this weakref is to prevent circular references between the parent object - # and the list attribute, which interferes with immediate garbage collection. - self.__obj = weakref.ref(obj) - self.key = attr.key - - # adapt to lists or sets - # TODO: make three subclasses of InstrumentedList that come off from a - # metaclass, based on the type of data sent in - if attr.typecallable is not None: - self.data = attr.typecallable() - else: - self.data = data or attr._blank_list() - - if isinstance(self.data, list): - self._data_appender = self.data.append - self._clear_data = self._clear_list - elif isinstance(self.data, util.Set): - self._data_appender = self.data.add - self._clear_data = self._clear_set - elif isinstance(self.data, dict): - if hasattr(self.data, 'append'): - self._data_appender = self.data.append - else: - raise exceptions.ArgumentError("Dictionary collection class '%s' must implement an append() method" % type(self.data).__name__) - self._clear_data = self._clear_dict - else: - if hasattr(self.data, 'append'): - self._data_appender = self.data.append - elif hasattr(self.data, 'add'): - self._data_appender = self.data.add - else: - raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no append() or add() method" % type(self.data).__name__) + if copy_function is None: + copy_function = self.__copy + self.copy = copy_function - if hasattr(self.data, 'clear'): - self._clear_data = self._clear_set - else: - raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no clear() method" % type(self.data).__name__) - - if data is not None and data is not self.data: - for elem in data: - self._data_appender(elem) - + def __copy(self, item): + # scalar values are assumed to be immutable unless a copy function + # is passed + return item - if init: - for x in self.data: - self.__setrecord(x) + def __delete__(self, obj): + old = self.get(obj) + del obj.__dict__[self.key] + self.fire_remove_event(obj, old, self) - def list_replaced(self): - """Fire off delete event handlers for each item in the list - but doesnt affect the original data list. - """ + def check_mutable_modified(self, obj): + if self.mutable_scalars: + h = self.get_history(obj, passive=True) + if h is not None and h.is_modified(): + obj._state['modified'] = True + return True + else: + return False + else: + return False - [self.__delrecord(x) for x in self.data] + def set(self, obj, value, initiator): + """Set a value on the given object. - def clear(self): - """Clear all items in this InstrumentedList and fires off - delete event handlers for each item. + `initiator` is the ``InstrumentedAttribute`` that initiated the + ``set()` operation and is used to control the depth of a circular + setter operation. """ - self._clear_data() - - def _clear_dict(self): - [self.__delrecord(x) for x in self.data.values()] - self.data.clear() - - def _clear_set(self): - [self.__delrecord(x) for x in self.data] - self.data.clear() - - def _clear_list(self): - self[:] = [] - - def __getstate__(self): - """Implemented to allow pickling, since `__obj` is a weakref, - also the ``InstrumentedAttribute`` has callables attached to - it. - """ + if initiator is self: + return - return {'key':self.key, 'obj':self.obj, 'data':self.data} + state = obj._state + # if an instance-wide "trigger" was set, call that + if 'trigger' in state: + trig = state['trigger'] + del state['trigger'] + trig() - def __setstate__(self, d): - """Implemented to allow pickling, since `__obj` is a weakref, - also the ``InstrumentedAttribute`` has callables attached to it. - """ + old = self.get(obj) + obj.__dict__[self.key] = value + self.fire_replace_event(obj, value, old, initiator) - self.key = d['key'] - self.__obj = weakref.ref(d['obj']) - self.data = d['data'] - self.attr = getattr(d['obj'].__class__, self.key) + type = property(lambda self: self.property.columns[0].type) - obj = property(lambda s:s.__obj()) + +class InstrumentedCollectionAttribute(InstrumentedAttribute): + """A collection-holding attribute that instruments changes in membership. - def unchanged_items(self): - """Deprecated.""" + InstrumentedCollectionAttribute holds an arbitrary, user-specified + container object (defaulting to a list) and brokers access to the + CollectionAdapter, a "view" onto that object that presents consistent + bag semantics to the orm layer independent of the user data implementation. + """ + + def __init__(self, class_, manager, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): + super(InstrumentedCollectionAttribute, self).__init__(class_, manager, + key, callable_, trackparent=trackparent, extension=extension, + compare_function=compare_function, **kwargs) - return self.attr.get_history(self.obj).unchanged_items + if copy_function is None: + copy_function = self.__copy + self.copy = copy_function - def added_items(self): - """Deprecated.""" + if typecallable is None: + typecallable = list + self.collection_factory = \ + collections._prepare_instrumentation(typecallable) + self.collection_interface = \ + util.duck_type_collection(self.collection_factory()) - return self.attr.get_history(self.obj).added_items + def __copy(self, item): + return [y for y in list(collections.collection_adapter(item))] - def deleted_items(self): - """Deprecated.""" + def __set__(self, obj, value): + """Replace the current collection with a new one.""" - return self.attr.get_history(self.obj).deleted_items + setting_type = util.duck_type_collection(value) - def __iter__(self): - return iter(self.data) + if value is None or setting_type != self.collection_interface: + raise exceptions.ArgumentError( + "Incompatible collection type on assignment: %s is not %s-like" % + (type(value).__name__, self.collection_interface.__name__)) - def __repr__(self): - return repr(self.data) + if hasattr(value, '_sa_adapter'): + self.set(obj, list(getattr(value, '_sa_adapter')), None) + elif setting_type == dict: + self.set(obj, value.values(), None) + else: + self.set(obj, value, None) - def __getattr__(self, attr): - """Proxy unknown methods and attributes to the underlying - data array. This allows custom list classes to be used. - """ + def __delete__(self, obj): + if self.key not in obj.__dict__: + return - return getattr(self.data, attr) + obj._state['modified'] = True - def __setrecord(self, item, event=None): - self.attr.append_event(event, self.obj, item) - return True + collection = self._get_collection(obj) + collection.clear_with_event() + del obj.__dict__[self.key] - def __delrecord(self, item, event=None): - self.attr.remove_event(event, self.obj, item) - return True + def initialize(self, obj): + """Initialize this attribute on the given object instance with an empty collection.""" - def append_with_event(self, item, event): - self.__setrecord(item, event) - self._data_appender(item) + _, user_data = self._build_collection(obj) + obj.__dict__[self.key] = user_data + return user_data - def append_without_event(self, item): - self._data_appender(item) + def append(self, obj, value, initiator): + if initiator is self: + return + collection = self._get_collection(obj) + collection.append_with_event(value, initiator) - def remove_with_event(self, item, event): - self.__delrecord(item, event) - self.data.remove(item) + def remove(self, obj, value, initiator): + if initiator is self: + return + collection = self._get_collection(obj) + collection.remove_with_event(value, initiator) - def append(self, item, _mapper_nohistory=False): - """Fire off dependent events, and appends the given item to the underlying list. + def set(self, obj, value, initiator): + """Set a value on the given object. - `_mapper_nohistory` is a backwards compatibility hack; call - ``append_without_event`` instead. + `initiator` is the ``InstrumentedAttribute`` that initiated the + ``set()` operation and is used to control the depth of a circular + setter operation. """ - if _mapper_nohistory: - self.append_without_event(item) - else: - self.__setrecord(item) - self._data_appender(item) - - def __getitem__(self, i): - return self.data[i] - - def __setitem__(self, i, item): - if isinstance(i, slice): - self.__setslice__(i.start, i.stop, item) - else: - self.__setrecord(item) - self.data[i] = item - - def __delitem__(self, i): - if isinstance(i, slice): - self.__delslice__(i.start, i.stop) - else: - self.__delrecord(self.data[i], None) - del self.data[i] - - def __lt__(self, other): return self.data < self.__cast(other) - - def __le__(self, other): return self.data <= self.__cast(other) + if initiator is self: + return - def __eq__(self, other): return self.data == self.__cast(other) + state = obj._state + # if an instance-wide "trigger" was set, call that + if 'trigger' in state: + trig = state['trigger'] + del state['trigger'] + trig() - def __ne__(self, other): return self.data != self.__cast(other) + old = self.get(obj) + old_collection = self._get_collection(obj, old) - def __gt__(self, other): return self.data > self.__cast(other) + new_collection, user_data = self._build_collection(obj) + self._load_collection(obj, value or [], emit_events=True, + collection=new_collection) - def __ge__(self, other): return self.data >= self.__cast(other) + obj.__dict__[self.key] = user_data + state['modified'] = True - def __cast(self, other): - if isinstance(other, InstrumentedList): return other.data - else: return other + # mark all the old elements as detached from the parent + if old_collection: + old_collection.clear_with_event() + old_collection.unlink(old) - def __cmp__(self, other): - return cmp(self.data, self.__cast(other)) + def set_committed_value(self, obj, value): + """Set an attribute value on the given instance and 'commit' it.""" + + state = obj._state + orig = state.get('original', None) - def __contains__(self, item): return item in self.data + collection, user_data = self._build_collection(obj) + self._load_collection(obj, value or [], emit_events=False, + collection=collection) + value = user_data - def __len__(self): + if orig is not None: + orig.commit_attribute(self, obj, value) + # remove per-instance callable, if any + state.pop(('callable', self), None) + obj.__dict__[self.key] = value + return value + + def _build_collection(self, obj): + user_data = self.collection_factory() + collection = collections.CollectionAdapter(self, obj, user_data) + return collection, user_data + + def _load_collection(self, obj, values, emit_events=True, collection=None): + collection = collection or self._get_collection(obj) + if values is None: + return + elif emit_events: + for item in values: + collection.append_with_event(item) + else: + for item in values: + collection.append_without_event(item) + + def _get_collection(self, obj, user_data=None): + if user_data is None: + user_data = self.get(obj) try: - return len(self.data) - except TypeError: - return len(list(self.data)) - - def __setslice__(self, i, j, other): - [self.__delrecord(x) for x in self.data[i:j]] - g = [a for a in list(other) if self.__setrecord(a)] - self.data[i:j] = g - - def __delslice__(self, i, j): - for a in self.data[i:j]: - self.__delrecord(a) - del self.data[i:j] - - def insert(self, i, item): - if self.__setrecord(item): - self.data.insert(i, item) - - def pop(self, i=-1): - item = self.data[i] - self.__delrecord(item) - return self.data.pop(i) - - def remove(self, item): - self.__delrecord(item) - self.data.remove(item) - - def discard(self, item): - if item in self.data: - self.__delrecord(item) - self.data.remove(item) - - def extend(self, item_list): - for item in item_list: - self.append(item) - - def __add__(self, other): - raise NotImplementedError() - - def __radd__(self, other): - raise NotImplementedError() - - def __iadd__(self, other): - raise NotImplementedError() - -class AttributeExtension(object): - """An abstract class which specifies `append`, `delete`, and `set` - event handlers to be attached to an object property. - """ - - def append(self, event, obj, child): - pass - - def delete(self, event, obj, child): - pass + return getattr(user_data, '_sa_adapter') + except AttributeError: + collections.CollectionAdapter(self, obj, user_data) + return getattr(user_data, '_sa_adapter') - def set(self, event, obj, child, oldchild): - pass -class GenericBackrefExtension(AttributeExtension): +class GenericBackrefExtension(interfaces.AttributeExtension): """An extension which synchronizes a two-way relationship. A typical two-way relationship is a parent object containing a @@ -639,19 +516,19 @@ class GenericBackrefExtension(AttributeExtension): def __init__(self, key): self.key = key - def set(self, event, obj, child, oldchild): + def set(self, obj, child, oldchild, initiator): if oldchild is child: return if oldchild is not None: - getattr(oldchild.__class__, self.key).remove(event, oldchild, obj) + getattr(oldchild.__class__, self.key).remove(oldchild, obj, initiator) if child is not None: - getattr(child.__class__, self.key).append(event, child, obj) + getattr(child.__class__, self.key).append(child, obj, initiator) - def append(self, event, obj, child): - getattr(child.__class__, self.key).append(event, child, obj) + def append(self, obj, child, initiator): + getattr(child.__class__, self.key).append(child, obj, initiator) - def delete(self, event, obj, child): - getattr(child.__class__, self.key).remove(event, child, obj) + def remove(self, obj, child, initiator): + getattr(child.__class__, self.key).remove(child, obj, initiator) class CommittedState(object): """Store the original state of an object when the ``commit()` @@ -673,7 +550,7 @@ class CommittedState(object): """ if value is CommittedState.NO_VALUE: - if obj.__dict__.has_key(attr.key): + if attr.key in obj.__dict__: value = obj.__dict__[attr.key] if value is not CommittedState.NO_VALUE: self.data[attr.key] = attr.copy(value) @@ -690,10 +567,13 @@ class CommittedState(object): def rollback(self, manager, obj): for attr in manager.managed_attributes(obj.__class__): if self.data.has_key(attr.key): - if attr.uselist: - obj.__dict__[attr.key][:] = self.data[attr.key] - else: + if not isinstance(attr, InstrumentedCollectionAttribute): obj.__dict__[attr.key] = self.data[attr.key] + else: + collection = attr._get_collection(obj) + collection.clear_without_event() + for item in self.data[attr.key]: + collection.append_without_event(item) else: del obj.__dict__[attr.key] @@ -718,17 +598,15 @@ class AttributeHistory(object): else: original = None - if attr.uselist: + if isinstance(attr, InstrumentedCollectionAttribute): self._current = current - else: - self._current = [current] - if attr.uselist: s = util.Set(original or []) self._added_items = [] self._unchanged_items = [] self._deleted_items = [] if current: - for a in current: + collection = attr._get_collection(obj, current) + for a in collection: if a in s: self._unchanged_items.append(a) else: @@ -737,6 +615,7 @@ class AttributeHistory(object): if a not in self._unchanged_items: self._deleted_items.append(a) else: + self._current = [current] if attr.is_equal(current, original): self._unchanged_items = [current] self._added_items = [] @@ -748,7 +627,6 @@ class AttributeHistory(object): else: self._deleted_items = [] self._unchanged_items = [] - #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items def __iter__(self): return iter(self._current) @@ -766,24 +644,13 @@ class AttributeHistory(object): return self._deleted_items def hasparent(self, obj): - """Deprecated. This should be called directly from the - appropriate ``InstrumentedAttribute`` object. + """Deprecated. This should be called directly from the appropriate ``InstrumentedAttribute`` object. """ return self.attr.hasparent(obj) class AttributeManager(object): - """Allow the instrumentation of object attributes. - - ``AttributeManager`` is stateless, but can be overridden by - subclasses to redefine some of its factory operations. Also be - aware ``AttributeManager`` will cache attributes for a given - class, allowing not to determine those for each objects (used in - ``managed_attributes()`` and - ``noninherited_managed_attributes()``). This cache is cleared for - a given class while calling ``register_attribute()``, and can be - cleared using ``clear_attribute_cache()``. - """ + """Allow the instrumentation of object attributes.""" def __init__(self): # will cache attributes, indexed by class objects @@ -827,7 +694,7 @@ class AttributeManager(object): o._state['modified'] = False def managed_attributes(self, class_): - """Return an iterator of all ``InstrumentedAttribute`` objects + """Return a list of all ``InstrumentedAttribute`` objects associated with the given class. """ @@ -878,7 +745,7 @@ class AttributeManager(object): """Return an attribute of the given name from the given object. If the attribute is a scalar, return it as a single-item list, - otherwise return the list based attribute. + otherwise return a collection based attribute. If the attribute's value is to be produced by an unexecuted callable, the callable will only be executed if the given @@ -887,10 +754,10 @@ class AttributeManager(object): attr = getattr(obj.__class__, key) x = attr.get(obj, passive=passive) - if x is InstrumentedAttribute.PASSIVE_NORESULT: + if x is PASSIVE_NORESULT: return [] - elif attr.uselist: - return x + elif isinstance(attr, InstrumentedCollectionAttribute): + return list(attr._get_collection(obj, x)) else: return [x] @@ -921,7 +788,7 @@ class AttributeManager(object): by ``trigger_history()``. """ - return obj._state.has_key('trigger') + return 'trigger' in obj._state def reset_instance_attribute(self, obj, key): """Remove any per-instance callable functions corresponding to @@ -946,10 +813,9 @@ class AttributeManager(object): """Return True if the given `key` correponds to an instrumented property on the given class. """ - return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute) - def init_instance_attribute(self, obj, key, uselist, callable_=None, **kwargs): + def init_instance_attribute(self, obj, key, callable_=None): """Initialize an attribute on an instance to either a blank value, cancelling out any class- or instance-level callables that were present, or if a `callable` is supplied set the @@ -964,7 +830,24 @@ class AttributeManager(object): events back to this ``AttributeManager``. """ - return InstrumentedAttribute(self, key, uselist, callable_, typecallable, **kwargs) + if uselist: + return InstrumentedCollectionAttribute(class_, self, key, + callable_, + typecallable, + **kwargs) + else: + return InstrumentedScalarAttribute(class_, self, key, callable_, + **kwargs) + + def get_attribute(self, obj_or_cls, key): + """Register an attribute at the class level to be instrumented + for all instances of the class. + """ + + if isinstance(obj_or_cls, type): + return getattr(obj_or_cls, key) + else: + return getattr(obj_or_cls.__class__, key) def register_attribute(self, class_, key, uselist, callable_=None, **kwargs): """Register an attribute at the class level to be instrumented @@ -973,10 +856,9 @@ class AttributeManager(object): # firt invalidate the cache for the given class # (will be reconstituted as needed, while getting managed attributes) - self._inherited_attribute_cache.pop(class_,None) - self._noninherited_attribute_cache.pop(class_,None) + self._inherited_attribute_cache.pop(class_, None) + self._noninherited_attribute_cache.pop(class_, None) - #print self, "register attribute", key, "for class", class_ if not hasattr(class_, '_state'): def _get_state(self): if not hasattr(self, '_sa_attr_state'): @@ -987,4 +869,12 @@ class AttributeManager(object): typecallable = kwargs.pop('typecallable', None) if isinstance(typecallable, InstrumentedAttribute): typecallable = None - setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs)) + setattr(class_, key, self.create_prop(class_, key, uselist, callable_, + typecallable=typecallable, **kwargs)) + + def init_collection(self, instance, key): + """Initialize a collection attribute and return the collection adapter.""" + + attr = self.get_attribute(instance, key) + user_data = attr.initialize(instance) + return attr._get_collection(instance, user_data) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py new file mode 100644 index 000000000..7ade882f5 --- /dev/null +++ b/lib/sqlalchemy/orm/collections.py @@ -0,0 +1,1182 @@ +"""Support for collections of mapped entities. + +The collections package supplies the machinery used to inform the ORM of +collection membership changes. An instrumentation via decoration approach is +used, allowing arbitrary types (including built-ins) to be used as entity +collections without requiring inheritance from a base class. + +Instrumentation decoration relays membership change events to the +``InstrumentedCollectionAttribute`` that is currently managing the collection. +The decorators observe function call arguments and return values, tracking +entities entering or leaving the collection. Two decorator approaches are +provided. One is a bundle of generic decorators that map function arguments +and return values to events:: + + from sqlalchemy.orm.collections import collection + class MyClass(object): + # ... + + @collection.adds(1) + def store(self, item): + self.data.append(item) + + @collection.removes_return() + def pop(self): + return self.data.pop() + + +The second approach is a bundle of targeted decorators that wrap appropriate +append and remove notifiers around the mutation methods present in the +standard Python ``list``, ``set`` and ``dict`` interfaces. These could be +specified in terms of generic decorator recipes, but are instead hand-tooled for +increased efficiency. The targeted decorators occasionally implement +adapter-like behavior, such as mapping bulk-set methods (``extend``, ``update``, +``__setslice``, etc.) into the series of atomic mutation events that the ORM +requires. + +The targeted decorators are used internally for automatic instrumentation of +entity collection classes. Every collection class goes through a +transformation process roughly like so: + +1. If the class is a built-in, substitute a trivial sub-class +2. Is this class already instrumented? +3. Add in generic decorators +4. Sniff out the collection interface through duck-typing +5. Add targeted decoration to any undecorated interface method + +This process modifies the class at runtime, decorating methods and adding some +bookkeeping properties. This isn't possible (or desirable) for built-in +classes like ``list``, so trivial sub-classes are substituted to hold +decoration:: + + class InstrumentedList(list): + pass + +Collection classes can be specified in ``relation(collection_class=)`` as +types or a function that returns an instance. Collection classes are +inspected and instrumented during the mapper compilation phase. The +collection_class callable will be executed once to produce a specimen +instance, and the type of that specimen will be instrumented. Functions that +return built-in types like ``lists`` will be adapted to produce instrumented +instances. + +When extending a known type like ``list``, additional decorations are not +generally not needed. Odds are, the extension method will delegate to a +method that's already instrumented. For example:: + + class QueueIsh(list): + def push(self, item): + self.append(item) + def shift(self): + return self.pop(0) + +There's no need to decorate these methods. ``append`` and ``pop`` are already +instrumented as part of the ``list`` interface. Decorating them would fire +duplicate events, which should be avoided. + +The targeted decoration tries not to rely on other methods in the underlying +collection class, but some are unavoidable. Many depend on 'read' methods +being present to properly instrument a 'write', for example, ``__setitem__`` +needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also +reimplemented in terms of atomic appends and removes, so the ``extend`` +decoration will actually perform many ``append`` operations and not call the +underlying method at all. + +Tight control over bulk operation and the firing of events is also possible by +implementing the instrumentation internally in your methods. The basic +instrumentation package works under the general assumption that collection +mutation will not raise unusual exceptions. If you want to closely +orchestrate append and remove events with exception management, internal +instrumentation may be the answer. Within your method, +``collection_adapter(self)`` will retrieve an object that you can use for +explicit control over triggering append and remove events. + +The owning object and InstrumentedCollectionAttribute are also reachable +through the adapter, allowing for some very sophisticated behavior. +""" + +import copy, inspect, sys, weakref + +from sqlalchemy import exceptions, schema, util as sautil +from sqlalchemy.orm import mapper + +try: + from threading import Lock +except: + from dummy_threading import Lock +try: + from operator import attrgetter +except: + def attrgetter(attribute): + return lambda value: getattr(value, attribute) + + +__all__ = ['collection', 'collection_adapter', + 'mapped_collection', 'column_mapped_collection', + 'attribute_mapped_collection'] + +def column_mapped_collection(mapping_spec): + """A dictionary-based collection type with column-based keying. + + Returns a MappedCollection factory with a keying function generated + from mapping_spec, which may be a Column or a sequence of Columns. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + """ + + if isinstance(mapping_spec, schema.Column): + def keyfunc(value): + m = mapper.object_mapper(value) + return m.get_attr_by_column(value, mapping_spec) + else: + cols = [] + for c in mapping_spec: + if not isinstance(c, schema.Column): + raise exceptions.ArgumentError( + "mapping_spec tuple may only contain columns") + cols.append(c) + mapping_spec = tuple(cols) + def keyfunc(value): + m = mapper.object_mapper(value) + return tuple([m.get_attr_by_column(value, c) for c in mapping_spec]) + return lambda: MappedCollection(keyfunc) + +def attribute_mapped_collection(attr_name): + """A dictionary-based collection type with attribute-based keying. + + Returns a MappedCollection factory with a keying based on the + 'attr_name' attribute of entities in the collection. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + """ + + return lambda: MappedCollection(attrgetter(attr_name)) + + +def mapped_collection(keyfunc): + """A dictionary-based collection type with arbitrary keying. + + Returns a MappedCollection factory with a keying function generated + from keyfunc, a callable that takes an entity and returns a key value. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + """ + + return lambda: MappedCollection(keyfunc) + +class collection(object): + """Decorators for entity collection classes. + + The decorators fall into two groups: annotations and interception recipes. + + The annotating decorators (appender, remover, iterator, + internally_instrumented, on_link) indicate the method's purpose and take no + arguments. They are not written with parens:: + + @collection.appender + def append(self, append): ... + + The recipe decorators all require parens, even those that take no + arguments:: + + @collection.adds('entity'): + def insert(self, position, entity): ... + + @collection.removes_return() + def popitem(self): ... + + Decorators can be specified in long-hand for Python 2.3, or with + the class-level dict attribute '__instrumentation__'- see the source + for details. + """ + + # Bundled as a class solely for ease of use: packaging, doc strings, + # importability. + + def appender(cls, fn): + """Tag the method as the collection appender. + + The appender method is called with one positional argument: the value + to append. The method will be automatically decorated with 'adds(1)' + if not already decorated:: + + @collection.appender + def add(self, append): ... + + # or, equivalently + @collection.appender + @collection.adds(1) + def add(self, append): ... + + # for mapping type, an 'append' may kick out a previous value + # that occupies that slot. consider d['a'] = 'foo'- any previous + # value in d['a'] is discarded. + @collection.appender + @collection.replaces(1) + def add(self, entity): + key = some_key_func(entity) + previous = None + if key in self: + previous = self[key] + self[key] = entity + return previous + + If the value to append is not allowed in the collection, you may + raise an exception. Something to remember is that the appender + will be called for each object mapped by a database query. If the + database contains rows that violate your collection semantics, you + will need to get creative to fix the problem, as access via the + collection will not work. + + If the appender method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + """ + + setattr(fn, '_sa_instrument_role', 'appender') + return fn + appender = classmethod(appender) + + def remover(cls, fn): + """Tag the method as the collection remover. + + The remover method is called with one positional argument: the value + to remove. The method will be automatically decorated with + 'removes_return()' if not already decorated:: + + @collection.remover + def zap(self, entity): ... + + # or, equivalently + @collection.remover + @collection.removes_return() + def zap(self, ): ... + + If the value to remove is not present in the collection, you may + raise an exception or return None to ignore the error. + + If the remove method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + """ + + setattr(fn, '_sa_instrument_role', 'remover') + return fn + remover = classmethod(remover) + + def iterator(cls, fn): + """Tag the method as the collection remover. + + The iterator method is called with no arguments. It is expected to + return an iterator over all collection members:: + + @collection.iterator + def __iter__(self): ... + """ + + setattr(fn, '_sa_instrument_role', 'iterator') + return fn + iterator = classmethod(iterator) + + def internally_instrumented(cls, fn): + """Tag the method as instrumented. + + This tag will prevent any decoration from being applied to the method. + Use this if you are orchestrating your own calls to collection_adapter + in one of the basic SQLAlchemy interface methods, or to prevent + an automatic ABC method decoration from wrapping your implementation:: + + # normally an 'extend' method on a list-like class would be + # automatically intercepted and re-implemented in terms of + # SQLAlchemy events and append(). your implementation will + # never be called, unless: + @collection.internally_instrumented + def extend(self, items): ... + """ + + setattr(fn, '_sa_instrumented', True) + return fn + internally_instrumented = classmethod(internally_instrumented) + + def on_link(cls, fn): + """Tag the method as a the "linked to attribute" event handler. + + This optional event handler will be called when the collection class + is linked to or unlinked from the InstrumentedAttribute. It is + invoked immediately after the '_sa_adapter' property is set on + the instance. A single argument is passed: the collection adapter + that has been linked, or None if unlinking. + """ + + setattr(fn, '_sa_instrument_role', 'on_link') + return fn + on_link = classmethod(on_link) + + def adds(cls, arg): + """Mark the method as adding an entity to the collection. + + Adds "add to collection" handling to the method. The decorator argument + indicates which method argument holds the SQLAlchemy-relevant value. + Arguments can be specified positionally (i.e. integer) or by name:: + + @collection.adds(1) + def push(self, item): ... + + @collection.adds('entity') + def do_stuff(self, thing, entity=None): ... + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) + return fn + return decorator + adds = classmethod(adds) + + def replaces(cls, arg): + """Mark the method as replacing an entity in the collection. + + Adds "add to collection" and "remove from collection" handling to + the method. The decorator argument indicates which method argument + holds the SQLAlchemy-relevant value to be added, and return value, if + any will be considered the value to remove. + + Arguments can be specified positionally (i.e. integer) or by name:: + + @collection.replaces(2) + def __setitem__(self, index, item): ... + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) + setattr(fn, '_sa_instrument_after', 'fire_remove_event') + return fn + return decorator + replaces = classmethod(replaces) + + def removes(cls, arg): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value to be removed. Arguments can be specified positionally (i.e. + integer) or by name:: + + @collection.removes(1) + def zap(self, item): ... + + For methods where the value to remove is not known at call-time, use + collection.removes_return. + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg)) + return fn + return decorator + removes = classmethod(removes) + + def removes_return(cls): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The return value + of the method, if any, is considered the value to remove. The method + arguments are not inspected:: + + @collection.removes_return() + def pop(self): ... + + For methods where the value to remove is known at call-time, use + collection.remove. + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_after', 'fire_remove_event') + return fn + return decorator + removes_return = classmethod(removes_return) + + +# public instrumentation interface for 'internally instrumented' +# implementations +def collection_adapter(collection): + """Fetch the CollectionAdapter for a collection.""" + + return getattr(collection, '_sa_adapter', None) + +class CollectionAdapter(object): + """Bridges between the ORM and arbitrary Python collections. + + Proxies base-level collection operations (append, remove, iterate) + to the underlying Python collection, and emits add/remove events for + entities entering or leaving the collection. + + The ORM uses an CollectionAdapter exclusively for interaction with + entity collections. + """ + + def __init__(self, attr, owner, data): + self.attr = attr + self._owner = weakref.ref(owner) + self._data = weakref.ref(data) + self.link_to_self(data) + + owner = property(lambda s: s._owner(), + doc="The object that owns the entity collection.") + data = property(lambda s: s._data(), + doc="The entity collection being adapted.") + + def link_to_self(self, data): + """Link a collection to this adapter, and fire a link event.""" + + setattr(data, '_sa_adapter', self) + if hasattr(data, '_sa_on_link'): + getattr(data, '_sa_on_link')(self) + + def unlink(self, data): + """Unlink a collection from any adapter, and fire a link event.""" + + setattr(data, '_sa_adapter', None) + if hasattr(data, '_sa_on_link'): + getattr(data, '_sa_on_link')(None) + + def append_with_event(self, item, initiator=None): + """Add an entity to the collection, firing mutation events.""" + + getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator) + + def append_without_event(self, item): + """Add or restore an entity to the collection, firing no events.""" + + getattr(self._data(), '_sa_appender')(item, _sa_initiator=False) + + def remove_with_event(self, item, initiator=None): + """Remove an entity from the collection, firing mutation events.""" + + getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator) + + def remove_without_event(self, item): + """Remove an entity from the collection, firing no events.""" + + getattr(self._data(), '_sa_remover')(item, _sa_initiator=False) + + def clear_with_event(self, initiator=None): + """Empty the collection, firing a mutation event for each entity.""" + + for item in list(self): + self.remove_with_event(item, initiator) + + def clear_without_event(self): + """Empty the collection, firing no events.""" + + for item in list(self): + self.remove_without_event(item) + + def __iter__(self): + """Iterate over entities in the collection.""" + + return getattr(self._data(), '_sa_iterator')() + + def __len__(self): + """Count entities in the collection.""" + + return len(list(getattr(self._data(), '_sa_iterator')())) + + def __nonzero__(self): + return True + + def fire_append_event(self, item, initiator=None): + """Notify that a entity has entered the collection. + + Initiator is the InstrumentedAttribute that initiated the membership + mutation, and should be left as None unless you are passing along + an initiator value from a chained operation. + """ + + if initiator is not False and item is not None: + self.attr.fire_append_event(self._owner(), item, initiator) + + def fire_remove_event(self, item, initiator=None): + """Notify that a entity has entered the collection. + + Initiator is the InstrumentedAttribute that initiated the membership + mutation, and should be left as None unless you are passing along + an initiator value from a chained operation. + """ + + if initiator is not False and item is not None: + self.attr.fire_remove_event(self._owner(), item, initiator) + + def __getstate__(self): + return { 'key': self.attr.key, + 'owner': self.owner, + 'data': self.data } + + def __setstate__(self, d): + self.attr = getattr(d['owner'].__class__, d['key']) + self._owner = weakref.ref(d['owner']) + self._data = weakref.ref(d['data']) + + +__instrumentation_mutex = Lock() +def _prepare_instrumentation(factory): + """Prepare a callable for future use as a collection class factory. + + Given a collection class factory (either a type or no-arg callable), + return another factory that will produce compatible instances when + called. + + This function is responsible for converting collection_class=list + into the run-time behavior of collection_class=InstrumentedList. + """ + + # Convert a builtin to 'Instrumented*' + if factory in __canned_instrumentation: + factory = __canned_instrumentation[factory] + + # Create a specimen + cls = type(factory()) + + # Did factory callable return a builtin? + if cls in __canned_instrumentation: + # Wrap it so that it returns our 'Instrumented*' + factory = __converting_factory(factory) + cls = factory() + + # Instrument the class if needed. + if __instrumentation_mutex.acquire(): + try: + if getattr(cls, '_sa_instrumented', None) != id(cls): + _instrument_class(cls) + finally: + __instrumentation_mutex.release() + + return factory + +def __converting_factory(original_factory): + """Convert the type returned by collection factories on the fly. + + Given a collection factory that returns a builtin type (e.g. a list), + return a wrapped function that converts that type to one of our + instrumented types. + """ + + def wrapper(): + collection = original_factory() + type_ = type(collection) + if type_ in __canned_instrumentation: + # return an instrumented type initialized from the factory's + # collection + return __canned_instrumentation[type_](collection) + else: + raise exceptions.InvalidRequestError( + "Collection class factories must produce instances of a " + "single class.") + try: + # often flawed but better than nothing + wrapper.__name__ = "%sWrapper" % original_factory.__name__ + wrapper.__doc__ = original_factory.__doc__ + except: + pass + return wrapper + +def _instrument_class(cls): + """Modify methods in a class and install instrumentation.""" + + # FIXME: more formally document this as a decoratorless/Python 2.3 + # option for specifying instrumentation. (likely doc'd here in code only, + # not in online docs.) + # + # __instrumentation__ = { + # 'rolename': 'methodname', # ... + # 'methods': { + # 'methodname': ('fire_{append,remove}_event', argspec, + # 'fire_{append,remove}_event'), + # 'append': ('fire_append_event', 1, None), + # '__setitem__': ('fire_append_event', 1, 'fire_remove_event'), + # 'pop': (None, None, 'fire_remove_event'), + # } + # } + + # In the normal call flow, a request for any of the 3 basic collection + # types is transformed into one of our trivial subclasses + # (e.g. InstrumentedList). Catch anything else that sneaks in here... + if cls.__module__ == '__builtin__': + raise exceptions.ArgumentError( + "Can not instrument a built-in type. Use a " + "subclass, even a trivial one.") + + collection_type = sautil.duck_type_collection(cls) + if collection_type in __interfaces: + roles = __interfaces[collection_type].copy() + decorators = roles.pop('_decorators', {}) + else: + roles, decorators = {}, {} + + if hasattr(cls, '__instrumentation__'): + roles.update(copy.deepcopy(getattr(cls, '__instrumentation__'))) + + methods = roles.pop('methods', {}) + + for name in dir(cls): + method = getattr(cls, name) + if not callable(method): + continue + + # note role declarations + if hasattr(method, '_sa_instrument_role'): + role = method._sa_instrument_role + assert role in ('appender', 'remover', 'iterator', 'on_link') + roles[role] = name + + # transfer instrumentation requests from decorated function + # to the combined queue + before, after = None, None + if hasattr(method, '_sa_instrument_before'): + op, argument = method._sa_instrument_before + assert op in ('fire_append_event', 'fire_remove_event') + before = op, argument + if hasattr(method, '_sa_instrument_after'): + op = method._sa_instrument_after + assert op in ('fire_append_event', 'fire_remove_event') + after = op + if before: + methods[name] = before[0], before[1], after + elif after: + methods[name] = None, None, after + + # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): + fn = getattr(cls, method, None) + if fn and method not in methods and not hasattr(fn, '_sa_instrumented'): + setattr(cls, method, decorator(fn)) + + # ensure all roles are present, and apply implicit instrumentation if + # needed + if 'appender' not in roles or not hasattr(cls, roles['appender']): + raise exceptions.ArgumentError( + "Type %s must elect an appender method to be " + "a collection class" % cls.__name__) + elif (roles['appender'] not in methods and + not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')): + methods[roles['appender']] = ('fire_append_event', 1, None) + + if 'remover' not in roles or not hasattr(cls, roles['remover']): + raise exceptions.ArgumentError( + "Type %s must elect a remover method to be " + "a collection class" % cls.__name__) + elif (roles['remover'] not in methods and + not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')): + methods[roles['remover']] = ('fire_remove_event', 1, None) + + if 'iterator' not in roles or not hasattr(cls, roles['iterator']): + raise exceptions.ArgumentError( + "Type %s must elect an iterator method to be " + "a collection class" % cls.__name__) + + # apply ad-hoc instrumentation from decorators, class-level defaults + # and implicit role declarations + for method, (before, argument, after) in methods.items(): + setattr(cls, method, + _instrument_membership_mutator(getattr(cls, method), + before, argument, after)) + # intern the role map + for role, method in roles.items(): + setattr(cls, '_sa_%s' % role, getattr(cls, method)) + + setattr(cls, '_sa_instrumented', id(cls)) + +def _instrument_membership_mutator(method, before, argument, after): + """Route method args and/or return value through the collection adapter.""" + + if type(argument) is int: + def wrapper(*args, **kw): + if before and len(args) < argument: + raise exceptions.ArgumentError( + 'Missing argument %i' % argument) + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = getattr(args[0], '_sa_adapter', None) + + if before and executor: + getattr(executor, before)(args[argument], initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + else: + def wrapper(*args, **kw): + if before: + vals = inspect.getargvalues(inspect.currentframe()) + if argument in kw: + value = kw[argument] + else: + positional = inspect.getargspec(method)[0] + pos = positional.index(argument) + if pos == -1: + raise exceptions.ArgumentError('Missing argument %s' % + argument) + else: + value = args[pos] + + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = getattr(args[0], '_sa_adapter', None) + + if before and executor: + getattr(executor, before)(value, initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + try: + wrapper._sa_instrumented = True + wrapper.__name__ = method.__name__ + wrapper.__doc__ = method.__doc__ + except: + pass + return wrapper + +def __set(collection, item, _sa_initiator=None): + """Run set events, may eventually be inlined into decorators.""" + + if _sa_initiator is not False and item is not None: + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_append_event')(item, _sa_initiator) + +def __del(collection, item, _sa_initiator=None): + """Run del events, may eventually be inlined into decorators.""" + + if _sa_initiator is not False and item is not None: + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_remove_event')(item, _sa_initiator) + +def _list_decorators(): + """Hand-turned instrumentation wrappers that can decorate any list-like + class.""" + + def _tidy(fn): + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__') + + def append(fn): + def append(self, item, _sa_initiator=None): + # FIXME: example of fully inlining __set and adapter.fire + # for critical path + if _sa_initiator is not False and item is not None: + executor = getattr(self, '_sa_adapter', None) + if executor: + executor.attr.fire_append_event(executor._owner(), + item, _sa_initiator) + fn(self, item) + _tidy(append) + return append + + def remove(fn): + def remove(self, value, _sa_initiator=None): + fn(self, value) + __del(self, value, _sa_initiator) + _tidy(remove) + return remove + + def insert(fn): + def insert(self, index, value): + __set(self, value) + fn(self, index, value) + _tidy(insert) + return insert + + def __setitem__(fn): + def __setitem__(self, index, value): + if not isinstance(index, slice): + existing = self[index] + if existing is not None: + __del(self, existing) + __set(self, value) + fn(self, index, value) + else: + # slice assignment requires __delitem__, insert, __len__ + if index.stop is None: + stop = 0 + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + rng = range(index.start or 0, stop, step) + if step == 1: + for i in rng: + del self[index.start] + i = index.start + for item in value: + self.insert(i, item) + i += 1 + else: + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), + len(rng))) + for i, item in zip(rng, value): + self.__setitem__(i, item) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, index): + if not isinstance(index, slice): + item = self[index] + __del(self, item) + fn(self, index) + else: + # slice deletion requires __getslice__ and a slice-groking + # __getitem__ for stepped deletion + # note: not breaking this into atomic dels + for item in self[index]: + __del(self, item) + fn(self, index) + _tidy(__delitem__) + return __delitem__ + + def __setslice__(fn): + def __setslice__(self, start, end, values): + for value in self[start:end]: + __del(self, value) + for value in values: + __set(self, value) + fn(self, start, end, values) + _tidy(__setslice__) + return __setslice__ + + def __delslice__(fn): + def __delslice__(self, start, end): + for value in self[start:end]: + __del(self, value) + fn(self, start, end) + _tidy(__delslice__) + return __delslice__ + + def extend(fn): + def extend(self, iterable): + for value in iterable: + self.append(value) + _tidy(extend) + return extend + + def pop(fn): + def pop(self, index=-1): + item = fn(self, index) + __del(self, item) + return item + _tidy(pop) + return pop + + l = locals().copy() + l.pop('_tidy') + return l + +def _dict_decorators(): + """Hand-turned instrumentation wrappers that can decorate any dict-like + mapping class.""" + + def _tidy(fn): + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') + + Unspecified=object() + + def __setitem__(fn): + def __setitem__(self, key, value, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + __set(self, value, _sa_initiator) + fn(self, key, value) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, key, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + fn(self, key) + _tidy(__delitem__) + return __delitem__ + + def clear(fn): + def clear(self): + for key in self: + __del(self, self[key]) + fn(self) + _tidy(clear) + return clear + + def pop(fn): + def pop(self, key, default=Unspecified): + if key in self: + __del(self, self[key]) + if default is Unspecified: + return fn(self, key) + else: + return fn(self, key, default) + _tidy(pop) + return pop + + def popitem(fn): + def popitem(self): + item = fn(self) + __del(self, item[1]) + return item + _tidy(popitem) + return popitem + + def setdefault(fn): + def setdefault(self, key, default=None): + if key not in self: + self.__setitem__(key, default) + return default + else: + return self.__getitem__(key) + _tidy(setdefault) + return setdefault + + if sys.version_info < (2, 4): + def update(fn): + def update(self, other): + for key in other.keys(): + if not self.has_key(key) or self[key] is not other[key]: + self[key] = other[key] + _tidy(update) + return update + else: + def update(fn): + def update(self, __other=Unspecified, **kw): + if __other is not Unspecified: + if hasattr(__other, 'keys'): + for key in __other.keys(): + if key not in self or self[key] is not __other[key]: + self[key] = __other[key] + else: + for key, value in __other: + if key not in self or self[key] is not value: + self[key] = value + for key in kw: + if key not in self or self[key] is not kw[key]: + self[key] = kw[key] + _tidy(update) + return update + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + +def _set_decorators(): + """Hand-turned instrumentation wrappers that can decorate any set-like + sequence class.""" + + def _tidy(fn): + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__') + + Unspecified=object() + + def add(fn): + def add(self, value, _sa_initiator=None): + __set(self, value, _sa_initiator) + fn(self, value) + _tidy(add) + return add + + def discard(fn): + def discard(self, value, _sa_initiator=None): + if value in self: + __del(self, value, _sa_initiator) + fn(self, value) + _tidy(discard) + return discard + + def remove(fn): + def remove(self, value, _sa_initiator=None): + if value in self: + __del(self, value, _sa_initiator) + fn(self, value) + _tidy(remove) + return remove + + def pop(fn): + def pop(self): + item = fn(self) + __del(self, item) + return item + _tidy(pop) + return pop + + def clear(fn): + def clear(self): + for item in list(self): + self.remove(item) + _tidy(clear) + return clear + + def update(fn): + def update(self, value): + for item in value: + if item not in self: + self.add(item) + _tidy(update) + return update + __ior__ = update + + def difference_update(fn): + def difference_update(self, value): + for item in value: + self.discard(item) + _tidy(difference_update) + return difference_update + __isub__ = difference_update + + def intersection_update(fn): + def intersection_update(self, other): + want, have = self.intersection(other), sautil.Set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(intersection_update) + return intersection_update + __iand__ = intersection_update + + def symmetric_difference_update(fn): + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), sautil.Set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(symmetric_difference_update) + return symmetric_difference_update + __ixor__ = symmetric_difference_update + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + + +class InstrumentedList(list): + """An instrumented version of the built-in list.""" + + __instrumentation__ = { + 'appender': 'append', + 'remover': 'remove', + 'iterator': '__iter__', } + +class InstrumentedSet(sautil.Set): + """An instrumented version of the built-in set (or Set).""" + + __instrumentation__ = { + 'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__', } + +class InstrumentedDict(dict): + """An instrumented version of the built-in dict.""" + + __instrumentation__ = { + 'iterator': 'itervalues', } + +__canned_instrumentation = { + list: InstrumentedList, + sautil.Set: InstrumentedSet, + dict: InstrumentedDict, + } + +__interfaces = { + list: { 'appender': 'append', + 'remover': 'remove', + 'iterator': '__iter__', + '_decorators': _list_decorators(), }, + sautil.Set: { 'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__', + '_decorators': _set_decorators(), }, + # decorators are required for dicts and object collections. + dict: { 'iterator': 'itervalues', + '_decorators': _dict_decorators(), }, + # < 0.4 compatible naming, deprecated- use decorators instead. + None: { } + } + + +class MappedCollection(dict): + """A basic dictionary-based collection class. + + Extends dict with the minimal bag semantics that collection classes require. + ``set`` and ``remove`` are implemented in terms of a keying function: any + callable that takes an object and returns an object for use as a dictionary + key. + """ + + def __init__(self, keyfunc): + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable any callable that takes an object and + returns an object for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or remove + a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + """ + self.keyfunc = keyfunc + + def set(self, value, _sa_initiator=None): + """Add an item to the collection, with a key provided by this instance's keyfunc.""" + + key = self.keyfunc(value) + self.__setitem__(key, value, _sa_initiator) + set = collection.internally_instrumented(set) + set = collection.appender(set) + + def remove(self, value, _sa_initiator=None): + """Remove an item from the collection by value, consulting this instance's keyfunc for the key.""" + + key = self.keyfunc(value) + # Let self[key] raise if key is not in this collection + if self[key] != value: + raise exceptions.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the MappedCollection key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % + (value, self[key], key)) + self.__delitem__(key, _sa_initiator) + remove = collection.internally_instrumented(remove) + remove = collection.remover(remove) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 54b043b32..c06db6963 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -6,8 +6,8 @@ """Bridge the ``PropertyLoader`` (i.e. a ``relation()``) and the -``UOWTransaction`` together to allow processing of scalar- and -list-based dependencies at flush time. +``UOWTransaction`` together to allow processing of relation()-based + dependencies at flush time. """ from sqlalchemy.orm import sync @@ -366,7 +366,7 @@ class ManyToManyDP(DependencyProcessor): if len(secondary_delete): secondary_delete.sort() # TODO: precompile the delete/insert queries? - statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type=c.type) for c in self.secondary.c if c.key in associationrow])) + statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow])) result = connection.execute(statement, secondary_delete) if result.supports_sane_rowcount() and result.rowcount != len(secondary_delete): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (result.rowcount, len(secondary_delete))) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index a9a26b57f..aeb8a23fa 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -5,7 +5,205 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import util, logging +from sqlalchemy import util, logging, sql + +# returned by a MapperExtension method to indicate a "do nothing" response +EXT_PASS = object() + +class MapperExtension(object): + """Base implementation for an object that provides overriding + behavior to various Mapper functions. For each method in + MapperExtension, a result of EXT_PASS indicates the functionality + is not overridden. + """ + + + def init_instance(self, mapper, class_, instance, args, kwargs): + return EXT_PASS + + def init_failed(self, mapper, class_, instance, args, kwargs): + return EXT_PASS + + def get_session(self): + """Retrieve a contextual Session instance with which to + register a new object. + + Note: this is not called if a session is provided with the + `__init__` params (i.e. `_sa_session`). + """ + + return EXT_PASS + + def load(self, query, *args, **kwargs): + """Override the `load` method of the Query object. + + The return value of this method is used as the result of + ``query.load()`` if the value is anything other than EXT_PASS. + """ + + return EXT_PASS + + def get(self, query, *args, **kwargs): + """Override the `get` method of the Query object. + + The return value of this method is used as the result of + ``query.get()`` if the value is anything other than EXT_PASS. + """ + + return EXT_PASS + + def get_by(self, query, *args, **kwargs): + """Override the `get_by` method of the Query object. + + The return value of this method is used as the result of + ``query.get_by()`` if the value is anything other than + EXT_PASS. + + DEPRECATED. + """ + + return EXT_PASS + + def select_by(self, query, *args, **kwargs): + """Override the `select_by` method of the Query object. + + The return value of this method is used as the result of + ``query.select_by()`` if the value is anything other than + EXT_PASS. + + DEPRECATED. + """ + + return EXT_PASS + + def select(self, query, *args, **kwargs): + """Override the `select` method of the Query object. + + The return value of this method is used as the result of + ``query.select()`` if the value is anything other than + EXT_PASS. + + DEPRECATED. + """ + + return EXT_PASS + + + def translate_row(self, mapper, context, row): + """Perform pre-processing on the given result row and return a + new row instance. + + This is called as the very first step in the ``_instance()`` + method. + """ + + return EXT_PASS + + def create_instance(self, mapper, selectcontext, row, class_): + """Receive a row when a new object instance is about to be + created from that row. + + The method can choose to create the instance itself, or it can + return None to indicate normal object creation should take + place. + + mapper + The mapper doing the operation + + selectcontext + SelectionContext corresponding to the instances() call + + row + The result row from the database + + class\_ + The class we are mapping. + """ + + return EXT_PASS + + def append_result(self, mapper, selectcontext, row, instance, result, **flags): + """Receive an object instance before that instance is appended + to a result list. + + If this method returns EXT_PASS, result appending will proceed + normally. if this method returns any other value or None, + result appending will not proceed for this instance, giving + this extension an opportunity to do the appending itself, if + desired. + + mapper + The mapper doing the operation. + + selectcontext + SelectionContext corresponding to the instances() call. + + row + The result row from the database. + + instance + The object instance to be appended to the result. + + result + List to which results are being appended. + + \**flags + extra information about the row, same as criterion in + `create_row_processor()` method of [sqlalchemy.orm.interfaces#MapperProperty] + """ + + return EXT_PASS + + def populate_instance(self, mapper, selectcontext, row, instance, **flags): + """Receive a newly-created instance before that instance has + its attributes populated. + + The normal population of attributes is according to each + attribute's corresponding MapperProperty (which includes + column-based attributes as well as relationships to other + classes). If this method returns EXT_PASS, instance + population will proceed normally. If any other value or None + is returned, instance population will not proceed, giving this + extension an opportunity to populate the instance itself, if + desired. + """ + + return EXT_PASS + + def before_insert(self, mapper, connection, instance): + """Receive an object instance before that instance is INSERTed + into its table. + + This is a good place to set up primary key values and such + that aren't handled otherwise. + """ + + return EXT_PASS + + def before_update(self, mapper, connection, instance): + """Receive an object instance before that instance is UPDATEed.""" + + return EXT_PASS + + def after_update(self, mapper, connection, instance): + """Receive an object instance after that instance is UPDATEed.""" + + return EXT_PASS + + def after_insert(self, mapper, connection, instance): + """Receive an object instance after that instance is INSERTed.""" + + return EXT_PASS + + def before_delete(self, mapper, connection, instance): + """Receive an object instance before that instance is DELETEed.""" + + return EXT_PASS + + def after_delete(self, mapper, connection, instance): + """Receive an object instance after that instance is DELETEed.""" + + return EXT_PASS class MapperProperty(object): """Manage the relationship of a ``Mapper`` to a single class @@ -15,22 +213,61 @@ class MapperProperty(object): """ def setup(self, querycontext, **kwargs): - """Called when a statement is being constructed.""" + """Called by Query for the purposes of constructing a SQL statement. + + Each MapperProperty associated with the target mapper processes the + statement referenced by the query context, adding columns and/or + criterion as appropriate. + """ pass - def execute(self, selectcontext, instance, row, identitykey, isnew): - """Called when the mapper receives a row. - - `instance` is the parent instance corresponding to the `row`. + def create_row_processor(self, selectcontext, mapper, row): + """return a 2-tuple consiting of a row processing function and an instance post-processing function. + + Input arguments are the query.SelectionContext and the *first* + applicable row of a result set obtained within query.Query.instances(), called + only the first time a particular mapper.populate_instance() is invoked for the + overal result. + + The settings contained within the SelectionContext as well as the columns present + in the row (which will be the same columns present in all rows) are used to determine + the behavior of the returned callables. The callables will then be used to process + all rows and to post-process all instances, respectively. + + callables are of the following form:: + + def execute(instance, row, **flags): + # process incoming instance and given row. + # flags is a dictionary containing at least the following attributes: + # isnew - indicates if the instance was newly created as a result of reading this row + # instancekey - identity key of the instance + # optional attribute: + # ispostselect - indicates if this row resulted from a 'post' select of additional tables/columns + + def post_execute(instance, **flags): + # process instance after all result rows have been processed. this + # function should be used to issue additional selections in order to + # eagerly load additional properties. + + return (execute, post_execute) + + either tuple value can also be ``None`` in which case no function is called. + """ - + raise NotImplementedError() - + def cascade_iterator(self, type, object, recursive=None, halt_on=None): + """return an iterator of objects which are child objects of the given object, + as attached to the attribute corresponding to this MapperProperty.""" + return [] def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None): + """run the given callable across all objects which are child objects of + the given object, as attached to the attribute corresponding to this MapperProperty.""" + return [] def get_criterion(self, query, key, value): @@ -60,7 +297,11 @@ class MapperProperty(object): self.do_init() def do_init(self): - """Template method for subclasses.""" + """Perform subclass-specific initialization steps. + + This is a *template* method called by the + ``MapperProperty`` object's init() method.""" + pass def register_dependencies(self, *args, **kwargs): @@ -90,59 +331,81 @@ class MapperProperty(object): raise NotImplementedError() - def compare(self, value): + def compare(self, operator, value): """Return a compare operation for the columns represented by this ``MapperProperty`` to the given value, which may be a - column value or an instance. + column value or an instance. 'operator' is an operator from + the operators module, or from sql.Comparator. + + By default uses the PropComparator attached to this MapperProperty + under the attribute name "comparator". """ - raise NotImplementedError() + return operator(self.comparator, value) -class SynonymProperty(MapperProperty): - def __init__(self, name, proxy=False): - self.name = name - self.proxy = proxy - - def setup(self, querycontext, **kwargs): - pass - - def execute(self, selectcontext, instance, row, identitykey, isnew): - pass - - def do_init(self): - if not self.proxy: - return - class SynonymProp(object): - def __set__(s, obj, value): - setattr(obj, self.name, value) - def __delete__(s, obj): - delattr(obj, self.name) - def __get__(s, obj, owner): - if obj is None: - return s - return getattr(obj, self.name) - setattr(self.parent.class_, self.key, SynonymProp()) +class PropComparator(sql.ColumnOperators): + """defines comparison operations for MapperProperty objects""" + + def expression_element(self): + return self.clause_element() + + def contains_op(a, b): + return a.contains(b) + contains_op = staticmethod(contains_op) + + def any_op(a, b, **kwargs): + return a.any(b, **kwargs) + any_op = staticmethod(any_op) + + def has_op(a, b, **kwargs): + return a.has(b, **kwargs) + has_op = staticmethod(has_op) + + def __init__(self, prop): + self.prop = prop + + def contains(self, other): + """return true if this collection contains other""" + return self.operate(PropComparator.contains_op, other) + + def any(self, criterion=None, **kwargs): + """return true if this collection contains any member that meets the given criterion. + + criterion + an optional ClauseElement formulated against the member class' table or attributes. + + \**kwargs + key/value pairs corresponding to member class attribute names which will be compared + via equality to the corresponding values. + """ - def merge(self, session, source, dest, _recursive): - pass + return self.operate(PropComparator.any_op, criterion, **kwargs) + + def has(self, criterion=None, **kwargs): + """return true if this element references a member which meets the given criterion. + + + criterion + an optional ClauseElement formulated against the member class' table or attributes. + + \**kwargs + key/value pairs corresponding to member class attribute names which will be compared + via equality to the corresponding values. + """ + return self.operate(PropComparator.has_op, criterion, **kwargs) + class StrategizedProperty(MapperProperty): """A MapperProperty which uses selectable strategies to affect loading behavior. - There is a single default strategy selected, and alternate - strategies can be selected at selection time through the usage of - ``StrategizedOption`` objects. + There is a single default strategy selected by default. Alternate + strategies can be selected at Query time through the usage of + ``StrategizedOption`` objects via the Query.options() method. """ def _get_context_strategy(self, context): - try: - return context.attributes[id(self)] - except KeyError: - # cache the located strategy per StrategizedProperty in the given context for faster re-lookup - ctx_strategy = self._get_strategy(context.attributes.get((LoaderStrategy, self), self.strategy.__class__)) - context.attributes[id(self)] = ctx_strategy - return ctx_strategy + return self._get_strategy(context.attributes.get(("loaderstrategy", self), self.strategy.__class__)) def _get_strategy(self, cls): try: @@ -156,11 +419,10 @@ class StrategizedProperty(MapperProperty): return strategy def setup(self, querycontext, **kwargs): - self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs) - def execute(self, selectcontext, instance, row, identitykey, isnew): - self._get_context_strategy(selectcontext).process_row(selectcontext, instance, row, identitykey, isnew) + def create_row_processor(self, selectcontext, mapper, row): + return self._get_context_strategy(selectcontext).create_row_processor(selectcontext, mapper, row) def do_init(self): self._all_strategies = {} @@ -170,6 +432,31 @@ class StrategizedProperty(MapperProperty): if self.is_primary(): self.strategy.init_class_attribute() +class LoaderStack(object): + """a stack object used during load operations to track the + current position among a chain of mappers to eager loaders.""" + + def __init__(self): + self.__stack = [] + + def push_property(self, key): + self.__stack.append(key) + + def push_mapper(self, mapper): + self.__stack.append(mapper.base_mapper()) + + def pop(self): + self.__stack.pop() + + def snapshot(self): + """return an 'snapshot' of this stack. + + this is a tuple form of the stack which can be used as a hash key.""" + return tuple(self.__stack) + + def __str__(self): + return "->".join([str(s) for s in self.__stack]) + class OperationContext(object): """Serve as a context during a query construction or instance loading operation. @@ -200,6 +487,44 @@ class MapperOption(object): def process_query(self, query): pass +class ExtensionOption(MapperOption): + """a MapperOption that applies a MapperExtension to a query operation.""" + + def __init__(self, ext): + self.ext = ext + + def process_query(self, query): + query._extension = query._extension.copy() + query._extension.append(self.ext) + +class SynonymProperty(MapperProperty): + def __init__(self, name, proxy=False): + self.name = name + self.proxy = proxy + + def setup(self, querycontext, **kwargs): + pass + + def create_row_processor(self, selectcontext, mapper, row): + return (None, None) + + def do_init(self): + if not self.proxy: + return + class SynonymProp(object): + def __set__(s, obj, value): + setattr(obj, self.name, value) + def __delete__(s, obj): + delattr(obj, self.name) + def __get__(s, obj, owner): + if obj is None: + return s + return getattr(obj, self.name) + setattr(self.parent.class_, self.key, SynonymProp()) + + def merge(self, session, source, dest, _recursive): + pass + class PropertyOption(MapperOption): """A MapperOption that is applied to a property off the mapper or one of its child mappers, identified by a dot-separated key. @@ -208,45 +533,72 @@ class PropertyOption(MapperOption): def __init__(self, key): self.key = key - def process_query_property(self, context, property): + def process_query_property(self, context, properties): pass - def process_selection_property(self, context, property): + def process_selection_property(self, context, properties): pass def process_query_context(self, context): - self.process_query_property(context, self._get_property(context)) + self.process_query_property(context, self._get_properties(context)) def process_selection_context(self, context): - self.process_selection_property(context, self._get_property(context)) + self.process_selection_property(context, self._get_properties(context)) - def _get_property(self, context): + def _get_properties(self, context): try: - prop = self.__prop + l = self.__prop except AttributeError: + l = [] mapper = context.mapper for token in self.key.split('.'): - prop = mapper.props[token] - if isinstance(prop, SynonymProperty): - prop = mapper.props[prop.name] + prop = mapper.get_property(token, resolve_synonyms=True) + l.append(prop) mapper = getattr(prop, 'mapper', None) - self.__prop = prop - return prop + self.__prop = l + return l PropertyOption.logger = logging.class_logger(PropertyOption) + +class AttributeExtension(object): + """An abstract class which specifies `append`, `delete`, and `set` + event handlers to be attached to an object property. + """ + + def append(self, obj, child, initiator): + pass + + def remove(self, obj, child, initiator): + pass + + def set(self, obj, child, oldchild, initiator): + pass + + class StrategizedOption(PropertyOption): """A MapperOption that affects which LoaderStrategy will be used for an operation by a StrategizedProperty. """ - def process_query_property(self, context, property): + def is_chained(self): + return False + + def process_query_property(self, context, properties): self.logger.debug("applying option to QueryContext, property key '%s'" % self.key) - context.attributes[(LoaderStrategy, property)] = self.get_strategy_class() + if self.is_chained(): + for prop in properties: + context.attributes[("loaderstrategy", prop)] = self.get_strategy_class() + else: + context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class() - def process_selection_property(self, context, property): + def process_selection_property(self, context, properties): self.logger.debug("applying option to SelectionContext, property key '%s'" % self.key) - context.attributes[(LoaderStrategy, property)] = self.get_strategy_class() + if self.is_chained(): + for prop in properties: + context.attributes[("loaderstrategy", prop)] = self.get_strategy_class() + else: + context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class() def get_strategy_class(self): raise NotImplementedError() @@ -291,5 +643,13 @@ class LoaderStrategy(object): def setup_query(self, context, **kwargs): pass - def process_row(self, selectcontext, instance, row, identitykey, isnew): - pass + def create_row_processor(self, selectcontext, mapper, row): + """return row processing functions which fulfill the contract specified + by MapperProperty.create_row_processor. + + + StrategizedProperty delegates its create_row_processor method + directly to this method. + """ + + raise NotImplementedError() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 375408926..76cc41289 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -4,14 +4,15 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, schema, util, exceptions, logging +from sqlalchemy import sql, util, exceptions, logging from sqlalchemy import sql_util as sqlutil from sqlalchemy.orm import util as mapperutil +from sqlalchemy.orm.util import ExtensionCarrier from sqlalchemy.orm import sync -from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, SynonymProperty -import weakref, warnings +from sqlalchemy.orm.interfaces import MapperProperty, EXT_PASS, MapperExtension, SynonymProperty +import weakref, warnings, operator -__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS', 'mapper_registry', 'ExtensionOption'] +__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry'] # a dictionary mapping classes to their primary mappers mapper_registry = weakref.WeakKeyDictionary() @@ -24,12 +25,13 @@ global_extensions = [] # column NO_ATTRIBUTE = object() -# returned by a MapperExtension method to indicate a "do nothing" response -EXT_PASS = object() - # lock used to synchronize the "mapper compile" step _COMPILE_MUTEX = util.threading.Lock() +# initialize these two lazily +attribute_manager = None +ColumnProperty = None + class Mapper(object): """Define the correlation of class attributes to database table columns. @@ -55,6 +57,7 @@ class Mapper(object): polymorphic_on=None, _polymorphic_map=None, polymorphic_identity=None, + polymorphic_fetch=None, concrete=False, select_table=None, allow_null_pks=False, @@ -62,126 +65,8 @@ class Mapper(object): column_prefix=None): """Construct a new mapper. - All arguments may be sent to the ``sqlalchemy.orm.mapper()`` - function where they are passed through to here. - - class\_ - The class to be mapped. - - local_table - The table to which the class is mapped, or None if this - mapper inherits from another mapper using concrete table - inheritance. - - properties - A dictionary mapping the string names of object attributes - to ``MapperProperty`` instances, which define the - persistence behavior of that attribute. Note that the - columns in the mapped table are automatically converted into - ``ColumnProperty`` instances based on the `key` property of - each ``Column`` (although they can be overridden using this - dictionary). - - primary_key - A list of ``Column`` objects which define the *primary key* - to be used against this mapper's selectable unit. This is - normally simply the primary key of the `local_table`, but - can be overridden here. - - non_primary - Construct a ``Mapper`` that will define only the selection - of instances, not their persistence. - - inherits - Another ``Mapper`` for which this ``Mapper`` will have an - inheritance relationship with. - - inherit_condition - For joined table inheritance, a SQL expression (constructed - ``ClauseElement``) which will define how the two tables are - joined; defaults to a natural join between the two tables. - - extension - A ``MapperExtension`` instance or list of - ``MapperExtension`` instances which will be applied to all - operations by this ``Mapper``. - - order_by - A single ``Column`` or list of ``Columns`` for which - selection operations should use as the default ordering for - entities. Defaults to the OID/ROWID of the table if any, or - the first primary key column of the table. - - allow_column_override - If True, allows the usage of a ``relation()`` which has the - same name as a column in the mapped table. The table column - will no longer be mapped. - - entity_name - A name to be associated with the `class`, to allow alternate - mappings for a single class. - - always_refresh - If True, all query operations for this mapped class will - overwrite all data within object instances that already - exist within the session, erasing any in-memory changes with - whatever information was loaded from the database. - - version_id_col - A ``Column`` which must have an integer type that will be - used to keep a running *version id* of mapped entities in - the database. this is used during save operations to ensure - that no other thread or process has updated the instance - during the lifetime of the entity, else a - ``ConcurrentModificationError`` exception is thrown. - - polymorphic_on - Used with mappers in an inheritance relationship, a ``Column`` - which will identify the class/mapper combination to be used - with a particular row. requires the polymorphic_identity - value to be set for all mappers in the inheritance - hierarchy. - - _polymorphic_map - Used internally to propigate the full map of polymorphic - identifiers to surrogate mappers. - - polymorphic_identity - A value which will be stored in the Column denoted by - polymorphic_on, corresponding to the *class identity* of - this mapper. - - concrete - If True, indicates this mapper should use concrete table - inheritance with its parent mapper. - - select_table - A ``Table`` or (more commonly) ``Selectable`` which will be - used to select instances of this mapper's class. usually - used to provide polymorphic loading among several classes in - an inheritance hierarchy. - - allow_null_pks - Indicates that composite primary keys where one or more (but - not all) columns contain NULL is a valid primary key. - Primary keys which contain NULL values usually indicate that - a result row does not contain an entity and should be - skipped. - - batch - Indicates that save operations of multiple entities can be - batched together for efficiency. setting to False indicates - that an instance will be fully saved before saving the next - instance, which includes inserting/updating all table rows - corresponding to the entity as well as calling all - ``MapperExtension`` methods corresponding to the save - operation. - - column_prefix - A string which will be prepended to the `key` name of all - Columns when creating column-based properties from the given - Table. Does not affect explicitly specified column-based - properties + Mappers are normally constructed via the [sqlalchemy.orm#mapper()] + function. See for details. """ if not issubclass(class_, object): @@ -227,6 +112,13 @@ class Mapper(object): # indicates this Mapper should be used to construct the object instance for that row. self.polymorphic_identity = polymorphic_identity + if polymorphic_fetch not in (None, 'union', 'select', 'deferred'): + raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch) + if polymorphic_fetch is None: + self.polymorphic_fetch = (self.select_table is None) and 'select' or 'union' + else: + self.polymorphic_fetch = polymorphic_fetch + # a dictionary of 'polymorphic identity' names, associating those names with # Mappers that will be used to construct object instances upon a select operation. if _polymorphic_map is None: @@ -297,20 +189,8 @@ class Mapper(object): else: return False - def _get_props(self): - self.compile() - return self.__props - - props = property(_get_props, doc="compiles this mapper if needed, and returns the " - "dictionary of MapperProperty objects associated with this mapper." - "(Deprecated; use get_property() and iterate_properties)") - def get_property(self, key, resolve_synonyms=False, raiseerr=True): - """return MapperProperty with the given key. - - forwards compatible with 0.4. - """ - + """return MapperProperty with the given key.""" self.compile() prop = self.__props.get(key, None) if resolve_synonyms: @@ -319,10 +199,22 @@ class Mapper(object): if prop is None and raiseerr: raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) return prop - - iterate_properties = property(lambda self: self._get_props().itervalues(), doc="returns an iterator of all MapperProperty objects." - " Forwards compatible with 0.4") - + + def iterate_properties(self): + self.compile() + return self.__props.itervalues() + iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.") + + def dispose(self): + attribute_manager.reset_class_managed(self.class_) + if hasattr(self.class_, 'c'): + del self.class_.c + if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'): + if self.class_.__init__._oldinit is not None: + self.class_.__init__ = self.class_.__init__._oldinit + else: + delattr(self.class_, '__init__') + def compile(self): """Compile this mapper into its final internal format. @@ -403,7 +295,7 @@ class Mapper(object): for ext_obj in util.to_list(extension): extlist.add(ext_obj) - self.extension = _ExtensionCarrier() + self.extension = ExtensionCarrier() for ext in extlist: self.extension.append(ext) @@ -452,8 +344,17 @@ class Mapper(object): self.mapped_table = self.local_table if self.polymorphic_identity is not None: self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self) - if self.polymorphic_on is None and self.inherits.polymorphic_on is not None: - self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False) + if self.polymorphic_on is None: + if self.inherits.polymorphic_on is not None: + self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False) + else: + raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) + + if self.polymorphic_identity is not None and not self.concrete: + self._identity_class = self.inherits._identity_class + else: + self._identity_class = self.class_ + if self.order_by is False: self.order_by = self.inherits.order_by self.polymorphic_map = self.inherits.polymorphic_map @@ -463,8 +364,11 @@ class Mapper(object): self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity is not None: + if self.polymorphic_on is None: + raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) self._add_polymorphic_mapping(self.polymorphic_identity, self) - + self._identity_class = self.class_ + if self.mapped_table is None: raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) @@ -503,39 +407,134 @@ class Mapper(object): # may be a join or other construct self.tables = sqlutil.TableFinder(self.mapped_table) - # determine primary key columns, either passed in, or get them from our set of tables + # determine primary key columns self.pks_by_table = {} + + # go through all of our represented tables + # and assemble primary key columns + for t in self.tables + [self.mapped_table]: + try: + l = self.pks_by_table[t] + except KeyError: + l = self.pks_by_table.setdefault(t, util.OrderedSet()) + for k in t.primary_key: + l.add(k) + if self.primary_key_argument is not None: - # determine primary keys using user-given list of primary key columns as a guide - # - # TODO: this might not work very well for joined-table and/or polymorphic - # inheritance mappers since local_table isnt taken into account nor is select_table - # need to test custom primary key columns used with inheriting mappers for k in self.primary_key_argument: self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k) - if k.table != self.mapped_table: - # associate pk cols from subtables to the "main" table - corr = self.mapped_table.corresponding_column(k, raiseerr=False) - if corr is not None: - self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(corr) - else: - # no user-defined primary key columns - go through all of our represented tables - # and assemble primary key columns - for t in self.tables + [self.mapped_table]: - try: - l = self.pks_by_table[t] - except KeyError: - l = self.pks_by_table.setdefault(t, util.OrderedSet()) - for k in t.primary_key: - #if k.key not in t.c and k._label not in t.c: - # this is a condition that was occurring when table reflection was doubling up primary keys - # that were overridden in the Table constructor - # raise exceptions.AssertionError("Column " + str(k) + " not located in the column set of table " + str(t)) - l.add(k) - + if len(self.pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) - self.primary_key = self.pks_by_table[self.mapped_table] + + if self.inherits is not None and not self.concrete and not self.primary_key_argument: + self.primary_key = self.inherits.primary_key + self._get_clause = self.inherits._get_clause + else: + # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns + # into one column, where "equivalent" means that one column references the other via foreign key, or + # multiple columns that all reference a common parent column. it will also resolve the column + # against the "mapped_table" of this mapper. + equivalent_columns = self._get_equivalent_columns() + + primary_key = sql.ColumnSet() + + for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + c = self.mapped_table.corresponding_column(col, raiseerr=False) + if c is None: + for cc in equivalent_columns[col]: + c = self.mapped_table.corresponding_column(cc, raiseerr=False) + if c is not None: + break + else: + raise exceptions.ArgumentError("Cant resolve column " + str(col)) + + # this step attempts to resolve the column to an equivalent which is not + # a foreign key elsewhere. this helps with joined table inheritance + # so that PKs are expressed in terms of the base table which is always + # present in the initial select + # TODO: this is a little hacky right now, the "tried" list is to prevent + # endless loops between cyclical FKs, try to make this cleaner/work better/etc., + # perhaps via topological sort (pick the leftmost item) + tried = util.Set() + while True: + if not len(c.foreign_keys) or c in tried: + break + for cc in c.foreign_keys: + cc = cc.column + c2 = self.mapped_table.corresponding_column(cc, raiseerr=False) + if c2 is not None: + c = c2 + tried.add(c) + break + else: + break + primary_key.add(c) + + if len(primary_key) == 0: + raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) + + self.primary_key = primary_key + self.__log("Identified primary key columns: " + str(primary_key)) + + _get_clause = sql.and_() + for primary_key in self.primary_key: + _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True)) + self._get_clause = _get_clause + + def _get_equivalent_columns(self): + """Create a map of all *equivalent* columns, based on + the determination of column pairs that are equated to + one another either by an established foreign key relationship + or by a joined-table inheritance join. + + This is used to determine the minimal set of primary key + columns for the mapper, as well as when relating + columns to those of a polymorphic selectable (i.e. a UNION of + several mapped tables), as that selectable usually only contains + one column in its columns clause out of a group of several which + are equated to each other. + + The resulting structure is a dictionary of columns mapped + to lists of equivalent columns, i.e. + + { + tablea.col1: + set([tableb.col1, tablec.col1]), + tablea.col2: + set([tabled.col2]) + } + + this method is called repeatedly during the compilation process as + the resulting dictionary contains more equivalents as more inheriting + mappers are compiled. the repetition process may be open to some optimization. + """ + + result = {} + def visit_binary(binary): + if binary.operator == operator.eq: + if binary.left in result: + result[binary.left].add(binary.right) + else: + result[binary.left] = util.Set([binary.right]) + if binary.right in result: + result[binary.right].add(binary.left) + else: + result[binary.right] = util.Set([binary.left]) + vis = mapperutil.BinaryVisitor(visit_binary) + + for mapper in self.base_mapper().polymorphic_iterator(): + if mapper.inherit_condition is not None: + vis.traverse(mapper.inherit_condition) + + for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + if not len(col.foreign_keys): + result.setdefault(col, util.Set()).add(col) + else: + for fk in col.foreign_keys: + result.setdefault(fk.column, util.Set()).add(col) + + return result def _compile_properties(self): """Inspect the properties dictionary sent to the Mapper's @@ -548,7 +547,7 @@ class Mapper(object): """ # object attribute names mapped to MapperProperty objects - self.__props = {} + self.__props = util.OrderedDict() # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as @@ -574,7 +573,7 @@ class Mapper(object): self.columns[column.key] = self.select_table.corresponding_column(column, keys_ok=True, raiseerr=True) column_key = (self.column_prefix or '') + column.key - prop = self.__props.get(column_key, None) + prop = self.__props.get(column.key, None) if prop is None: prop = ColumnProperty(column) self.__props[column_key] = prop @@ -582,7 +581,7 @@ class Mapper(object): self.__log("adding ColumnProperty %s" % (column_key)) elif isinstance(prop, ColumnProperty): if prop.parent is not self: - prop = ColumnProperty(deferred=prop.deferred, group=prop.group, *prop.columns) + prop = prop.copy() prop.set_parent(self) self.__props[column_key] = prop if column in self.primary_key and prop.columns[-1] in self.primary_key: @@ -597,8 +596,7 @@ class Mapper(object): # its a ColumnProperty - match the ultimate table columns # back to the property - proplist = self.columntoproperty.setdefault(column, []) - proplist.append(prop) + self.columntoproperty.setdefault(column, []).append(prop) def _initialize_properties(self): @@ -660,66 +658,43 @@ class Mapper(object): attribute_manager.reset_class_managed(self.class_) oldinit = self.class_.__init__ - def init(self, *args, **kwargs): - entity_name = kwargs.pop('_sa_entity_name', None) - mapper = mapper_registry.get(ClassKey(self.__class__, entity_name)) - if mapper is not None: - mapper = mapper.compile() - - # this gets the AttributeManager to do some pre-initialization, - # in order to save on KeyErrors later on - attribute_manager.init_attr(self) - - if kwargs.has_key('_sa_session'): - session = kwargs.pop('_sa_session') - else: - # works for whatever mapper the class is associated with - if mapper is not None: - session = mapper.extension.get_session() - if session is EXT_PASS: - session = None - else: - session = None - # if a session was found, either via _sa_session or via mapper extension, - # and we have found a mapper, save() this instance to the session, and give it an associated entity_name. - # otherwise, this instance will not have a session or mapper association until it is - # save()d to some session. - if session is not None and mapper is not None: - self._entity_name = entity_name - session._register_pending(self) + def init(instance, *args, **kwargs): + self.compile() + self.extension.init_instance(self, self.class_, instance, args, kwargs) if oldinit is not None: try: - oldinit(self, *args, **kwargs) + oldinit(instance, *args, **kwargs) except: - def go(): - if session is not None: - session.expunge(self) - # convert expunge() exceptions to warnings - util.warn_exception(go) + # call init_failed but suppress exceptions into warnings so that original __init__ + # exception is raised + util.warn_exception(self.extension.init_failed, self, self.class_, instance, args, kwargs) raise - - # override oldinit, insuring that its not already a Mapper-decorated init method - if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'): - init._sa_mapper_init = True + + # override oldinit, ensuring that its not already a Mapper-decorated init method + if oldinit is None or not hasattr(oldinit, '_oldinit'): try: init.__name__ = oldinit.__name__ init.__doc__ = oldinit.__doc__ except: # cant set __name__ in py 2.3 ! pass + init._oldinit = oldinit self.class_.__init__ = init + _COMPILE_MUTEX.acquire() try: mapper_registry[self.class_key] = self finally: _COMPILE_MUTEX.release() + if self.entity_name is None: self.class_.c = self.c def base_mapper(self): """Return the ultimate base mapper in an inheritance chain.""" + # TODO: calculate this at mapper setup time if self.inherits is not None: return self.inherits.base_mapper() else: @@ -759,43 +734,6 @@ class Mapper(object): for m in mapper.polymorphic_iterator(): yield m - def _get_inherited_column_equivalents(self): - """Return a map of all *equivalent* columns, based on - traversing the full set of inherit_conditions across all - inheriting mappers and determining column pairs that are - equated to one another. - - This is used when relating columns to those of a polymorphic - selectable, as the selectable usually only contains one of two (or more) - columns that are equated to one another. - - The resulting structure is a dictionary of columns mapped - to lists of equivalent columns, i.e. - - { - tablea.col1: - [tableb.col1, tablec.col1], - tablea.col2: - [tabled.col2] - } - """ - - result = {} - def visit_binary(binary): - if binary.operator == '=': - if binary.left in result: - result[binary.left].append(binary.right) - else: - result[binary.left] = [binary.right] - if binary.right in result: - result[binary.right].append(binary.left) - else: - result[binary.right] = [binary.left] - vis = mapperutil.BinaryVisitor(visit_binary) - for mapper in self.base_mapper().polymorphic_iterator(): - if mapper.inherit_condition is not None: - vis.traverse(mapper.inherit_condition) - return result def add_properties(self, dict_of_properties): """Add the given dictionary of properties to this mapper, @@ -947,7 +885,7 @@ class Mapper(object): dictionary corresponding result-set ``ColumnElement`` instances to their values within a row. """ - return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name) + return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name) def identity_key_from_primary_key(self, primary_key): """Return an identity-map key for use in storing/retrieving an @@ -956,7 +894,7 @@ class Mapper(object): primary_key A list of values indicating the identifier. """ - return (self.class_, tuple(util.to_list(primary_key)), self.entity_name) + return (self._identity_class, tuple(util.to_list(primary_key)), self.entity_name) def identity_key_from_instance(self, instance): """Return the identity key for the given instance, based on @@ -972,7 +910,7 @@ class Mapper(object): instance. """ - return [self.get_attr_by_column(instance, column) for column in self.pks_by_table[self.mapped_table]] + return [self.get_attr_by_column(instance, column) for column in self.primary_key] def canload(self, instance): """return true if this mapper is capable of loading the given instance""" @@ -981,21 +919,6 @@ class Mapper(object): else: return instance.__class__ is self.class_ - def instance_key(self, instance): - """Deprecated. A synonym for `identity_key_from_instance`.""" - - return self.identity_key_from_instance(instance) - - def identity_key(self, primary_key): - """Deprecated. A synonym for `identity_key_from_primary_key`.""" - - return self.identity_key_from_primary_key(primary_key) - - def identity(self, instance): - """Deprecated. A synoynm for `primary_key_from_instance`.""" - - return self.primary_key_from_instance(instance) - def _getpropbycolumn(self, column, raiseerror=True): try: prop = self.columntoproperty[column] @@ -1017,13 +940,12 @@ class Mapper(object): prop = self._getpropbycolumn(column, raiseerror) if prop is None: return NO_ATTRIBUTE - #print "get column attribute '%s' from instance %s" % (column.key, mapperutil.instance_str(obj)) - return prop.getattr(obj) + return prop.getattr(obj, column) def set_attr_by_column(self, obj, column, value): """Set the value of an instance attribute using a Column as the key.""" - self.columntoproperty[column][0].setattr(obj, value) + self.columntoproperty[column][0].setattr(obj, value, column) def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1048,10 +970,15 @@ class Mapper(object): self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return - connection = uowtransaction.transaction.connection(self) - + if 'connection_callable' in uowtransaction.mapper_flush_opts: + connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] + tups = [(obj, connection_callable(self, obj)) for obj in objects] + else: + connection = uowtransaction.transaction.connection(self) + tups = [(obj, connection) for obj in objects] + if not postupdate: - for obj in objects: + for obj, connection in tups: if not has_identity(obj): for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.before_insert(mapper, connection, obj) @@ -1059,12 +986,12 @@ class Mapper(object): for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.before_update(mapper, connection, obj) - for obj in objects: + for obj, connection in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), # and another instance with the same identity key already exists as persistent. convert to an # UPDATE if so. mapper = object_mapper(obj) - instance_key = mapper.instance_key(obj) + instance_key = mapper.identity_key_from_instance(obj) is_row_switch = not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map if is_row_switch: existing = uowtransaction.uow.identity_map[instance_key] @@ -1090,11 +1017,11 @@ class Mapper(object): insert = [] update = [] - for obj in objects: + for obj, connection in tups: mapper = object_mapper(obj) if table not in mapper.tables or not mapper._has_pks(table): continue - instance_key = mapper.instance_key(obj) + instance_key = mapper.identity_key_from_instance(obj) if self.__should_log_debug: self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key))) @@ -1149,7 +1076,7 @@ class Mapper(object): if history: a = history.added_items() if len(a): - params[col.key] = a[0] + params[col.key] = prop.get_col_value(col, a[0]) hasdata = True else: # doing an INSERT, non primary key col ? @@ -1168,17 +1095,17 @@ class Mapper(object): if hasdata: # if none of the attributes changed, dont even # add the row to be updated. - update.append((obj, params, mapper)) + update.append((obj, params, mapper, connection)) else: - insert.append((obj, params, mapper)) + insert.append((obj, params, mapper, connection)) if len(update): mapper = table_to_mapper[table] clause = sql.and_() for col in mapper.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True)) + clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True)) if mapper.version_id_col is not None: - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True)) + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True)) statement = table.update(clause) rows = 0 supports_sane_rowcount = True @@ -1190,11 +1117,11 @@ class Mapper(object): return 0 update.sort(comparator) for rec in update: - (obj, params, mapper) = rec + (obj, params, mapper, connection) = rec c = connection.execute(statement, params) mapper._postfetch(connection, table, obj, c, c.last_updated_params()) - updated_objects.add(obj) + updated_objects.add((obj, connection)) rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): @@ -1206,7 +1133,7 @@ class Mapper(object): return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order) insert.sort(comparator) for rec in insert: - (obj, params, mapper) = rec + (obj, params, mapper, connection) = rec c = connection.execute(statement, params) primary_key = c.last_inserted_ids() if primary_key is not None: @@ -1228,12 +1155,12 @@ class Mapper(object): mapper._synchronizer.execute(obj, obj) sync(mapper) - inserted_objects.add(obj) + inserted_objects.add((obj, connection)) if not postupdate: - for obj in inserted_objects: + for obj, connection in inserted_objects: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.after_insert(mapper, connection, obj) - for obj in updated_objects: + for obj, connection in updated_objects: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.after_update(mapper, connection, obj) @@ -1273,9 +1200,14 @@ class Mapper(object): if self.__should_log_debug: self.__log_debug("delete_obj() start") - connection = uowtransaction.transaction.connection(self) + if 'connection_callable' in uowtransaction.mapper_flush_opts: + connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] + tups = [(obj, connection_callable(self, obj)) for obj in objects] + else: + connection = uowtransaction.transaction.connection(self) + tups = [(obj, connection) for obj in objects] - for obj in objects: + for (obj, connection) in tups: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.before_delete(mapper, connection, obj) @@ -1286,8 +1218,8 @@ class Mapper(object): table_to_mapper.setdefault(t, mapper) for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True): - delete = [] - for obj in objects: + delete = {} + for (obj, connection) in tups: mapper = object_mapper(obj) if table not in mapper.tables or not mapper._has_pks(table): continue @@ -1296,13 +1228,13 @@ class Mapper(object): if not hasattr(obj, '_instance_key'): continue else: - delete.append(params) + delete.setdefault(connection, []).append(params) for col in mapper.pks_by_table[table]: params[col.key] = mapper.get_attr_by_column(obj, col) if mapper.version_id_col is not None: params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col) - deleted_objects.add(obj) - if len(delete): + deleted_objects.add((obj, connection)) + for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): for col in mapper.pks_by_table[table]: @@ -1310,18 +1242,18 @@ class Mapper(object): if x != 0: return x return 0 - delete.sort(comparator) + del_objects.sort(comparator) clause = sql.and_() for col in mapper.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True)) + clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True)) if mapper.version_id_col is not None: - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type=mapper.version_id_col.type, unique=True)) + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True)) statement = table.delete(clause) - c = connection.execute(statement, delete) - if c.supports_sane_rowcount() and c.rowcount != len(delete): + c = connection.execute(statement, del_objects) + if c.supports_sane_rowcount() and c.rowcount != len(del_objects): raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete))) - for obj in deleted_objects: + for obj, connection in deleted_objects: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.after_delete(mapper, connection, obj) @@ -1429,15 +1361,17 @@ class Mapper(object): if discriminator is not None: mapper = self.polymorphic_map[discriminator] if mapper is not self: + if ('polymorphic_fetch', mapper) not in context.attributes: + context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables]) row = self.translate_row(mapper, row) return mapper._instance(context, row, result=result, skip_polymorphic=True) - + # look in main identity map. if its there, we dont do anything to it, # including modifying any of its related items lists, as its already # been exposed to being modified by the application. - populate_existing = context.populate_existing or self.always_refresh identitykey = self.identity_key_from_row(row) + populate_existing = context.populate_existing or self.always_refresh if context.session.has_key(identitykey): instance = context.session._get(identitykey) if self.__should_log_debug: @@ -1450,32 +1384,31 @@ class Mapper(object): if not context.identity_map.has_key(identitykey): context.identity_map[identitykey] = instance isnew = True - if extension.populate_instance(self, context, row, instance, identitykey, isnew) is EXT_PASS: - self.populate_instance(context, instance, row, identitykey, isnew) - if extension.append_result(self, context, row, instance, identitykey, result, isnew) is EXT_PASS: + if extension.populate_instance(self, context, row, instance, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS: + self.populate_instance(context, instance, row, **{'instancekey':identitykey, 'isnew':isnew}) + if extension.append_result(self, context, row, instance, result, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS: if result is not None: result.append(instance) return instance else: if self.__should_log_debug: - self.__log_debug("_instance(): identity key %s not in session" % str(identitykey) + repr([mapperutil.instance_str(x) for x in context.session])) + self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) # look in result-local identitymap for it. - exists = context.identity_map.has_key(identitykey) + exists = identitykey in context.identity_map if not exists: if self.allow_null_pks: # check if *all* primary key cols in the result are None - this indicates # an instance of the object is not present in the row. - for col in self.pks_by_table[self.mapped_table]: - if row[col] is not None: + for x in identitykey[1]: + if x is not None: break else: return None else: # otherwise, check if *any* primary key cols in the result are None - this indicates # an instance of the object is not present in the row. - for col in self.pks_by_table[self.mapped_table]: - if row[col] is None: - return None + if None in identitykey[1]: + return None # plugin point instance = extension.create_instance(self, context, row, self.class_) @@ -1493,9 +1426,10 @@ class Mapper(object): # call further mapper properties on the row, to pull further # instances from the row and possibly populate this item. - if extension.populate_instance(self, context, row, instance, identitykey, isnew) is EXT_PASS: - self.populate_instance(context, instance, row, identitykey, isnew) - if extension.append_result(self, context, row, instance, identitykey, result, isnew) is EXT_PASS: + flags = {'instancekey':identitykey, 'isnew':isnew} + if extension.populate_instance(self, context, row, instance, **flags) is EXT_PASS: + self.populate_instance(context, instance, row, **flags) + if extension.append_result(self, context, row, instance, result, **flags) is EXT_PASS: if result is not None: result.append(instance) return instance @@ -1510,6 +1444,24 @@ class Mapper(object): return obj + def _deferred_inheritance_condition(self, needs_tables): + cond = self.inherit_condition + + param_names = [] + def visit_binary(binary): + leftcol = binary.left + rightcol = binary.right + if leftcol is None or rightcol is None: + return + if leftcol.table not in needs_tables: + binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True) + param_names.append(leftcol) + elif rightcol not in needs_tables: + binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True) + param_names.append(rightcol) + cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True) + return cond, param_names + def translate_row(self, tomapper, row): """Translate the column keys of a row into a new or proxied row that can be understood by another mapper. @@ -1520,288 +1472,71 @@ class Mapper(object): newrow = util.DictDecorator(row) for c in tomapper.mapped_table.c: - c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True) - if row.has_key(c2): + c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=False) + if c2 and row.has_key(c2): newrow[c] = row[c2] return newrow - def populate_instance(self, selectcontext, instance, row, identitykey, isnew): - """populate an instance from a result row. - - This method iterates through the list of MapperProperty objects attached to this Mapper - and calls each properties execute() method.""" - for prop in self.__props.values(): - prop.execute(selectcontext, instance, row, identitykey, isnew) - -Mapper.logger = logging.class_logger(Mapper) - - -class MapperExtension(object): - """Base implementation for an object that provides overriding - behavior to various Mapper functions. For each method in - MapperExtension, a result of EXT_PASS indicates the functionality - is not overridden. - """ - - def get_session(self): - """Retrieve a contextual Session instance with which to - register a new object. - - Note: this is not called if a session is provided with the - `__init__` params (i.e. `_sa_session`). - """ - - return EXT_PASS - - def load(self, query, *args, **kwargs): - """Override the `load` method of the Query object. - - The return value of this method is used as the result of - ``query.load()`` if the value is anything other than EXT_PASS. - """ - - return EXT_PASS - - def get(self, query, *args, **kwargs): - """Override the `get` method of the Query object. - - The return value of this method is used as the result of - ``query.get()`` if the value is anything other than EXT_PASS. - """ - - return EXT_PASS - - def get_by(self, query, *args, **kwargs): - """Override the `get_by` method of the Query object. - - The return value of this method is used as the result of - ``query.get_by()`` if the value is anything other than - EXT_PASS. - """ - - return EXT_PASS - - def select_by(self, query, *args, **kwargs): - """Override the `select_by` method of the Query object. - - The return value of this method is used as the result of - ``query.select_by()`` if the value is anything other than - EXT_PASS. - """ - - return EXT_PASS - - def select(self, query, *args, **kwargs): - """Override the `select` method of the Query object. - - The return value of this method is used as the result of - ``query.select()`` if the value is anything other than - EXT_PASS. - """ - - return EXT_PASS - - - def translate_row(self, mapper, context, row): - """Perform pre-processing on the given result row and return a - new row instance. - - This is called as the very first step in the ``_instance()`` - method. - """ - - return EXT_PASS - - def create_instance(self, mapper, selectcontext, row, class_): - """Receive a row when a new object instance is about to be - created from that row. - - The method can choose to create the instance itself, or it can - return None to indicate normal object creation should take - place. - - mapper - The mapper doing the operation - - selectcontext - SelectionContext corresponding to the instances() call - - row - The result row from the database - - class\_ - The class we are mapping. - """ - - return EXT_PASS - - def append_result(self, mapper, selectcontext, row, instance, identitykey, result, isnew): - """Receive an object instance before that instance is appended - to a result list. - - If this method returns EXT_PASS, result appending will proceed - normally. if this method returns any other value or None, - result appending will not proceed for this instance, giving - this extension an opportunity to do the appending itself, if - desired. - - mapper - The mapper doing the operation. - - selectcontext - SelectionContext corresponding to the instances() call. - - row - The result row from the database. - - instance - The object instance to be appended to the result. - - identitykey - The identity key of the instance. - - result - List to which results are being appended. - - isnew - Indicates if this is the first time we have seen this object - instance in the current result set. if you are selecting - from a join, such as an eager load, you might see the same - object instance many times in the same result set. - """ - - return EXT_PASS - - def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew): - """Receive a newly-created instance before that instance has - its attributes populated. - - The normal population of attributes is according to each - attribute's corresponding MapperProperty (which includes - column-based attributes as well as relationships to other - classes). If this method returns EXT_PASS, instance - population will proceed normally. If any other value or None - is returned, instance population will not proceed, giving this - extension an opportunity to populate the instance itself, if - desired. - """ - - return EXT_PASS - - def before_insert(self, mapper, connection, instance): - """Receive an object instance before that instance is INSERTed - into its table. - - This is a good place to set up primary key values and such - that aren't handled otherwise. - """ - - return EXT_PASS - - def before_update(self, mapper, connection, instance): - """Receive an object instance before that instance is UPDATEed.""" - - return EXT_PASS - - def after_update(self, mapper, connection, instance): - """Receive an object instance after that instance is UPDATEed.""" - - return EXT_PASS - - def after_insert(self, mapper, connection, instance): - """Receive an object instance after that instance is INSERTed.""" - - return EXT_PASS - - def before_delete(self, mapper, connection, instance): - """Receive an object instance before that instance is DELETEed.""" - - return EXT_PASS - - def after_delete(self, mapper, connection, instance): - """Receive an object instance after that instance is DELETEed.""" - - return EXT_PASS - -class _ExtensionCarrier(MapperExtension): - def __init__(self): - self.__elements = [] - - def __iter__(self): - return iter(self.__elements) + def populate_instance(self, selectcontext, instance, row, ispostselect=None, **flags): + """populate an instance from a result row.""" + + selectcontext.stack.push_mapper(self) + populators = selectcontext.attributes.get(('instance_populators', self, selectcontext.stack.snapshot(), ispostselect), None) + if populators is None: + populators = [] + post_processors = [] + for prop in self.__props.values(): + (pop, post_proc) = prop.create_row_processor(selectcontext, self, row) + if pop is not None: + populators.append(pop) + if post_proc is not None: + post_processors.append(post_proc) + + poly_select_loader = self._get_poly_select_loader(selectcontext, row) + if poly_select_loader is not None: + post_processors.append(poly_select_loader) + + selectcontext.attributes[('instance_populators', self, selectcontext.stack.snapshot(), ispostselect)] = populators + selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors + + for p in populators: + p(instance, row, ispostselect=ispostselect, **flags) - def insert(self, extension): - """Insert a MapperExtension at the beginning of this ExtensionCarrier's list.""" - - self.__elements.insert(0, extension) - - def append(self, extension): - """Append a MapperExtension at the end of this ExtensionCarrier's list.""" - - self.__elements.append(extension) - - def get_session(self, *args, **kwargs): - return self._do('get_session', *args, **kwargs) - - def load(self, *args, **kwargs): - return self._do('load', *args, **kwargs) - - def get(self, *args, **kwargs): - return self._do('get', *args, **kwargs) - - def get_by(self, *args, **kwargs): - return self._do('get_by', *args, **kwargs) - - def select_by(self, *args, **kwargs): - return self._do('select_by', *args, **kwargs) - - def select(self, *args, **kwargs): - return self._do('select', *args, **kwargs) - - def translate_row(self, *args, **kwargs): - return self._do('translate_row', *args, **kwargs) - - def create_instance(self, *args, **kwargs): - return self._do('create_instance', *args, **kwargs) - - def append_result(self, *args, **kwargs): - return self._do('append_result', *args, **kwargs) - - def populate_instance(self, *args, **kwargs): - return self._do('populate_instance', *args, **kwargs) - - def before_insert(self, *args, **kwargs): - return self._do('before_insert', *args, **kwargs) - - def before_update(self, *args, **kwargs): - return self._do('before_update', *args, **kwargs) - - def after_update(self, *args, **kwargs): - return self._do('after_update', *args, **kwargs) + selectcontext.stack.pop() + + if self.non_primary: + selectcontext.attributes[('populating_mapper', instance)] = self + + def _post_instance(self, selectcontext, instance): + post_processors = selectcontext.attributes[('post_processors', self, None)] + for p in post_processors: + p(instance) + + def _get_poly_select_loader(self, selectcontext, row): + # 'select' or 'union'+col not present + (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None)) + if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred': + return + + cond, param_names = self._deferred_inheritance_condition(needs_tables) + statement = sql.select(needs_tables, cond, use_labels=True) + def post_execute(instance, **flags): + self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance)) - def after_insert(self, *args, **kwargs): - return self._do('after_insert', *args, **kwargs) + identitykey = self.identity_key_from_instance(instance) - def before_delete(self, *args, **kwargs): - return self._do('before_delete', *args, **kwargs) + params = {} + for c in param_names: + params[c.name] = self.get_attr_by_column(instance, c) + row = selectcontext.session.connection(self).execute(statement, **params).fetchone() + self.populate_instance(selectcontext, instance, row, **{'isnew':False, 'instancekey':identitykey, 'ispostselect':True}) - def after_delete(self, *args, **kwargs): - return self._do('after_delete', *args, **kwargs) + return post_execute + +Mapper.logger = logging.class_logger(Mapper) - def _do(self, funcname, *args, **kwargs): - for elem in self.__elements: - ret = getattr(elem, funcname)(*args, **kwargs) - if ret is not EXT_PASS: - return ret - else: - return EXT_PASS -class ExtensionOption(MapperOption): - def __init__(self, ext): - self.ext = ext - def process_query(self, query): - query.extension.append(self.ext) class ClassKey(object): """Key a class and an entity name to a mapper, via the mapper_registry.""" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a00a35ab6..6ce9fd706 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -15,8 +15,11 @@ from sqlalchemy import sql, schema, util, exceptions, sql_util, logging from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil -import sets, random -from sqlalchemy.orm.interfaces import * +import operator +from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator +from sqlalchemy.exceptions import ArgumentError + +__all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef'] class ColumnProperty(StrategizedProperty): """Describes an object attribute that corresponds to a table column.""" @@ -31,17 +34,27 @@ class ColumnProperty(StrategizedProperty): self.columns = list(columns) self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) - + self.comparator = ColumnProperty.ColumnComparator(self) + # sanity check + for col in columns: + if not hasattr(col, 'name'): + if hasattr(col, 'label'): + raise ArgumentError('ColumnProperties must be named for the mapper to work with them. Try .label() to fix this') + raise ArgumentError('%r is not a valid candidate for ColumnProperty' % col) + def create_strategy(self): if self.deferred: return strategies.DeferredColumnLoader(self) else: return strategies.ColumnLoader(self) - - def getattr(self, object): + + def copy(self): + return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) + + def getattr(self, object, column): return getattr(object, self.key) - def setattr(self, object, value): + def setattr(self, object, value, column): setattr(object, self.key, value) def get_history(self, obj, passive=False): @@ -50,19 +63,69 @@ class ColumnProperty(StrategizedProperty): def merge(self, session, source, dest, _recursive): setattr(dest, self.key, getattr(source, self.key, None)) - def compare(self, value): - return self.columns[0] == value + def get_col_value(self, column, value): + return value + + class ColumnComparator(PropComparator): + def clause_element(self): + return self.prop.columns[0] + + def operate(self, op, other): + return op(self.prop.columns[0], other) + + def reverse_operate(self, op, other): + col = self.prop.columns[0] + return op(col._bind_param(other), col) + ColumnProperty.logger = logging.class_logger(ColumnProperty) mapper.ColumnProperty = ColumnProperty +class CompositeProperty(ColumnProperty): + """subclasses ColumnProperty to provide composite type support.""" + + def __init__(self, class_, *columns, **kwargs): + super(CompositeProperty, self).__init__(*columns, **kwargs) + self.composite_class = class_ + self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator(self)) + + def copy(self): + return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) + + def getattr(self, object, column): + obj = getattr(object, self.key) + return self.get_col_value(column, obj) + + def setattr(self, object, value, column): + obj = getattr(object, self.key, None) + if obj is None: + obj = self.composite_class(*[None for c in self.columns]) + for a, b in zip(self.columns, value.__colset__()): + if a is column: + setattr(obj, b, value) + + def get_col_value(self, column, value): + for a, b in zip(self.columns, value.__colset__()): + if a is column: + return b + + class Comparator(PropComparator): + def __eq__(self, other): + if other is None: + return sql.and_(*[a==None for a in self.prop.columns]) + else: + return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())]) + + def __ne__(self, other): + return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())]) + class PropertyLoader(StrategizedProperty): """Describes an object property that holds a single item or list of items that correspond to a related database table. """ - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True): + def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None): self.uselist = uselist self.argument = argument self.entity_name = entity_name @@ -80,7 +143,9 @@ class PropertyLoader(StrategizedProperty): self.remote_side = util.to_set(remote_side) self.enable_typechecks = enable_typechecks self._parent_join_cache = {} - + self.comparator = PropertyLoader.Comparator(self) + self.join_depth = join_depth + if cascade is not None: self.cascade = mapperutil.CascadeOptions(cascade) else: @@ -91,7 +156,7 @@ class PropertyLoader(StrategizedProperty): self.association = association self.order_by = order_by - self.attributeext = attributeext + self.attributeext=attributeext if isinstance(backref, str): # propigate explicitly sent primary/secondary join conditions to the BackRef object if # just a string was sent @@ -104,9 +169,96 @@ class PropertyLoader(StrategizedProperty): self.backref = backref self.is_backref = is_backref - def compare(self, value): - return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))]) - + class Comparator(PropComparator): + def __eq__(self, other): + if other is None: + return ~sql.exists([1], self.prop.primaryjoin) + elif self.prop.uselist: + if not hasattr(other, '__iter__'): + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.") + else: + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + clauses = [] + for o in other: + clauses.append( + sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))])) + ) + return sql.and_(*clauses) + else: + return self.prop._optimized_compare(other) + + def any(self, criterion=None, **kwargs): + if not self.prop.uselist: + raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + for k in kwargs: + crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) + if criterion is None: + criterion = crit + else: + criterion = criterion & crit + return sql.exists([1], j & criterion) + + def has(self, criterion=None, **kwargs): + if self.prop.uselist: + raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().") + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + for k in kwargs: + crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) + if criterion is None: + criterion = crit + else: + criterion = criterion & crit + return sql.exists([1], j & criterion) + + def contains(self, other): + if not self.prop.uselist: + raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") + clause = self.prop._optimized_compare(other) + + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + + clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + return clause + + def __ne__(self, other): + if self.prop.uselist and not hasattr(other, '__iter__'): + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") + + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + + def compare(self, op, value, value_is_parent=False): + if op == operator.eq: + if value is None: + return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin) + else: + return self._optimized_compare(value, value_is_parent=value_is_parent) + else: + return op(self.comparator, value) + + def _optimized_compare(self, value, value_is_parent=False): + # optimized operation for ==, uses a lazy clause. + (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent) + bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) + + class Visitor(sql.ClauseVisitor): + def visit_bindparam(s, bindparam): + mapper = value_is_parent and self.parent or self.mapper + bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key]) + Visitor().traverse(criterion) + return criterion + private = property(lambda s:s.cascade.delete_orphan) def create_strategy(self): @@ -127,12 +279,13 @@ class PropertyLoader(StrategizedProperty): if childlist is None: return if self.uselist: - # sets a blank list according to the correct list class - dest_list = getattr(self.parent.class_, self.key).initialize(dest) + # sets a blank collection according to the correct list class + dest_list = sessionlib.attribute_manager.init_collection(dest, self.key) for current in list(childlist): obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive) if obj is not None: - dest_list.append(obj) + #dest_list.append_without_event(obj) + dest_list.append_with_event(obj) else: current = list(childlist)[0] if current is not None: @@ -267,7 +420,7 @@ class PropertyLoader(StrategizedProperty): if len(self.foreign_keys): self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return if binary.left in self.foreign_keys: self._opposite_side.add(binary.right) @@ -280,7 +433,7 @@ class PropertyLoader(StrategizedProperty): self.foreign_keys = util.Set() self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return # this check is for when the user put the "view_only" flag on and has tables that have nothing @@ -362,16 +515,13 @@ class PropertyLoader(StrategizedProperty): "argument." % (str(self))) def _determine_remote_side(self): - if len(self.remote_side): - return - self.remote_side = util.Set() + if not len(self.remote_side): + if self.direction is sync.MANYTOONE: + self.remote_side = util.Set(self._opposite_side) + elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: + self.remote_side = util.Set(self.foreign_keys) - if self.direction is sync.MANYTOONE: - for c in self._opposite_side: - self.remote_side.add(c) - elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: - for c in self.foreign_keys: - self.remote_side.add(c) + self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side) def _create_polymorphic_joins(self): # get ready to create "polymorphic" primary/secondary join clauses. @@ -383,27 +533,26 @@ class PropertyLoader(StrategizedProperty): # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out, # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge # several "equivalent" columns (such as parent/child fk cols) into just one column. - target_equivalents = self.mapper._get_inherited_column_equivalents() + target_equivalents = self.mapper._get_equivalent_columns() + # if the target mapper loads polymorphically, adapt the clauses to the target's selectable if self.loads_polymorphic: if self.secondaryjoin: - self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container() - sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin) - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() + self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True) + self.polymorphic_primaryjoin = self.primaryjoin else: - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() if self.direction is sync.ONETOMANY: - sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) + self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) elif self.direction is sync.MANYTOONE: - sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) + self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) self.polymorphic_secondaryjoin = None # load "polymorphic" versions of the columns present in "remote_side" - this is # important for lazy-clause generation which goes off the polymorphic target selectable for c in list(self.remote_side): - if self.secondary and c in self.secondary.columns: + if self.secondary and self.secondary.columns.contains_column(c): continue - for equiv in [c] + (c in target_equivalents and target_equivalents[c] or []): + for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False) if corr: self.remote_side.add(corr) @@ -411,8 +560,8 @@ class PropertyLoader(StrategizedProperty): else: raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table)) else: - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() - self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None + self.polymorphic_primaryjoin = self.primaryjoin + self.polymorphic_secondaryjoin = self.secondaryjoin def _post_init(self): if logging.is_info_enabled(self.logger): @@ -450,22 +599,20 @@ class PropertyLoader(StrategizedProperty): def _is_self_referential(self): return self.parent.mapped_table is self.target or self.parent.select_table is self.target - def get_join(self, parent, primary=True, secondary=True): + def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): try: - return self._parent_join_cache[(parent, primary, secondary)] + return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] except KeyError: - parent_equivalents = parent._get_inherited_column_equivalents() - primaryjoin = self.polymorphic_primaryjoin.copy_container() - if self.secondaryjoin is not None: - secondaryjoin = self.polymorphic_secondaryjoin.copy_container() - else: - secondaryjoin = None - if self.direction is sync.ONETOMANY: - sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) - elif self.direction is sync.MANYTOONE: - sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) - elif self.secondaryjoin: - sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) + parent_equivalents = parent._get_equivalent_columns() + secondaryjoin = self.polymorphic_secondaryjoin + if polymorphic_parent: + # adapt the "parent" side of our join condition to the "polymorphic" select of the parent + if self.direction is sync.ONETOMANY: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + elif self.direction is sync.MANYTOONE: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + elif self.secondaryjoin: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) if secondaryjoin is not None: if secondary and not primary: @@ -476,7 +623,7 @@ class PropertyLoader(StrategizedProperty): j = primaryjoin else: j = primaryjoin - self._parent_join_cache[(parent, primary, secondary)] = j + self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j return j def register_dependencies(self, uowcommit): @@ -501,7 +648,7 @@ class BackRef(object): # try to set a LazyLoader on our mapper referencing the parent mapper mapper = prop.mapper.primary_mapper() - if not mapper.props.has_key(self.key): + if not mapper.get_property(self.key, raiseerr=False) is not None: pj = self.kwargs.pop('primaryjoin', None) sj = self.kwargs.pop('secondaryjoin', None) # the backref property is set on the primary mapper @@ -512,26 +659,26 @@ class BackRef(object): backref=prop.key, is_backref=True, **self.kwargs) mapper._compile_property(self.key, relation); - elif not isinstance(mapper.props[self.key], PropertyLoader): + elif not isinstance(mapper.get_property(self.key), PropertyLoader): raise exceptions.ArgumentError( "Can't create backref '%s' on mapper '%s'; an incompatible " "property of that name already exists" % (self.key, str(mapper))) else: # else set one of us as the "backreference" parent = prop.parent.primary_mapper() - if parent.class_ is not mapper.props[self.key]._get_target_class(): + if parent.class_ is not mapper.get_property(self.key)._get_target_class(): raise exceptions.ArgumentError( "Backrefs do not match: backref '%s' expects to connect to %s, " "but found a backref already connected to %s" % - (self.key, str(parent.class_), str(mapper.props[self.key].mapper.class_))) - if not mapper.props[self.key].is_backref: + (self.key, str(parent.class_), str(mapper.get_property(self.key).mapper.class_))) + if not mapper.get_property(self.key).is_backref: prop.is_backref=True if not prop.viewonly: prop._dependency_processor.is_backref=True # reverse_property used by dependencies.ManyToManyDP to check # association table operations - prop.reverse_property = mapper.props[self.key] - mapper.props[self.key].reverse_property = prop + prop.reverse_property = mapper.get_property(self.key) + mapper.get_property(self.key).reverse_property = prop def get_extension(self): """Return an attribute extension to use with this backreference.""" diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d51fd75c3..284653b5c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -4,89 +4,49 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util, logging, schema -from sqlalchemy.orm import mapper, class_mapper, object_mapper -from sqlalchemy.orm.interfaces import OperationContext, SynonymProperty -import random +from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy.orm import mapper, object_mapper +from sqlalchemy.orm import util as mapperutil +from sqlalchemy.orm.interfaces import OperationContext, LoaderStack +import operator __all__ = ['Query', 'QueryContext', 'SelectionContext'] class Query(object): - """Encapsulates the object-fetching operations provided by Mappers. + """Encapsulates the object-fetching operations provided by Mappers.""" - Note that this particular version of Query contains the 0.3 API as well as most of the - 0.4 API for forwards compatibility. A large part of the API here is deprecated (but still present) - in the 0.4 series. - """ - - def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, with_options=None, extension=None, **kwargs): + def __init__(self, class_or_mapper, session=None, entity_name=None): if isinstance(class_or_mapper, type): self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name) else: self.mapper = class_or_mapper.compile() - self.with_options = with_options or [] self.select_mapper = self.mapper.get_select_mapper().compile() - self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) - self.lockmode = lockmode - self.extension = mapper._ExtensionCarrier() - if extension is not None: - self.extension.append(extension) - self.extension.append(self.mapper.extension) - self.is_polymorphic = self.mapper is not self.select_mapper + self._session = session - if not hasattr(self.mapper, '_get_clause'): - _get_clause = sql.and_() - for primary_key in self.primary_key_columns: - _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True)) - self.mapper._get_clause = _get_clause + self._with_options = [] + self._lockmode = None + self._extension = self.mapper.extension.copy() self._entities = [] - self._get_clause = self.mapper._get_clause - - self._order_by = kwargs.pop('order_by', False) - self._group_by = kwargs.pop('group_by', False) - self._distinct = kwargs.pop('distinct', False) - self._offset = kwargs.pop('offset', None) - self._limit = kwargs.pop('limit', None) - self._criterion = None + self._order_by = False + self._group_by = False + self._distinct = False + self._offset = None + self._limit = None + self._statement = None self._params = {} - self._col = None - self._func = None + self._criterion = None + self._column_aggregate = None self._joinpoint = self.mapper + self._aliases = None + self._alias_ids = {} self._from_obj = [self.table] - self._statement = None - - for opt in util.flatten_iterator(self.with_options): - opt.process_query(self) + self._populate_existing = False + self._version_check = False def _clone(self): - # yes, a little embarassing here. - # go look at 0.4 for the simple version. q = Query.__new__(Query) - q.mapper = self.mapper - q.select_mapper = self.select_mapper - q._order_by = self._order_by - q._distinct = self._distinct - q._entities = list(self._entities) - q.always_refresh = self.always_refresh - q.with_options = list(self.with_options) - q._session = self.session - q.is_polymorphic = self.is_polymorphic - q.lockmode = self.lockmode - q.extension = mapper._ExtensionCarrier() - for ext in self.extension: - q.extension.append(ext) - q._offset = self._offset - q._limit = self._limit - q._params = self._params - q._group_by = self._group_by - q._get_clause = self._get_clause - q._from_obj = list(self._from_obj) - q._joinpoint = self._joinpoint - q._criterion = self._criterion - q._statement = self._statement - q._col = self._col - q._func = self._func + q.__dict__ = self.__dict__.copy() return q def _get_session(self): @@ -96,7 +56,7 @@ class Query(object): return self._session table = property(lambda s:s.select_mapper.mapped_table) - primary_key_columns = property(lambda s:s.select_mapper.pks_by_table[s.select_mapper.mapped_table]) + primary_key_columns = property(lambda s:s.select_mapper.primary_key) session = property(_get_session) def get(self, ident, **kwargs): @@ -108,13 +68,20 @@ class Query(object): columns. """ - ret = self.extension.get(self, ident, **kwargs) + ret = self._extension.get(self, ident, **kwargs) if ret is not mapper.EXT_PASS: return ret - key = self.mapper.identity_key(ident) + + # convert composite types to individual args + # TODO: account for the order of columns in the + # ColumnProperty it corresponds to + if hasattr(ident, '__colset__'): + ident = ident.__colset__() + + key = self.mapper.identity_key_from_primary_key(ident) return self._get(key, ident, **kwargs) - def load(self, ident, **kwargs): + def load(self, ident, raiseerr=True, **kwargs): """Return an instance of the object based on the given identifier. @@ -125,304 +92,14 @@ class Query(object): columns. """ - ret = self.extension.load(self, ident, **kwargs) + ret = self._extension.load(self, ident, **kwargs) if ret is not mapper.EXT_PASS: return ret - key = self.mapper.identity_key(ident) + key = self.mapper.identity_key_from_primary_key(ident) instance = self._get(key, ident, reload=True, **kwargs) - if instance is None: + if instance is None and raiseerr: raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) return instance - - def get_by(self, *args, **params): - """Like ``select_by()``, but only return the first - as a scalar, or None if no object found. - Synonymous with ``selectfirst_by()``. - - The criterion is constructed in the same way as the - ``select_by()`` method. - - this method is deprecated in 0.4. - """ - - ret = self.extension.get_by(self, *args, **params) - if ret is not mapper.EXT_PASS: - return ret - x = self.select_whereclause(self.join_by(*args, **params), limit=1) - if x: - return x[0] - else: - return None - - def select_by(self, *args, **params): - """Return an array of object instances based on the given - clauses and key/value criterion. - - \*args - a list of zero or more ``ClauseElements`` which will be - connected by ``AND`` operators. - - \**params - a set of zero or more key/value parameters which - are converted into ``ClauseElements``. the keys are mapped to - property or column names mapped by this mapper's Table, and - the values are coerced into a ``WHERE`` clause separated by - ``AND`` operators. If the local property/column names dont - contain the key, a search will be performed against this - mapper's immediate list of relations as well, forming the - appropriate join conditions if a matching property is located. - - if the located property is a column-based property, the comparison - value should be a scalar with an appropriate type. If the - property is a relationship-bound property, the comparison value - should be an instance of the related class. - - E.g.:: - - result = usermapper.select_by(user_name = 'fred') - - this method is deprecated in 0.4. - """ - - ret = self.extension.select_by(self, *args, **params) - if ret is not mapper.EXT_PASS: - return ret - return self.select_whereclause(self.join_by(*args, **params)) - - def join_by(self, *args, **params): - """Return a ``ClauseElement`` representing the ``WHERE`` - clause that would normally be sent to ``select_whereclause()`` - by ``select_by()``. - - The criterion is constructed in the same way as the - ``select_by()`` method. - - this method is deprecated in 0.4. - """ - - return self._join_by(args, params) - - - def join_to(self, key): - """Given the key name of a property, will recursively descend - through all child properties from this Query's mapper to - locate the property, and will return a ClauseElement - representing a join from this Query's mapper to the endmost - mapper. - - this method is deprecated in 0.4. - """ - - [keys, p] = self._locate_prop(key) - return self.join_via(keys) - - def join_via(self, keys): - """Given a list of keys that represents a path from this - Query's mapper to a related mapper based on names of relations - from one mapper to the next, return a ClauseElement - representing a join from this Query's mapper to the endmost - mapper. - - this method is deprecated in 0.4. - """ - - mapper = self.mapper - clause = None - for key in keys: - prop = mapper.get_property(key, resolve_synonyms=True) - if clause is None: - clause = prop.get_join(mapper) - else: - clause &= prop.get_join(mapper) - mapper = prop.mapper - - return clause - - def selectfirst_by(self, *args, **params): - """Like ``select_by()``, but only return the first - as a scalar, or None if no object found. - Synonymous with ``get_by()``. - - The criterion is constructed in the same way as the - ``select_by()`` method. - - this method is deprecated in 0.4. - """ - - return self.get_by(*args, **params) - - def selectone_by(self, *args, **params): - """Like ``selectfirst_by()``, but throws an error if not - exactly one result was returned. - - The criterion is constructed in the same way as the - ``select_by()`` method. - - this method is deprecated in 0.4. - """ - - ret = self.select_whereclause(self.join_by(*args, **params), limit=2) - if len(ret) == 1: - return ret[0] - elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for selectone_by') - else: - raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by') - - def count_by(self, *args, **params): - """Return the count of instances based on the given clauses - and key/value criterion. - - The criterion is constructed in the same way as the - ``select_by()`` method. - - this method is deprecated in 0.4. - """ - - return self.count(self.join_by(*args, **params)) - - def selectfirst(self, arg=None, **kwargs): - """Query for a single instance using the given criterion. - - Arguments are the same as ``select()``. In the case that - the given criterion represents ``WHERE`` criterion only, - LIMIT 1 is applied to the fully generated statement. - - this method is deprecated in 0.4. - """ - - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - ret = self.select_statement(arg, **kwargs) - else: - kwargs['limit'] = 1 - ret = self.select_whereclause(whereclause=arg, **kwargs) - if ret: - return ret[0] - else: - return None - - def selectone(self, arg=None, **kwargs): - """Query for a single instance using the given criterion. - - Unlike ``selectfirst``, this method asserts that only one - row exists. In the case that the given criterion represents - ``WHERE`` criterion only, LIMIT 2 is applied to the fully - generated statement. - - this method is deprecated in 0.4. - """ - - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - ret = self.select_statement(arg, **kwargs) - else: - kwargs['limit'] = 2 - ret = self.select_whereclause(whereclause=arg, **kwargs) - if len(ret) == 1: - return ret[0] - elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for selectone_by') - else: - raise exceptions.InvalidRequestError('Multiple rows returned for selectone') - - def select(self, arg=None, **kwargs): - """Select instances of the object from the database. - - `arg` can be any ClauseElement, which will form the criterion - with which to load the objects. - - For more advanced usage, arg can also be a Select statement - object, which will be executed and its resulting rowset used - to build new object instances. - - In this case, the developer must ensure that an adequate set - of columns exists in the rowset with which to build new object - instances. - - this method is deprecated in 0.4. - """ - - ret = self.extension.select(self, arg=arg, **kwargs) - if ret is not mapper.EXT_PASS: - return ret - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - return self.select_statement(arg, **kwargs) - else: - return self.select_whereclause(whereclause=arg, **kwargs) - - def select_whereclause(self, whereclause=None, params=None, **kwargs): - """Given a ``WHERE`` criterion, create a ``SELECT`` statement, - execute and return the resulting instances. - - this method is deprecated in 0.4. - """ - statement = self.compile(whereclause, **kwargs) - return self._select_statement(statement, params=params) - - def count(self, whereclause=None, params=None, **kwargs): - """Given a ``WHERE`` criterion, create a ``SELECT COUNT`` - statement, execute and return the resulting count value. - - the additional arguments to this method are is deprecated in 0.4. - - """ - if self._criterion: - if whereclause is not None: - whereclause = sql.and_(self._criterion, whereclause) - else: - whereclause = self._criterion - from_obj = kwargs.pop('from_obj', self._from_obj) - kwargs.setdefault('distinct', self._distinct) - - alltables = [] - for l in [sql_util.TableFinder(x) for x in from_obj]: - alltables += l - - if self.table not in alltables: - from_obj.append(self.table) - if self._nestable(**kwargs): - s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count() - else: - primary_key = self.primary_key_columns - s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs) - return self.session.scalar(self.mapper, s, params=params) - - def select_statement(self, statement, **params): - """Given a ``ClauseElement``-based statement, execute and - return the resulting instances. - - this method is deprecated in 0.4. - """ - - return self._select_statement(statement, params=params) - - def select_text(self, text, **params): - """Given a literal string-based statement, execute and return - the resulting instances. - - this method is deprecated in 0.4. use from_statement() instead. - """ - - t = sql.text(text) - return self.execute(t, params=params) - - def _with_lazy_criterion(cls, instance, prop, reverse=False): - """extract query criterion from a LazyLoader strategy given a Mapper, - source persisted/detached instance and PropertyLoader. - - """ - - from sqlalchemy.orm import strategies - (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(prop, reverse_direction=reverse) - bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) - - class Visitor(sql.ClauseVisitor): - def visit_bindparam(self, bindparam): - mapper = reverse and prop.mapper or prop.parent - bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key]) - Visitor().traverse(criterion) - return criterion - _with_lazy_criterion = classmethod(_with_lazy_criterion) - def query_from_parent(cls, instance, property, **kwargs): """return a newly constructed Query object, with criterion corresponding to @@ -445,9 +122,25 @@ class Query(object): mapper = object_mapper(instance) prop = mapper.get_property(property, resolve_synonyms=True) target = prop.mapper - criterion = cls._with_lazy_criterion(instance, prop) + criterion = prop.compare(operator.eq, instance, value_is_parent=True) return Query(target, **kwargs).filter(criterion) query_from_parent = classmethod(query_from_parent) + + def populate_existing(self): + """return a Query that will refresh all instances loaded. + + this includes all entities accessed from the database, including + secondary entities, eagerly-loaded collection items. + + All changes present on entities which are already present in the session will + be reset and the entities will all be marked "clean". + + This is essentially the en-masse version of load(). + """ + + q = self._clone() + q._populate_existing = True + return q def with_parent(self, instance, property=None): """add a join criterion corresponding to a relationship to the given parent instance. @@ -474,9 +167,9 @@ class Query(object): raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__)) else: prop = mapper.get_property(property, resolve_synonyms=True) - return self.filter(Query._with_lazy_criterion(instance, prop)) + return self.filter(prop.compare(operator.eq, instance, value_is_parent=True)) - def add_entity(self, entity): + def add_entity(self, entity, alias=None, id=None): """add a mapped entity to the list of result columns to be returned. This will have the effect of all result-returning methods returning a tuple @@ -492,12 +185,25 @@ class Query(object): entity a class or mapper which will be added to the results. + alias + a sqlalchemy.sql.Alias object which will be used to select rows. this + will match the usage of the given Alias in filter(), order_by(), etc. expressions + + id + a string ID matching that given to query.join() or query.outerjoin(); rows will be + selected from the aliased join created via those methods. """ q = self._clone() - q._entities.append(entity) + + if isinstance(entity, type): + entity = mapper.class_mapper(entity) + if alias is not None: + alias = mapperutil.AliasedClauses(entity.mapped_table, alias=alias) + + q._entities = q._entities + [(entity, alias, id)] return q - def add_column(self, column): + def add_column(self, column, id=None): """add a SQL ColumnElement to the list of result columns to be returned. This will have the effect of all result-returning methods returning a tuple @@ -517,51 +223,56 @@ class Query(object): """ q = self._clone() - + # alias non-labeled column elements. - # TODO: make the generation deterministic if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'): - column = column.label("anon_" + hex(random.randint(0, 65535))[2:]) - - q._entities.append(column) + column = column.label(None) + + q._entities = q._entities + [(column, None, id)] return q - def options(self, *args, **kwargs): + def options(self, *args): """Return a new Query object, applying the given list of MapperOptions. """ + q = self._clone() - for opt in util.flatten_iterator(args): - q.with_options.append(opt) + opts = [o for o in util.flatten_iterator(args)] + q._with_options = q._with_options + opts + for opt in opts: opt.process_query(q) return q def with_lockmode(self, mode): """Return a new Query object with the specified locking mode.""" q = self._clone() - q.lockmode = mode + q._lockmode = mode return q def params(self, **kwargs): """add values for bind parameters which may have been specified in filter().""" - + q = self._clone() q._params = q._params.copy() q._params.update(kwargs) return q - + def filter(self, criterion): """apply the given filtering criterion to the query and return the newly resulting ``Query`` the criterion is any sql.ClauseElement applicable to the WHERE clause of a select. """ - + if isinstance(criterion, basestring): criterion = sql.text(criterion) - + if criterion is not None and not isinstance(criterion, sql.ClauseElement): raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") - + + + if self._aliases is not None: + criterion = self._aliases.adapt_clause(criterion) + q = self._clone() if q._criterion is not None: q._criterion = q._criterion & criterion @@ -569,23 +280,36 @@ class Query(object): q._criterion = criterion return q - def filter_by(self, *args, **kwargs): - """apply the given filtering criterion to the query and return the newly resulting ``Query`` + def filter_by(self, **kwargs): + """apply the given filtering criterion to the query and return the newly resulting ``Query``.""" - The criterion is constructed in the same way as the - ``select_by()`` method. - """ - return self.filter(self._join_by(args, kwargs, start=self._joinpoint)) + #import properties + + alias = None + join = None + clause = None + joinpoint = self._joinpoint - def _join_to(self, prop, outerjoin=False, start=None): - if start is None: - start = self._joinpoint + for key, value in kwargs.iteritems(): + prop = joinpoint.get_property(key, resolve_synonyms=True) + c = prop.compare(operator.eq, value) - if isinstance(prop, list): - keys = prop + if alias is not None: + sql_util.ClauseAdapter(alias).traverse(c) + if clause is None: + clause = c + else: + clause &= c + + if join is not None: + return self.select_from(join).filter(clause) else: - [keys,p] = self._locate_prop(prop, start=start) + return self.filter(clause) + def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): + if start is None: + start = self._joinpoint + clause = self._from_obj[-1] currenttables = [clause] @@ -594,101 +318,56 @@ class Query(object): currenttables.append(join.left) currenttables.append(join.right) FindJoinedTables().traverse(clause) - + mapper = start - for key in keys: + alias = self._aliases + for key in util.to_list(keys): prop = mapper.get_property(key, resolve_synonyms=True) - if prop._is_self_referential(): - raise exceptions.InvalidRequestError("Self-referential query on '%s' property must be constructed manually using an Alias object for the related table." % str(prop)) - # dont re-join to a table already in our from objects - if prop.select_table not in currenttables: - if outerjoin: - if prop.secondary: - clause = clause.outerjoin(prop.secondary, prop.get_join(mapper, primary=True, secondary=False)) - clause = clause.outerjoin(prop.select_table, prop.get_join(mapper, primary=False)) + if prop._is_self_referential() and not create_aliases: + raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop)) + + if prop.select_table not in currenttables or create_aliases: + if prop.secondary: + if create_aliases: + alias = mapperutil.PropertyAliasedClauses(prop, + prop.get_join(mapper, primary=True, secondary=False), + prop.get_join(mapper, primary=False, secondary=True), + alias + ) + clause = clause.join(alias.secondary, alias.primaryjoin, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin) else: - clause = clause.outerjoin(prop.select_table, prop.get_join(mapper)) + clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False), isouter=outerjoin) + clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False), isouter=outerjoin) else: - if prop.secondary: - clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False)) - clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False)) + if create_aliases: + alias = mapperutil.PropertyAliasedClauses(prop, + prop.get_join(mapper, primary=True, secondary=False), + None, + alias + ) + clause = clause.join(alias.alias, alias.primaryjoin, isouter=outerjoin) else: - clause = clause.join(prop.select_table, prop.get_join(mapper)) - elif prop.secondary is not None and prop.secondary not in currenttables: + clause = clause.join(prop.select_table, prop.get_join(mapper), isouter=outerjoin) + elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables: # TODO: this check is not strong enough for different paths to the same endpoint which # does not use secondary tables - raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use explicit `Alias` objects." % prop.key) + raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key) mapper = prop.mapper - return (clause, mapper) - - def _join_by(self, args, params, start=None): - """Return a ``ClauseElement`` representing the ``WHERE`` - clause that would normally be sent to ``select_whereclause()`` - by ``select_by()``. - - The criterion is constructed in the same way as the - ``select_by()`` method. - """ - import properties - - clause = None - for arg in args: - if clause is None: - clause = arg - else: - clause &= arg - - for key, value in params.iteritems(): - (keys, prop) = self._locate_prop(key, start=start) - if isinstance(prop, properties.PropertyLoader): - c = self._with_lazy_criterion(value, prop, True) & self.join_via(keys[:-1]) - else: - c = prop.compare(value) & self.join_via(keys) - if clause is None: - clause = c - else: - clause &= c - return clause - - def _locate_prop(self, key, start=None): - import properties - keys = [] - seen = util.Set() - def search_for_prop(mapper_): - if mapper_ in seen: - return None - seen.add(mapper_) - if mapper_.props.has_key(key): - prop = mapper_.get_property(key, resolve_synonyms=True) - if isinstance(prop, properties.PropertyLoader): - keys.insert(0, prop.key) - return prop - else: - for prop in mapper_.iterate_properties: - if not isinstance(prop, properties.PropertyLoader): - continue - x = search_for_prop(prop.mapper) - if x: - keys.insert(0, prop.key) - return x - else: - return None - p = search_for_prop(start or self.mapper) - if p is None: - raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key) - return [keys, p] + if create_aliases: + return (clause, mapper, alias) + else: + return (clause, mapper, None) def _generative_col_aggregate(self, col, func): """apply the given aggregate function to the query and return the newly resulting ``Query``. """ - if self._col is not None or self._func is not None: + if self._column_aggregate is not None: raise exceptions.InvalidRequestError("Query already contains an aggregate column or function") q = self._clone() - q._col = col - q._func = func + q._column_aggregate = (col, func) return q def apply_min(self, col): @@ -721,13 +400,13 @@ class Query(object): For performance, only use subselect if `order_by` attribute is set. """ - ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj} + ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj} if self._order_by is not False: s1 = sql.select([col], self._criterion, **ops).alias('u') - return sql.select([func(s1.corresponding_column(col))]).scalar() + return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar() else: - return sql.select([func(col)], self._criterion, **ops).scalar() + return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar() def min(self, col): """Execute the SQL ``min()`` function against the given column.""" @@ -756,7 +435,7 @@ class Query(object): if q._order_by is False: q._order_by = util.to_list(criterion) else: - q._order_by.extend(util.to_list(criterion)) + q._order_by = q._order_by + util.to_list(criterion) return q def group_by(self, criterion): @@ -766,52 +445,62 @@ class Query(object): if q._group_by is False: q._group_by = util.to_list(criterion) else: - q._group_by.extend(util.to_list(criterion)) + q._group_by = q._group_by + util.to_list(criterion) return q - - def join(self, prop): + + def join(self, prop, id=None, aliased=False, from_joinpoint=False): """create a join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. - - 'prop' may be a string property name in which it is located - in the same manner as keyword arguments in ``select_by``, or - it may be a list of strings in which case the property is located - by direct traversal of each keyname (i.e. like join_via()). + + 'prop' may be a string property name or a list of string + property names. """ - - q = self._clone() - (clause, mapper) = self._join_to(prop, outerjoin=False) - q._from_obj = [clause] - q._joinpoint = mapper - return q - def outerjoin(self, prop): + return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint) + + def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False): """create a left outer join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. - 'prop' may be a string property name in which it is located - in the same manner as keyword arguments in ``select_by``, or - it may be a list of strings in which case the property is located - by direct traversal of each keyname (i.e. like join_via()). + 'prop' may be a string property name or a list of string + property names. """ + + return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint) + + def _join(self, prop, id, outerjoin, aliased, from_joinpoint): + (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased) q = self._clone() - (clause, mapper) = self._join_to(prop, outerjoin=True) q._from_obj = [clause] q._joinpoint = mapper + q._aliases = aliases + + a = aliases + while a is not None: + q._alias_ids.setdefault(a.mapper, []).append(a) + q._alias_ids.setdefault(a.table, []).append(a) + q._alias_ids.setdefault(a.alias, []).append(a) + a = a.parentclauses + + if id: + q._alias_ids[id] = aliases return q def reset_joinpoint(self): """return a new Query reset the 'joinpoint' of this Query reset back to the starting mapper. Subsequent generative calls will be constructed from the new joinpoint. - - This is an interim method which will not be needed with new behavior - to be released in 0.4.""" - + + Note that each call to join() or outerjoin() also starts from + the root. + """ + q = self._clone() q._joinpoint = q.mapper + q._aliases = None return q + def select_from(self, from_obj): """Set the `from_obj` parameter of the query and return the newly resulting ``Query``. @@ -823,20 +512,6 @@ class Query(object): new._from_obj = list(new._from_obj) + util.to_list(from_obj) return new - def __getattr__(self, key): - if (key.startswith('select_by_')): - key = key[10:] - def foo(arg): - return self.select_by(**{key:arg}) - return foo - elif (key.startswith('get_by_')): - key = key[7:] - def foo(arg): - return self.get_by(**{key:arg}) - return foo - else: - raise AttributeError(key) - def __getitem__(self, item): if isinstance(item, slice): start = item.start @@ -884,99 +559,65 @@ class Query(object): new._distinct = True return new - def list(self): + def all(self): """Return the results represented by this ``Query`` as a list. This results in an execution of the underlying query. - - this method is deprecated in 0.4. use all() instead. """ - return list(self) - - def one(self): - """Return the first result of this ``Query``, raising an exception if more than one row exists. - - This results in an execution of the underlying query. - this method is for forwards-compatibility with 0.4. - """ - - if self._col is None or self._func is None: - ret = list(self[0:2]) - - if len(ret) == 1: - return ret[0] - elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for one()') - else: - raise exceptions.InvalidRequestError('Multiple rows returned for one()') - else: - return self._col_aggregate(self._col, self._func) - + + def from_statement(self, statement): + if isinstance(statement, basestring): + statement = sql.text(statement) + q = self._clone() + q._statement = statement + return q + def first(self): """Return the first result of this ``Query``. This results in an execution of the underlying query. - - this method is for forwards-compatibility with 0.4. """ - if self._col is None or self._func is None: - ret = list(self[0:1]) - if len(ret) > 0: - return ret[0] - else: - return None + if self._column_aggregate is not None: + return self._col_aggregate(*self._column_aggregate) + + ret = list(self[0:1]) + if len(ret) > 0: + return ret[0] else: - return self._col_aggregate(self._col, self._func) + return None - def all(self): - """Return the results represented by this ``Query`` as a list. + def one(self): + """Return the first result of this ``Query``, raising an exception if more than one row exists. This results in an execution of the underlying query. """ - return self.list() - - def from_statement(self, statement): - """execute a full select() statement, or literal textual string as a SELECT statement. - - this method is for forwards compatibility with 0.4. - """ - if isinstance(statement, basestring): - statement = sql.text(statement) - q = self._clone() - q._statement = statement - return q - def scalar(self): - """Return the first result of this ``Query``. + if self._column_aggregate is not None: + return self._col_aggregate(*self._column_aggregate) - This results in an execution of the underlying query. - - this method will be deprecated in 0.4; first() is added for - forwards-compatibility. - """ + ret = list(self[0:2]) - return self.first() + if len(ret) == 1: + return ret[0] + elif len(ret) == 0: + raise exceptions.InvalidRequestError('No rows returned for one()') + else: + raise exceptions.InvalidRequestError('Multiple rows returned for one()') def __iter__(self): - return iter(self.select_whereclause()) - - def execute(self, clauseelement, params=None, *args, **kwargs): - """Execute the given ClauseElement-based statement against - this Query's session/mapper, return the resulting list of - instances. - - this method is deprecated in 0.4. Use from_statement() instead. - """ - - p = self._params - if params is not None: - p.update(params) - result = self.session.execute(self.mapper, clauseelement, params=p) + statement = self.compile() + statement.use_labels = True + if self.session.autoflush: + self.session.flush() + return self._execute_and_instances(statement) + + def _execute_and_instances(self, statement): + result = self.session.execute(statement, params=self._params, mapper=self.mapper) try: - return self.instances(result, **kwargs) + return iter(self.instances(result)) finally: result.close() @@ -984,50 +625,47 @@ class Query(object): """Return a list of mapped instances corresponding to the rows in a given *cursor* (i.e. ``ResultProxy``). - \*mappers_or_columns is an optional list containing one or more of - classes, mappers, strings or sql.ColumnElements which will be - applied to each row and added horizontally to the result set, - which becomes a list of tuples. The first element in each tuple - is the usual result based on the mapper represented by this - ``Query``. Each additional element in the tuple corresponds to an - entry in the \*mappers_or_columns list. - - For each element in \*mappers_or_columns, if the element is - a mapper or mapped class, an additional class instance will be - present in the tuple. If the element is a string or sql.ColumnElement, - the corresponding result column from each row will be present in the tuple. - - Note that when \*mappers_or_columns is present, "uniquing" for the result set - is *disabled*, so that the resulting tuples contain entities as they actually - correspond. this indicates that multiple results may be present if this - option is used. + The \*mappers_or_columns and \**kwargs arguments are deprecated. + To add instances or columns to the results, use add_entity() + and add_column(). """ self.__log_debug("instances()") session = self.session - context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs) + kwargs.setdefault('populate_existing', self._populate_existing) + kwargs.setdefault('version_check', self._version_check) + + context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs) process = [] mappers_or_columns = tuple(self._entities) + mappers_or_columns if mappers_or_columns: - for m in mappers_or_columns: + for tup in mappers_or_columns: + if isinstance(tup, tuple): + (m, alias, alias_id) = tup + clauses = self._get_entity_clauses(tup) + else: + clauses = alias = alias_id = None + m = tup if isinstance(m, type): m = mapper.class_mapper(m) if isinstance(m, mapper.Mapper): def x(m): + row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row) appender = [] def proc(context, row): - if not m._instance(context, row, appender): + if not m._instance(context, row_adapter(row), appender): appender.append(None) process.append((proc, appender)) x(m) - elif isinstance(m, sql.ColumnElement) or isinstance(m, basestring): + elif isinstance(m, (sql.ColumnElement, basestring)): def y(m): + row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row) res = [] def proc(context, row): - res.append(row[m]) + res.append(row_adapter(row)[m]) process.append((proc, res)) y(m) result = [] @@ -1039,9 +677,12 @@ class Query(object): for proc in process: proc[0](context, row) + for instance in context.identity_map.values(): + context.attributes.get(('populating_mapper', instance), object_mapper(instance))._post_instance(context, instance) + # store new stuff in the identity map - for value in context.identity_map.values(): - session._register_persistent(value) + for instance in context.identity_map.values(): + session._register_persistent(instance) if mappers_or_columns: return list(util.OrderedSet(zip(*([result] + [o[1] for o in process])))) @@ -1050,8 +691,8 @@ class Query(object): def _get(self, key, ident=None, reload=False, lockmode=None): - lockmode = lockmode or self.lockmode - if not reload and not self.always_refresh and lockmode is None: + lockmode = lockmode or self._lockmode + if not reload and not self.mapper.always_refresh and lockmode is None: try: return self.session._get(key) except KeyError: @@ -1062,21 +703,22 @@ class Query(object): else: ident = util.to_list(ident) params = {} - try: - for i, primary_key in enumerate(self.primary_key_columns): + + for i, primary_key in enumerate(self.primary_key_columns): + try: params[primary_key._label] = ident[i] - except IndexError: - raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns])) + except IndexError: + raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns])) try: - statement = self.compile(self._get_clause, lockmode=lockmode) - return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0] + q = self + if lockmode is not None: + q = q.with_lockmode(lockmode) + q = q.filter(self.select_mapper._get_clause) + q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None)) + return q.first() except IndexError: return None - def _select_statement(self, statement, params=None, **kwargs): - statement.use_labels = True - return self.execute(statement, params=params, **kwargs) - def _should_nest(self, querycontext): """Return True if the given statement options indicate that we should *nest* the generated query as a subquery inside of a @@ -1094,21 +736,56 @@ class Query(object): return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False)) - def compile(self, whereclause = None, **kwargs): - """Given a WHERE criterion, produce a ClauseElement-based - statement suitable for usage in the execute() method. + def count(self, whereclause=None, params=None, **kwargs): + """Apply this query's criterion to a SELECT COUNT statement. + + the whereclause, params and \**kwargs arguments are deprecated. use filter() + and other generative methods to establish modifiers. + """ + + q = self + if whereclause is not None: + q = q.filter(whereclause) + if params is not None: + q = q.params(**params) + q = q._legacy_select_kwargs(**kwargs) + return q._count() - the arguments to this function are deprecated and are removed in version 0.4. + def _count(self): + """Apply this query's criterion to a SELECT COUNT statement. + + this is the purely generative version which will become + the public method in version 0.5. """ + whereclause = self._criterion + + context = QueryContext(self) + from_obj = context.from_obj + + alltables = [] + for l in [sql_util.TableFinder(x) for x in from_obj]: + alltables += l + + if self.table not in alltables: + from_obj.append(self.table) + if self._nestable(**context.select_args()): + s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count() + else: + primary_key = self.primary_key_columns + s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args()) + return self.session.scalar(s, params=self._params, mapper=self.mapper) + + def compile(self): + """compiles and returns a SQL statement based on the criterion and conditions within this Query.""" + if self._statement: self._statement.use_labels = True return self._statement + + whereclause = self._criterion - if self._criterion: - whereclause = sql.and_(self._criterion, whereclause) - - if whereclause is not None and self.is_polymorphic: + if whereclause is not None and (self.mapper is not self.select_mapper): # adapt the given WHERECLAUSE to adjust instances of this query's mapped # table to be that of our select_table, # which may be the "polymorphic" selectable used by our mapper. @@ -1124,16 +801,10 @@ class Query(object): # get/create query context. get the ultimate compile arguments # from there - context = kwargs.pop('query_context', None) - if context is None: - context = QueryContext(self, kwargs) + context = QueryContext(self) order_by = context.order_by - group_by = context.group_by from_obj = context.from_obj lockmode = context.lockmode - distinct = context.distinct - limit = context.limit - offset = context.offset if order_by is False: order_by = self.mapper.order_by if order_by is False: @@ -1161,31 +832,33 @@ class Query(object): # if theres an order by, add those columns to the column list # of the "rowcount" query we're going to make if order_by: - order_by = util.to_list(order_by) or [] + order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []] cf = sql_util.ColumnFinder() for o in order_by: cf.traverse(o) else: cf = [] - s2 = sql.select(self.table.primary_key + list(cf), whereclause, use_labels=True, from_obj=from_obj, **context.select_args()) + s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args()) if order_by: - s2.order_by(*util.to_list(order_by)) + s2 = s2.order_by(*util.to_list(order_by)) s3 = s2.alias('tbl_row_count') - crit = s3.primary_key==self.table.primary_key + crit = s3.primary_key==self.primary_key_columns statement = sql.select([], crit, use_labels=True, for_update=for_update) # now for the order by, convert the columns to their corresponding columns # in the "rowcount" query, and tack that new order by onto the "rowcount" query if order_by: - statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by)) + statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by)) else: statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args()) if order_by: - statement.order_by(*util.to_list(order_by)) + statement.append_order_by(*util.to_list(order_by)) + # for a DISTINCT query, you need the columns explicitly specified in order # to use it in "order_by". ensure they are in the column criterion (particularly oid). # TODO: this should be done at the SQL level not the mapper level - if kwargs.get('distinct', False) and order_by: + # TODO: need test coverage for this + if context.distinct and order_by: [statement.append_column(c) for c in util.to_list(order_by)] context.statement = statement @@ -1197,20 +870,268 @@ class Query(object): value.setup(context) # additional entities/columns, add those to selection criterion - for m in self._entities: - if isinstance(m, type): - m = mapper.class_mapper(m) + for tup in self._entities: + (m, alias, alias_id) = tup + clauses = self._get_entity_clauses(tup) if isinstance(m, mapper.Mapper): for value in m.iterate_properties: - value.setup(context) + value.setup(context, parentclauses=clauses) elif isinstance(m, sql.ColumnElement): + if clauses is not None: + m = clauses.adapt_clause(m) statement.append_column(m) return statement + def _get_entity_clauses(self, m): + """for tuples added via add_entity() or add_column(), attempt to locate + an AliasedClauses object which should be used to formulate the query as well + as to process result rows.""" + (m, alias, alias_id) = m + if alias is not None: + return alias + if alias_id is not None: + try: + return self._alias_ids[alias_id] + except KeyError: + raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id) + if isinstance(m, type): + m = mapper.class_mapper(m) + if isinstance(m, mapper.Mapper): + l = self._alias_ids.get(m) + if l: + if len(l) > 1: + raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(m)) + else: + return l[0] + else: + return None + elif isinstance(m, sql.ColumnElement): + aliases = [] + for table in sql_util.TableFinder(m, check_columns=True): + for a in self._alias_ids.get(table, []): + aliases.append(a) + if len(aliases) > 1: + raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column()" % str(m)) + elif len(aliases) == 1: + return aliases[0] + else: + return None + def __log_debug(self, msg): self.logger.debug(msg) + def __str__(self): + return str(self.compile()) + + # DEPRECATED LAND ! + + def list(self): + """DEPRECATED. use all()""" + + return list(self) + + def scalar(self): + """DEPRECATED. use first()""" + + return self.first() + + def _legacy_filter_by(self, *args, **kwargs): + return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint)) + + def count_by(self, *args, **params): + """DEPRECATED. use query.filter_by(\**params).count()""" + + return self.count(self.join_by(*args, **params)) + + + def select_whereclause(self, whereclause=None, params=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).all()""" + + q = self.filter(whereclause)._legacy_select_kwargs(**kwargs) + if params is not None: + q = q.params(**params) + return list(q) + + def _legacy_select_kwargs(self, **kwargs): + q = self + if "order_by" in kwargs and kwargs['order_by']: + q = q.order_by(kwargs['order_by']) + if "group_by" in kwargs: + q = q.group_by(kwargs['group_by']) + if "from_obj" in kwargs: + q = q.select_from(kwargs['from_obj']) + if "lockmode" in kwargs: + q = q.with_lockmode(kwargs['lockmode']) + if "distinct" in kwargs: + q = q.distinct() + if "limit" in kwargs: + q = q.limit(kwargs['limit']) + if "offset" in kwargs: + q = q.offset(kwargs['offset']) + return q + + + def get_by(self, *args, **params): + """DEPRECATED. use query.filter_by(\**params).first()""" + + ret = self._extension.get_by(self, *args, **params) + if ret is not mapper.EXT_PASS: + return ret + + return self._legacy_filter_by(*args, **params).first() + + def select_by(self, *args, **params): + """DEPRECATED. use use query.filter_by(\**params).all().""" + + ret = self._extension.select_by(self, *args, **params) + if ret is not mapper.EXT_PASS: + return ret + + return self._legacy_filter_by(*args, **params).list() + + def join_by(self, *args, **params): + """DEPRECATED. use join() to construct joins based on attribute names.""" + + return self._legacy_join_by(args, params, start=self._joinpoint) + + def _build_select(self, arg=None, params=None, **kwargs): + if isinstance(arg, sql.FromClause) and arg.supports_execution(): + return self.from_statement(arg) + else: + return self.filter(arg)._legacy_select_kwargs(**kwargs) + + def selectfirst(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).first()""" + + return self._build_select(arg, **kwargs).first() + + def selectone(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).one()""" + + return self._build_select(arg, **kwargs).one() + + def select(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" + + ret = self._extension.select(self, arg=arg, **kwargs) + if ret is not mapper.EXT_PASS: + return ret + return self._build_select(arg, **kwargs).all() + + def execute(self, clauseelement, params=None, *args, **kwargs): + """DEPRECATED. use query.from_statement().all()""" + + return self._select_statement(clauseelement, params, **kwargs) + + def select_statement(self, statement, **params): + """DEPRECATED. Use query.from_statement(statement)""" + + return self._select_statement(statement, params) + + def select_text(self, text, **params): + """DEPRECATED. Use query.from_statement(statement)""" + + return self._select_statement(text, params) + + def _select_statement(self, statement, params=None, **kwargs): + q = self.from_statement(statement) + if params is not None: + q = q.params(**params) + q._select_context_options(**kwargs) + return list(q) + + def _select_context_options(self, populate_existing=None, version_check=None): + if populate_existing is not None: + self._populate_existing = populate_existing + if version_check is not None: + self._version_check = version_check + return self + + def join_to(self, key): + """DEPRECATED. use join() to create joins based on property names.""" + + [keys, p] = self._locate_prop(key) + return self.join_via(keys) + + def join_via(self, keys): + """DEPRECATED. use join() to create joins based on property names.""" + + mapper = self._joinpoint + clause = None + for key in keys: + prop = mapper.get_property(key, resolve_synonyms=True) + if clause is None: + clause = prop.get_join(mapper) + else: + clause &= prop.get_join(mapper) + mapper = prop.mapper + + return clause + + def _legacy_join_by(self, args, params, start=None): + import properties + + clause = None + for arg in args: + if clause is None: + clause = arg + else: + clause &= arg + + for key, value in params.iteritems(): + (keys, prop) = self._locate_prop(key, start=start) + if isinstance(prop, properties.PropertyLoader): + c = prop.compare(operator.eq, value) & self.join_via(keys[:-1]) + else: + c = prop.compare(operator.eq, value) & self.join_via(keys) + if clause is None: + clause = c + else: + clause &= c + return clause + + def _locate_prop(self, key, start=None): + import properties + keys = [] + seen = util.Set() + def search_for_prop(mapper_): + if mapper_ in seen: + return None + seen.add(mapper_) + + prop = mapper_.get_property(key, resolve_synonyms=True, raiseerr=False) + if prop is not None: + if isinstance(prop, properties.PropertyLoader): + keys.insert(0, prop.key) + return prop + else: + for prop in mapper_.iterate_properties: + if not isinstance(prop, properties.PropertyLoader): + continue + x = search_for_prop(prop.mapper) + if x: + keys.insert(0, prop.key) + return x + else: + return None + p = search_for_prop(start or self.mapper) + if p is None: + raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key) + return [keys, p] + + def selectfirst_by(self, *args, **params): + """DEPRECATED. Use query.filter_by(\**kwargs).first()""" + + return self._legacy_filter_by(*args, **params).first() + + def selectone_by(self, *args, **params): + """DEPRECATED. Use query.filter_by(\**kwargs).one()""" + + return self._legacy_filter_by(*args, **params).one() + + + Query.logger = logging.class_logger(Query) class QueryContext(OperationContext): @@ -1219,25 +1140,25 @@ class QueryContext(OperationContext): in a query construction. """ - def __init__(self, query, kwargs): + def __init__(self, query): self.query = query - self.order_by = kwargs.pop('order_by', query._order_by) - self.group_by = kwargs.pop('group_by', query._group_by) - self.from_obj = kwargs.pop('from_obj', query._from_obj) - self.lockmode = kwargs.pop('lockmode', query.lockmode) - self.distinct = kwargs.pop('distinct', query._distinct) - self.limit = kwargs.pop('limit', query._limit) - self.offset = kwargs.pop('offset', query._offset) + self.order_by = query._order_by + self.group_by = query._group_by + self.from_obj = query._from_obj + self.lockmode = query._lockmode + self.distinct = query._distinct + self.limit = query._limit + self.offset = query._offset self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders]) self.statement = None - super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs) + super(QueryContext, self).__init__(query.mapper, query._with_options) def select_args(self): """Return a dictionary of attributes from this ``QueryContext`` that can be applied to a ``sql.Select`` statement. """ - return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by} + return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None} def accept_option(self, opt): """Accept a ``MapperOption`` which will process (modify) the @@ -1265,8 +1186,10 @@ class SelectionContext(OperationContext): yet been added as persistent to the Session. attributes - A dictionary to store arbitrary data; eager loaders use it to - store additional result lists. + A dictionary to store arbitrary data; mappers, strategies, and + options all store various state information here in order + to communicate with each other and to themselves. + populate_existing Indicates if its OK to overwrite the attributes of instances @@ -1284,6 +1207,7 @@ class SelectionContext(OperationContext): self.session = session self.extension = extension self.identity_map = {} + self.stack = LoaderStack() super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs) def accept_option(self, opt): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 4e7453d84..6b5c4a072 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -4,12 +4,12 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import weakref + from sqlalchemy import util, exceptions, sql, engine -from sqlalchemy.orm import unitofwork, query +from sqlalchemy.orm import unitofwork, query, util as mapperutil from sqlalchemy.orm.mapper import object_mapper as _object_mapper from sqlalchemy.orm.mapper import class_mapper as _class_mapper -import weakref -import sqlalchemy class SessionTransaction(object): """Represents a Session-level Transaction. @@ -21,70 +21,95 @@ class SessionTransaction(object): The SessionTransaction object is **not** threadsafe. """ - def __init__(self, session, parent=None, autoflush=True): + def __init__(self, session, parent=None, autoflush=True, nested=False): self.session = session - self.connections = {} - self.parent = parent + self.__connections = {} + self.__parent = parent self.autoflush = autoflush + self.nested = nested - def connection(self, mapper_or_class, entity_name=None): + def connection(self, mapper_or_class, entity_name=None, **kwargs): if isinstance(mapper_or_class, type): mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name) - engine = self.session.get_bind(mapper_or_class) + engine = self.session.get_bind(mapper_or_class, **kwargs) return self.get_or_add(engine) - def _begin(self): - return SessionTransaction(self.session, self) + def _begin(self, **kwargs): + return SessionTransaction(self.session, self, **kwargs) def add(self, bind): - if self.parent is not None: - return self.parent.add(bind) - - if self.connections.has_key(bind.engine): - raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or "")) + if self.__parent is not None: + return self.__parent.add(bind) + if self.__connections.has_key(bind.engine): + raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or "")) return self.get_or_add(bind) + def _connection_dict(self): + if self.__parent is not None and not self.nested: + return self.__parent._connection_dict() + else: + return self.__connections + def get_or_add(self, bind): - if self.parent is not None: - return self.parent.get_or_add(bind) + if self.__parent is not None: + if not self.nested: + return self.__parent.get_or_add(bind) + + if self.__connections.has_key(bind): + return self.__connections[bind][0] + + if bind in self.__parent._connection_dict(): + (conn, trans, autoclose) = self.__parent.__connections[bind] + self.__connections[conn] = self.__connections[bind.engine] = (conn, conn.begin_nested(), autoclose) + return conn + elif self.__connections.has_key(bind): + return self.__connections[bind][0] - if self.connections.has_key(bind): - return self.connections[bind][0] - if not isinstance(bind, engine.Connection): e = bind c = bind.contextual_connect() else: e = bind.engine c = bind - if e in self.connections: + if e in self.__connections: raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") - - self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind) - return self.connections[bind][0] + if self.nested: + trans = c.begin_nested() + elif self.session.twophase: + trans = c.begin_twophase() + else: + trans = c.begin() + self.__connections[c] = self.__connections[e] = (c, trans, c is not bind) + return self.__connections[c][0] def commit(self): - if self.parent is not None: - return + if self.__parent is not None and not self.nested: + return self.__parent if self.autoflush: self.session.flush() - for t in util.Set(self.connections.values()): + + if self.session.twophase: + for t in util.Set(self.__connections.values()): + t[1].prepare() + + for t in util.Set(self.__connections.values()): t[1].commit() self.close() + return self.__parent def rollback(self): - if self.parent is not None: - self.parent.rollback() - return - for k, t in self.connections.iteritems(): + if self.__parent is not None and not self.nested: + return self.__parent.rollback() + for t in util.Set(self.__connections.values()): t[1].rollback() self.close() - + return self.__parent + def close(self): - if self.parent is not None: + if self.__parent is not None: return - for t in self.connections.values(): + for t in util.Set(self.__connections.values()): if t[2]: t[0].close() self.session.transaction = None @@ -108,23 +133,24 @@ class Session(object): of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module. """ - def __init__(self, bind=None, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False): - if import_session is not None: - self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map) - else: - self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map) + def __init__(self, bind=None, autoflush=False, transactional=False, twophase=False, echo_uow=False, weak_identity_map=False): + self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map) - self.bind = bind or bind_to - self.binds = {} + self.bind = bind + self.__binds = {} self.echo_uow = echo_uow self.weak_identity_map = weak_identity_map self.transaction = None - if hash_key is None: - self.hash_key = id(self) - else: - self.hash_key = hash_key + self.hash_key = id(self) + self.autoflush = autoflush + self.transactional = transactional or autoflush + self.twophase = twophase + self._query_cls = query.Query + self._mapper_flush_opts = {} + if self.transactional: + self.begin() _sessions[self.hash_key] = self - + def _get_echo_uow(self): return self.uow.echo @@ -132,37 +158,39 @@ class Session(object): self.uow.echo = value echo_uow = property(_get_echo_uow,_set_echo_uow) - bind_to = property(lambda self:self.bind) - - def create_transaction(self, **kwargs): - """Return a new ``SessionTransaction`` corresponding to an - existing or new transaction. - - If the transaction is new, the returned ``SessionTransaction`` - will have commit control over the underlying transaction, else - will have rollback control only. - """ + def begin(self, **kwargs): + """Begin a transaction on this Session.""" if self.transaction is not None: - return self.transaction._begin() + self.transaction = self.transaction._begin(**kwargs) else: self.transaction = SessionTransaction(self, **kwargs) - return self.transaction - - def connect(self, mapper=None, **kwargs): - """Return a unique connection corresponding to the given mapper. - - This connection will not be part of any pre-existing - transactional context. - """ - - return self.get_bind(mapper).connect(**kwargs) - - def connection(self, mapper, **kwargs): - """Return a ``Connection`` corresponding to the given mapper. + return self.transaction + + create_transaction = begin - Used by the ``execute()`` method which performs select - operations for ``Mapper`` and ``Query``. + def begin_nested(self): + return self.begin(nested=True) + + def rollback(self): + if self.transaction is None: + raise exceptions.InvalidRequestError("No transaction is begun.") + else: + self.transaction = self.transaction.rollback() + if self.transaction is None and self.transactional: + self.begin() + + def commit(self): + if self.transaction is None: + raise exceptions.InvalidRequestError("No transaction is begun.") + else: + self.transaction = self.transaction.commit() + if self.transaction is None and self.transactional: + self.begin() + + def connection(self, mapper=None, **kwargs): + """Return a ``Connection`` corresponding to this session's + transactional context, if any. If this ``Session`` is transactional, the connection will be in the context of this session's transaction. Otherwise, the @@ -173,6 +201,9 @@ class Session(object): The given `**kwargs` will be sent to the engine's ``contextual_connect()`` method, if no transaction is in progress. + + the "mapper" argument is a class or mapper to which a bound engine + will be located; use this when the Session itself is unbound. """ if self.transaction is not None: @@ -180,7 +211,7 @@ class Session(object): else: return self.get_bind(mapper).contextual_connect(**kwargs) - def execute(self, mapper, clause, params, **kwargs): + def execute(self, clause, params=None, mapper=None, **kwargs): """Using the given mapper to identify the appropriate ``Engine`` or ``Connection`` to be used for statement execution, execute the given ``ClauseElement`` using the provided parameter dictionary. @@ -191,12 +222,12 @@ class Session(object): then the ``ResultProxy`` 's ``close()`` method will release the resources of the underlying ``Connection``, otherwise its a no-op. """ - return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs) + return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs) - def scalar(self, mapper, clause, params, **kwargs): + def scalar(self, clause, params=None, mapper=None, **kwargs): """Like execute() but return a scalar result.""" - return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs) + return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs) def close(self): """Close this Session.""" @@ -224,14 +255,17 @@ class Session(object): return _class_mapper(class_, entity_name = entity_name) - def bind_mapper(self, mapper, bind): - """Bind the given `mapper` to the given ``Engine`` or ``Connection``. + def bind_mapper(self, mapper, bind, entity_name=None): + """Bind the given `mapper` or `class` to the given ``Engine`` or ``Connection``. All subsequent operations involving this ``Mapper`` will use the given `bind`. """ + + if isinstance(mapper, type): + mapper = _class_mapper(mapper, entity_name=entity_name) - self.binds[mapper] = bind + self.__binds[mapper] = bind def bind_table(self, table, bind): """Bind the given `table` to the given ``Engine`` or ``Connection``. @@ -240,7 +274,7 @@ class Session(object): given `bind`. """ - self.binds[table] = bind + self.__binds[table] = bind def get_bind(self, mapper): """Return the ``Engine`` or ``Connection`` which is used to execute @@ -270,15 +304,18 @@ class Session(object): """ if mapper is None: - return self.bind - elif self.binds.has_key(mapper): - return self.binds[mapper] - elif self.binds.has_key(mapper.mapped_table): - return self.binds[mapper.mapped_table] + if self.bind is not None: + return self.bind + else: + raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()") + elif self.__binds.has_key(mapper): + return self.__binds[mapper] + elif self.__binds.has_key(mapper.mapped_table): + return self.__binds[mapper.mapped_table] elif self.bind is not None: return self.bind else: - e = mapper.mapped_table.engine + e = mapper.mapped_table.bind if e is None: raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) return e @@ -291,9 +328,9 @@ class Session(object): entity_name = kwargs.pop('entity_name', None) if isinstance(mapper_or_class, type): - q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) + q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) else: - q = query.Query(mapper_or_class, self, **kwargs) + q = self._query_cls(mapper_or_class, self, **kwargs) for ent in addtl_entities: q = q.add_entity(ent) @@ -499,7 +536,7 @@ class Session(object): merged = self.get(mapper.class_, key[1]) if merged is None: raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object)) - for prop in mapper.props.values(): + for prop in mapper.iterate_properties: prop.merge(self, object, merged, _recursive) if key is None: self.save(merged, entity_name=mapper.entity_name) @@ -611,12 +648,12 @@ class Session(object): def _attach(self, obj): """Attach the given object to this ``Session``.""" - if getattr(obj, '_sa_session_id', None) != self.hash_key: - old = getattr(obj, '_sa_session_id', None) - if old is not None and _sessions.has_key(old): + old_id = getattr(obj, '_sa_session_id', None) + if old_id != self.hash_key: + if old_id is not None and _sessions.has_key(old_id): raise exceptions.InvalidRequestError("Object '%s' is already attached " "to session '%s' (this is '%s')" % - (repr(obj), old, id(self))) + (repr(obj), old_id, id(self))) # auto-removal from the old session is disabled. but if we decide to # turn it back on, do it as below: gingerly since _sessions is a WeakValueDict @@ -695,6 +732,7 @@ def object_session(obj): return _sessions.get(hashkey) return None +# Lazy initialization to avoid circular imports unitofwork.object_session = object_session from sqlalchemy.orm import mapper mapper.attribute_manager = attribute_manager diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py new file mode 100644 index 000000000..cc13f8c1f --- /dev/null +++ b/lib/sqlalchemy/orm/shard.py @@ -0,0 +1,112 @@ +from sqlalchemy.orm.session import Session +from sqlalchemy.orm import Query + +class ShardedSession(Session): + def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs): + """construct a ShardedSession. + + shard_chooser + a callable which, passed a Mapper and a mapped instance, returns a + shard ID. this id may be based off of the attributes present within the + object, or on some round-robin scheme. If the scheme is based on a + selection, it should set whatever state on the instance to mark it in + the future as participating in that shard. + + id_chooser + a callable, passed a tuple of identity values, which should return + a list of shard ids where the ID might reside. The databases will + be queried in the order of this listing. + + query_chooser + for a given Query, returns the list of shard_ids where the query + should be issued. Results from all shards returned will be + combined together into a single listing. + + """ + super(ShardedSession, self).__init__(**kwargs) + self.shard_chooser = shard_chooser + self.id_chooser = id_chooser + self.query_chooser = query_chooser + self.__binds = {} + self._mapper_flush_opts = {'connection_callable':self.connection} + self._query_cls = ShardedQuery + + def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance) + + if self.transaction is not None: + return self.transaction.connection(mapper, shard_id=shard_id) + else: + return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs) + + def get_bind(self, mapper, shard_id=None, instance=None): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance) + return self.__binds[shard_id] + + def bind_shard(self, shard_id, bind): + self.__binds[shard_id] = bind + +class ShardedQuery(Query): + def __init__(self, *args, **kwargs): + super(ShardedQuery, self).__init__(*args, **kwargs) + self.id_chooser = self.session.id_chooser + self.query_chooser = self.session.query_chooser + self._shard_id = None + + def _clone(self): + q = ShardedQuery.__new__(ShardedQuery) + q.__dict__ = self.__dict__.copy() + return q + + def set_shard(self, shard_id): + """return a new query, limited to a single shard ID. + + all subsequent operations with the returned query will + be against the single shard regardless of other state. + """ + + q = self._clone() + q._shard_id = shard_id + return q + + def _execute_and_instances(self, statement): + if self._shard_id is not None: + result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(statement, **self._params) + try: + return iter(self.instances(result)) + finally: + result.close() + else: + partial = [] + for shard_id in self.query_chooser(self): + result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(statement, **self._params) + try: + partial = partial + list(self.instances(result)) + finally: + result.close() + # if some kind of in memory 'sorting' were done, this is where it would happen + return iter(partial) + + def get(self, ident, **kwargs): + if self._shard_id is not None: + return super(ShardedQuery, self).get(ident) + else: + for shard_id in self.id_chooser(ident): + o = self.set_shard(shard_id).get(ident, **kwargs) + if o is not None: + return o + else: + return None + + def load(self, ident, **kwargs): + if self._shard_id is not None: + return super(ShardedQuery, self).load(ident) + else: + for shard_id in self.id_chooser(ident): + o = self.set_shard(shard_id).load(ident, raiseerr=False, **kwargs) + if o is not None: + return o + else: + raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 462954f6b..babd6e4c0 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -6,12 +6,11 @@ """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" -from sqlalchemy import sql, schema, util, exceptions, sql_util, logging -from sqlalchemy.orm import mapper, query -from sqlalchemy.orm.interfaces import * +from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy.orm import mapper, attributes +from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil -import random class ColumnLoader(LoaderStrategy): @@ -19,8 +18,9 @@ class ColumnLoader(LoaderStrategy): super(ColumnLoader, self).init() self.columns = self.parent_property.columns self._should_log_debug = logging.is_debug_enabled(self.logger) + self.is_composite = hasattr(self.parent_property, 'composite_class') - def setup_query(self, context, eagertable=None, parentclauses=None, **kwargs): + def setup_query(self, context, parentclauses=None, **kwargs): for c in self.columns: if parentclauses is not None: context.statement.append_column(parentclauses.aliased_column(c)) @@ -28,16 +28,93 @@ class ColumnLoader(LoaderStrategy): context.statement.append_column(c) def init_class_attribute(self): + if self.is_composite: + self._init_composite_attribute() + else: + self._init_scalar_attribute() + + def _init_composite_attribute(self): + self.logger.info("register managed composite attribute %s on class %s" % (self.key, self.parent.class_.__name__)) + def copy(obj): + return self.parent_property.composite_class(*obj.__colset__()) + def compare(a, b): + for col, aprop, bprop in zip(self.columns, a.__colset__(), b.__colset__()): + if not col.type.compare_values(aprop, bprop): + return False + else: + return True + sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator) + + def _init_scalar_attribute(self): self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) coltype = self.columns[0].type - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable()) + sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) + + def create_row_processor(self, selectcontext, mapper, row): + if self.is_composite: + for c in self.columns: + if c not in row: + break + else: + def execute(instance, row, isnew, ispostselect=None, **flags): + if isnew or ispostselect: + if self._should_log_debug: + self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) + instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns]) + self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key)) + return (execute, None) + + elif self.columns[0] in row: + def execute(instance, row, isnew, ispostselect=None, **flags): + if isnew or ispostselect: + if self._should_log_debug: + self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) + instance.__dict__[self.key] = row[self.columns[0]] + self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key)) + return (execute, None) + + (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None)) + if hosted_mapper is None: + return (None, None) + + if hosted_mapper.polymorphic_fetch == 'deferred': + def execute(instance, row, isnew, **flags): + if isnew: + sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_loader(instance, mapper, needs_tables)) + self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key)) + return (execute, None) + else: + self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key)) + return (None, None) + + def _get_deferred_loader(self, instance, mapper, needs_tables): + def load(): + group = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables] - def process_row(self, selectcontext, instance, row, identitykey, isnew): - if isnew: if self._should_log_debug: - self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) - instance.__dict__[self.key] = row[self.columns[0]] - + self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None')) + + session = sessionlib.object_session(instance) + if session is None: + raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) + + cond, param_names = mapper._deferred_inheritance_condition(needs_tables) + statement = sql.select(needs_tables, cond, use_labels=True) + params = {} + for c in param_names: + params[c.name] = mapper.get_attr_by_column(instance, c) + + result = session.execute(statement, params, mapper=mapper) + try: + row = result.fetchone() + for prop in group: + sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) + return attributes.ATTR_WAS_SET + finally: + result.close() + + return load + ColumnLoader.logger = logging.class_logger(ColumnLoader) class DeferredColumnLoader(LoaderStrategy): @@ -47,74 +124,86 @@ class DeferredColumnLoader(LoaderStrategy): This is per-column lazy loading. """ + def create_row_processor(self, selectcontext, mapper, row): + if self.group is not None and selectcontext.attributes.get(('undefer', self.group), False): + return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row) + elif not self.is_default or len(selectcontext.options): + def execute(instance, row, isnew, **flags): + if isnew: + if self._should_log_debug: + self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) + sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance)) + return (execute, None) + else: + def execute(instance, row, isnew, **flags): + if isnew: + if self._should_log_debug: + self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) + sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) + return (execute, None) + def init(self): super(DeferredColumnLoader, self).init() + if hasattr(self.parent_property, 'composite_class'): + raise NotImplementedError("Deferred loading for composite types not implemented yet") self.columns = self.parent_property.columns self.group = self.parent_property.group self._should_log_debug = logging.is_debug_enabled(self.logger) def init_class_attribute(self): self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=lambda i:self.setup_loader(i), copy_function=lambda x: self.columns[0].type.copy_value(x), compare_function=lambda x,y:self.columns[0].type.compare_values(x,y), mutable_scalars=self.columns[0].type.is_mutable()) + sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) def setup_query(self, context, **kwargs): - pass + if self.group is not None and context.attributes.get(('undefer', self.group), False): + self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs) - def process_row(self, selectcontext, instance, row, identitykey, isnew): - if isnew: - if not self.is_default or len(selectcontext.options): - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self.setup_loader(instance, selectcontext.options)) - else: - sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) - - def setup_loader(self, instance, options=None): - if not mapper.has_mapper(instance): + def setup_loader(self, instance): + localparent = mapper.object_mapper(instance, raiseerror=False) + if localparent is None: return None - else: - prop = mapper.object_mapper(instance).props[self.key] - if prop is not self.parent_property: - return prop._get_strategy(DeferredColumnLoader).setup_loader(instance) - def lazyload(): - if self._should_log_debug: - self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), str(self.group))) + + prop = localparent.get_property(self.key) + if prop is not self.parent_property: + return prop._get_strategy(DeferredColumnLoader).setup_loader(instance) + def lazyload(): if not mapper.has_identity(instance): return None - try: - pk = self.parent.pks_by_table[self.columns[0].table] - except KeyError: - pk = self.columns[0].table.primary_key - - clause = sql.and_() - for primary_key in pk: - attr = self.parent.get_attr_by_column(instance, primary_key) - if not attr: - return None - clause.clauses.append(primary_key == attr) + if self.group is not None: + group = [p for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group] + else: + group = None + + if self._should_log_debug: + self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None')) session = sessionlib.object_session(instance) if session is None: raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) - - localparent = mapper.object_mapper(instance) - if self.group is not None: - groupcols = [p for p in localparent.props.values() if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group] - result = session.execute(localparent, sql.select([g.columns[0] for g in groupcols], clause, use_labels=True), None) + + clause = localparent._get_clause + ident = instance._instance_key[1] + params = {} + for i, primary_key in enumerate(localparent.primary_key): + params[primary_key._label] = ident[i] + if group is not None: + statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True) + else: + statement = sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True) + + if group is not None: + result = session.execute(statement, params, mapper=localparent) try: row = result.fetchone() - for prop in groupcols: - if prop is self: - continue - # set a scalar object instance directly on the object, - # bypassing SmartProperty event handlers. - sessionlib.attribute_manager.init_instance_attribute(instance, prop.key, uselist=False) - instance.__dict__[prop.key] = row[prop.columns[0]] - return row[self.columns[0]] + for prop in group: + sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) + return attributes.ATTR_WAS_SET finally: result.close() else: - return session.scalar(localparent, sql.select([self.columns[0]], clause, use_labels=True),None) + return session.scalar(sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True),params, mapper=localparent) return lazyload @@ -131,6 +220,15 @@ class DeferredOption(StrategizedOption): else: return ColumnLoader +class UndeferGroupOption(MapperOption): + def __init__(self, group): + self.group = group + def process_query_context(self, context): + context.attributes[('undefer', self.group)] = True + + def process_selection_context(self, context): + context.attributes[('undefer', self.group)] = True + class AbstractRelationLoader(LoaderStrategy): def init(self): super(AbstractRelationLoader, self).init() @@ -139,22 +237,26 @@ class AbstractRelationLoader(LoaderStrategy): self._should_log_debug = logging.is_debug_enabled(self.logger) def _init_instance_attribute(self, instance, callable_=None): - return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_) + return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_) def _register_attribute(self, class_, callable_=None): self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__)) - sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_) + sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator) class NoLoader(AbstractRelationLoader): def init_class_attribute(self): self._register_attribute(self.parent.class_) - def process_row(self, selectcontext, instance, row, identitykey, isnew): - if isnew: - if not self.is_default or len(selectcontext.options): - if self._should_log_debug: - self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key)) - self._init_instance_attribute(instance) + def create_row_processor(self, selectcontext, mapper, row): + if not self.is_default or len(selectcontext.options): + def execute(instance, row, isnew, **flags): + if isnew: + if self._should_log_debug: + self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key)) + self._init_instance_attribute(instance) + return (execute, None) + else: + return (None, None) NoLoader.logger = logging.class_logger(NoLoader) @@ -167,7 +269,8 @@ class LazyLoader(AbstractRelationLoader): # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() - self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere) + #from sqlalchemy.orm import query + self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere) if self.use_get: self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads") @@ -178,7 +281,7 @@ class LazyLoader(AbstractRelationLoader): if not mapper.has_mapper(instance): return None else: - prop = mapper.object_mapper(instance).props[self.key] + prop = mapper.object_mapper(instance).get_property(self.key) if prop is not self.parent_property: return prop._get_strategy(LazyLoader).setup_loader(instance) def lazyload(): @@ -211,20 +314,27 @@ class LazyLoader(AbstractRelationLoader): # if we have a simple straight-primary key load, use mapper.get() # to possibly save a DB round trip + q = session.query(self.mapper) if self.use_get: ident = [] - for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]: + # TODO: when options are added to allow switching between union-based and non-union + # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper", + # probably via the query's own "mapper" property, and also use one of two "lazy" clauses, + # one against the "union" the other not + for primary_key in self.select_mapper.primary_key: bind = self.lazyreverse[primary_key] ident.append(params[bind.key]) - return session.query(self.mapper).get(ident) + return q.get(ident) elif self.order_by is not False: - order_by = self.order_by + q = q.order_by(self.order_by) elif self.secondary is not None and self.secondary.default_order_by() is not None: - order_by = self.secondary.default_order_by() - else: - order_by = False - result = session.query(self.mapper, with_options=options).select_whereclause(self.lazywhere, order_by=order_by, params=params) + q = q.order_by(self.secondary.default_order_by()) + if options: + q = q.options(*options) + q = q.filter(self.lazywhere).params(**params) + + result = q.all() if self.uselist: return result else: @@ -232,25 +342,37 @@ class LazyLoader(AbstractRelationLoader): return result[0] else: return None + + if self.uselist: + return q.all() + else: + return q.first() + return lazyload - def process_row(self, selectcontext, instance, row, identitykey, isnew): - if isnew: - # new object instance being loaded from a result row - if not self.is_default or len(selectcontext.options): - self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) - # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, - # which will override the clareset_instance_attributess-level behavior - self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options)) - else: - self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) - # we are the primary manager for this attribute on this class - reset its per-instance attribute state, - # so that the class-level lazy loader is executed when next referenced on this instance. - # this usually is not needed unless the constructor of the object referenced the attribute before we got - # to load data into it. - sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) - - def _create_lazy_clause(cls, prop, reverse_direction=False): + def create_row_processor(self, selectcontext, mapper, row): + if not self.is_default or len(selectcontext.options): + def execute(instance, row, isnew, **flags): + if isnew: + if self._should_log_debug: + self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) + # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, + # which will override the clareset_instance_attributess-level behavior + self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options)) + return (execute, None) + else: + def execute(instance, row, isnew, **flags): + if isnew: + if self._should_log_debug: + self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) + # we are the primary manager for this attribute on this class - reset its per-instance attribute state, + # so that the class-level lazy loader is executed when next referenced on this instance. + # this usually is not needed unless the constructor of the object referenced the attribute before we got + # to load data into it. + sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) + return (execute, None) + + def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='): (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side) binds = {} @@ -272,19 +394,16 @@ class LazyLoader(AbstractRelationLoader): FindColumnInColumnClause().traverse(expr) return len(columns) and columns[0] or None - def bind_label(): - # TODO: make this generation deterministic - return "lazy_" + hex(random.randint(0, 65535))[2:] - def visit_binary(binary): leftcol = find_column_in_expr(binary.left) rightcol = find_column_in_expr(binary.right) if leftcol is None or rightcol is None: return + if should_bind(leftcol, rightcol): col = leftcol binary.left = binds.setdefault(leftcol, - sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type, unique=True)) + sql.bindparam(None, None, shortname=leftcol.name, type_=binary.right.type, unique=True)) reverse[rightcol] = binds[col] # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1", @@ -292,21 +411,19 @@ class LazyLoader(AbstractRelationLoader): if leftcol is not rightcol and should_bind(rightcol, leftcol): col = rightcol binary.right = binds.setdefault(rightcol, - sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True)) + sql.bindparam(None, None, shortname=rightcol.name, type_=binary.left.type, unique=True)) reverse[leftcol] = binds[col] - lazywhere = primaryjoin.copy_container() + lazywhere = primaryjoin li = mapperutil.BinaryVisitor(visit_binary) if not secondaryjoin or not reverse_direction: - li.traverse(lazywhere) + lazywhere = li.traverse(lazywhere, clone=True) if secondaryjoin is not None: - secondaryjoin = secondaryjoin.copy_container() if reverse_direction: - li.traverse(secondaryjoin) + secondaryjoin = li.traverse(secondaryjoin, clone=True) lazywhere = sql.and_(lazywhere, secondaryjoin) - return (lazywhere, binds, reverse) _create_lazy_clause = classmethod(_create_lazy_clause) @@ -318,154 +435,42 @@ class EagerLoader(AbstractRelationLoader): def init(self): super(EagerLoader, self).init() - if self.parent.isa(self.mapper): - raise exceptions.ArgumentError( - "Error creating eager relationship '%s' on parent class '%s' " - "to child class '%s': Cant use eager loading on a self " - "referential relationship." % - (self.key, repr(self.parent.class_), repr(self.mapper.class_))) if self.is_default: self.parent._eager_loaders.add(self.parent_property) self.clauses = {} - self.clauses_by_lead_mapper = {} - - class AliasedClauses(object): - """Defines a set of join conditions and table aliases which - are aliased on a randomly-generated alias name, corresponding - to the connection of an optional parent AliasedClauses object - and a target mapper. - - EagerLoader has a distinct AliasedClauses object per parent - AliasedClauses object, so that all paths from one mapper to - another across a chain of eagerloaders generates a distinct - chain of joins. The AliasedClauses objects are generated and - cached on an as-needed basis. - - E.g.:: - - mapper A --> - (EagerLoader 'items') --> - mapper B --> - (EagerLoader 'keywords') --> - mapper C - - will generate:: - - EagerLoader 'items' --> { - None : AliasedClauses(items, None, alias_suffix='AB34') # mappera JOIN mapperb_AB34 - } - - EagerLoader 'keywords' --> [ - None : AliasedClauses(keywords, None, alias_suffix='43EF') # mapperb JOIN mapperc_43EF - AliasedClauses(items, None, alias_suffix='AB34') : - AliasedClauses(keywords, items, alias_suffix='8F44') # mapperb_AB34 JOIN mapperc_8F44 - ] - """ - - def __init__(self, eagerloader, parentclauses=None): - self.id = (parentclauses is not None and (parentclauses.id + "/") or '') + str(eagerloader.parent_property) - self.parent = eagerloader - self.target = eagerloader.select_table - self.eagertarget = eagerloader.select_table.alias(self._aliashash("/target")) - self.extra_cols = {} - - if eagerloader.secondary: - self.eagersecondary = eagerloader.secondary.alias(self._aliashash("/secondary")) - if parentclauses is not None: - aliasizer = sql_util.ClauseAdapter(self.eagertarget).\ - chain(sql_util.ClauseAdapter(self.eagersecondary)).\ - chain(sql_util.ClauseAdapter(parentclauses.eagertarget)) - else: - aliasizer = sql_util.ClauseAdapter(self.eagertarget).\ - chain(sql_util.ClauseAdapter(self.eagersecondary)) - self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container() - aliasizer.traverse(self.eagersecondaryjoin) - self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() - aliasizer.traverse(self.eagerprimary) - else: - self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() - if parentclauses is not None: - aliasizer = sql_util.ClauseAdapter(self.eagertarget) - aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side)) - else: - aliasizer = sql_util.ClauseAdapter(self.eagertarget) - aliasizer.traverse(self.eagerprimary) - - if eagerloader.order_by: - self.eager_order_by = sql_util.ClauseAdapter(self.eagertarget).copy_and_process(util.to_list(eagerloader.order_by)) - else: - self.eager_order_by = None - - self._row_decorator = self._create_decorator_row() - - def aliased_column(self, column): - """return the aliased version of the given column, creating a new label for it if not already - present in this AliasedClauses eagertable.""" - - conv = self.eagertarget.corresponding_column(column, raiseerr=False) - if conv: - return conv - - if column in self.extra_cols: - return self.extra_cols[column] - - aliased_column = column.copy_container() - sql_util.ClauseAdapter(self.eagertarget).traverse(aliased_column) - alias = self._aliashash(column.name) - aliased_column = aliased_column.label(alias) - self._row_decorator.map[column] = alias - self.extra_cols[column] = aliased_column - return aliased_column - - def _aliashash(self, extra): - """return a deterministic 4 digit hash value for this AliasedClause's id + extra.""" - # use the first 4 digits of an MD5 hash - return "anon_" + util.hash(self.id + extra)[0:4] - - def _create_decorator_row(self): - class EagerRowAdapter(object): - def __init__(self, row): - self.row = row - def has_key(self, key): - return map.has_key(key) or self.row.has_key(key) - def __getitem__(self, key): - if map.has_key(key): - key = map[key] - return self.row[key] - def keys(self): - return map.keys() - map = {} - for c in self.eagertarget.c: - parent = self.target.corresponding_column(c) - map[parent] = c - map[parent._label] = c - map[parent.name] = c - EagerRowAdapter.map = map - return EagerRowAdapter - - def _decorate_row(self, row): - # adapts a row at row iteration time to transparently - # convert plain columns into the aliased columns that were actually - # added to the column clause of the SELECT. - return self._row_decorator(row) + self.join_depth = self.parent_property.join_depth def init_class_attribute(self): self.parent_property._get_strategy(LazyLoader).init_class_attribute() - def setup_query(self, context, eagertable=None, parentclauses=None, parentmapper=None, **kwargs): + def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs): """Add a left outer join to the statement thats being constructed.""" + # build a path as we setup the query. the format of this path + # matches that of interfaces.LoaderStack, and will be used in the + # row-loading phase to match up AliasedClause objects with the current + # LoaderStack position. + if parentclauses: + path = parentclauses.path + (self.parent.base_mapper(), self.key) + else: + path = (self.parent.base_mapper(), self.key) + + + if self.join_depth: + if len(path) / 2 > self.join_depth: + return + else: + if self.mapper in path: + return + + #print "CREATING EAGER PATH FOR", "->".join([str(s) for s in path]) + if parentmapper is None: localparent = context.mapper else: localparent = parentmapper - if self.mapper in context.recursion_stack: - return - else: - context.recursion_stack.add(self.parent) - statement = context.statement if hasattr(statement, '_outerjoin'): @@ -487,55 +492,57 @@ class EagerLoader(AbstractRelationLoader): break else: raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table)) - + try: - clauses = self.clauses[parentclauses] + clauses = self.clauses[path] except KeyError: - clauses = EagerLoader.AliasedClauses(self, parentclauses) - self.clauses[parentclauses] = clauses - - if context.mapper not in self.clauses_by_lead_mapper: - self.clauses_by_lead_mapper[context.mapper] = clauses - + clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.polymorphic_primaryjoin, self.parent_property.polymorphic_secondaryjoin, parentclauses) + self.clauses[path] = clauses + if self.secondaryjoin is not None: - statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin) + statement._outerjoin = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin) if self.order_by is False and self.secondary.default_order_by() is not None: - statement.order_by(*clauses.eagersecondary.default_order_by()) + statement.append_order_by(*clauses.secondary.default_order_by()) else: - statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary) - if self.order_by is False and clauses.eagertarget.default_order_by() is not None: - statement.order_by(*clauses.eagertarget.default_order_by()) + statement._outerjoin = towrap.outerjoin(clauses.alias, clauses.primaryjoin) + if self.order_by is False and clauses.alias.default_order_by() is not None: + statement.append_order_by(*clauses.alias.default_order_by()) - if clauses.eager_order_by: - statement.order_by(*util.to_list(clauses.eager_order_by)) - + if clauses.order_by: + statement.append_order_by(*util.to_list(clauses.order_by)) + statement.append_from(statement._outerjoin) - for value in self.select_mapper.props.values(): - value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper) - def _create_row_processor(self, selectcontext, row): - """Create a *row processing* function that will apply eager + for value in self.select_mapper.iterate_properties: + value.setup(context, parentclauses=clauses, parentmapper=self.select_mapper) + + def _create_row_decorator(self, selectcontext, row, path): + """Create a *row decorating* function that will apply eager aliasing to the row. Also check that an identity key can be retrieved from the row, else return None. """ + #print "creating row decorator for path ", "->".join([str(s) for s in path]) + # check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option) - if selectcontext.attributes.has_key((EagerLoader, self.parent_property)): + if selectcontext.attributes.has_key(("eager_row_processor", self.parent_property)): # custom row decoration function, placed in the selectcontext by the # contains_eager() mapper option - decorator = selectcontext.attributes[(EagerLoader, self.parent_property)] + decorator = selectcontext.attributes[("eager_row_processor", self.parent_property)] if decorator is None: decorator = lambda row: row else: try: # decorate the row according to the stored AliasedClauses for this eager load - clauses = self.clauses_by_lead_mapper[selectcontext.mapper] - decorator = clauses._row_decorator + clauses = self.clauses[path] + decorator = clauses.row_decorator except KeyError, k: # no stored AliasedClauses: eager loading was not set up in the query and # AliasedClauses never got initialized + if self._should_log_debug: + self.logger.debug("Could not locate aliased clauses for key: " + str(path)) return None try: @@ -550,81 +557,80 @@ class EagerLoader(AbstractRelationLoader): self.logger.debug("could not locate identity key from row '%s'; missing column '%s'" % (repr(decorated_row), str(k))) return None - def process_row(self, selectcontext, instance, row, identitykey, isnew): - """Receive a row. + def create_row_processor(self, selectcontext, mapper, row): + selectcontext.stack.push_property(self.key) + path = selectcontext.stack.snapshot() - Tell our mapper to look for a new object instance in the row, - and attach it to a list on the parent instance. - """ - - if self in selectcontext.recursion_stack: - return - - try: - # check for row processor - row_processor = selectcontext.attributes[id(self)] - except KeyError: - # create a row processor function and cache it in the context - row_processor = self._create_row_processor(selectcontext, row) - selectcontext.attributes[id(self)] = row_processor - - if row_processor is not None: - decorated_row = row_processor(row) - else: - # row_processor was None: degrade to a lazy loader - if self._should_log_debug: - self.logger.debug("degrade to lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) - self.parent_property._get_strategy(LazyLoader).process_row(selectcontext, instance, row, identitykey, isnew) - return - - # TODO: recursion check a speed hit...? try to get a "termination point" into the AliasedClauses - # or EagerRowAdapter ? - selectcontext.recursion_stack.add(self) - try: - if not self.uselist: - if self._should_log_debug: - self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) - if isnew: - # set a scalar object instance directly on the parent object, - # bypassing SmartProperty event handlers. - instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None) + row_decorator = self._create_row_decorator(selectcontext, row, path) + if row_decorator is not None: + def execute(instance, row, isnew, **flags): + decorated_row = row_decorator(row) + + selectcontext.stack.push_property(self.key) + + if not self.uselist: + if self._should_log_debug: + self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) + if isnew: + # set a scalar object instance directly on the + # parent object, bypassing InstrumentedAttribute + # event handlers. + # + # FIXME: instead of... + sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None)) + # bypass and set directly: + #instance.__dict__[self.key] = ... + else: + # call _instance on the row, even though the object has been created, + # so that we further descend into properties + self.mapper._instance(selectcontext, decorated_row, None) else: - # call _instance on the row, even though the object has been created, - # so that we further descend into properties - self.mapper._instance(selectcontext, decorated_row, None) - else: - if isnew: + if isnew: + if self._should_log_debug: + self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key)) + + collection = sessionlib.attribute_manager.init_collection(instance, self.key) + appender = util.UniqueAppender(collection, 'append_without_event') + + # store it in the "scratch" area, which is local to this load operation. + selectcontext.attributes[(instance, self.key)] = appender + result_list = selectcontext.attributes[(instance, self.key)] if self._should_log_debug: - self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key)) - # call the SmartProperty's initialize() method to create a new, blank list - l = getattr(instance.__class__, self.key).initialize(instance) - - # create an appender object which will add set-like semantics to the list - appender = util.UniqueAppender(l.data) - - # store it in the "scratch" area, which is local to this load operation. - selectcontext.attributes[(instance, self.key)] = appender - result_list = selectcontext.attributes[(instance, self.key)] - if self._should_log_debug: - self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) - self.select_mapper._instance(selectcontext, decorated_row, result_list) - finally: - selectcontext.recursion_stack.remove(self) + self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) + + self.select_mapper._instance(selectcontext, decorated_row, result_list) + selectcontext.stack.pop() + selectcontext.stack.pop() + return (execute, None) + else: + self.logger.debug("eager loader %s degrading to lazy loader" % str(self)) + selectcontext.stack.pop() + return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row) + + + def __str__(self): + return str(self.parent) + "." + self.key + EagerLoader.logger = logging.class_logger(EagerLoader) class EagerLazyOption(StrategizedOption): - def __init__(self, key, lazy=True): + def __init__(self, key, lazy=True, chained=False): super(EagerLazyOption, self).__init__(key) self.lazy = lazy - - def process_query_property(self, context, prop): + self.chained = chained + + def is_chained(self): + return not self.lazy and self.chained + + def process_query_property(self, context, properties): if self.lazy: - if prop in context.eager_loaders: - context.eager_loaders.remove(prop) + if properties[-1] in context.eager_loaders: + context.eager_loaders.remove(properties[-1]) else: - context.eager_loaders.add(prop) - super(EagerLazyOption, self).process_query_property(context, prop) + for prop in properties: + context.eager_loaders.add(prop) + super(EagerLazyOption, self).process_query_property(context, properties) def get_strategy_class(self): if self.lazy: @@ -636,24 +642,39 @@ class EagerLazyOption(StrategizedOption): EagerLazyOption.logger = logging.class_logger(EagerLazyOption) +# TODO: enable FetchMode option. currently +# this class does nothing. will require Query +# to swich between using its "polymorphic" selectable +# and its regular selectable in order to make decisions +# (therefore might require that FetchModeOperation is performed +# only as the first operation on a Query.) +class FetchModeOption(PropertyOption): + def __init__(self, key, type): + super(FetchModeOption, self).__init__(key) + if type not in ('join', 'select'): + raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'") + self.type = type + + def process_selection_property(self, context, properties): + context.attributes[('fetchmode', properties[-1])] = self.type + class RowDecorateOption(PropertyOption): def __init__(self, key, decorator=None, alias=None): super(RowDecorateOption, self).__init__(key) self.decorator = decorator self.alias = alias - def process_selection_property(self, context, property): + def process_selection_property(self, context, properties): if self.alias is not None and self.decorator is None: if isinstance(self.alias, basestring): - self.alias = property.target.alias(self.alias) + self.alias = properties[-1].target.alias(self.alias) def decorate(row): d = {} - for c in property.target.columns: + for c in properties[-1].target.columns: d[c] = row[self.alias.corresponding_column(c)] return d self.decorator = decorate - context.attributes[(EagerLoader, property)] = self.decorator + context.attributes[("eager_row_processor", properties[-1])] = self.decorator RowDecorateOption.logger = logging.class_logger(RowDecorateOption) - diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 8c70f8cf8..cf48202b0 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -4,17 +4,16 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - - -from sqlalchemy import sql, schema, exceptions -from sqlalchemy import logging -from sqlalchemy.orm import util as mapperutil - """Contains the ClauseSynchronizer class, which is used to map attributes between two objects in a manner corresponding to a SQL clause that compares column values. """ +from sqlalchemy import sql, schema, exceptions +from sqlalchemy import logging +from sqlalchemy.orm import util as mapperutil +import operator + ONETOMANY = 0 MANYTOONE = 1 MANYTOMANY = 2 @@ -44,7 +43,7 @@ class ClauseSynchronizer(object): def compile_binary(binary): """Assemble a SyncRule given a single binary condition.""" - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return source_column = None diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index c6b0b2689..f59042810 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -19,15 +19,17 @@ new, dirty, or deleted and provides the capability to flush all those changes at once. """ -from sqlalchemy import util, logging, topological -from sqlalchemy.orm import attributes +from sqlalchemy import util, logging, topological, exceptions +from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.mapper import object_mapper, class_mapper -from sqlalchemy.exceptions import * +from sqlalchemy.orm.mapper import object_mapper import StringIO import weakref -class UOWEventHandler(attributes.AttributeExtension): +# Load lazily +object_session = None + +class UOWEventHandler(interfaces.AttributeExtension): """An event handler added to all class attributes which handles session operations. """ @@ -37,52 +39,46 @@ class UOWEventHandler(attributes.AttributeExtension): self.class_ = class_ self.cascade = cascade - def append(self, event, obj, item): + def append(self, obj, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance sess = object_session(obj) if sess is not None: if self.cascade is not None and self.cascade.save_update and item not in sess: mapper = object_mapper(obj) - prop = mapper.props[self.key] + prop = mapper.get_property(self.key) ename = prop.mapper.entity_name sess.save_or_update(item, entity_name=ename) - def delete(self, event, obj, item): + def remove(self, obj, item, initiator): # currently no cascade rules for removing an item from a list # (i.e. it stays in the Session) pass - def set(self, event, obj, newvalue, oldvalue): + def set(self, obj, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance sess = object_session(obj) if sess is not None: if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess: mapper = object_mapper(obj) - prop = mapper.props[self.key] + prop = mapper.get_property(self.key) ename = prop.mapper.entity_name sess.save_or_update(newvalue, entity_name=ename) -class UOWProperty(attributes.InstrumentedAttribute): - """Override ``InstrumentedAttribute`` to provide an extra - ``AttributeExtension`` to all managed attributes as well as the - `property` property. - """ - - def __init__(self, manager, class_, key, uselist, callable_, typecallable, cascade=None, extension=None, **kwargs): - extension = util.to_list(extension or []) - extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) - super(UOWProperty, self).__init__(manager, key, uselist, callable_, typecallable, extension=extension,**kwargs) - self.class_ = class_ - - property = property(lambda s:class_mapper(s.class_).props[s.key], doc="returns the MapperProperty object associated with this property") class UOWAttributeManager(attributes.AttributeManager): """Override ``AttributeManager`` to provide the ``UOWProperty`` instance for all ``InstrumentedAttributes``. """ - def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs): - return UOWProperty(self, class_, key, uselist, callable_, typecallable, **kwargs) + def create_prop(self, class_, key, uselist, callable_, typecallable, + cascade=None, extension=None, **kwargs): + extension = util.to_list(extension or []) + extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) + + return super(UOWAttributeManager, self).create_prop( + class_, key, uselist, callable_, typecallable, + extension=extension, **kwargs) + class UnitOfWork(object): """Main UOW object which stores lists of dirty/new/deleted objects. @@ -122,7 +118,7 @@ class UnitOfWork(object): def _validate_obj(self, obj): if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \ (not hasattr(obj, '_instance_key') and obj not in self.new): - raise InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj)) + raise exceptions.InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj)) def _is_valid(self, obj): if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \ @@ -138,7 +134,7 @@ class UnitOfWork(object): self.new.remove(obj) if not hasattr(obj, '_instance_key'): mapper = object_mapper(obj) - obj._instance_key = mapper.instance_key(obj) + obj._instance_key = mapper.identity_key_from_instance(obj) if hasattr(obj, '_sa_insert_order'): delattr(obj, '_sa_insert_order') self.identity_map[obj._instance_key] = obj @@ -148,7 +144,7 @@ class UnitOfWork(object): """register the given object as 'new' (i.e. unsaved) within this unit of work.""" if hasattr(obj, '_instance_key'): - raise InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj)) + raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj)) if obj not in self.new: self.new.add(obj) obj._sa_insert_order = len(self.new) @@ -204,14 +200,14 @@ class UnitOfWork(object): for obj in self.deleted.intersection(objset).difference(processed): flush_context.register_object(obj, isdelete=True) - trans = session.create_transaction(autoflush=False) - flush_context.transaction = trans + session.create_transaction(autoflush=False) + flush_context.transaction = session.transaction try: flush_context.execute() except: - trans.rollback() + session.rollback() raise - trans.commit() + session.commit() flush_context.post_exec() @@ -228,6 +224,7 @@ class UOWTransaction(object): def __init__(self, uow, session): self.uow = uow self.session = session + self.mapper_flush_opts = session._mapper_flush_opts # stores tuples of mapper/dependent mapper pairs, # representing a partial ordering fed into topological sort diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 3b3b9b7ed..d248c0dd0 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -4,7 +4,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions +from sqlalchemy import sql, util, exceptions, sql_util +from sqlalchemy.orm.interfaces import MapperExtension, EXT_PASS all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none"]) @@ -89,8 +90,6 @@ class TranslatingDict(dict): def __translate_col(self, col): ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False) -# if col is not ourcol and ourcol is not None: -# print "TD TRANSLATING ", col, "TO", ourcol if ourcol is None: return col else: @@ -111,6 +110,56 @@ class TranslatingDict(dict): def setdefault(self, col, value): return super(TranslatingDict, self).setdefault(self.__translate_col(col), value) +class ExtensionCarrier(MapperExtension): + def __init__(self, _elements=None): + self.__elements = _elements or [] + + def copy(self): + return ExtensionCarrier(list(self.__elements)) + + def __iter__(self): + return iter(self.__elements) + + def insert(self, extension): + """Insert a MapperExtension at the beginning of this ExtensionCarrier's list.""" + + self.__elements.insert(0, extension) + + def append(self, extension): + """Append a MapperExtension at the end of this ExtensionCarrier's list.""" + + self.__elements.append(extension) + + def _create_do(funcname): + def _do(self, *args, **kwargs): + for elem in self.__elements: + ret = getattr(elem, funcname)(*args, **kwargs) + if ret is not EXT_PASS: + return ret + else: + return EXT_PASS + return _do + + init_instance = _create_do('init_instance') + init_failed = _create_do('init_failed') + dispose_class = _create_do('dispose_class') + get_session = _create_do('get_session') + load = _create_do('load') + get = _create_do('get') + get_by = _create_do('get_by') + select_by = _create_do('select_by') + select = _create_do('select') + translate_row = _create_do('translate_row') + create_instance = _create_do('create_instance') + append_result = _create_do('append_result') + populate_instance = _create_do('populate_instance') + before_insert = _create_do('before_insert') + before_update = _create_do('before_update') + after_update = _create_do('after_update') + after_insert = _create_do('after_insert') + before_delete = _create_do('before_delete') + after_delete = _create_do('after_delete') + class BinaryVisitor(sql.ClauseVisitor): def __init__(self, func): self.func = func @@ -118,6 +167,138 @@ class BinaryVisitor(sql.ClauseVisitor): def visit_binary(self, binary): self.func(binary) +class AliasedClauses(object): + """Creates aliases of a mapped tables for usage in ORM queries. + """ + + def __init__(self, mapped_table, alias=None): + if alias: + self.alias = alias + else: + self.alias = mapped_table.alias() + self.mapped_table = mapped_table + self.extra_cols = {} + self.row_decorator = self._create_row_adapter() + + def aliased_column(self, column): + """return the aliased version of the given column, creating a new label for it if not already + present in this AliasedClauses.""" + + conv = self.alias.corresponding_column(column, raiseerr=False) + if conv: + return conv + + if column in self.extra_cols: + return self.extra_cols[column] + + aliased_column = column + # for column-level subqueries, swap out its selectable with our + # eager version as appropriate, and manually build the + # "correlation" list of the subquery. + class ModifySubquery(sql.ClauseVisitor): + def visit_select(s, select): + select._should_correlate = False + select.append_correlation(self.alias) + aliased_column = sql_util.ClauseAdapter(self.alias).chain(ModifySubquery()).traverse(aliased_column, clone=True) + aliased_column = aliased_column.label(None) + self.row_decorator.map[column] = aliased_column + # TODO: this is a little hacky + for attr in ('name', '_label'): + if hasattr(column, attr): + self.row_decorator.map[getattr(column, attr)] = aliased_column + self.extra_cols[column] = aliased_column + return aliased_column + + def adapt_clause(self, clause): + return self.aliased_column(clause) +# return sql_util.ClauseAdapter(self.alias).traverse(clause, clone=True) + + def _create_row_adapter(self): + """Return a callable which, + when passed a RowProxy, will return a new dict-like object + that translates Column objects to that of this object's Alias before calling upon the row. + + This allows a regular Table to be used to target columns in a row that was in reality generated from an alias + of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form + of the table. + """ + class AliasedRowAdapter(object): + def __init__(self, row): + self.row = row + def __contains__(self, key): + return key in map or key in self.row + def has_key(self, key): + return key in self + def __getitem__(self, key): + if key in map: + key = map[key] + return self.row[key] + def keys(self): + return map.keys() + map = {} + for c in self.alias.c: + parent = self.mapped_table.corresponding_column(c) + map[parent] = c + map[parent._label] = c + map[parent.name] = c + for c in self.extra_cols: + map[c] = self.extra_cols[c] + # TODO: this is a little hacky + for attr in ('name', '_label'): + if hasattr(c, attr): + map[getattr(c, attr)] = self.extra_cols[c] + + AliasedRowAdapter.map = map + return AliasedRowAdapter + + +class PropertyAliasedClauses(AliasedClauses): + """extends AliasedClauses to add support for primary/secondary joins on a relation().""" + + def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None): + super(PropertyAliasedClauses, self).__init__(prop.select_table) + + self.parentclauses = parentclauses + if parentclauses is not None: + self.path = parentclauses.path + (prop.parent, prop.key) + else: + self.path = (prop.parent, prop.key) + + self.prop = prop + + if prop.secondary: + self.secondary = prop.secondary.alias() + if parentclauses is not None: + aliasizer = sql_util.ClauseAdapter(self.alias).\ + chain(sql_util.ClauseAdapter(self.secondary)).\ + chain(sql_util.ClauseAdapter(parentclauses.alias)) + else: + aliasizer = sql_util.ClauseAdapter(self.alias).\ + chain(sql_util.ClauseAdapter(self.secondary)) + self.secondaryjoin = aliasizer.traverse(secondaryjoin, clone=True) + self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True) + else: + if parentclauses is not None: + aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side) + aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side)) + else: + aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side) + self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True) + self.secondary = None + self.secondaryjoin = None + + if prop.order_by: + self.order_by = sql_util.ClauseAdapter(self.alias).copy_and_process(util.to_list(prop.order_by)) + else: + self.order_by = None + + mapper = property(lambda self:self.prop.mapper) + table = property(lambda self:self.prop.select_table) + + def __str__(self): + return "->".join([str(s) for s in self.path]) + + def instance_str(instance): """Return a string describing an instance.""" diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 8670464a0..f86e14ab1 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -13,7 +13,7 @@ automatically, based on module type and connect arguments, simply by calling regular DBAPI connect() methods. """ -import weakref, string, time, sys, traceback +import weakref, time try: import cPickle as pickle except: @@ -190,6 +190,7 @@ class _ConnectionRecord(object): def __init__(self, pool): self.__pool = pool self.connection = self.__connect() + self.properties = {} def close(self): if self.connection is not None: @@ -207,10 +208,12 @@ class _ConnectionRecord(object): def get_connection(self): if self.connection is None: self.connection = self.__connect() + self.properties.clear() elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle): self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection)) self.__close() self.connection = self.__connect() + self.properties.clear() return self.connection def __close(self): @@ -257,6 +260,21 @@ class _ConnectionFairy(object): _logger = property(lambda self: self._pool.logger) is_valid = property(lambda self:self.connection is not None) + + def _get_properties(self): + """A property collection unique to this DBAPI connection.""" + + try: + return self._connection_record.properties + except AttributeError: + if self.connection is None: + raise exceptions.InvalidRequestError("This connection is closed") + try: + return self._detatched_properties + except AttributeError: + self._detatched_properties = value = {} + return value + properties = property(_get_properties) def invalidate(self, e=None): """Mark this connection as invalidated. @@ -301,6 +319,8 @@ class _ConnectionFairy(object): if self._connection_record is not None: self._connection_record.connection = None self._pool.do_return_conn(self._connection_record) + self._detatched_properties = \ + self._connection_record.properties.copy() self._connection_record = None def close_open_cursors(self): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index f6ad52adc..3faa3b89c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -17,17 +17,19 @@ objects as well as the visitor interface, so that the schema package *plugs in* to the SQL package. """ -from sqlalchemy import sql, types, exceptions, util, databases +from sqlalchemy import sql, types, exceptions,util, databases import sqlalchemy -import copy, re, string +import re, string, inspect __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', 'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'DefaultGenerator', 'Constraint', - 'MetaData', 'ThreadLocalMetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] + 'MetaData', 'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): """Base class for items that define a database schema.""" + __metaclass__ = sql._FigureVisitName + def _init_items(self, *args): """Initialize the list of child items for this SchemaItem.""" @@ -69,15 +71,7 @@ class SchemaItem(object): m = self._derived_metadata() return m and m.bind or None - def get_engine(self): - """Return the engine or raise an error if no engine. - - Deprecated. use the "bind" attribute. - """ - - return self._get_engine(raiseerr=True) - - def _set_casing_strategy(self, name, kwargs, keyname='case_sensitive'): + def _set_casing_strategy(self, kwargs, keyname='case_sensitive'): """Set the "case_sensitive" argument sent via keywords to the item's constructor. For the purposes of Table's 'schema' property, the name of the @@ -85,7 +79,7 @@ class SchemaItem(object): """ setattr(self, '_%s_setting' % keyname, kwargs.pop(keyname, None)) - def _determine_case_sensitive(self, name, keyname='case_sensitive'): + def _determine_case_sensitive(self, keyname='case_sensitive'): """Determine the `case_sensitive` value for this item. For the purposes of Table's `schema` property, the name of the @@ -111,16 +105,22 @@ class SchemaItem(object): return True def _get_case_sensitive(self): + """late-compile the 'case-sensitive' setting when first accessed. + + typically the SchemaItem will be assembled into its final structure + of other SchemaItems at this point, whereby it can attain this setting + from its containing SchemaItem if not defined locally. + """ + try: return self.__case_sensitive except AttributeError: - self.__case_sensitive = self._determine_case_sensitive(self.name) + self.__case_sensitive = self._determine_case_sensitive() return self.__case_sensitive case_sensitive = property(_get_case_sensitive) - engine = property(lambda s:s._get_engine()) metadata = property(lambda s:s._derived_metadata()) - bind = property(lambda s:s.engine) + bind = property(lambda s:s._get_engine()) def _get_table_key(name, schema): if schema is None: @@ -128,30 +128,16 @@ def _get_table_key(name, schema): else: return schema + "." + name -class _TableSingleton(type): +class _TableSingleton(sql._FigureVisitName): """A metaclass used by the ``Table`` object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): - if isinstance(metadata, sql.Executor): - # backwards compatibility - get a BoundSchema associated with the engine - engine = metadata - if not hasattr(engine, '_legacy_metadata'): - engine._legacy_metadata = MetaData(engine) - metadata = engine._legacy_metadata - elif metadata is not None and not isinstance(metadata, MetaData): - # they left MetaData out, so assume its another SchemaItem, add it to *args - args = list(args) - args.insert(0, metadata) - metadata = None - - if metadata is None: - metadata = default_metadata - schema = kwargs.get('schema', None) autoload = kwargs.pop('autoload', False) autoload_with = kwargs.pop('autoload_with', False) mustexist = kwargs.pop('mustexist', False) useexisting = kwargs.pop('useexisting', False) + include_columns = kwargs.pop('include_columns', None) key = _get_table_key(name, schema) try: table = metadata.tables[key] @@ -170,9 +156,9 @@ class _TableSingleton(type): if autoload: try: if autoload_with: - autoload_with.reflecttable(table) + autoload_with.reflecttable(table, include_columns=include_columns) else: - metadata._get_engine(raiseerr=True).reflecttable(table) + metadata._get_engine(raiseerr=True).reflecttable(table, include_columns=include_columns) except exceptions.NoSuchTableError: del metadata.tables[key] raise @@ -187,7 +173,7 @@ class Table(SchemaItem, sql.TableClause): This subclasses ``sql.TableClause`` to provide a table that is associated with an instance of ``MetaData``, which in turn - may be associated with an instance of ``SQLEngine``. + may be associated with an instance of ``Engine``. Whereas ``TableClause`` represents a table as its used in an SQL expression, ``Table`` represents a table as it exists in a @@ -232,16 +218,28 @@ class Table(SchemaItem, sql.TableClause): options include: schema - Defaults to None: the *schema name* for this table, which is + The *schema name* for this table, which is required if the table resides in a schema other than the default selected schema for the engine's database - connection. + connection. Defaults to ``None``. autoload Defaults to False: the Columns for this table should be reflected from the database. Usually there will be no Column objects in the constructor if this property is set. + autoload_with + if autoload==True, this is an optional Engine or Connection + instance to be used for the table reflection. If ``None``, + the underlying MetaData's bound connectable will be used. + + include_columns + A list of strings indicating a subset of columns to be + loaded via the ``autoload`` operation; table columns who + aren't present in this list will not be represented on the resulting + ``Table`` object. Defaults to ``None`` which indicates all + columns should be reflected. + mustexist Defaults to False: indicates that this Table must already have been defined elsewhere in the application, else an @@ -293,8 +291,8 @@ class Table(SchemaItem, sql.TableClause): self.fullname = self.name self.owner = kwargs.pop('owner', None) - self._set_casing_strategy(name, kwargs) - self._set_casing_strategy(self.schema or '', kwargs, keyname='case_sensitive_schema') + self._set_casing_strategy(kwargs) + self._set_casing_strategy(kwargs, keyname='case_sensitive_schema') if len([k for k in kwargs if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]): raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys())) @@ -302,6 +300,8 @@ class Table(SchemaItem, sql.TableClause): # store extra kwargs, which should only contain db-specific options self.kwargs = kwargs + key = property(lambda self:_get_table_key(self.name, self.schema)) + def _export_columns(self, columns=None): # override FromClause's collection initialization logic; TableClause and Table # implement it differently @@ -311,7 +311,7 @@ class Table(SchemaItem, sql.TableClause): try: return getattr(self, '_case_sensitive_schema') except AttributeError: - setattr(self, '_case_sensitive_schema', self._determine_case_sensitive(self.schema or '', keyname='case_sensitive_schema')) + setattr(self, '_case_sensitive_schema', self._determine_case_sensitive(keyname='case_sensitive_schema')) return getattr(self, '_case_sensitive_schema') case_sensitive_schema = property(_get_case_sensitive_schema) @@ -361,36 +361,28 @@ class Table(SchemaItem, sql.TableClause): else: return [] - def exists(self, bind=None, connectable=None): + def exists(self, bind=None): """Return True if this table exists.""" - if connectable is not None: - bind = connectable - if bind is None: bind = self._get_engine(raiseerr=True) def do(conn): - e = conn.engine - return e.dialect.has_table(conn, self.name, schema=self.schema) + return conn.dialect.has_table(conn, self.name, schema=self.schema) return bind.run_callable(do) - def create(self, bind=None, checkfirst=False, connectable=None): + def create(self, bind=None, checkfirst=False): """Issue a ``CREATE`` statement for this table. See also ``metadata.create_all()``.""" - if connectable is not None: - bind = connectable self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self]) - def drop(self, bind=None, checkfirst=False, connectable=None): + def drop(self, bind=None, checkfirst=False): """Issue a ``DROP`` statement for this table. See also ``metadata.drop_all()``.""" - if connectable is not None: - bind = connectable self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self]) def tometadata(self, metadata, schema=None): @@ -417,7 +409,7 @@ class Column(SchemaItem, sql._ColumnClause): ``TableClause``/``Table``. """ - def __init__(self, name, type, *args, **kwargs): + def __init__(self, name, type_, *args, **kwargs): """Construct a new ``Column`` object. Arguments are: @@ -426,7 +418,7 @@ class Column(SchemaItem, sql._ColumnClause): The name of this column. This should be the identical name as it appears, or will appear, in the database. - type + type\_ The ``TypeEngine`` for this column. This can be any subclass of ``types.AbstractType``, including the database-agnostic types defined in the types module, @@ -516,7 +508,7 @@ class Column(SchemaItem, sql._ColumnClause): identifier contains mixed case. """ - super(Column, self).__init__(name, None, type) + super(Column, self).__init__(name, None, type_) self.args = args self.key = kwargs.pop('key', name) self._primary_key = kwargs.pop('primary_key', False) @@ -526,7 +518,7 @@ class Column(SchemaItem, sql._ColumnClause): self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) self.quote = kwargs.pop('quote', False) - self._set_casing_strategy(name, kwargs) + self._set_casing_strategy(kwargs) self.onupdate = kwargs.pop('onupdate', None) self.autoincrement = kwargs.pop('autoincrement', True) self.constraints = util.Set() @@ -631,12 +623,13 @@ class Column(SchemaItem, sql._ColumnClause): This is a copy of this ``Column`` referenced by a different parent (such as an alias or select statement). """ + fk = [ForeignKey(f._colspec) for f in self.foreign_keys] c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk) c.table = selectable c.orig_set = self.orig_set - c._distance = self._distance + 1 c.__originating_column = self.__originating_column + c._distance = self._distance + 1 if not c._is_oid: selectable.columns.add(c) if self.primary_key: @@ -749,18 +742,14 @@ class ForeignKey(SchemaItem): raise exceptions.ArgumentError("Could not create ForeignKey '%s' on table '%s': table '%s' has no column named '%s'" % (self._colspec, parenttable.name, table.name, str(e))) else: self._column = self._colspec + # propigate TypeEngine to parent if it didnt have one - if self.parent.type is types.NULLTYPE: + if isinstance(self.parent.type, types.NullType): self.parent.type = self._column.type return self._column column = property(lambda s: s._init_column()) - def accept_visitor(self, visitor): - """Call the `visit_foreign_key` method on the given visitor.""" - - visitor.visit_foreign_key(self) - def _get_parent(self): return self.parent @@ -802,7 +791,7 @@ class DefaultGenerator(SchemaItem): def execute(self, bind=None, **kwargs): if bind is None: bind = self._get_engine(raiseerr=True) - return bind.execute_default(self, **kwargs) + return bind._execute_default(self, **kwargs) def __repr__(self): return "DefaultGenerator()" @@ -814,9 +803,6 @@ class PassiveDefault(DefaultGenerator): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_visitor(self, visitor): - return visitor.visit_passive_default(self) - def __repr__(self): return "PassiveDefault(%s)" % repr(self.arg) @@ -829,15 +815,26 @@ class ColumnDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) - self.arg = arg - - def accept_visitor(self, visitor): - """Call the visit_column_default method on the given visitor.""" + if callable(arg): + if not inspect.isfunction(arg): + self.arg = lambda ctx: arg() + else: + argspec = inspect.getargspec(arg) + if len(argspec[0]) == 0: + self.arg = lambda ctx: arg() + elif len(argspec[0]) != 1: + raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments") + else: + self.arg = arg + else: + self.arg = arg + def _visit_name(self): if self.for_update: - return visitor.visit_column_onupdate(self) + return "column_onupdate" else: - return visitor.visit_column_default(self) + return "column_default" + __visit_name__ = property(_visit_name) def __repr__(self): return "ColumnDefault(%s)" % repr(self.arg) @@ -852,7 +849,7 @@ class Sequence(DefaultGenerator): self.increment = increment self.optional=optional self.quote = quote - self._set_casing_strategy(name, kwargs) + self._set_casing_strategy(kwargs) def __repr__(self): return "Sequence(%s)" % string.join( @@ -864,20 +861,16 @@ class Sequence(DefaultGenerator): super(Sequence, self)._set_parent(column) column.sequence = self - def create(self, bind=None): + def create(self, bind=None, checkfirst=True): if bind is None: bind = self._get_engine(raiseerr=True) - bind.create(self) + bind.create(self, checkfirst=checkfirst) - def drop(self, bind=None): + def drop(self, bind=None, checkfirst=True): if bind is None: bind = self._get_engine(raiseerr=True) - bind.drop(self) - - def accept_visitor(self, visitor): - """Call the visit_seauence method on the given visitor.""" + bind.drop(self, checkfirst=checkfirst) - return visitor.visit_sequence(self) class Constraint(SchemaItem): """Represent a table-level ``Constraint`` such as a composite primary key, foreign key, or unique constraint. @@ -891,7 +884,7 @@ class Constraint(SchemaItem): self.columns = sql.ColumnCollection() def __contains__(self, x): - return x in self.columns + return self.columns.contains_column(x) def keys(self): return self.columns.keys() @@ -916,11 +909,12 @@ class CheckConstraint(Constraint): super(CheckConstraint, self).__init__(name) self.sqltext = sqltext - def accept_visitor(self, visitor): + def _visit_name(self): if isinstance(self.parent, Table): - visitor.visit_check_constraint(self) + return "check_constraint" else: - visitor.visit_column_check_constraint(self) + return "column_check_constraint" + __visit_name__ = property(_visit_name) def _set_parent(self, parent): self.parent = parent @@ -949,9 +943,6 @@ class ForeignKeyConstraint(Constraint): for (c, r) in zip(self.__colnames, self.__refcolnames): self.append_element(c,r) - def accept_visitor(self, visitor): - visitor.visit_foreign_key_constraint(self) - def append_element(self, col, refcol): fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter) fk._set_parent(self.table.c[col]) @@ -975,9 +966,6 @@ class PrimaryKeyConstraint(Constraint): for c in self.__colnames: self.append_column(table.c[c]) - def accept_visitor(self, visitor): - visitor.visit_primary_key_constraint(self) - def add(self, col): self.append_column(col) @@ -1009,9 +997,6 @@ class UniqueConstraint(Constraint): def append_column(self, col): self.columns.add(col) - def accept_visitor(self, visitor): - visitor.visit_unique_constraint(self) - def copy(self): return UniqueConstraint(name=self.name, *self.__colnames) @@ -1075,22 +1060,19 @@ class Index(SchemaItem): % (self.name, column)) self.columns.append(column) - def create(self, connectable=None): - if connectable is not None: - connectable.create(self) + def create(self, bind=None): + if bind is not None: + bind.create(self) else: self._get_engine(raiseerr=True).create(self) return self - def drop(self, connectable=None): - if connectable is not None: - connectable.drop(self) + def drop(self, bind=None): + if bind is not None: + bind.drop(self) else: self._get_engine(raiseerr=True).drop(self) - def accept_visitor(self, visitor): - visitor.visit_index(self) - def __str__(self): return repr(self) @@ -1103,69 +1085,34 @@ class Index(SchemaItem): class MetaData(SchemaItem): """Represent a collection of Tables and their associated schema constructs.""" - def __init__(self, engine_or_url=None, url=None, bind=None, engine=None, **kwargs): + __visit_name__ = 'metadata' + + def __init__(self, bind=None, **kwargs): """create a new MetaData object. bind an Engine, or a string or URL instance which will be passed - to create_engine(), along with \**kwargs - this MetaData will - be bound to the resulting engine. + to create_engine(), this MetaData will be bound to the resulting + engine. - engine_or_url - deprecated; a synonym for 'bind' - - url - deprecated. a string or URL instance which will be passed to - create_engine(), along with \**kwargs - this MetaData will be - bound to the resulting engine. - - engine - deprecated. an Engine instance to which this MetaData will - be bound. - case_sensitive - popped from \**kwargs, indicates default case sensitive - setting for all contained objects. defaults to True. + popped from \**kwargs, indicates default case sensitive setting for + all contained objects. defaults to True. - name - deprecated, optional name for this MetaData instance. - - """ - - # transition from <= 0.3.8 signature: - # MetaData(name=None, url=None, engine=None) - # to 0.4 signature: - # MetaData(engine_or_url=None) - name = kwargs.get('name', None) - if engine_or_url is None: - engine_or_url = url or bind or engine - elif 'name' in kwargs: - engine_or_url = engine_or_url or bind or engine or url - else: - import sqlalchemy.engine as engine - import sqlalchemy.engine.url as url - if (not isinstance(engine_or_url, url.URL) and - not isinstance(engine_or_url, engine.Connectable)): - try: - url.make_url(engine_or_url) - except exceptions.ArgumentError: - # nope, must have been a name as 1st positional - name, engine_or_url = engine_or_url, (url or engine or bind) - kwargs.pop('name', None) + """ self.tables = {} - self.name = name - self._bind = None - self._set_casing_strategy(name, kwargs) - if engine_or_url: - self.connect(engine_or_url, **kwargs) + self._set_casing_strategy(kwargs) + self.bind = bind + + def __repr__(self): + return 'MetaData(%r)' % self.bind def __getstate__(self): - return {'tables':self.tables, 'name':self.name, 'casesensitive':self._case_sensitive_setting} - + return {'tables':self.tables, 'casesensitive':self._case_sensitive_setting} + def __setstate__(self, state): self.tables = state['tables'] - self.name = state['name'] self._case_sensitive_setting = state['casesensitive'] self._bind = None @@ -1173,7 +1120,7 @@ class MetaData(SchemaItem): """return True if this MetaData is bound to an Engine.""" return self._bind is not None - def connect(self, bind=None, **kwargs): + def connect(self, bind, **kwargs): """bind this MetaData to an Engine. DEPRECATED. use metadata.bind = <engine> or metadata.bind = <url>. @@ -1184,13 +1131,8 @@ class MetaData(SchemaItem): produce the engine which to connect to. otherwise connects directly to the given Engine. - engine_or_url - deprecated. synonymous with "bind" - """ - if bind is None: - bind = kwargs.pop('engine_or_url', None) from sqlalchemy.engine.url import URL if isinstance(bind, (basestring, URL)): self._bind = sqlalchemy.create_engine(bind, **kwargs) @@ -1202,6 +1144,10 @@ class MetaData(SchemaItem): def clear(self): self.tables.clear() + def remove(self, table): + # TODO: scan all other tables and remove FK _column + del self.tables[table.key] + def table_iterator(self, reverse=True, tables=None): import sqlalchemy.sql_util if tables is None: @@ -1214,7 +1160,7 @@ class MetaData(SchemaItem): def _get_parent(self): return None - def create_all(self, bind=None, tables=None, checkfirst=True, connectable=None): + def create_all(self, bind=None, tables=None, checkfirst=True): """Create all tables stored in this metadata. This will conditionally create tables depending on if they do @@ -1224,21 +1170,16 @@ class MetaData(SchemaItem): A ``Connectable`` used to access the database; if None, uses the existing bind on this ``MetaData``, if any. - connectable - deprecated. synonymous with "bind" - tables Optional list of tables, which is a subset of the total tables in the ``MetaData`` (others are ignored). """ - if connectable is not None: - bind = connectable if bind is None: bind = self._get_engine(raiseerr=True) bind.create(self, checkfirst=checkfirst, tables=tables) - def drop_all(self, bind=None, tables=None, checkfirst=True, connectable=None): + def drop_all(self, bind=None, tables=None, checkfirst=True): """Drop all tables stored in this metadata. This will conditionally drop tables depending on if they @@ -1248,23 +1189,15 @@ class MetaData(SchemaItem): A ``Connectable`` used to access the database; if None, uses the existing bind on this ``MetaData``, if any. - connectable - deprecated. synonymous with "bind" - tables Optional list of tables, which is a subset of the total tables in the ``MetaData`` (others are ignored). """ - if connectable is not None: - bind = connectable if bind is None: bind = self._get_engine(raiseerr=True) bind.drop(self, checkfirst=checkfirst, tables=tables) - def accept_visitor(self, visitor): - visitor.visit_metadata(self) - def _derived_metadata(self): return self @@ -1276,27 +1209,22 @@ class MetaData(SchemaItem): return None return self._bind - -class BoundMetaData(MetaData): - """Deprecated. Use ``MetaData``.""" - - def __init__(self, engine_or_url, name=None, **kwargs): - super(BoundMetaData, self).__init__(engine_or_url=engine_or_url, - name=name, **kwargs) - - class ThreadLocalMetaData(MetaData): - """A ``MetaData`` that binds to multiple ``Engine`` implementations on a thread-local basis.""" + """Build upon ``MetaData`` to provide the capability to bind to +multiple ``Engine`` implementations on a dynamically alterable, +thread-local basis. + """ + + __visit_name__ = 'metadata' - def __init__(self, name=None, **kwargs): + def __init__(self, **kwargs): self.context = util.ThreadLocal() self.__engines = {} - super(ThreadLocalMetaData, self).__init__(engine_or_url=None, - name=name, **kwargs) + super(ThreadLocalMetaData, self).__init__(**kwargs) def connect(self, engine_or_url, **kwargs): from sqlalchemy.engine.url import URL - if isinstance(engine_or_url, basestring) or isinstance(engine_or_url, URL): + if isinstance(engine_or_url, (basestring, URL)): try: self.context._engine = self.__engines[engine_or_url] except KeyError: @@ -1304,6 +1232,8 @@ class ThreadLocalMetaData(MetaData): self.__engines[engine_or_url] = e self.context._engine = e else: + # TODO: this is squirrely. we shouldnt have to hold onto engines + # in a case like this if not self.__engines.has_key(engine_or_url): self.__engines[engine_or_url] = engine_or_url self.context._engine = engine_or_url @@ -1325,77 +1255,10 @@ class ThreadLocalMetaData(MetaData): raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") else: return None - engine=property(_get_engine) bind = property(_get_engine, connect) -def DynamicMetaData(name=None, threadlocal=True, **kw): - """Deprecated. Use ``MetaData`` or ``ThreadLocalMetaData``.""" - - if threadlocal: - return ThreadLocalMetaData(name=name, **kw) - else: - return MetaData(name=name, **kw) - class SchemaVisitor(sql.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" __traverse_options__ = {'schema_visitor':True} - - def visit_schema(self, schema): - """Visit a generic ``SchemaItem``.""" - pass - - def visit_table(self, table): - """Visit a ``Table``.""" - pass - - def visit_column(self, column): - """Visit a ``Column``.""" - pass - - def visit_foreign_key(self, join): - """Visit a ``ForeignKey``.""" - pass - - def visit_index(self, index): - """Visit an ``Index``.""" - pass - - def visit_passive_default(self, default): - """Visit a passive default.""" - pass - - def visit_column_default(self, default): - """Visit a ``ColumnDefault``.""" - pass - - def visit_column_onupdate(self, onupdate): - """Visit a ``ColumnDefault`` with the `for_update` flag set.""" - pass - - def visit_sequence(self, sequence): - """Visit a ``Sequence``.""" - pass - - def visit_primary_key_constraint(self, constraint): - """Visit a ``PrimaryKeyConstraint``.""" - pass - - def visit_foreign_key_constraint(self, constraint): - """Visit a ``ForeignKeyConstraint``.""" - pass - - def visit_unique_constraint(self, constraint): - """Visit a ``UniqueConstraint``.""" - pass - - def visit_check_constraint(self, constraint): - """Visit a ``CheckConstraint``.""" - pass - - def visit_column_check_constraint(self, constraint): - """Visit a ``CheckConstraint`` on a ``Column``.""" - pass - -default_metadata = ThreadLocalMetaData(name='default') diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 8b454947e..01588e92d 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -24,54 +24,21 @@ are less guaranteed to stay the same in future releases. """ -from sqlalchemy import util, exceptions, logging +from sqlalchemy import util, exceptions from sqlalchemy import types as sqltypes -import string, re, random, sets +import re, operator - -__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', +__all__ = ['Alias', 'ClauseElement', 'ClauseParameters', 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', - 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join', - 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc', - 'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete', + 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', + 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', + 'between', 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', 'insert', 'intersect', 'intersect_all', 'join', 'literal', - 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select', + 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select', 'subquery', 'table', 'text', 'union', 'union_all', 'update',] -# precedence ordering for common operators. if an operator is not present in this list, -# it will be parenthesized when grouped against other operators -PRECEDENCE = { - 'FROM':15, - '*':7, - '/':7, - '%':7, - '+':6, - '-':6, - 'ILIKE':5, - 'NOT ILIKE':5, - 'LIKE':5, - 'NOT LIKE':5, - 'IN':5, - 'NOT IN':5, - 'IS':5, - 'IS NOT':5, - '=':5, - '!=':5, - '>':5, - '<':5, - '>=':5, - '<=':5, - 'BETWEEN':5, - 'NOT':4, - 'AND':3, - 'OR':2, - ',':-1, - 'AS':-1, - 'EXISTS':0, - '_smallest': -1000, - '_largest': 1000 -} +BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) def desc(column): """Return a descending ``ORDER BY`` clause element. @@ -141,7 +108,7 @@ def join(left, right, onclause=None, **kwargs): return Join(left, right, onclause, **kwargs) -def select(columns=None, whereclause = None, from_obj = [], **kwargs): +def select(columns=None, whereclause=None, from_obj=[], **kwargs): """Returns a ``SELECT`` clause element. Similar functionality is also available via the ``select()`` method on any @@ -224,9 +191,6 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs): automatically bind to whatever ``Connectable`` instances can be located within its contained ``ClauseElement`` members. - engine=None - deprecated. a synonym for "bind". - limit=None a numerical value which usually compiles to a ``LIMIT`` expression in the resulting select. Databases that don't support ``LIMIT`` @@ -238,12 +202,8 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs): will attempt to provide similar functionality. scalar=False - when ``True``, indicates that the resulting ``Select`` object - is to be used in the "columns" clause of another select statement, - where the evaluated value of the column is the scalar result of - this statement. Normally, placing any ``Selectable`` within the - columns clause of a ``select()`` call will expand the member - columns of the ``Selectable`` individually. + deprecated. use select(...).as_scalar() to create a "scalar column" + proxy for an existing Select object. correlate=True indicates that this ``Select`` object should have its contained @@ -254,8 +214,12 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs): rendered in the ``FROM`` clause of this select statement. """ - - return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs) + scalar = kwargs.pop('scalar', False) + s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) + if scalar: + return s.as_scalar() + else: + return s def subquery(alias, *args, **kwargs): """Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select]. @@ -271,7 +235,7 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) def insert(table, values = None, **kwargs): - """Return an [sqlalchemy.sql#_Insert] clause element. + """Return an [sqlalchemy.sql#Insert] clause element. Similar functionality is available via the ``insert()`` method on [sqlalchemy.schema#Table]. @@ -304,10 +268,10 @@ def insert(table, values = None, **kwargs): against the ``INSERT`` statement. """ - return _Insert(table, values, **kwargs) + return Insert(table, values, **kwargs) def update(table, whereclause = None, values = None, **kwargs): - """Return an [sqlalchemy.sql#_Update] clause element. + """Return an [sqlalchemy.sql#Update] clause element. Similar functionality is available via the ``update()`` method on [sqlalchemy.schema#Table]. @@ -344,10 +308,10 @@ def update(table, whereclause = None, values = None, **kwargs): against the ``UPDATE`` statement. """ - return _Update(table, whereclause, values, **kwargs) + return Update(table, whereclause, values, **kwargs) def delete(table, whereclause = None, **kwargs): - """Return a [sqlalchemy.sql#_Delete] clause element. + """Return a [sqlalchemy.sql#Delete] clause element. Similar functionality is available via the ``delete()`` method on [sqlalchemy.schema#Table]. @@ -361,7 +325,7 @@ def delete(table, whereclause = None, **kwargs): """ - return _Delete(table, whereclause, **kwargs) + return Delete(table, whereclause, **kwargs) def and_(*clauses): """Join a list of clauses together using the ``AND`` operator. @@ -371,7 +335,7 @@ def and_(*clauses): """ if len(clauses) == 1: return clauses[0] - return ClauseList(operator='AND', *clauses) + return ClauseList(operator=operator.and_, *clauses) def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -382,7 +346,7 @@ def or_(*clauses): if len(clauses) == 1: return clauses[0] - return ClauseList(operator='OR', *clauses) + return ClauseList(operator=operator.or_, *clauses) def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -391,7 +355,7 @@ def not_(clause): subclasses to produce the same result. """ - return clause._negate() + return operator.inv(clause) def distinct(expr): """return a ``DISTINCT`` clause.""" @@ -407,12 +371,9 @@ def between(ctest, cleft, cright): provides similar functionality. """ - return _BinaryExpression(ctest, ClauseList(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type), operator='AND', group=False), 'BETWEEN') + ctest = _literal_as_binds(ctest) + return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op) -def between_(*args, **kwargs): - """synonym for [sqlalchemy.sql#between()] (deprecated).""" - - return between(*args, **kwargs) def case(whens, value=None, else_=None): """Produce a ``CASE`` statement. @@ -435,7 +396,7 @@ def case(whens, value=None, else_=None): type = list(whenlist[-1])[-1].type else: type = None - cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END']) + cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END']) return cc def cast(clause, totype, **kwargs): @@ -457,7 +418,7 @@ def cast(clause, totype, **kwargs): def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" - expr = _BinaryExpression(text(field), expr, "FROM") + expr = _BinaryExpression(text(field), expr, Operators.from_) return func.extract(expr) def exists(*args, **kwargs): @@ -587,7 +548,7 @@ def alias(selectable, alias=None): return Alias(selectable, alias=alias) -def literal(value, type=None): +def literal(value, type_=None): """Return a literal clause, bound to a bind parameter. Literal clauses are created automatically when non- @@ -603,13 +564,13 @@ def literal(value, type=None): the underlying DBAPI, or is translatable via the given type argument. - type + type\_ an optional [sqlalchemy.types#TypeEngine] which will provide bind-parameter translation for this literal. """ - return _BindParamClause('literal', value, type=type, unique=True) + return _BindParamClause('literal', value, type_=type_, unique=True) def label(name, obj): """Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement]. @@ -630,7 +591,7 @@ def label(name, obj): return _Label(name, obj) -def column(text, type=None): +def column(text, type_=None): """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. @@ -644,15 +605,15 @@ def column(text, type=None): constructs that are not to be quoted, use the [sqlalchemy.sql#literal_column()] function. - type + type\_ an optional [sqlalchemy.types#TypeEngine] object which will provide result-set translation for this column. """ - return _ColumnClause(text, type=type) + return _ColumnClause(text, type_=type_) -def literal_column(text, type=None): +def literal_column(text, type_=None): """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. @@ -674,7 +635,7 @@ def literal_column(text, type=None): """ - return _ColumnClause(text, type=type, is_literal=True) + return _ColumnClause(text, type_=type_, is_literal=True) def table(name, *columns): """Return a [sqlalchemy.sql#Table] object. @@ -685,7 +646,7 @@ def table(name, *columns): return TableClause(name, *columns) -def bindparam(key, value=None, type=None, shortname=None, unique=False): +def bindparam(key, value=None, type_=None, shortname=None, unique=False): """Create a bind parameter clause with the given key. value @@ -707,11 +668,22 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False): """ if isinstance(key, _ColumnClause): - return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique) + return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique) else: - return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique) + return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique) -def text(text, bind=None, engine=None, *args, **kwargs): +def outparam(key, type_=None): + """create an 'OUT' parameter for usage in functions (stored procedures), for databases + whith support them. + + The ``outparam`` can be used like a regular function parameter. The "output" value will + be available from the [sqlalchemy.engine#ResultProxy] object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + """ + + return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True) + +def text(text, bind=None, *args, **kwargs): """Create literal text to be inserted into a query. When constructing a query from a ``select()``, ``update()``, @@ -729,9 +701,6 @@ def text(text, bind=None, engine=None, *args, **kwargs): bind An optional connection or engine to be used for this text query. - engine - deprecated. a synonym for 'bind'. - bindparams A list of ``bindparam()`` instances which can be used to define the types and/or initial values for the bind parameters within @@ -748,7 +717,7 @@ def text(text, bind=None, engine=None, *args, **kwargs): """ - return _TextClause(text, engine=engine, bind=bind, *args, **kwargs) + return _TextClause(text, bind=bind, *args, **kwargs) def null(): """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement.""" @@ -786,30 +755,44 @@ def _compound_select(keyword, *selects, **kwargs): def _is_literal(element): return not isinstance(element, ClauseElement) -def _literals_as_text(element): - if _is_literal(element): +def _literal_as_text(element): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): return _TextClause(unicode(element)) else: return element -def _literals_as_binds(element, name='literal', type=None): - if _is_literal(element): +def _literal_as_column(element): + if isinstance(element, Operators): + return element.clause_element() + elif _is_literal(element): + return literal_column(str(element)) + else: + return element + +def _literal_as_binds(element, name='literal', type_=None): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): if element is None: return null() else: - return _BindParamClause(name, element, shortname=name, type=type, unique=True) + return _BindParamClause(name, element, shortname=name, type_=type_, unique=True) else: return element + +def _selectable(element): + if hasattr(element, '__selectable__'): + return element.__selectable__() + elif isinstance(element, Selectable): + return element + else: + raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) def is_column(col): return isinstance(col, ColumnElement) -class AbstractDialect(object): - """Represent the behavior of a particular database. - - Used by ``Compiled`` objects.""" - pass - class ClauseParameters(object): """Represent a dictionary/iterator of bind parameter key names/values. @@ -822,52 +805,51 @@ class ClauseParameters(object): def __init__(self, dialect, positional=None): super(ClauseParameters, self).__init__() self.dialect = dialect - self.binds = {} - self.binds_to_names = {} - self.binds_to_values = {} + self.__binds = {} self.positional = positional or [] + def get_parameter(self, key): + return self.__binds[key] + def set_parameter(self, bindparam, value, name): - self.binds[bindparam.key] = bindparam - self.binds[name] = bindparam - self.binds_to_names[bindparam] = name - self.binds_to_values[bindparam] = value + self.__binds[name] = [bindparam, name, value] def get_original(self, key): - """Return the given parameter as it was originally placed in - this ``ClauseParameters`` object, without any ``Type`` - conversion.""" - return self.binds_to_values[self.binds[key]] + return self.__binds[key][2] + + def get_type(self, key): + return self.__binds[key][0].type def get_processed(self, key): - bind = self.binds[key] - value = self.binds_to_values[bind] + (bind, name, value) = self.__binds[key] return bind.typeprocess(value, self.dialect) def keys(self): - return self.binds_to_names.values() + return self.__binds.keys() + + def __iter__(self): + return iter(self.keys()) def __getitem__(self, key): return self.get_processed(key) def __contains__(self, key): - return key in self.binds + return key in self.__binds def set_value(self, key, value): - bind = self.binds[key] - self.binds_to_values[bind] = value + self.__binds[key][2] = value def get_original_dict(self): - return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()]) + return dict([(name, value) for (b, name, value) in self.__binds.values()]) def get_raw_list(self): return [self.get_processed(key) for key in self.positional] - def get_raw_dict(self): - d = {} - for k in self.binds_to_names.values(): - d[k] = self.get_processed(k) - return d + def get_raw_dict(self, encode_keys=False): + if encode_keys: + return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()]) + else: + return dict([(key, self.get_processed(key)) for key in self.keys()]) def __repr__(self): return self.__class__.__name__ + ":" + repr(self.get_original_dict()) @@ -876,8 +858,8 @@ class ClauseVisitor(object): """A class that knows how to traverse and visit ``ClauseElements``. - Each ``ClauseElement``'s accept_visitor() method will call a - corresponding visit_XXXX() method here. Traversal of a + Calls visit_XXX() methods dynamically generated for each particualr + ``ClauseElement`` subclass encountered. Traversal of a hierarchy of ``ClauseElements`` is achieved via the ``traverse()`` method, which is passed the lead ``ClauseElement``. @@ -889,22 +871,44 @@ class ClauseVisitor(object): these options can indicate modifications to the set of elements returned, such as to not return column collections (column_collections=False) or to return Schema-level items - (schema_visitor=True).""" + (schema_visitor=True). + + ``ClauseVisitor`` also supports a simultaneous copy-and-traverse + operation, which will produce a copy of a given ``ClauseElement`` + structure while at the same time allowing ``ClauseVisitor`` subclasses + to modify the new structure in-place. + + """ __traverse_options__ = {} - def traverse(self, obj, stop_on=None): - stack = [obj] - traversal = [] - while len(stack) > 0: - t = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - for target in traversal: - v = self - while v is not None: - target.accept_visitor(v) - v = getattr(v, '_next', None) + + def traverse_single(self, obj, **kwargs): + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) + + def traverse(self, obj, stop_on=None, clone=False): + if clone: + obj = obj._clone() + + v = self + visitors = [] + while v is not None: + visitors.append(v) + v = getattr(v, '_next', None) + + def _trav(obj): + if stop_on is not None and obj in stop_on: + return + if clone: + obj._copy_internals() + for c in obj.get_children(**self.__traverse_options__): + _trav(c) + + for v in visitors: + meth = getattr(v, "visit_%s" % obj.__visit_name__, None) + if meth: + meth(obj) + _trav(obj) return obj def chain(self, visitor): @@ -916,78 +920,6 @@ class ClauseVisitor(object): tail = tail._next tail._next = visitor return self - - def visit_column(self, column): - pass - def visit_table(self, table): - pass - def visit_fromclause(self, fromclause): - pass - def visit_bindparam(self, bindparam): - pass - def visit_textclause(self, textclause): - pass - def visit_compound(self, compound): - pass - def visit_compound_select(self, compound): - pass - def visit_binary(self, binary): - pass - def visit_unary(self, unary): - pass - def visit_alias(self, alias): - pass - def visit_select(self, select): - pass - def visit_join(self, join): - pass - def visit_null(self, null): - pass - def visit_clauselist(self, list): - pass - def visit_calculatedclause(self, calcclause): - pass - def visit_grouping(self, gr): - pass - def visit_function(self, func): - pass - def visit_cast(self, cast): - pass - def visit_label(self, label): - pass - def visit_typeclause(self, typeclause): - pass - -class LoggingClauseVisitor(ClauseVisitor): - """extends ClauseVisitor to include debug logging of all traversal. - - To install this visitor, set logging.DEBUG for - 'sqlalchemy.sql.ClauseVisitor' **before** you import the - sqlalchemy.sql module. - """ - - def traverse(self, obj, stop_on=None): - stack = [(obj, "")] - traversal = [] - while len(stack) > 0: - (t, indent) = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, (t, indent)) - for c in t.get_children(**self.__traverse_options__): - stack.append((c, indent + " ")) - - for (target, indent) in traversal: - self.logger.debug(indent + repr(target)) - v = self - while v is not None: - target.accept_visitor(v) - v = getattr(v, '_next', None) - return obj - -LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor) - -if logging.is_debug_enabled(LoggingClauseVisitor.logger): - ClauseVisitor=LoggingClauseVisitor class NoColumnVisitor(ClauseVisitor): """a ClauseVisitor that will not traverse the exported Column @@ -1000,113 +932,35 @@ class NoColumnVisitor(ClauseVisitor): """ __traverse_options__ = {'column_collections':False} - -class Executor(object): - """Interface representing a "thing that can produce Compiled objects - and execute them".""" - - def execute_compiled(self, compiled, parameters, echo=None, **kwargs): - """Execute a Compiled object.""" - - raise NotImplementedError() - - def compiler(self, statement, parameters, **kwargs): - """Return a Compiled object for the given statement and parameters.""" - - raise NotImplementedError() - -class Compiled(ClauseVisitor): - """Represent a compiled SQL expression. - - The ``__str__`` method of the ``Compiled`` object should produce - the actual text of the statement. ``Compiled`` objects are - specific to their underlying database dialect, and also may - or may not be specific to the columns referenced within a - particular set of bind parameters. In no case should the - ``Compiled`` object be dependent on the actual values of those - bind parameters, even though it may reference those values as - defaults. - """ - - def __init__(self, dialect, statement, parameters, bind=None, engine=None): - """Construct a new ``Compiled`` object. - - statement - ``ClauseElement`` to be compiled. - - parameters - Optional dictionary indicating a set of bind parameters - specified with this ``Compiled`` object. These parameters - are the *default* values corresponding to the - ``ClauseElement``'s ``_BindParamClauses`` when the - ``Compiled`` is executed. In the case of an ``INSERT`` or - ``UPDATE`` statement, these parameters will also result in - the creation of new ``_BindParamClause`` objects for each - key and will also affect the generated column list in an - ``INSERT`` statement and the ``SET`` clauses of an - ``UPDATE`` statement. The keys of the parameter dictionary - can either be the string names of columns or - ``_ColumnClause`` objects. - - bind - optional engine or connection which will be bound to the - compiled object. - - engine - deprecated, a synonym for 'bind' - """ - self.dialect = dialect - self.statement = statement - self.parameters = parameters - self.bind = bind or engine - self.can_execute = statement.supports_execution() - - def compile(self): - self.traverse(self.statement) - self.after_compile() - - def __str__(self): - """Return the string text of the generated SQL statement.""" - - raise NotImplementedError() - def get_params(self, **params): - """Deprecated. use construct_params(). (supports unicode names) - """ - return self.construct_params(params) - - def construct_params(self, params): - """Return the bind params for this compiled object. - - Will start with the default parameters specified when this - ``Compiled`` object was first constructed, and will override - those values with those sent via `**params`, which are - key/value pairs. Each key should match one of the - ``_BindParamClause`` objects compiled into this object; either - the `key` or `shortname` property of the ``_BindParamClause``. - """ - raise NotImplementedError() +class _FigureVisitName(type): + def __init__(cls, clsname, bases, dict): + if not '__visit_name__' in cls.__dict__: + m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) + x = m.group(1) + x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) + cls.__visit_name__ = x.lower() + super(_FigureVisitName, cls).__init__(clsname, bases, dict) - def execute(self, *multiparams, **params): - """Execute this compiled object.""" - - e = self.bind - if e is None: - raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.") - return e.execute_compiled(self, *multiparams, **params) - - def scalar(self, *multiparams, **params): - """Execute this compiled object and return the result's scalar value.""" - - return self.execute(*multiparams, **params).scalar() - class ClauseElement(object): """Base class for elements of a programmatically constructed SQL expression. """ + __metaclass__ = _FigureVisitName + + def _clone(self): + """create a shallow copy of this ClauseElement. + + This method may be used by a generative API. + Its also used as part of the "deep" copy afforded + by a traversal that combines the _copy_internals() + method.""" + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + return c - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): """Return objects represented in this ``ClauseElement`` that should be added to the ``FROM`` list of a query, when this ``ClauseElement`` is placed in the column clause of a @@ -1115,7 +969,7 @@ class ClauseElement(object): raise NotImplementedError(repr(self)) - def _hide_froms(self): + def _hide_froms(self, **modifiers): """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces. """ @@ -1131,13 +985,14 @@ class ClauseElement(object): return self is other - def accept_visitor(self, visitor): - """Accept a ``ClauseVisitor`` and call the appropriate - ``visit_xxx`` method. - """ - - raise NotImplementedError(repr(self)) - + def _copy_internals(self): + """reassign internal elements to be clones of themselves. + + called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy.""" + + pass + def get_children(self, **kwargs): """return immediate child elements of this ``ClauseElement``. @@ -1160,18 +1015,6 @@ class ClauseElement(object): return False - def copy_container(self): - """Return a copy of this ``ClauseElement``, if this - ``ClauseElement`` contains other ``ClauseElements``. - - If this ``ClauseElement`` is not a container, it should return - self. This is used to create copies of expression trees that - still reference the same *leaf nodes*. The new structure can - then be restructured without affecting the original. - """ - - return self - def _find_engine(self): """Default strategy for locating an engine within the clause element. @@ -1195,7 +1038,6 @@ class ClauseElement(object): return None bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""") - engine = bind def execute(self, *multiparams, **params): """Compile and execute this ``ClauseElement``.""" @@ -1213,7 +1055,7 @@ class ClauseElement(object): return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, engine=None, parameters=None, compiler=None, dialect=None): + def compile(self, bind=None, parameters=None, compiler=None, dialect=None): """Compile this SQL expression. Uses the given ``Compiler``, or the given ``AbstractDialect`` @@ -1236,7 +1078,7 @@ class ClauseElement(object): ``SET`` and ``VALUES`` clause of those statements. """ - if (isinstance(parameters, list) or isinstance(parameters, tuple)): + if isinstance(parameters, (list, tuple)): parameters = parameters[0] if compiler is None: @@ -1244,8 +1086,6 @@ class ClauseElement(object): compiler = dialect.compiler(self, parameters) elif bind is not None: compiler = bind.compiler(self, parameters) - elif engine is not None: - compiler = engine.compiler(self, parameters) elif self.bind is not None: compiler = self.bind.compiler(self, parameters) @@ -1268,49 +1108,257 @@ class ClauseElement(object): return self._negate() def _negate(self): - return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None) + if hasattr(self, 'negation_clause'): + return self.negation_clause + else: + return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None) + -class _CompareMixin(object): - """Defines comparison operations for ``ClauseElement`` instances. +class Operators(object): + def from_(): + raise NotImplementedError() + from_ = staticmethod(from_) - This is a mixin class that adds the capability to produce ``ClauseElement`` - instances based on regular Python operators. - These operations are achieved using Python's operator overload methods - (i.e. ``__eq__()``, ``__ne__()``, etc. + def as_(): + raise NotImplementedError() + as_ = staticmethod(as_) - Overridden operators include all comparison operators (i.e. '==', '!=', '<'), - math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate - to ``AND`` and ``OR`` respectively. + def exists(): + raise NotImplementedError() + exists = staticmethod(exists) - Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``, - ``DISTINCT``, etc. + def is_(): + raise NotImplementedError() + is_ = staticmethod(is_) - """ + def isnot(): + raise NotImplementedError() + isnot = staticmethod(isnot) + + def __and__(self, other): + return self.operate(operator.and_, other) + + def __or__(self, other): + return self.operate(operator.or_, other) + + def __invert__(self): + return self.operate(operator.inv) + + def clause_element(self): + raise NotImplementedError() + + def operate(self, op, *other, **kwargs): + raise NotImplementedError() + + def reverse_operate(self, op, *other, **kwargs): + raise NotImplementedError() + +class ColumnOperators(Operators): + """defines comparison and math operations""" + + def like_op(a, b): + return a.like(b) + like_op = staticmethod(like_op) + + def notlike_op(a, b): + raise NotImplementedError() + notlike_op = staticmethod(notlike_op) + def ilike_op(a, b): + return a.ilike(b) + ilike_op = staticmethod(ilike_op) + + def notilike_op(a, b): + raise NotImplementedError() + notilike_op = staticmethod(notilike_op) + + def between_op(a, b): + return a.between(b) + between_op = staticmethod(between_op) + + def in_op(a, b): + return a.in_(*b) + in_op = staticmethod(in_op) + + def notin_op(a, b): + raise NotImplementedError() + notin_op = staticmethod(notin_op) + + def startswith_op(a, b): + return a.startswith(b) + startswith_op = staticmethod(startswith_op) + + def endswith_op(a, b): + return a.endswith(b) + endswith_op = staticmethod(endswith_op) + + def comma_op(a, b): + raise NotImplementedError() + comma_op = staticmethod(comma_op) + + def concat_op(a, b): + return a.concat(b) + concat_op = staticmethod(concat_op) + def __lt__(self, other): - return self._compare('<', other) + return self.operate(operator.lt, other) def __le__(self, other): - return self._compare('<=', other) + return self.operate(operator.le, other) def __eq__(self, other): - return self._compare('=', other) + return self.operate(operator.eq, other) def __ne__(self, other): - return self._compare('!=', other) + return self.operate(operator.ne, other) def __gt__(self, other): - return self._compare('>', other) + return self.operate(operator.gt, other) def __ge__(self, other): - return self._compare('>=', other) + return self.operate(operator.ge, other) + def concat(self, other): + return self.operate(ColumnOperators.concat_op, other) + def like(self, other): - """produce a ``LIKE`` clause.""" - return self._compare('LIKE', other) + return self.operate(ColumnOperators.like_op, other) + + def in_(self, *other): + return self.operate(ColumnOperators.in_op, other) + + def startswith(self, other): + return self.operate(ColumnOperators.startswith_op, other) + + def endswith(self, other): + return self.operate(ColumnOperators.endswith_op, other) + + def __radd__(self, other): + return self.reverse_operate(operator.add, other) + + def __rsub__(self, other): + return self.reverse_operate(operator.sub, other) + + def __rmul__(self, other): + return self.reverse_operate(operator.mul, other) + + def __rdiv__(self, other): + return self.reverse_operate(operator.div, other) + + def between(self, cleft, cright): + return self.operate(Operators.between_op, (cleft, cright)) + + def __add__(self, other): + return self.operate(operator.add, other) + + def __sub__(self, other): + return self.operate(operator.sub, other) + + def __mul__(self, other): + return self.operate(operator.mul, other) + + def __div__(self, other): + return self.operate(operator.div, other) + + def __mod__(self, other): + return self.operate(operator.mod, other) + + def __truediv__(self, other): + return self.operate(operator.truediv, other) + +# precedence ordering for common operators. if an operator is not present in this list, +# it will be parenthesized when grouped against other operators +_smallest = object() +_largest = object() + +PRECEDENCE = { + Operators.from_:15, + operator.mul:7, + operator.div:7, + operator.mod:7, + operator.add:6, + operator.sub:6, + ColumnOperators.concat_op:6, + ColumnOperators.ilike_op:5, + ColumnOperators.notilike_op:5, + ColumnOperators.like_op:5, + ColumnOperators.notlike_op:5, + ColumnOperators.in_op:5, + ColumnOperators.notin_op:5, + Operators.is_:5, + Operators.isnot:5, + operator.eq:5, + operator.ne:5, + operator.gt:5, + operator.lt:5, + operator.ge:5, + operator.le:5, + ColumnOperators.between_op:5, + operator.inv:4, + operator.and_:3, + operator.or_:2, + ColumnOperators.comma_op:-1, + Operators.as_:-1, + Operators.exists:0, + _smallest: -1000, + _largest: 1000 +} + +class _CompareMixin(ColumnOperators): + """Defines comparison and math operations for ``ClauseElement`` instances.""" + + def __compare(self, op, obj, negate=None): + if obj is None or isinstance(obj, _Null): + if op == operator.eq: + return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot) + elif op == operator.ne: + return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_) + else: + raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + else: + obj = self._check_literal(obj) + + + return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate) + + def __operate(self, op, obj): + obj = self._check_literal(obj) + + type_ = self._compare_type(obj) + + # TODO: generalize operator overloading like this out into the types module + if op == operator.add and isinstance(type_, (sqltypes.Concatenable)): + op = ColumnOperators.concat_op + + return _BinaryExpression(self.expression_element(), obj, op, type_=type_) + + operators = { + operator.add : (__operate,), + operator.mul : (__operate,), + operator.sub : (__operate,), + operator.div : (__operate,), + operator.mod : (__operate,), + operator.truediv : (__operate,), + operator.lt : (__compare, operator.ge), + operator.le : (__compare, operator.gt), + operator.ne : (__compare, operator.eq), + operator.gt : (__compare, operator.le), + operator.ge : (__compare, operator.lt), + operator.eq : (__compare, operator.ne), + ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op), + } + + def operate(self, op, other): + o = _CompareMixin.operators[op] + return o[0](self, op, other, *o[1:]) + + def reverse_operate(self, op, other): + return self._bind_param(other).operate(op, self) def in_(self, *other): - """produce an ``IN`` clause.""" + return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other) + + def _in_impl(self, op, negate_op, *other): if len(other) == 0: return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1'))) elif len(other) == 1: @@ -1318,8 +1366,8 @@ class _CompareMixin(object): if _is_literal(o) or isinstance( o, _CompareMixin): return self.__eq__( o) #single item -> == else: - assert hasattr( o, '_selectable') #better check? - return self._compare( 'IN', o, negate='NOT IN') #single selectable + assert isinstance(o, Selectable) + return self.__compare( op, o, negate=negate_op) #single selectable args = [] for o in other: @@ -1329,29 +1377,22 @@ class _CompareMixin(object): else: o = self._bind_param(o) args.append(o) - return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN') + return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op) def startswith(self, other): """produce the clause ``LIKE '<other>%'``""" - perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String) - return self._compare('LIKE', other + perc) + + perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String) + return self.__compare(ColumnOperators.like_op, other + perc) def endswith(self, other): """produce the clause ``LIKE '%<other>'``""" + if isinstance(other,(str,unicode)): po = '%' + other else: - po = literal('%', type= sqltypes.String) + other - po.type = sqltypes.to_instance( sqltypes.String) #force! - return self._compare('LIKE', po) - - def __radd__(self, other): - return self._bind_param(other)._operate('+', self) - def __rsub__(self, other): - return self._bind_param(other)._operate('-', self) - def __rmul__(self, other): - return self._bind_param(other)._operate('*', self) - def __rdiv__(self, other): - return self._bind_param(other)._operate('/', self) + po = literal('%', type_=sqltypes.String) + other + po.type = sqltypes.to_instance(sqltypes.String) #force! + return self.__compare(ColumnOperators.like_op, po) def label(self, name): """produce a column label, i.e. ``<columnname> AS <name>``""" @@ -1363,7 +1404,8 @@ class _CompareMixin(object): def between(self, cleft, cright): """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``""" - return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN') + + return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op) def op(self, operator): """produce a generic operator function. @@ -1382,59 +1424,25 @@ class _CompareMixin(object): passed to the generated function. """ - return lambda other: self._operate(operator, other) - - # and here come the math operators: - - def __add__(self, other): - return self._operate('+', other) - - def __sub__(self, other): - return self._operate('-', other) - - def __mul__(self, other): - return self._operate('*', other) - - def __div__(self, other): - return self._operate('/', other) - - def __mod__(self, other): - return self._operate('%', other) - - def __truediv__(self, other): - return self._operate('/', other) + return lambda other: self.__operate(operator, other) def _bind_param(self, obj): - return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True) + return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True) def _check_literal(self, other): - if _is_literal(other): + if isinstance(other, Operators): + return other.expression_element() + elif _is_literal(other): return self._bind_param(other) else: return other + + def clause_element(self): + """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" + return self - def _compare(self, operator, obj, negate=None): - if obj is None or isinstance(obj, _Null): - if operator == '=': - return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT') - elif operator == '!=': - return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS') - else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") - else: - obj = self._check_literal(obj) - - return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate) - - def _operate(self, operator, obj): - if _is_literal(obj): - obj = self._bind_param(obj) - return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj)) - - def _compare_self(self): - """Allow ``ColumnImpl`` to return its ``Column`` object for - usage in ``ClauseElements``, all others to just return self. - """ + def expression_element(self): + """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" return self @@ -1460,23 +1468,10 @@ class Selectable(ClauseElement): columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""") - def _selectable(self): - return self - - def accept_visitor(self, visitor): - raise NotImplementedError(repr(self)) - def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _group_parenthesized(self): - """Indicate if this ``Selectable`` requires parenthesis when - grouped into a compound statement. - """ - return True - - class ColumnElement(Selectable, _CompareMixin): """Represent an element that is useable within the "column clause" portion of a ``SELECT`` statement. @@ -1616,8 +1611,10 @@ class ColumnCollection(util.OrderedProperties): l.append(c==local) return and_(*l) - def __contains__(self, col): - return self.contains_column(col) + def __contains__(self, other): + if not isinstance(other, basestring): + raise exceptions.ArgumentError("__contains__ requires a string argument") + return self.has_key(other) def contains_column(self, col): # have to use a Set here, because it will compare the identity @@ -1649,19 +1646,18 @@ class FromClause(Selectable): clause of a ``SELECT`` statement. """ + __visit_name__ = 'fromclause' + def __init__(self, name=None): self.name = name - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): # this could also be [self], at the moment it doesnt matter to the Select object return [] def default_order_by(self): return [self.oid_column] - def accept_visitor(self, visitor): - visitor.visit_fromclause(self) - def count(self, whereclause=None, **params): if len(self.primary_key): col = list(self.primary_key)[0] @@ -1703,6 +1699,13 @@ class FromClause(Selectable): FindCols().traverse(self) return ret + def is_derived_from(self, fromclause): + """return True if this FromClause is 'derived' from the given FromClause. + + An example would be an Alias of a Table is derived from that Table.""" + + return False + def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): """Given a ``ColumnElement``, return the exported ``ColumnElement`` object from this ``Selectable`` which @@ -1730,9 +1733,10 @@ class FromClause(Selectable): it merely shares a common anscestor with one of the exported columns of this ``FromClause``. """ - if column in self.c: + + if self.c.contains_column(column): return column - + if require_embedded and column not in util.Set(self._get_all_embedded_columns()): if not raiseerr: return None @@ -1761,6 +1765,15 @@ class FromClause(Selectable): self._export_columns() return getattr(self, name) + def _clone_from_clause(self): + # delete all the "generated" collections of columns for a newly cloned FromClause, + # so that they will be re-derived from the item. + # this is because FromClause subclasses, when cloned, need to reestablish new "proxied" + # columns that are linked to the new item + for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'): + if hasattr(self, attr): + delattr(self, attr) + columns = property(lambda s:s._get_exported_attribute('_columns')) c = property(lambda s:s._get_exported_attribute('_columns')) primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) @@ -1791,8 +1804,9 @@ class FromClause(Selectable): self._primary_key = ColumnSet() self._foreign_keys = util.Set() self._orig_cols = {} + if columns is None: - columns = self._adjusted_exportable_columns() + columns = self._flatten_exportable_columns() for co in columns: cp = self._proxy_column(co) for ci in cp.orig_set: @@ -1806,13 +1820,14 @@ class FromClause(Selectable): for ci in self.oid_column.orig_set: self._orig_cols[ci] = self.oid_column - def _adjusted_exportable_columns(self): + def _flatten_exportable_columns(self): """return the list of ColumnElements represented within this FromClause's _exportable_columns""" export = self._exportable_columns() for column in export: - try: - s = column._selectable() - except AttributeError: + # TODO: is this conditional needed ? + if isinstance(column, Selectable): + s = column + else: continue for co in s.columns: yield co @@ -1829,7 +1844,9 @@ class _BindParamClause(ClauseElement, _CompareMixin): Public constructor is the ``bindparam()`` function. """ - def __init__(self, key, value, shortname=None, type=None, unique=False): + __visit_name__ = 'bindparam' + + def __init__(self, key, value, shortname=None, type_=None, unique=False, isoutparam=False): """Construct a _BindParamClause. key @@ -1852,7 +1869,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): execution may match either the key or the shortname of the corresponding ``_BindParamClause`` objects. - type + type\_ A ``TypeEngine`` object that will be used to pre-process the value corresponding to this ``_BindParamClause`` at execution time. @@ -1862,23 +1879,34 @@ class _BindParamClause(ClauseElement, _CompareMixin): modified if another ``_BindParamClause`` of the same name already has been located within the containing ``ClauseElement``. + + isoutparam + if True, the parameter should be treated like a stored procedure "OUT" + parameter. """ - self.key = key + self.key = key or "{ANON %d param}" % id(self) self.value = value self.shortname = shortname or key self.unique = unique - self.type = sqltypes.to_instance(type) - - def accept_visitor(self, visitor): - visitor.visit_bindparam(self) - - def _get_from_objects(self): + self.isoutparam = isoutparam + type_ = sqltypes.to_instance(type_) + if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map: + self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)]) + else: + self.type = type_ + + # TODO: move to types module, obviously + type_map = { + str : sqltypes.String, + unicode : sqltypes.Unicode, + int : sqltypes.Integer, + float : sqltypes.Numeric + } + + def _get_from_objects(self, **modifiers): return [] - def copy_container(self): - return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique) - def typeprocess(self, value, dialect): return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) @@ -1893,7 +1921,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ def __repr__(self): - return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type)) + return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type)) class _TypeClause(ClauseElement): """Handle a type keyword in a SQL statement. @@ -1901,13 +1929,12 @@ class _TypeClause(ClauseElement): Used by the ``Case`` statement. """ + __visit_name__ = 'typeclause' + def __init__(self, type): self.type = type - def accept_visitor(self, visitor): - visitor.visit_typeclause(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] class _TextClause(ClauseElement): @@ -1916,8 +1943,10 @@ class _TextClause(ClauseElement): Public constructor is the ``text()`` function. """ - def __init__(self, text = "", bind=None, engine=None, bindparams=None, typemap=None): - self._bind = bind or engine + __visit_name__ = 'textclause' + + def __init__(self, text = "", bind=None, bindparams=None, typemap=None): + self._bind = bind self.bindparams = {} self.typemap = typemap if typemap is not None: @@ -1930,7 +1959,7 @@ class _TextClause(ClauseElement): # scan the string and search for bind parameter names, add them # to the list of bindparams - self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text) + self.text = BIND_PARAMS.sub(repl, text) if bindparams is not None: for b in bindparams: self.bindparams[b.key] = b @@ -1944,13 +1973,13 @@ class _TextClause(ClauseElement): columns = property(lambda s:[]) + def _copy_internals(self): + self.bindparams = [b._clone() for b in self.bindparams] + def get_children(self, **kwargs): return self.bindparams.values() - def accept_visitor(self, visitor): - visitor.visit_textclause(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] def supports_execution(self): @@ -1965,10 +1994,7 @@ class _Null(ColumnElement): def __init__(self): self.type = sqltypes.NULLTYPE - def accept_visitor(self, visitor): - visitor.visit_null(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] class ClauseList(ClauseElement): @@ -1976,14 +2002,16 @@ class ClauseList(ClauseElement): By default, is comma-separated, such as a column listing. """ - + __visit_name__ = 'clauselist' + def __init__(self, *clauses, **kwargs): self.clauses = [] - self.operator = kwargs.pop('operator', ',') + self.operator = kwargs.pop('operator', ColumnOperators.comma_op) self.group = kwargs.pop('group', True) self.group_contents = kwargs.pop('group_contents', True) for c in clauses: - if c is None: continue + if c is None: + continue self.append(c) def __iter__(self): @@ -1991,32 +2019,28 @@ class ClauseList(ClauseElement): def __len__(self): return len(self.clauses) - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return ClauseList(operator=self.operator, *clauses) - def append(self, clause): # TODO: not sure if i like the 'group_contents' flag. need to define the difference between # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ? if self.group_contents: - self.clauses.append(_literals_as_text(clause).self_group(against=self.operator)) + self.clauses.append(_literal_as_text(clause).self_group(against=self.operator)) else: - self.clauses.append(_literals_as_text(clause)) + self.clauses.append(_literal_as_text(clause)) + + def _copy_internals(self): + self.clauses = [clause._clone() for clause in self.clauses] def get_children(self, **kwargs): return self.clauses - def accept_visitor(self, visitor): - visitor.visit_clauselist(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): f = [] for c in self.clauses: - f += c._get_from_objects() + f += c._get_from_objects(**modifiers) return f def self_group(self, against=None): - if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): + if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): return _Grouping(self) else: return self @@ -2043,40 +2067,45 @@ class _CalculatedClause(ColumnElement): Extends ``ColumnElement`` to provide column-level comparison operators. """ - + __visit_name__ = 'calculatedclause' + def __init__(self, name, *clauses, **kwargs): self.name = name - self.type = sqltypes.to_instance(kwargs.get('type', None)) - self._bind = kwargs.get('bind', kwargs.get('engine', None)) + self.type = sqltypes.to_instance(kwargs.get('type_', None)) + self._bind = kwargs.get('bind', None) self.group = kwargs.pop('group', True) - self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) + clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) if self.group: - self.clause_expr = self.clauses.self_group() + self.clause_expr = clauses.self_group() else: - self.clause_expr = self.clauses + self.clause_expr = clauses key = property(lambda self:self.name or "_calc_") - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _CalculatedClause(type=self.type, bind=self._bind, *clauses) - + def _copy_internals(self): + self.clause_expr = self.clause_expr._clone() + + def clauses(self): + if isinstance(self.clause_expr, _Grouping): + return self.clause_expr.elem + else: + return self.clause_expr + clauses = property(clauses) + def get_children(self, **kwargs): return self.clause_expr, - def accept_visitor(self, visitor): - visitor.visit_calculatedclause(self) - def _get_from_objects(self): - return self.clauses._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.clauses._get_from_objects(**modifiers) def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type=self.type, unique=True) + return _BindParamClause(self.name, obj, type_=self.type, unique=True) def select(self): return select([self]) def scalar(self): - return select([self]).scalar() + return select([self]).execute().scalar() def execute(self): return select([self]).execute() @@ -2092,28 +2121,26 @@ class _Function(_CalculatedClause, FromClause): """ def __init__(self, name, *clauses, **kwargs): - self.type = sqltypes.to_instance(kwargs.get('type', None)) self.packagenames = kwargs.get('packagenames', None) or [] - kwargs['operator'] = ',' - self._engine = kwargs.get('engine', None) + kwargs['operator'] = ColumnOperators.comma_op _CalculatedClause.__init__(self, name, **kwargs) for c in clauses: self.append(c) key = property(lambda self:self.name) + def _copy_internals(self): + _CalculatedClause._copy_internals(self) + self._clone_from_clause() - def append(self, clause): - self.clauses.append(_literals_as_binds(clause, self.name)) - - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _Function(self.name, type=self.type, packagenames=self.packagenames, bind=self._bind, *clauses) + def get_children(self, **kwargs): + return _CalculatedClause.get_children(self, **kwargs) - def accept_visitor(self, visitor): - visitor.visit_function(self) + def append(self, clause): + self.clauses.append(_literal_as_binds(clause, self.name)) class _Cast(ColumnElement): + def __init__(self, clause, totype, **kwargs): if not hasattr(clause, 'label'): clause = literal(clause) @@ -2122,17 +2149,19 @@ class _Cast(ColumnElement): self.typeclause = _TypeClause(self.type) self._distance = 0 + def _copy_internals(self): + self.clause = self.clause._clone() + self.typeclause = self.typeclause._clone() + def get_children(self, **kwargs): return self.clause, self.typeclause - def accept_visitor(self, visitor): - visitor.visit_cast(self) - def _get_from_objects(self): - return self.clause._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.clause._get_from_objects(**modifiers) def _make_proxy(self, selectable, name=None): if name is not None: - co = _ColumnClause(name, selectable, type=self.type) + co = _ColumnClause(name, selectable, type_=self.type) co._distance = self._distance + 1 co.orig_set = self.orig_set selectable.columns[name]= co @@ -2142,26 +2171,23 @@ class _Cast(ColumnElement): class _UnaryExpression(ColumnElement): - def __init__(self, element, operator=None, modifier=None, type=None, negate=None): + def __init__(self, element, operator=None, modifier=None, type_=None, negate=None): self.operator = operator self.modifier = modifier - self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier) - self.type = sqltypes.to_instance(type) + self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier) + self.type = sqltypes.to_instance(type_) self.negate = negate - def copy_container(self): - return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate) + def _get_from_objects(self, **modifiers): + return self.element._get_from_objects(**modifiers) - def _get_from_objects(self): - return self.element._get_from_objects() + def _copy_internals(self): + self.element = self.element._clone() def get_children(self, **kwargs): return self.element, - def accept_visitor(self, visitor): - visitor.visit_unary(self) - def compare(self, other): """Compare this ``_UnaryExpression`` against the given ``ClauseElement``.""" @@ -2170,14 +2196,15 @@ class _UnaryExpression(ColumnElement): self.modifier == other.modifier and self.element.compare(other.element) ) + def _negate(self): if self.negate is not None: - return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type) + return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type) else: return super(_UnaryExpression, self)._negate() def self_group(self, against): - if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): + if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): return _Grouping(self) else: return self @@ -2186,25 +2213,23 @@ class _UnaryExpression(ColumnElement): class _BinaryExpression(ColumnElement): """Represent an expression that is ``LEFT <operator> RIGHT``.""" - def __init__(self, left, right, operator, type=None, negate=None): - self.left = _literals_as_text(left).self_group(against=operator) - self.right = _literals_as_text(right).self_group(against=operator) + def __init__(self, left, right, operator, type_=None, negate=None): + self.left = _literal_as_text(left).self_group(against=operator) + self.right = _literal_as_text(right).self_group(against=operator) self.operator = operator - self.type = sqltypes.to_instance(type) + self.type = sqltypes.to_instance(type_) self.negate = negate - def copy_container(self): - return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) + def _get_from_objects(self, **modifiers): + return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - def _get_from_objects(self): - return self.left._get_from_objects() + self.right._get_from_objects() + def _copy_internals(self): + self.left = self.left._clone() + self.right = self.right._clone() def get_children(self, **kwargs): return self.left, self.right - def accept_visitor(self, visitor): - visitor.visit_binary(self) - def compare(self, other): """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" @@ -2213,7 +2238,7 @@ class _BinaryExpression(ColumnElement): ( self.left.compare(other.left) and self.right.compare(other.right) or ( - self.operator in ['=', '!=', '+', '*'] and + self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and self.left.compare(other.right) and self.right.compare(other.left) ) ) @@ -2221,25 +2246,27 @@ class _BinaryExpression(ColumnElement): def self_group(self, against=None): # use small/large defaults for comparison so that unknown operators are always parenthesized - if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])): + if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])): return _Grouping(self) else: return self def _negate(self): if self.negate is not None: - return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) + return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type) else: return super(_BinaryExpression, self)._negate() class _Exists(_UnaryExpression): + __visit_name__ = _UnaryExpression.__visit_name__ + def __init__(self, *args, **kwargs): kwargs['correlate'] = True s = select(*args, **kwargs).self_group() - _UnaryExpression.__init__(self, s, operator="EXISTS") + _UnaryExpression.__init__(self, s, operator=Operators.exists) - def _hide_froms(self): - return self._get_from_objects() + def _hide_froms(self, **modifiers): + return self._get_from_objects(**modifiers) class Join(FromClause): """represent a ``JOIN`` construct between two ``FromClause`` @@ -2251,8 +2278,8 @@ class Join(FromClause): """ def __init__(self, left, right, onclause=None, isouter = False): - self.left = left._selectable() - self.right = right._selectable() + self.left = _selectable(left) + self.right = _selectable(right).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) else: @@ -2265,8 +2292,8 @@ class Join(FromClause): encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace')) def _init_primary_key(self): - pkcol = util.Set([c for c in self._adjusted_exportable_columns() if c.primary_key]) - + pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key]) + equivs = {} def add_equiv(a, b): for x, y in ((a, b), (b, a)): @@ -2277,7 +2304,7 @@ class Join(FromClause): class BinaryVisitor(ClauseVisitor): def visit_binary(self, binary): - if binary.operator == '=': + if binary.operator == operator.eq: add_equiv(binary.left, binary.right) BinaryVisitor().traverse(self.onclause) @@ -2294,9 +2321,12 @@ class Join(FromClause): omit.add(p) p = c - self.__primary_key = ColumnSet([c for c in self._adjusted_exportable_columns() if c.primary_key and c not in omit]) + self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit]) primary_key = property(lambda s:s.__primary_key) + + def self_group(self, against=None): + return _Grouping(self) def _locate_oid_column(self): return self.left.oid_column @@ -2310,6 +2340,17 @@ class Join(FromClause): self._foreign_keys.add(f) return column + def _copy_internals(self): + self._clone_from_clause() + self.left = self.left._clone() + self.right = self.right._clone() + self.onclause = self.onclause._clone() + self.__folded_equivalents = None + self._init_primary_key() + + def get_children(self, **kwargs): + return self.left, self.right, self.onclause + def _match_primaries(self, primary, secondary): crit = [] constraints = util.Set() @@ -2338,9 +2379,6 @@ class Join(FromClause): else: return and_(*crit) - def _group_parenthesized(self): - return True - def _get_folded_equivalents(self, equivs=None): if self.__folded_equivalents is not None: return self.__folded_equivalents @@ -2348,7 +2386,7 @@ class Join(FromClause): equivs = util.Set() class LocateEquivs(NoColumnVisitor): def visit_binary(self, binary): - if binary.operator == '=' and binary.left.name == binary.right.name: + if binary.operator == operator.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) LocateEquivs().traverse(self.onclause) @@ -2401,13 +2439,7 @@ class Join(FromClause): return select(collist, whereclause, from_obj=[self], **kwargs) - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def accept_visitor(self, visitor): - visitor.visit_join(self) - - engine = property(lambda s:s.left.engine or s.right.engine) + bind = property(lambda s:s.left.bind or s.right.bind) def alias(self, name=None): """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it. @@ -2417,11 +2449,11 @@ class Join(FromClause): return self.select(use_labels=True, correlate=False).alias(name) - def _hide_froms(self): - return self.left._get_from_objects() + self.right._get_from_objects() + def _hide_froms(self, **modifiers): + return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - def _get_from_objects(self): - return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects() + def _get_from_objects(self, **modifiers): + return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) class Alias(FromClause): """represent an alias, as typically applied to any @@ -2443,15 +2475,22 @@ class Alias(FromClause): if alias is None: if self.original.named_with_column(): alias = getattr(self.original, 'name', None) - if alias is None: - alias = 'anon' - elif len(alias) > 15: - alias = alias[0:15] - alias = alias + "_" + hex(random.randint(0, 65535))[2:] + alias = '{ANON %d %s}' % (id(self), alias or 'anon') self.name = alias self.encodedname = alias.encode('ascii', 'backslashreplace') self.case_sensitive = getattr(baseselectable, "case_sensitive", True) + def is_derived_from(self, fromclause): + x = self.selectable + while True: + if x is fromclause: + return True + if isinstance(x, Alias): + x = x.selectable + else: + break + return False + def supports_execution(self): return self.original.supports_execution() @@ -2468,46 +2507,57 @@ class Alias(FromClause): #return self.selectable._exportable_columns() return self.selectable.columns + def _copy_internals(self): + self._clone_from_clause() + self.selectable = self.selectable._clone() + baseselectable = self.selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable + def get_children(self, **kwargs): for c in self.c: yield c yield self.selectable - def accept_visitor(self, visitor): - visitor.visit_alias(self) - def _get_from_objects(self): return [self] - def _group_parenthesized(self): - return False - bind = property(lambda s: s.selectable.bind) - engine = bind -class _Grouping(ColumnElement): +class _ColumnElementAdapter(ColumnElement): + """adapts a ClauseElement which may or may not be a + ColumnElement subclass itself into an object which + acts like a ColumnElement. + """ + def __init__(self, elem): self.elem = elem self.type = getattr(elem, 'type', None) - + self.orig_set = getattr(elem, 'orig_set', util.Set()) + key = property(lambda s: s.elem.key) _label = property(lambda s: s.elem._label) - orig_set = property(lambda s:s.elem.orig_set) - - def copy_container(self): - return _Grouping(self.elem.copy_container()) - - def accept_visitor(self, visitor): - visitor.visit_grouping(self) + columns = c = property(lambda s:s.elem.columns) + + def _copy_internals(self): + self.elem = self.elem._clone() + def get_children(self, **kwargs): return self.elem, - def _hide_froms(self): - return self.elem._hide_froms() - def _get_from_objects(self): - return self.elem._get_from_objects() + + def _hide_froms(self, **modifiers): + return self.elem._hide_froms(**modifiers) + + def _get_from_objects(self, **modifiers): + return self.elem._get_from_objects(**modifiers) + def __getattr__(self, attr): return getattr(self.elem, attr) - + +class _Grouping(_ColumnElementAdapter): + pass + class _Label(ColumnElement): """represent a label, as typically applied to any column-level element using the ``AS`` sql keyword. @@ -2518,32 +2568,33 @@ class _Label(ColumnElement): """ - def __init__(self, name, obj, type=None): - self.name = name + def __init__(self, name, obj, type_=None): while isinstance(obj, _Label): obj = obj.obj - self.obj = obj.self_group(against='AS') + self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) + + self.obj = obj.self_group(against=Operators.as_) self.case_sensitive = getattr(obj, "case_sensitive", True) - self.type = sqltypes.to_instance(type or getattr(obj, 'type', None)) + self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) key = property(lambda s: s.name) _label = property(lambda s: s.name) orig_set = property(lambda s:s.obj.orig_set) - def _compare_self(self): + def expression_element(self): return self.obj - + + def _copy_internals(self): + self.obj = self.obj._clone() + def get_children(self, **kwargs): return self.obj, - def accept_visitor(self, visitor): - visitor.visit_label(self) + def _get_from_objects(self, **modifiers): + return self.obj._get_from_objects(**modifiers) - def _get_from_objects(self): - return self.obj._get_from_objects() - - def _hide_froms(self): - return self.obj._hide_froms() + def _hide_froms(self, **modifiers): + return self.obj._hide_froms(**modifiers) def _make_proxy(self, selectable, name = None): if isinstance(self.obj, Selectable): @@ -2551,8 +2602,6 @@ class _Label(ColumnElement): else: return column(self.name)._make_proxy(selectable=selectable) -legal_characters = util.Set(string.ascii_letters + string.digits + '_') - class _ColumnClause(ColumnElement): """Represents a generic column expression from any textual string. This includes columns associated with tables, aliases and select @@ -2584,17 +2633,21 @@ class _ColumnClause(ColumnElement): """ - def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False): + def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False): self.key = self.name = text self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name self.table = selectable - self.type = sqltypes.to_instance(type) + self.type = sqltypes.to_instance(type_) self._is_oid = _is_oid self._distance = 0 self.__label = None self.case_sensitive = case_sensitive self.is_literal = is_literal - + + def _clone(self): + # ColumnClause is immutable + return self + def _get_label(self): """Generate a 'label' for this column. @@ -2617,7 +2670,6 @@ class _ColumnClause(ColumnElement): counter += 1 else: self.__label = self.name - self.__label = "".join([x for x in self.__label if x in legal_characters]) return self.__label is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name) @@ -2632,23 +2684,20 @@ class _ColumnClause(ColumnElement): else: return super(_ColumnClause, self).label(name) - def accept_visitor(self, visitor): - visitor.visit_column(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): if self.table is not None: return [self.table] else: return [] def _bind_param(self, obj): - return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True) + return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True) def _make_proxy(self, selectable, name = None): # propigate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) - c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal) + c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) c.orig_set = self.orig_set c._distance = self._distance + 1 if not self._is_oid: @@ -2658,9 +2707,6 @@ class _ColumnClause(ColumnElement): def _compare_type(self, obj): return self.type - def _group_parenthesized(self): - return False - class TableClause(FromClause): """represents a "table" construct. @@ -2677,6 +2723,10 @@ class TableClause(FromClause): self._oid_column = _ColumnClause('oid', self, _is_oid=True) self._export_columns(columns) + def _clone(self): + # TableClause is immutable + return self + def named_with_column(self): return True @@ -2709,15 +2759,9 @@ class TableClause(FromClause): else: return [] - def accept_visitor(self, visitor): - visitor.visit_table(self) - def _exportable_columns(self): raise NotImplementedError() - def _group_parenthesized(self): - return False - def count(self, whereclause=None, **params): if len(self.primary_key): col = list(self.primary_key)[0] @@ -2746,68 +2790,120 @@ class TableClause(FromClause): def delete(self, whereclause = None): return delete(self, whereclause) - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [self] + class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None): + self.use_labels = use_labels + self.for_update = for_update + self._limit = limit + self._offset = offset + self._bind = bind + + self.append_order_by(*util.to_list(order_by, [])) + self.append_group_by(*util.to_list(group_by, [])) + + def as_scalar(self): + return _ScalarSelect(self) + + def label(self, name): + return self.as_scalar().label(name) + def supports_execution(self): return True + def _generate(self): + s = self._clone() + s._clone_from_clause() + return s + + def limit(self, limit): + s = self._generate() + s._limit = limit + return s + + def offset(self, offset): + s = self._generate() + s._offset = offset + return s + def order_by(self, *clauses): - if len(clauses) == 1 and clauses[0] is None: - self.order_by_clause = ClauseList() - elif getattr(self, 'order_by_clause', None): - self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses))) - else: - self.order_by_clause = ClauseList(*clauses) + s = self._generate() + s.append_order_by(*clauses) + return s def group_by(self, *clauses): - if len(clauses) == 1 and clauses[0] is None: - self.group_by_clause = ClauseList() - elif getattr(self, 'group_by_clause', None): - self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses))) + s = self._generate() + s.append_group_by(*clauses) + return s + + def append_order_by(self, *clauses): + if clauses == [None]: + self._order_by_clause = ClauseList() else: - self.group_by_clause = ClauseList(*clauses) + if getattr(self, '_order_by_clause', None): + clauses = list(self._order_by_clause) + list(clauses) + self._order_by_clause = ClauseList(*clauses) + def append_group_by(self, *clauses): + if clauses == [None]: + self._group_by_clause = ClauseList() + else: + if getattr(self, '_group_by_clause', None): + clauses = list(self._group_by_clause) + list(clauses) + self._group_by_clause = ClauseList(*clauses) + def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _get_from_objects(self): - if self.is_where or self.is_scalar: + def _get_from_objects(self, is_where=False, **modifiers): + if is_where: return [] else: return [self] +class _ScalarSelect(_Grouping): + __visit_name__ = 'grouping' + + def __init__(self, elem): + super(_ScalarSelect, self).__init__(elem) + self.type = list(elem.inner_columns)[0].type + + columns = property(lambda self:[self]) + + def self_group(self, **kwargs): + return self + + def _make_proxy(self, selectable, name): + return list(self.inner_columns)[0]._make_proxy(selectable, name) + + def _get_from_objects(self, **modifiers): + return [] + class CompoundSelect(_SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): - _SelectBaseMixin.__init__(self) + self._should_correlate = kwargs.pop('correlate', False) self.keyword = keyword - self.use_labels = kwargs.pop('use_labels', False) - self.should_correlate = kwargs.pop('correlate', False) - self.for_update = kwargs.pop('for_update', False) - self.nowait = kwargs.pop('nowait', False) - self.limit = kwargs.pop('limit', None) - self.offset = kwargs.pop('offset', None) - self.is_compound = True - self.is_where = False - self.is_scalar = False - self.is_subquery = False - - # unions group from left to right, so don't group first select - self.selects = [n and select.self_group(self) or select for n,select in enumerate(selects)] + self.selects = [] # some DBs do not like ORDER BY in the inner queries of a UNION, etc. - for s in selects: - s.order_by(None) + for n, s in enumerate(selects): + if len(s._order_by_clause): + s = s.order_by(None) + # unions group from left to right, so don't group first select + if n: + self.selects.append(s.self_group(self)) + else: + self.selects.append(s) - self.group_by(*kwargs.pop('group_by', [None])) - self.order_by(*kwargs.pop('order_by', [None])) - if len(kwargs): - raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys())) self._col_map = {} + _SelectBaseMixin.__init__(self, **kwargs) + name = property(lambda s:s.keyword + " statement") def self_group(self, against=None): @@ -2835,12 +2931,18 @@ class CompoundSelect(_SelectBaseMixin, FromClause): col.orig_set = colset return col + def _copy_internals(self): + self._clone_from_clause() + self._col_map = {} + self.selects = [s._clone() for s in self.selects] + for attr in ('_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, getattr(self, attr)._clone()) + def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.c) or []) + \ - [self.order_by_clause, self.group_by_clause] + list(self.selects) - def accept_visitor(self, visitor): - visitor.visit_compound_select(self) - + [self._order_by_clause, self._group_by_clause] + list(self.selects) + def _find_engine(self): for s in self.selects: e = s._find_engine() @@ -2855,160 +2957,287 @@ class Select(_SelectBaseMixin, FromClause): """ - def __init__(self, columns=None, whereclause=None, from_obj=[], - order_by=None, group_by=None, having=None, - use_labels=False, distinct=False, for_update=False, - engine=None, bind=None, limit=None, offset=None, scalar=False, - correlate=True): + def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs): """construct a Select object. The public constructor for Select is the [sqlalchemy.sql#select()] function; see that function for argument descriptions. """ - _SelectBaseMixin.__init__(self) - self.__froms = util.OrderedSet() - self.__hide_froms = util.Set([self]) - self.use_labels = use_labels - self.whereclause = None - self.having = None - self._bind = bind or engine - self.limit = limit - self.offset = offset - self.for_update = for_update - self.is_compound = False - - # indicates that this select statement should not expand its columns - # into the column clause of an enclosing select, and should instead - # act like a single scalar column - self.is_scalar = scalar - if scalar: - # allow corresponding_column to return None - self.orig_set = util.Set() - - # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select, update, or delete statement. - # note that the "correlate" method can be used to explicitly add a value to be correlated. - self.should_correlate = correlate - - # indicates if this select statement is a subquery inside another query - self.is_subquery = False - - # indicates if this select statement is in the from clause of another query - self.is_selected_from = False - - # indicates if this select statement is a subquery as a criterion - # inside of a WHERE clause - self.is_where = False + + self._should_correlate = correlate + self._distinct = distinct - self.distinct = distinct self._raw_columns = [] - self.__correlated = {} - self.__correlator = Select._CorrelatedVisitor(self, False) - self.__wherecorrelator = Select._CorrelatedVisitor(self, True) - self.__fromvisitor = Select._FromVisitor(self) - - - self.order_by_clause = self.group_by_clause = None + self.__correlate = util.Set() + self._froms = util.OrderedSet() + self._whereclause = None + self._having = None + self._prefixes = [] if columns is not None: for c in columns: self.append_column(c) - if order_by: - order_by = util.to_list(order_by) - if group_by: - group_by = util.to_list(group_by) - self.order_by(*(order_by or [None])) - self.group_by(*(group_by or [None])) - for c in self.order_by_clause: - self.__correlator.traverse(c) - for c in self.group_by_clause: - self.__correlator.traverse(c) - - for f in from_obj: - self.append_from(f) - - # whereclauses must be appended after the columns/FROM, since it affects - # the correlation of subqueries. see test/sql/select.py SelectTest.testwheresubquery + if from_obj is not None: + for f in from_obj: + self.append_from(f) + if whereclause is not None: self.append_whereclause(whereclause) + if having is not None: self.append_having(having) + _SelectBaseMixin.__init__(self, **kwargs) - class _CorrelatedVisitor(NoColumnVisitor): - """Visit a clause, locate any ``Select`` clauses, and tell - them that they should correlate their ``FROM`` list to that of - their parent. + def _get_display_froms(self, correlation_state=None): + """return the full list of 'from' clauses to be displayed. + + takes into account an optional 'correlation_state' + dictionary which contains information about this Select's + correlation to an enclosing select, which may cause some 'from' + clauses to not display in this Select's FROM clause. + this dictionary is generated during compile time by the + _calculate_correlations() method. + """ + froms = util.OrderedSet() + hide_froms = util.Set() + + for col in self._raw_columns: + for f in col._hide_froms(): + hide_froms.add(f) + for f in col._get_from_objects(): + froms.add(f) + + if self._whereclause is not None: + for f in self._whereclause._get_from_objects(is_where=True): + froms.add(f) + + for elem in self._froms: + froms.add(elem) + for f in elem._get_from_objects(): + froms.add(f) + + for elem in froms: + for f in elem._hide_froms(): + hide_froms.add(f) + + froms = froms.difference(hide_froms) + + if len(froms) > 1: + corr = self.__correlate + if correlation_state is not None: + corr = correlation_state[self].get('correlate', util.Set()).union(corr) + f = froms.difference(corr) + if len(f) == 0: + raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) + return f + else: + return froms + + froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") + + def locate_all_froms(self): + froms = util.Set() + for col in self._raw_columns: + for f in col._get_from_objects(): + froms.add(f) + + if self._whereclause is not None: + for f in self._whereclause._get_from_objects(is_where=True): + froms.add(f) + + for elem in self._froms: + froms.add(elem) + for f in elem._get_from_objects(): + froms.add(f) + return froms + + def _calculate_correlations(self, correlation_state): + """generate a 'correlation_state' dictionary used by the _get_display_froms() method. + + The dictionary is passed in initially empty, or already + containing the state information added by an enclosing + Select construct. The method will traverse through all + embedded Select statements and add information about their + position and "from" objects to the dictionary. Those Select + statements will later consult the 'correlation_state' dictionary + when their list of 'FROM' clauses are generated using their + _get_display_froms() method. + """ + + if self not in correlation_state: + correlation_state[self] = {} - def __init__(self, select, is_where): - NoColumnVisitor.__init__(self) - self.select = select - self.is_where = is_where - - def visit_compound_select(self, cs): - self.visit_select(cs) - - def visit_column(self, c): - pass - - def visit_table(self, c): - pass - - def visit_select(self, select): - if select is self.select: - return - select.is_where = self.is_where - select.is_subquery = True - if not select.should_correlate: - return - [select.correlate(x) for x in self.select._Select__froms] + display_froms = self._get_display_froms(correlation_state) + + class CorrelatedVisitor(NoColumnVisitor): + def __init__(self, is_where=False, is_column=False, is_from=False): + self.is_where = is_where + self.is_column = is_column + self.is_from = is_from + + def visit_compound_select(self, cs): + self.visit_select(cs) - class _FromVisitor(NoColumnVisitor): - def __init__(self, select): - NoColumnVisitor.__init__(self) - self.select = select + def visit_select(s, select): + if select not in correlation_state: + correlation_state[select] = {} + + if select is self: + return + + select_state = correlation_state[select] + if s.is_from: + select_state['is_selected_from'] = True + if s.is_where: + select_state['is_where'] = True + select_state['is_subquery'] = True + + if select._should_correlate: + corr = select_state.setdefault('correlate', util.Set()) + # not crazy about this part. need to be clearer on what elements in the + # subquery correspond to elements in the enclosing query. + for f in display_froms: + corr.add(f) + for f2 in f._get_from_objects(): + corr.add(f2) + + col_vis = CorrelatedVisitor(is_column=True) + where_vis = CorrelatedVisitor(is_where=True) + from_vis = CorrelatedVisitor(is_from=True) + + for col in self._raw_columns: + col_vis.traverse(col) + for f in col._get_from_objects(): + if f is not self: + from_vis.traverse(f) + + for col in list(self._order_by_clause) + list(self._group_by_clause): + col_vis.traverse(col) - def visit_select(self, select): - if select is self.select: - return - select.is_selected_from = True - select.is_subquery = True + if self._whereclause is not None: + where_vis.traverse(self._whereclause) + for f in self._whereclause._get_from_objects(is_where=True): + if f is not self: + from_vis.traverse(f) + + for elem in self._froms: + from_vis.traverse(elem) + + def _get_inner_columns(self): + for c in self._raw_columns: + if isinstance(c, Selectable): + for co in c.columns: + yield co + else: + yield c + + inner_columns = property(_get_inner_columns) + + def _copy_internals(self): + self._clone_from_clause() + self._raw_columns = [c._clone() for c in self._raw_columns] + self._recorrelate_froms([(f, f._clone()) for f in self._froms]) + for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, getattr(self, attr)._clone()) + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.columns) or []) + \ + list(self._froms) + \ + [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] + + def _recorrelate_froms(self, froms): + newcorrelate = util.Set() + newfroms = util.Set() + oldfroms = util.Set(self._froms) + for old, new in froms: + if old in self.__correlate: + newcorrelate.add(new) + self.__correlate.remove(old) + if old in oldfroms: + newfroms.add(new) + oldfroms.remove(old) + self.__correlate = self.__correlate.union(newcorrelate) + self._froms = [f for f in oldfroms.union(newfroms)] + + def column(self, column): + s = self._generate() + s.append_column(column) + return s + + def where(self, whereclause): + s = self._generate() + s.append_whereclause(whereclause) + return s + + def having(self, having): + s = self._generate() + s.append_having(having) + return s + + def distinct(self): + s = self._generate() + s.distinct = True + return s + + def prefix_with(self, clause): + s = self._generate() + s.append_prefix(clause) + return s + + def select_from(self, fromclause): + s = self._generate() + s.append_from(fromclause) + return s + + def __dont_correlate(self): + s = self._generate() + s._should_correlate = False + return s + + def correlate(self, fromclause): + s = self._generate() + s._should_correlate=False + if fromclause is None: + s.__correlate = util.Set() + else: + s.append_correlation(fromclause) + return s + + def append_correlation(self, fromclause): + self.__correlate.add(fromclause) + def append_column(self, column): - if _is_literal(column): - column = literal_column(str(column)) + column = _literal_as_column(column) - if isinstance(column, Select) and column.is_scalar: - column = column.self_group(against=',') + if isinstance(column, _ScalarSelect): + column = column.self_group(against=ColumnOperators.comma_op) self._raw_columns.append(column) - - if self.is_scalar and not hasattr(self, 'type'): - self.type = column.type + + def append_prefix(self, clause): + clause = _literal_as_text(clause) + self._prefixes.append(clause) - # if the column is a Select statement itself, - # accept visitor - self.__correlator.traverse(column) - - # visit the FROM objects of the column looking for more Selects - for f in column._get_from_objects(): - if f is not self: - self.__correlator.traverse(f) - self._process_froms(column, False) - - def _make_proxy(self, selectable, name): - if self.is_scalar: - return self._raw_columns[0]._make_proxy(selectable, name) + def append_whereclause(self, whereclause): + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) else: - raise exceptions.InvalidRequestError("Not a scalar select statement") - - def label(self, name): - if not self.is_scalar: - raise exceptions.InvalidRequestError("Not a scalar select statement") + self._whereclause = _literal_as_text(whereclause) + + def append_having(self, having): + if self._having is not None: + self._having = and_(self._having, _literal_as_text(having)) else: - return label(name, self) + self._having = _literal_as_text(having) + + def append_from(self, fromclause): + if _is_literal(fromclause): + fromclause = FromClause(fromclause) + self._froms.add(fromclause) def _exportable_columns(self): return [c for c in self._raw_columns if isinstance(c, Selectable)] @@ -3019,53 +3248,13 @@ class Select(_SelectBaseMixin, FromClause): else: return column._make_proxy(self) - def _process_froms(self, elem, asfrom): - for f in elem._get_from_objects(): - self.__fromvisitor.traverse(f) - self.__froms.add(f) - if asfrom: - self.__froms.add(elem) - for f in elem._hide_froms(): - self.__hide_froms.add(f) - def self_group(self, against=None): if isinstance(against, CompoundSelect): return self return _Grouping(self) - - def append_whereclause(self, whereclause): - self._append_condition('whereclause', whereclause) - - def append_having(self, having): - self._append_condition('having', having) - - def _append_condition(self, attribute, condition): - if isinstance(condition, basestring): - condition = _TextClause(condition) - self.__wherecorrelator.traverse(condition) - self._process_froms(condition, False) - if getattr(self, attribute) is not None: - setattr(self, attribute, and_(getattr(self, attribute), condition)) - else: - setattr(self, attribute, condition) - - def correlate(self, from_obj): - """Given a ``FROM`` object, correlate this ``SELECT`` statement to it. - - This basically means the given from object will not come out - in this select statement's ``FROM`` clause when printed. - """ - - self.__correlated[from_obj] = from_obj - - def append_from(self, fromclause): - if isinstance(fromclause, basestring): - fromclause = FromClause(fromclause) - self.__correlator.traverse(fromclause) - self._process_froms(fromclause, True) def _locate_oid_column(self): - for f in self.__froms: + for f in self.locate_all_froms(): if f is self: # we might be in our own _froms list if a column with us as the parent is attached, # which includes textual columns. @@ -3076,25 +3265,6 @@ class Select(_SelectBaseMixin, FromClause): else: return None - def _calc_froms(self): - f = self.__froms.difference(self.__hide_froms) - if (len(f) > 1): - return f.difference(self.__correlated) - else: - return f - - froms = property(_calc_froms, - doc="""A collection containing all elements - of the ``FROM`` clause.""") - - def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.columns) or []) + \ - list(self.froms) + \ - [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None] - - def accept_visitor(self, visitor): - visitor.visit_select(self) - def union(self, other, **kwargs): return union(self, other, **kwargs) @@ -3108,7 +3278,7 @@ class Select(_SelectBaseMixin, FromClause): if self._bind is not None: return self._bind - for f in self.__froms: + for f in self._froms: if f is self: continue e = f.bind @@ -3133,20 +3303,24 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True - class _SelectCorrelator(NoColumnVisitor): - def __init__(self, table): - NoColumnVisitor.__init__(self) - self.table = table - - def visit_select(self, select): - if select.should_correlate: - select.correlate(self.table) - - def _process_whereclause(self, whereclause): - if whereclause is not None: - _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) - return whereclause + def _calculate_correlations(self, correlate_state): + class SelectCorrelator(NoColumnVisitor): + def visit_select(s, select): + if select._should_correlate: + select_state = correlate_state.setdefault(select, {}) + corr = select_state.setdefault('correlate', util.Set()) + corr.add(self.table) + + vis = SelectCorrelator() + if self._whereclause is not None: + vis.traverse(self._whereclause) + + if getattr(self, 'parameters', None) is not None: + for key, value in self.parameters.items(): + if isinstance(value, ClauseElement): + vis.traverse(value) + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -3155,7 +3329,7 @@ class _UpdateBase(ClauseElement): if parameters is None: return None - if isinstance(parameters, list) or isinstance(parameters, tuple): + if isinstance(parameters, (list, tuple)): pp = {} i = 0 for c in self.table.c: @@ -3163,11 +3337,10 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp - correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] if isinstance(value, ClauseElement): - correlator.traverse(value) + parameters[key] = value.self_group() elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -3182,7 +3355,7 @@ class _UpdateBase(ClauseElement): def _find_engine(self): return self.table.bind -class _Insert(_UpdateBase): +class Insert(_UpdateBase): def __init__(self, table, values=None): self.table = table self.select = None @@ -3193,32 +3366,41 @@ class _Insert(_UpdateBase): return self.select, else: return () - def accept_visitor(self, visitor): - visitor.visit_insert(self) -class _Update(_UpdateBase): +class Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = self._process_whereclause(whereclause) + self._whereclause = whereclause self.parameters = self._process_colparams(values) def get_children(self, **kwargs): - if self.whereclause is not None: - return self.whereclause, + if self._whereclause is not None: + return self._whereclause, else: return () - def accept_visitor(self, visitor): - visitor.visit_update(self) -class _Delete(_UpdateBase): +class Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = self._process_whereclause(whereclause) + self._whereclause = whereclause def get_children(self, **kwargs): - if self.whereclause is not None: - return self.whereclause, + if self._whereclause is not None: + return self._whereclause, else: return () - def accept_visitor(self, visitor): - visitor.visit_delete(self) + +class _IdentifiedClause(ClauseElement): + def __init__(self, ident): + self.ident = ident + def supports_execution(self): + return True + +class SavepointClause(_IdentifiedClause): + pass + +class RollbackToSavepointClause(_IdentifiedClause): + pass + +class ReleaseSavepointClause(_IdentifiedClause): + pass diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 9235b9c4e..d91fbe4b5 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -53,7 +53,7 @@ class TableCollection(object): for table in self.tables: vis.traverse(table) sorter = topological.QueueDependencySorter( tuples, self.tables ) - head = sorter.sort() + head = sorter.sort() sequence = [] def to_sequence( node, seq=sequence): seq.append( node.item ) @@ -67,12 +67,12 @@ class TableCollection(object): class TableFinder(TableCollection, sql.NoColumnVisitor): """locate all Tables within a clause.""" - def __init__(self, table, check_columns=False, include_aliases=False): + def __init__(self, clause, check_columns=False, include_aliases=False): TableCollection.__init__(self) self.check_columns = check_columns self.include_aliases = include_aliases - if table is not None: - self.traverse(table) + for clause in util.to_list(clause): + self.traverse(clause) def visit_alias(self, alias): if self.include_aliases: @@ -83,7 +83,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor): def visit_column(self, column): if self.check_columns: - self.traverse(column.table) + self.tables.append(column.table) class ColumnFinder(sql.ClauseVisitor): def __init__(self): @@ -125,7 +125,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): process the new list. """ - list_ = [o.copy_container() for o in list_] + list_ = list(list_) self.process_list(list_) return list_ @@ -137,7 +137,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): if elem is not None: list_[i] = elem else: - self.traverse(list_[i]) + list_[i] = self.traverse(list_[i], clone=True) def visit_grouping(self, grouping): elem = self.convert_element(grouping.elem) @@ -162,8 +162,24 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): elem = self.convert_element(binary.right) if elem is not None: binary.right = elem - - # TODO: visit_select(). + + def visit_select(self, select): + fr = util.OrderedSet() + for elem in select._froms: + n = self.convert_element(elem) + if n is not None: + fr.add((elem, n)) + select._recorrelate_froms(fr) + + col = [] + for elem in select._raw_columns: + print "RAW COLUMN", elem + n = self.convert_element(elem) + if n is None: + col.append(elem) + else: + col.append(n) + select._raw_columns = col class ClauseAdapter(AbstractClauseProcessor): """Given a clause (like as in a WHERE criterion), locate columns @@ -200,6 +216,9 @@ class ClauseAdapter(AbstractClauseProcessor): self.equivalents = equivalents def convert_element(self, col): + if isinstance(col, sql.FromClause): + if self.selectable.is_derived_from(col): + return self.selectable if not isinstance(col, sql.ColumnElement): return None if self.include is not None: @@ -214,4 +233,9 @@ class ClauseAdapter(AbstractClauseProcessor): newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False) if newcol: return newcol + #if newcol is None: + # self.traverse(col) + # return col return newcol + + diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index bad71293e..56c8cb46e 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -42,7 +42,6 @@ nature - very tricky to reproduce and track down, particularly before I realized this characteristic of the algorithm. """ -import string, StringIO from sqlalchemy import util from sqlalchemy.exceptions import CircularDependencyError @@ -68,7 +67,7 @@ class _Node(object): str(self.item) + \ (self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \ "\n" + \ - string.join([n.safestr(indent + 1) for n in self.children], '') + ''.join([n.safestr(indent + 1) for n in self.children]) def __repr__(self): return "%s" % (str(self.item)) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 3cceedae6..ec1459852 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -7,34 +7,29 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine', 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'FLOAT', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN', 'String', 'Integer', 'SmallInteger','Smallinteger', - 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE', + 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE', 'NullType', 'SMALLINT', 'DATE', 'TIME','Interval' ] -from sqlalchemy import util, exceptions -import inspect, weakref +import inspect import datetime as dt +from decimal import Decimal try: import cPickle as pickle except: import pickle -_impl_cache = weakref.WeakKeyDictionary() +from sqlalchemy import exceptions class AbstractType(object): - def _get_impl_dict(self): - try: - return _impl_cache[self] - except KeyError: - return _impl_cache.setdefault(self, {}) - - impl_dict = property(_get_impl_dict) - + def __init__(self, *args, **kwargs): + pass + def copy_value(self, value): return value def compare_values(self, x, y): - return x is y + return x == y def is_mutable(self): return False @@ -51,15 +46,20 @@ class AbstractType(object): return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]])) class TypeEngine(AbstractType): - def __init__(self, *args, **params): - pass - def dialect_impl(self, dialect): try: - return self.impl_dict[dialect] + return self._impl_dict[dialect] + except AttributeError: + self._impl_dict = {} + return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self)) except KeyError: - return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self)) - + return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self)) + + def __getstate__(self): + d = self.__dict__.copy() + d['_impl_dict'] = {} + return d + def get_col_spec(self): raise NotImplementedError() @@ -88,15 +88,19 @@ class TypeDecorator(AbstractType): def dialect_impl(self, dialect): try: - return self.impl_dict[dialect] - except: - typedesc = self.load_dialect_impl(dialect) - tt = self.copy() - if not isinstance(tt, self.__class__): - raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) - tt.impl = typedesc - self.impl_dict[dialect] = tt - return tt + return self._impl_dict[dialect] + except AttributeError: + self._impl_dict = {} + except KeyError: + pass + + typedesc = self.load_dialect_impl(dialect) + tt = self.copy() + if not isinstance(tt, self.__class__): + raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) + tt.impl = typedesc + self._impl_dict[dialect] = tt + return tt def load_dialect_impl(self, dialect): """loads the dialect-specific implementation of this type. @@ -179,7 +183,7 @@ def adapt_type(typeobj, colspecs): return typeobj return typeobj.adapt(impltype) -class NullTypeEngine(TypeEngine): +class NullType(TypeEngine): def get_col_spec(self): raise NotImplementedError() @@ -188,8 +192,13 @@ class NullTypeEngine(TypeEngine): def convert_result_value(self, value, dialect): return value +NullTypeEngine = NullType -class String(TypeEngine): +class Concatenable(object): + """marks a type as supporting 'concatenation'""" + pass + +class String(TypeEngine, Concatenable): def __init__(self, length=None, convert_unicode=False): self.length = length self.convert_unicode = convert_unicode @@ -219,9 +228,6 @@ class String(TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.STRING - def compare_values(self, x, y): - return x == y - class Unicode(String): def __init__(self, length=None, **kwargs): kwargs['convert_unicode'] = True @@ -241,22 +247,36 @@ class SmallInteger(Integer): Smallinteger = SmallInteger class Numeric(TypeEngine): - def __init__(self, precision = 10, length = 2): + def __init__(self, precision = 10, length = 2, asdecimal=True): self.precision = precision self.length = length + self.asdecimal = asdecimal def adapt(self, impltype): - return impltype(precision=self.precision, length=self.length) + return impltype(precision=self.precision, length=self.length, asdecimal=self.asdecimal) def get_dbapi_type(self, dbapi): return dbapi.NUMBER + def convert_bind_param(self, value, dialect): + if value is not None: + return float(value) + else: + return value + + def convert_result_value(self, value, dialect): + if value is not None and self.asdecimal: + return Decimal(str(value)) + else: + return value + class Float(Numeric): - def __init__(self, precision = 10): + def __init__(self, precision = 10, asdecimal=False, **kwargs): self.precision = precision - + self.asdecimal = asdecimal + def adapt(self, impltype): - return impltype(precision=self.precision) + return impltype(precision=self.precision, asdecimal=self.asdecimal) class DateTime(TypeEngine): """Implement a type for ``datetime.datetime()`` objects.""" @@ -416,4 +436,4 @@ class NCHAR(Unicode):pass class BLOB(Binary): pass class BOOLEAN(Boolean): pass -NULLTYPE = NullTypeEngine() +NULLTYPE = NullType() diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index b47822d61..e711de3a3 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -14,7 +14,6 @@ from sqlalchemy import exceptions import md5 import sys import warnings - import __builtin__ try: @@ -33,10 +32,35 @@ except: i -= 1 raise StopIteration() -def to_list(x): +if sys.version_info >= (2, 5): + class PopulateDict(dict): + """a dict which populates missing values via a creation function. + + note the creation function takes a key, unlike collections.defaultdict. + """ + + def __init__(self, creator): + self.creator = creator + def __missing__(self, key): + self[key] = val = self.creator(key) + return val +else: + class PopulateDict(dict): + """a dict which populates missing values via a creation function.""" + + def __init__(self, creator): + self.creator = creator + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + self[key] = value = self.creator(key) + return value + +def to_list(x, default=None): if x is None: - return None - if not isinstance(x, list) and not isinstance(x, tuple): + return default + if not isinstance(x, (list, tuple)): return [x] else: return x @@ -113,19 +137,25 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): else: kw[key] = type_(kw[key]) -def duck_type_collection(col, default=None): +def duck_type_collection(specimen, default=None): """Given an instance or class, guess if it is or is acting as one of the basic collection types: list, set and dict. If the __emulates__ property is present, return that preferentially. """ - if hasattr(col, '__emulates__'): - return getattr(col, '__emulates__') - elif hasattr(col, 'append'): + if hasattr(specimen, '__emulates__'): + return specimen.__emulates__ + + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): return list + if isa(specimen, Set): return Set + if isa(specimen, dict): return dict + + if hasattr(specimen, 'append'): return list - elif hasattr(col, 'add'): + elif hasattr(specimen, 'add'): return Set - elif hasattr(col, 'set'): + elif hasattr(specimen, 'set'): return dict else: return default @@ -138,11 +168,11 @@ def assert_arg_type(arg, argtype, name): raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg)))) else: raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg)))) - -def warn_exception(func): + +def warn_exception(func, *args, **kwargs): """executes the given function, catches all exceptions and converts to a warning.""" try: - return func() + return func(*args, **kwargs) except: warnings.warn(RuntimeWarning("%s('%s') ignored" % sys.exc_info()[0:2])) @@ -246,12 +276,12 @@ class OrderedProperties(object): class OrderedDict(dict): """A Dictionary that returns keys/values/items in the order they were added.""" - def __init__(self, d=None, **kwargs): + def __init__(self, ____sequence=None, **kwargs): self._list = [] - if d is None: + if ____sequence is None: self.update(**kwargs) else: - self.update(d, **kwargs) + self.update(____sequence, **kwargs) def clear(self): self._list = [] @@ -347,7 +377,13 @@ class DictDecorator(dict): return dict.__getitem__(self, key) except KeyError: return self.decorate[key] - + + def __contains__(self, key): + return dict.__contains__(self, key) or key in self.decorate + + def has_key(self, key): + return key in self + def __repr__(self): return dict.__repr__(self) + repr(self.decorate) @@ -442,19 +478,28 @@ class OrderedSet(Set): __isub__ = difference_update class UniqueAppender(object): - def __init__(self, data): + """appends items to a collection such that only unique items + are added.""" + + def __init__(self, data, via=None): self.data = data - if hasattr(data, 'append'): + self._unique = Set() + if via: + self._data_appender = getattr(data, via) + elif hasattr(data, 'append'): self._data_appender = data.append elif hasattr(data, 'add'): + # TODO: we think its a set here. bypass unneeded uniquing logic ? self._data_appender = data.add - self.set = Set() - + def append(self, item): - if item not in self.set: - self.set.add(item) + if item not in self._unique: self._data_appender(item) - + self._unique.add(item) + + def __iter__(self): + return iter(self.data) + class ScopedRegistry(object): """A Registry that can store one or multiple instances of a single class on a per-thread scoped basis, or on a customized scope. |
