diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 410 |
1 files changed, 41 insertions, 369 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 6796d7edb..50ce30aaf 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,48 +1,32 @@ # sql/util.py -# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file> +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .. import exc, schema, util, sql -from ..util import topological -from . import expression, operators, visitors +"""High level utilities which build upon other modules here. + +""" + +from .. import exc, util +from .base import _from_objects, ColumnSet +from . import operators, visitors from itertools import chain from collections import deque -"""Utility functions that build upon SQL and Schema constructs.""" +from .elements import BindParameter, ColumnClause, ColumnElement, \ + Null, UnaryExpression, literal_column, Label +from .selectable import ScalarSelect, Join, FromClause, FromGrouping +from .schema import Column +join_condition = util.langhelpers.public_factory( + Join._join_condition, + ".sql.util.join_condition") -def sort_tables(tables, skip_fn=None, extra_dependencies=None): - """sort a collection of Table objects in order of - their foreign-key dependency.""" - - tables = list(tables) - tuples = [] - if extra_dependencies is not None: - tuples.extend(extra_dependencies) - - def visit_foreign_key(fkey): - if fkey.use_alter: - return - elif skip_fn and skip_fn(fkey): - return - parent_table = fkey.column.table - if parent_table in tables: - child_table = fkey.parent.table - if parent_table is not child_table: - tuples.append((parent_table, child_table)) - - for table in tables: - visitors.traverse(table, - {'schema_visitor': True}, - {'foreign_key': visit_foreign_key}) - - tuples.extend( - [parent, table] for parent in table._extra_dependencies - ) - - return list(topological.sort(tuples, tables)) +# names that are still being imported from the outside +from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate +from .elements import _find_columns +from .ddl import sort_tables def find_join_source(clauses, join_to): @@ -62,7 +46,7 @@ def find_join_source(clauses, join_to): """ - selectables = list(expression._from_objects(join_to)) + selectables = list(_from_objects(join_to)) for i, f in enumerate(clauses): for s in selectables: if f.is_derived_from(s): @@ -109,7 +93,7 @@ def visit_binary_product(fn, expr): stack = [] def visit(element): - if isinstance(element, (expression.ScalarSelect)): + if isinstance(element, ScalarSelect): # we dont want to dig into correlated subqueries, # those are just column elements by themselves yield element @@ -123,7 +107,7 @@ def visit_binary_product(fn, expr): for elem in element.get_children(): visit(elem) else: - if isinstance(element, expression.ColumnClause): + if isinstance(element, ColumnClause): yield element for elem in element.get_children(): for e in visit(elem): @@ -163,13 +147,6 @@ def find_tables(clause, check_columns=False, return tables -def find_columns(clause): - """locate Column objects within the given expression.""" - - cols = util.column_set() - visitors.traverse(clause, {}, {'column': cols.add}) - return cols - def unwrap_order_by(clause): """Break up an 'order by' expression into individual column-expressions, @@ -179,9 +156,9 @@ def unwrap_order_by(clause): stack = deque([clause]) while stack: t = stack.popleft() - if isinstance(t, expression.ColumnElement) and \ + if isinstance(t, ColumnElement) and \ ( - not isinstance(t, expression.UnaryExpression) or \ + not isinstance(t, UnaryExpression) or \ not operators.is_ordering_modifier(t.modifier) ): cols.add(t) @@ -211,9 +188,9 @@ def surface_selectables(clause): while stack: elem = stack.pop() yield elem - if isinstance(elem, expression.Join): + if isinstance(elem, Join): stack.extend((elem.left, elem.right)) - elif isinstance(elem, expression.FromGrouping): + elif isinstance(elem, FromGrouping): stack.append(elem.element) def selectables_overlap(left, right): @@ -277,27 +254,6 @@ class _repr_params(object): return repr(self.params) -def expression_as_ddl(clause): - """Given a SQL expression, convert for usage in DDL, such as - CREATE INDEX and CHECK CONSTRAINT. - - Converts bind params into quoted literals, column identifiers - into detached column constructs so that the parent table - identifier is not included. - - """ - def repl(element): - if isinstance(element, expression.BindParameter): - return expression.literal_column(_quote_ddl_expr(element.value)) - elif isinstance(element, expression.ColumnClause) and \ - element.table is not None: - col = expression.column(element.name) - col.quote = element.quote - return col - else: - return None - - return visitors.replacement_traverse(clause, {}, repl) def adapt_criterion_to_null(crit, nulls): @@ -307,308 +263,22 @@ def adapt_criterion_to_null(crit, nulls): """ def visit_binary(binary): - if isinstance(binary.left, expression.BindParameter) \ + if isinstance(binary.left, BindParameter) \ and binary.left._identifying_key in nulls: # reverse order if the NULL is on the left side binary.left = binary.right - binary.right = expression.null() + binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - elif isinstance(binary.right, expression.BindParameter) \ + elif isinstance(binary.right, BindParameter) \ and binary.right._identifying_key in nulls: - binary.right = expression.null() + binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) -def join_condition(a, b, ignore_nonexistent_tables=False, - a_subset=None, - consider_as_foreign_keys=None): - """create a join condition between two tables or selectables. - - e.g.:: - - join_condition(tablea, tableb) - - would produce an expression along the lines of:: - - tablea.c.id==tableb.c.tablea_id - - The join is determined based on the foreign key relationships - between the two selectables. If there are multiple ways - to join, or no way to join, an error is raised. - - :param ignore_nonexistent_tables: Deprecated - this - flag is no longer used. Only resolution errors regarding - the two given tables are propagated. - - :param a_subset: An optional expression that is a sub-component - of ``a``. An attempt will be made to join to just this sub-component - first before looking at the full ``a`` construct, and if found - will be successful even if there are other ways to join to ``a``. - This allows the "right side" of a join to be passed thereby - providing a "natural join". - - """ - crit = [] - constraints = set() - - for left in (a_subset, a): - if left is None: - continue - for fk in sorted( - b.foreign_keys, - key=lambda fk: fk.parent._creation_order): - if consider_as_foreign_keys is not None and \ - fk.parent not in consider_as_foreign_keys: - continue - try: - col = fk.get_referent(left) - except exc.NoReferenceError as nrte: - if nrte.table_name == left.name: - raise - else: - continue - - if col is not None: - crit.append(col == fk.parent) - constraints.add(fk.constraint) - if left is not b: - for fk in sorted( - left.foreign_keys, - key=lambda fk: fk.parent._creation_order): - if consider_as_foreign_keys is not None and \ - fk.parent not in consider_as_foreign_keys: - continue - try: - col = fk.get_referent(b) - except exc.NoReferenceError as nrte: - if nrte.table_name == b.name: - raise - else: - # this is totally covered. can't get - # coverage to mark it. - continue - - if col is not None: - crit.append(col == fk.parent) - constraints.add(fk.constraint) - if crit: - break - - if len(crit) == 0: - if isinstance(b, expression.FromGrouping): - hint = " Perhaps you meant to convert the right side to a "\ - "subquery using alias()?" - else: - hint = "" - raise exc.NoForeignKeysError( - "Can't find any foreign key relationships " - "between '%s' and '%s'.%s" % (a.description, b.description, hint)) - elif len(constraints) > 1: - raise exc.AmbiguousForeignKeysError( - "Can't determine join between '%s' and '%s'; " - "tables have more than one foreign key " - "constraint relationship between them. " - "Please specify the 'onclause' of this " - "join explicitly." % (a.description, b.description)) - elif len(crit) == 1: - return (crit[0]) - else: - return sql.and_(*crit) - - -class Annotated(object): - """clones a ClauseElement and applies an 'annotations' dictionary. - - Unlike regular clones, this clone also mimics __hash__() and - __cmp__() of the original element so that it takes its place - in hashed collections. - - A reference to the original element is maintained, for the important - reason of keeping its hash value current. When GC'ed, the - hash value may be reused, causing conflicts. - - """ - - def __new__(cls, *args): - if not args: - # clone constructor - return object.__new__(cls) - else: - element, values = args - # pull appropriate subclass from registry of annotated - # classes - try: - cls = annotated_classes[element.__class__] - except KeyError: - cls = annotated_classes[element.__class__] = type.__new__(type, - "Annotated%s" % element.__class__.__name__, - (cls, element.__class__), {}) - return object.__new__(cls) - - def __init__(self, element, values): - # force FromClause to generate their internal - # collections into __dict__ - if isinstance(element, expression.FromClause): - element.c - - self.__dict__ = element.__dict__.copy() - expression.ColumnElement.comparator._reset(self) - self.__element = element - self._annotations = values - - def _annotate(self, values): - _values = self._annotations.copy() - _values.update(values) - return self._with_annotations(_values) - - def _with_annotations(self, values): - clone = self.__class__.__new__(self.__class__) - clone.__dict__ = self.__dict__.copy() - expression.ColumnElement.comparator._reset(clone) - clone._annotations = values - return clone - - def _deannotate(self, values=None, clone=True): - if values is None: - return self.__element - else: - _values = self._annotations.copy() - for v in values: - _values.pop(v, None) - return self._with_annotations(_values) - - def _compiler_dispatch(self, visitor, **kw): - return self.__element.__class__._compiler_dispatch(self, visitor, **kw) - - @property - def _constructor(self): - return self.__element._constructor - - def _clone(self): - clone = self.__element._clone() - if clone is self.__element: - # detect immutable, don't change anything - return self - else: - # update the clone with any changes that have occurred - # to this object's __dict__. - clone.__dict__.update(self.__dict__) - return self.__class__(clone, self._annotations) - - def __hash__(self): - return hash(self.__element) - - def __eq__(self, other): - if isinstance(self.__element, expression.ColumnOperators): - return self.__element.__class__.__eq__(self, other) - else: - return hash(other) == hash(self) - - -class AnnotatedColumnElement(Annotated): - def __init__(self, element, values): - Annotated.__init__(self, element, values) - for attr in ('name', 'key'): - if self.__dict__.get(attr, False) is None: - self.__dict__.pop(attr) - - @util.memoized_property - def name(self): - """pull 'name' from parent, if not present""" - return self._Annotated__element.name - - @util.memoized_property - def key(self): - """pull 'key' from parent, if not present""" - return self._Annotated__element.key - - @util.memoized_property - def info(self): - return self._Annotated__element.info - -# hard-generate Annotated subclasses. this technique -# is used instead of on-the-fly types (i.e. type.__new__()) -# so that the resulting objects are pickleable. -annotated_classes = {} - -for cls in list(expression.__dict__.values()) + [schema.Column, schema.Table]: - if isinstance(cls, type) and issubclass(cls, expression.ClauseElement): - if issubclass(cls, expression.ColumnElement): - annotation_cls = "AnnotatedColumnElement" - else: - annotation_cls = "Annotated" - exec("class Annotated%s(%s, cls):\n" \ - " pass" % (cls.__name__, annotation_cls), locals()) - exec("annotated_classes[cls] = Annotated%s" % (cls.__name__,)) - - -def _deep_annotate(element, annotations, exclude=None): - """Deep copy the given ClauseElement, annotating each element - with the given annotations dictionary. - - Elements within the exclude collection will be cloned but not annotated. - - """ - def clone(elem): - if exclude and \ - hasattr(elem, 'proxy_set') and \ - elem.proxy_set.intersection(exclude): - newelem = elem._clone() - elif annotations != elem._annotations: - newelem = elem._annotate(annotations) - else: - newelem = elem - newelem._copy_internals(clone=clone) - return newelem - - if element is not None: - element = clone(element) - return element - - -def _deep_deannotate(element, values=None): - """Deep copy the given element, removing annotations.""" - - cloned = util.column_dict() - - def clone(elem): - # if a values dict is given, - # the elem must be cloned each time it appears, - # as there may be different annotations in source - # elements that are remaining. if totally - # removing all annotations, can assume the same - # slate... - if values or elem not in cloned: - newelem = elem._deannotate(values=values, clone=True) - newelem._copy_internals(clone=clone) - if not values: - cloned[elem] = newelem - return newelem - else: - return cloned[elem] - - if element is not None: - element = clone(element) - return element - - -def _shallow_annotate(element, annotations): - """Annotate the given ClauseElement and copy its internals so that - internal objects refer to the new annotated object. - - Basically used to apply a "dont traverse" annotation to a - selectable, without digging throughout the whole - structure wasting time. - """ - element = element._annotate(annotations) - element._copy_internals() - return element - - def splice_joins(left, right, stop_on=None): if left is None: return right @@ -619,7 +289,7 @@ def splice_joins(left, right, stop_on=None): ret = None while stack: (right, prevright) = stack.pop() - if isinstance(right, expression.Join) and right is not stop_on: + if isinstance(right, Join) and right is not stop_on: right = right._clone() right._reset_exported() right.onclause = adapter.traverse(right.onclause) @@ -703,7 +373,7 @@ def reduce_columns(columns, *clauses, **kw): if clause is not None: visitors.traverse(clause, {}, {'binary': visit_binary}) - return expression.ColumnSet(columns.difference(omit)) + return ColumnSet(columns.difference(omit)) def criterion_as_pairs(expression, consider_as_foreign_keys=None, @@ -722,8 +392,8 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return - if not isinstance(binary.left, sql.ColumnElement) or \ - not isinstance(binary.right, sql.ColumnElement): + if not isinstance(binary.left, ColumnElement) or \ + not isinstance(binary.right, ColumnElement): return if consider_as_foreign_keys: @@ -745,8 +415,8 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: - if isinstance(binary.left, schema.Column) and \ - isinstance(binary.right, schema.Column): + if isinstance(binary.left, Column) and \ + isinstance(binary.right, Column): if binary.left.references(binary.right): pairs.append((binary.right, binary.left)) elif binary.right.references(binary.left): @@ -756,6 +426,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, return pairs + class AliasedRow(object): """Wrap a RowProxy with a translation map. @@ -848,10 +519,10 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): magic_flag = False def replace(self, col): - if not self.magic_flag and isinstance(col, expression.FromClause) and \ + if not self.magic_flag and isinstance(col, FromClause) and \ self.selectable.is_derived_from(col): return self.selectable - elif not isinstance(col, expression.ColumnElement): + elif not isinstance(col, ColumnElement): return None elif self.include_fn and not self.include_fn(col): return None @@ -903,7 +574,7 @@ class ColumnAdapter(ClauseAdapter): c = self.adapt_clause(col) # anonymize labels in case they have a hardcoded name - if isinstance(c, expression.Label): + if isinstance(c, Label): c = c.label(None) # adapt_required used by eager loading to indicate that @@ -927,3 +598,4 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) self.columns = util.PopulateDict(self._locate_col) + |