From 4a6afd469fad170868554bf28578849bf3dfd5dd Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 9 May 2008 16:34:10 +0000 Subject: r4695 merged to trunk; trunk now becomes 0.5. 0.4 development continues at /sqlalchemy/branches/rel_0_4 --- lib/sqlalchemy/sql/expression.py | 490 +++++++++++++++++---------------------- 1 file changed, 219 insertions(+), 271 deletions(-) (limited to 'lib/sqlalchemy/sql/expression.py') diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 867fdd69c..7ce637701 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,12 +26,12 @@ to stay the same in future releases. """ import itertools, re -from sqlalchemy import util, exceptions +from sqlalchemy import util, exc from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes functions, schema, sql_util = None, None, None -DefaultDialect, ClauseAdapter = None, None +DefaultDialect, ClauseAdapter, Annotated = None, None, None __all__ = [ 'Alias', 'ClauseElement', @@ -503,15 +503,21 @@ def collate(expression, collation): def exists(*args, **kwargs): """Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object. + + Calling styles are of the following forms:: + + # use on an existing select() + s = select([]).where() + s = exists(s) + + # construct a select() at once + exists(['*'], **select_arguments).where() + + # columns argument is optional, generates "EXISTS (SELECT *)" + # by default. + exists().where() - The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by - itself or used as a subquery within an enclosing select. - - \*args, \**kwargs - all arguments are sent directly to the [sqlalchemy.sql.expression#select()] - function to produce a ``SELECT`` statement. """ - return _Exists(*args, **kwargs) def union(*selects, **kwargs): @@ -872,27 +878,36 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, ClauseElement) + return not isinstance(element, (ClauseElement, Operators)) + +def _from_objects(*elements, **kwargs): + return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements]) +def _labeled(element): + if not hasattr(element, 'name'): + return element.label(None) + else: + return element + def _literal_as_text(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return _TextClause(unicode(element)) else: return element def _literal_as_column(element): - if isinstance(element, Operators): - return element.clause_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return literal_column(str(element)) else: return element def _literal_as_binds(element, name=None, type_=None): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): if element is None: return null() @@ -902,17 +917,17 @@ def _literal_as_binds(element, name=None, type_=None): return element def _no_literals(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): - raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) else: return element def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) return c def _selectable(element): @@ -921,9 +936,8 @@ def _selectable(element): elif isinstance(element, Selectable): return element else: - raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) + raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) - def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) @@ -941,7 +955,9 @@ class _FigureVisitName(type): class ClauseElement(object): """Base class for elements of a programmatically constructed SQL expression.""" __metaclass__ = _FigureVisitName - + _annotations = {} + supports_execution = False + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -976,6 +992,14 @@ class ClauseElement(object): """ raise NotImplementedError(repr(self)) + + def _annotate(self, values): + """return a copy of this ClauseElement with the given annotations dictionary.""" + + global Annotated + if Annotated is None: + from sqlalchemy.sql.util import Annotated + return Annotated(self, values) def unique_params(self, *optionaldict, **kwargs): """Return a copy with ``bindparam()`` elments replaced. @@ -1006,14 +1030,14 @@ class ClauseElement(object): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: - raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument") + raise exc.ArgumentError("params() takes zero or one positional dictionary argument") def visit_bindparam(bind): if bind.key in kwargs: bind.value = kwargs[bind.key] if unique: bind._convert_to_unique() - return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True) + return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam}) def compare(self, other): """Compare this ClauseElement to the given ClauseElement. @@ -1049,11 +1073,6 @@ class ClauseElement(object): def self_group(self, against=None): return self - def supports_execution(self): - """Return True if this clause element represents a complete executable statement.""" - - return False - def bind(self): """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" @@ -1062,7 +1081,7 @@ class ClauseElement(object): return self._bind except AttributeError: pass - for f in self._get_from_objects(): + for f in _from_objects(self): if f is self: continue engine = f.bind @@ -1083,7 +1102,7 @@ class ClauseElement(object): 'Engine for execution. Or, assign a bind to the statement ' 'or the Metadata of its underlying tables to enable ' 'implicit execution via this method.' % label) - raise exceptions.UnboundExecutionError(msg) + raise exc.UnboundExecutionError(msg) return e.execute_clauseelement(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -1159,6 +1178,12 @@ class ClauseElement(object): self.__module__, self.__class__.__name__, id(self), friendly) +class _Immutable(object): + """mark a ClauseElement as 'immutable' when expressions are cloned.""" + + def _clone(self): + return self + class Operators(object): def __and__(self, other): return self.operate(operators.and_, other) @@ -1174,9 +1199,6 @@ class Operators(object): return self.operate(operators.op, opstring, b) return op - def clause_element(self): - raise NotImplementedError() - def operate(self, op, *other, **kwargs): raise NotImplementedError() @@ -1216,7 +1238,7 @@ class ColumnOperators(Operators): def ilike(self, other, escape=None): return self.operate(operators.ilike_op, other, escape=escape) - def in_(self, *other): + def in_(self, other): return self.operate(operators.in_op, other) def startswith(self, other, **kwargs): @@ -1279,18 +1301,18 @@ class _CompareMixin(ColumnOperators): def __compare(self, op, obj, negate=None, reverse=False, **kwargs): if obj is None or isinstance(obj, _Null): if op == operators.eq: - return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) + return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot) elif op == operators.ne: - return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_) + return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_) else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) else: - return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): obj = self._check_literal(obj) @@ -1298,9 +1320,9 @@ class _CompareMixin(ColumnOperators): type_ = self._compare_type(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_) + return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_) else: - return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) + return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_) # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1329,17 +1351,10 @@ class _CompareMixin(ColumnOperators): o = _CompareMixin.operators[op] return o[0](self, op, other, reverse=True, *o[1:], **kwargs) - def in_(self, *other): - return self._in_impl(operators.in_op, operators.notin_op, *other) - - def _in_impl(self, op, negate_op, *other): - # Handle old style *args argument passing - if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)): - util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable') - seq_or_selectable = other - else: - seq_or_selectable = other[0] + def in_(self, other): + return self._in_impl(operators.in_op, operators.notin_op, other) + def _in_impl(self, op, negate_op, seq_or_selectable): if isinstance(seq_or_selectable, Selectable): return self.__compare( op, seq_or_selectable, negate=negate_op) @@ -1348,7 +1363,7 @@ class _CompareMixin(ColumnOperators): for o in seq_or_selectable: if not _is_literal(o): if not isinstance( o, _CompareMixin): - raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) + raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) else: o = self._bind_param(o) args.append(o) @@ -1433,22 +1448,13 @@ class _CompareMixin(ColumnOperators): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): other.type = self.type return other - elif isinstance(other, Operators): - return other.expression_element() + elif hasattr(other, '__clause_element__'): + return other.__clause_element__() elif _is_literal(other): return self._bind_param(other) else: return other - def clause_element(self): - """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" - return self - - def expression_element(self): - """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" - - return self - def _compare_type(self, obj): """Allow subclasses to override the type used in constructing ``_BinaryExpression`` objects. @@ -1480,23 +1486,22 @@ class ColumnElement(ClauseElement, _CompareMixin): primary_key = False foreign_keys = [] - + quote = None + def base_columns(self): - if hasattr(self, '_base_columns'): - return self._base_columns - self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) + if not hasattr(self, '_base_columns'): + self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) return self._base_columns base_columns = property(base_columns) def proxy_set(self): - if hasattr(self, '_proxy_set'): - return self._proxy_set - s = util.Set([self]) - if hasattr(self, 'proxies'): - for c in self.proxies: - s = s.union(c.proxy_set) - self._proxy_set = s - return s + if not hasattr(self, '_proxy_set'): + s = util.Set([self]) + if hasattr(self, 'proxies'): + for c in self.proxies: + s.update(c.proxy_set) + self._proxy_set = s + return self._proxy_set proxy_set = property(proxy_set) def shares_lineage(self, othercolumn): @@ -1518,7 +1523,7 @@ class ColumnElement(ClauseElement, _CompareMixin): co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) co.proxies = [self] - selectable.columns[name]= co + selectable.columns[name] = co return co def anon_label(self): @@ -1613,7 +1618,7 @@ class ColumnCollection(util.OrderedProperties): def __contains__(self, other): if not isinstance(other, basestring): - raise exceptions.ArgumentError("__contains__ requires a string argument") + raise exc.ArgumentError("__contains__ requires a string argument") return util.OrderedProperties.__contains__(self, other) def contains_column(self, col): @@ -1641,6 +1646,9 @@ class ColumnSet(util.OrderedSet): l.append(c==local) return and_(*l) + def __hash__(self): + return hash(tuple(self._list)) + class Selectable(ClauseElement): """mark a class as being selectable""" @@ -1648,8 +1656,9 @@ class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" __visit_name__ = 'fromclause' - named_with_column=False + named_with_column = False _hide_froms = [] + quote = None def _get_from_objects(self, **modifiers): return [] @@ -1694,12 +1703,12 @@ class FromClause(Selectable): return fromclause in util.Set(self._cloned_set) def replace_selectable(self, old, alias): - """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" + """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - global ClauseAdapter - if ClauseAdapter is None: - from sqlalchemy.sql.util import ClauseAdapter - return ClauseAdapter(alias).traverse(self, clone=True) + global ClauseAdapter + if ClauseAdapter is None: + from sqlalchemy.sql.util import ClauseAdapter + return ClauseAdapter(alias).traverse(self) def correspond_on_equivalents(self, column, equivalents): col = self.corresponding_column(column, require_embedded=True) @@ -1859,7 +1868,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): def _convert_to_unique(self): if not self.unique: - self.unique=True + self.unique = True self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param') def _get_from_objects(self, **modifiers): @@ -1910,6 +1919,7 @@ class _TextClause(ClauseElement): __visit_name__ = 'textclause' _bind_params_regex = re.compile(r'(? 1 or self._correlate: if self._correlate: - froms.difference_update(_cloned_intersection(froms, self._correlate)) + froms = froms.difference(_cloned_intersection(froms, self._correlate)) if self._should_correlate and existing_froms: - froms.difference_update(_cloned_intersection(froms, existing_froms)) + froms = froms.difference(_cloned_intersection(froms, existing_froms)) if not len(froms): - raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate() to control correlation manually." % self) + raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate() to control correlation manually." % self) return froms froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") def type(self): - raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") + raise exc.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") type = property(type) def locate_all_froms(self): @@ -3059,22 +3025,10 @@ class Select(_SelectBaseMixin, FromClause): is specifically for those FromClause elements that would actually be rendered. """ - if hasattr(self, '_all_froms'): - return self._all_froms - - froms = util.Set( - itertools.chain(* - [self._froms] + - [f._get_from_objects() for f in self._froms] + - [col._get_from_objects() for col in self._raw_columns] - ) - ) + if not hasattr(self, '_all_froms'): + self._all_froms = self._froms.union(_from_objects(*list(self._froms))) - if self._whereclause: - froms.update(self._whereclause._get_from_objects(is_where=True)) - - self._all_froms = froms - return froms + return self._all_froms def inner_columns(self): """an iteratorof all ColumnElement expressions which would @@ -3092,7 +3046,7 @@ class Select(_SelectBaseMixin, FromClause): def is_derived_from(self, fromclause): if self in util.Set(fromclause._cloned_set): return True - + for f in self.locate_all_froms(): if f.is_derived_from(fromclause): return True @@ -3112,7 +3066,7 @@ class Select(_SelectBaseMixin, FromClause): """return child elements as per the ClauseElement specification.""" return (column_collections and list(self.columns) or []) + \ - list(self.locate_all_froms()) + \ + list(self._froms) + \ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] def column(self, column): @@ -3125,6 +3079,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) s._raw_columns = s._raw_columns + [column] + s._froms = s._froms.union(_from_objects(column)) return s def where(self, whereclause): @@ -3185,7 +3140,7 @@ class Select(_SelectBaseMixin, FromClause): """ s = self._generate() - s._should_correlate=False + s._should_correlate = False if fromclauses == (None,): s._correlate = util.Set() else: @@ -3195,7 +3150,7 @@ class Select(_SelectBaseMixin, FromClause): def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" - self._should_correlate=False + self._should_correlate = False self._correlate = self._correlate.union([fromclause]) def append_column(self, column): @@ -3207,6 +3162,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) self._raw_columns = self._raw_columns + [column] + self._froms = self._froms.union(_from_objects(column)) self._reset_exported() def append_prefix(self, clause): @@ -3221,10 +3177,13 @@ class Select(_SelectBaseMixin, FromClause): The expression will be joined to existing WHERE criterion via AND. """ + whereclause = _literal_as_text(whereclause) + self._froms = self._froms.union(_from_objects(whereclause, is_where=True)) + if self._whereclause is not None: - self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + self._whereclause = and_(self._whereclause, whereclause) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = whereclause def append_having(self, having): """append the given expression to this select() construct's HAVING criterion. @@ -3311,31 +3270,23 @@ class Select(_SelectBaseMixin, FromClause): return intersect_all(self, other, **kwargs) - def _table_iterator(self): - for t in visitors.NoColumnVisitor().iterate(self): - if isinstance(t, TableClause): - yield t - def bind(self): if self._bind: return self._bind - for f in self._froms: - if f is self: - continue - e = f.bind - if e: - self._bind = e - return e - # look through the columns (largely synomous with looking - # through the FROMs except in the case of _CalculatedClause/_Function) - for c in self._raw_columns: - if getattr(c, 'table', None) is self: - continue - e = c.bind + if not self._froms: + for c in self._raw_columns: + e = c.bind + if e: + self._bind = e + return e + else: + e = list(self._froms)[0].bind if e: self._bind = e return e + return None + def _set_bind(self, bind): self._bind = bind bind = property(bind, _set_bind) @@ -3343,11 +3294,7 @@ class Select(_SelectBaseMixin, FromClause): class _UpdateBase(ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" - def supports_execution(self): - return True - - def _table_iterator(self): - return iter([self.table]) + supports_execution = True def _generate(self): s = self.__class__.__new__(self.__class__) @@ -3407,7 +3354,7 @@ class Insert(_ValuesBase): self._bind = bind self.table = table self.select = None - self.inline=inline + self.inline = inline if prefixes: self._prefixes = [_literal_as_text(p) for p in prefixes] else: @@ -3502,10 +3449,11 @@ class Delete(_UpdateBase): self._whereclause = clone(self._whereclause) class _IdentifiedClause(ClauseElement): + supports_execution = True + quote = None + def __init__(self, ident): self.ident = ident - def supports_execution(self): - return True class SavepointClause(_IdentifiedClause): pass -- cgit v1.2.1