diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 65 | ||||
| -rw-r--r-- | lib/sqlalchemy/types.py | 109 |
3 files changed, 135 insertions, 41 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 56a03fc7c..e900b0cab 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -134,7 +134,7 @@ class ColumnProperty(StrategizedProperty): def reverse_operate(self, op, other, **kwargs): col = self.__clause_element__() - return op(col._bind_param(other), col, **kwargs) + return op(col._bind_param(op, other), col, **kwargs) # TODO: legacy..do we need this ? (0.5) ColumnComparator = Comparator diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 1c3961f1f..c559f3850 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1443,7 +1443,7 @@ class _CompareMixin(ColumnOperators): else: raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") else: - obj = self._check_literal(obj) + obj = self._check_literal(op, obj) if reverse: return _BinaryExpression(obj, @@ -1459,7 +1459,7 @@ class _CompareMixin(ColumnOperators): negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): - obj = self._check_literal(obj) + obj = self._check_literal(op, obj) if reverse: left, right = obj, self @@ -1532,7 +1532,7 @@ class _CompareMixin(ColumnOperators): "in() function accepts either a list of non-selectable values, " "or a selectable: %r" % o) else: - o = self._bind_param(o) + o = self._bind_param(op, o) args.append(o) if len(args) == 0: @@ -1558,7 +1558,9 @@ class _CompareMixin(ColumnOperators): # use __radd__ to force string concat behavior return self.__compare( operators.like_op, - literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)), + literal_column("'%'", type_=sqltypes.String).__radd__( + self._check_literal(operators.like_op, other) + ), escape=escape) def endswith(self, other, escape=None): @@ -1566,7 +1568,8 @@ class _CompareMixin(ColumnOperators): return self.__compare( operators.like_op, - literal_column("'%'", type_=sqltypes.String) + self._check_literal(other), + literal_column("'%'", type_=sqltypes.String) + + self._check_literal(operators.like_op, other), escape=escape) def contains(self, other, escape=None): @@ -1575,7 +1578,7 @@ class _CompareMixin(ColumnOperators): return self.__compare( operators.like_op, literal_column("'%'", type_=sqltypes.String) + - self._check_literal(other) + + self._check_literal(operators.like_op, other) + literal_column("'%'", type_=sqltypes.String), escape=escape) @@ -1585,7 +1588,7 @@ class _CompareMixin(ColumnOperators): The allowed contents of ``other`` are database backend specific. """ - return self.__compare(operators.match_op, self._check_literal(other)) + return self.__compare(operators.match_op, self._check_literal(operators.match_op, other)) def label(self, name): """Produce a column label, i.e. ``<columnname> AS <name>``. @@ -1615,8 +1618,8 @@ class _CompareMixin(ColumnOperators): return _BinaryExpression( self, ClauseList( - self._check_literal(cleft), - self._check_literal(cright), + self._check_literal(operators.and_, cleft), + self._check_literal(operators.and_, cright), operator=operators.and_, group=False), operators.between_op) @@ -1651,17 +1654,23 @@ class _CompareMixin(ColumnOperators): """ return lambda other: self.__operate(operator, other) - def _bind_param(self, obj): - return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) - - def _check_literal(self, other): - if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): + def _bind_param(self, operator, obj): + return _BindParamClause(None, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + + def _check_literal(self, operator, other): + if isinstance(other, _BindParamClause) and \ + isinstance(other.type, sqltypes.NullType): + # TODO: perhaps we should not mutate the incoming bindparam() + # here and instead make a copy of it. this might + # be the only place that we're mutating an incoming construct. other.type = self.type return other elif hasattr(other, '__clause_element__'): return other.__clause_element__() elif not isinstance(other, ClauseElement): - return self._bind_param(other) + return self._bind_param(operator, other) elif isinstance(other, (_SelectBaseMixin, Alias)): return other.as_scalar() else: @@ -2108,7 +2117,8 @@ class _BindParamClause(ColumnElement): def __init__(self, key, value, type_=None, unique=False, isoutparam=False, required=False, - _fallback_type=None): + _compared_to_operator=None, + _compared_to_type=None): """Construct a _BindParamClause. key @@ -2154,9 +2164,10 @@ class _BindParamClause(ColumnElement): self.required = required if type_ is None: - self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE) - if _fallback_type and _fallback_type._type_affinity == self.type._type_affinity: - self.type = _fallback_type + if _compared_to_type is not None: + self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value) + else: + self.type = sqltypes.type_map.get(type(value), sqltypes.NULLTYPE) elif isinstance(type_, type): self.type = type_() else: @@ -2434,9 +2445,9 @@ class _Tuple(ClauseList, ColumnElement): def _select_iterable(self): return (self, ) - def _bind_param(self, obj): + def _bind_param(self, operator, obj): return _Tuple(*[ - _BindParamClause(None, o, _fallback_type=self.type, unique=True) + _BindParamClause(None, o, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) for o in obj ]).self_group() @@ -2538,8 +2549,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): def execute(self): return self.select().execute() - def _bind_param(self, obj): - return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) class Function(FunctionElement): @@ -2555,8 +2566,8 @@ class Function(FunctionElement): FunctionElement.__init__(self, *clauses, **kw) - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) class _Cast(ColumnElement): @@ -3165,8 +3176,8 @@ class ColumnClause(_Immutable, ColumnElement): else: return [] - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) def _make_proxy(self, selectable, name=None, attach=True): # propagate the "is_literal" flag only if we are keeping our name, diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index cdbf7927e..5c4e2ca3f 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -116,6 +116,13 @@ class AbstractType(Visitable): typ = t else: return self.__class__ + + def _coerce_compared_value(self, op, value): + _coerced_type = type_map.get(type(value), NULLTYPE) + if _coerced_type is NULLTYPE or _coerced_type._type_affinity is self._type_affinity: + return self + else: + return _coerced_type def _compare_type_affinity(self, other): return self._type_affinity is other._type_affinity @@ -229,17 +236,22 @@ class UserDefinedType(TypeEngine): class TypeDecorator(AbstractType): """Allows the creation of types which add additional functionality to an existing type. + + This method is preferred to direct subclassing of SQLAlchemy's + built-in types as it ensures that all required functionality of + the underlying type is kept in place. Typical usage:: import sqlalchemy.types as types class MyType(types.TypeDecorator): - # Prefixes Unicode values with "PREFIX:" on the way in and - # strips it off on the way out. + '''Prefixes Unicode values with "PREFIX:" on the way in and + strips it off on the way out. + ''' impl = types.Unicode - + def process_bind_param(self, value, dialect): return "PREFIX:" + value @@ -255,14 +267,49 @@ class TypeDecorator(AbstractType): given; in this case, the "impl" variable can reference ``TypeEngine`` as a placeholder. - The reason that type behavior is modified using class decoration - instead of subclassing is due to the way dialect specific types - are used. Such as with the example above, when using the mysql - dialect, the actual type in use will be a - ``sqlalchemy.databases.mysql.MSString`` instance. - ``TypeDecorator`` handles the mechanics of passing the values - between user-defined ``process_`` methods and the current - dialect-specific type in use. + Types that receive a Python type that isn't similar to the + ultimate type used may want to define the :meth:`TypeDecorator.coerce_compared_value` + method. This is used to give the expression system a hint + when coercing Python objects into bind parameters within expressions. + Consider this expression:: + + mytable.c.somecol + datetime.date(2009, 5, 15) + + Above, if "somecol" is an ``Integer`` variant, it makes sense that + we're doing date arithmetic, where above is usually interpreted + by databases as adding a number of days to the given date. + The expression system does the right thing by not attempting to + coerce the "date()" value into an integer-oriented bind parameter. + + However, in the case of ``TypeDecorator``, we are usually changing + an incoming Python type to something new - ``TypeDecorator`` by + default will "coerce" the non-typed side to be the same type as itself. + Such as below, we define an "epoch" type that stores a date value as an integer:: + + class MyEpochType(types.TypeDecorator): + impl = types.Integer + + epoch = datetime.date(1970, 1, 1) + + def process_bind_param(self, value, dialect): + return (value - self.epoch).days + + def process_result_value(self, value, dialect): + return self.epoch + timedelta(days=value) + + Our expression of ``somecol + date`` with the above type will coerce the + "date" on the right side to also be treated as ``MyEpochType``. + + This behavior can be overridden via the :meth:`~TypeDecorator.coerce_compared_value` + method, which returns a type that should be used for the value of the expression. + Below we set it such that an integer value will be treated as an ``Integer``, + and any other value is assumed to be a date and will be treated as a ``MyEpochType``:: + + def coerce_compared_value(self, op, value): + if isinstance(value, int): + return Integer() + else: + return self """ @@ -365,7 +412,28 @@ class TypeDecorator(AbstractType): return process else: return self.impl.result_processor(dialect, coltype) + + def coerce_compared_value(self, op, value): + """Suggest a type for a 'coerced' Python value in an expression. + + By default, returns self. This method is called by + the expression system when an object using this type is + on the left or right side of an expression against a plain Python + object which does not yet have a SQLAlchemy type assigned:: + + expr = table.c.somecolumn + 35 + + Where above, if ``somecolumn`` uses this type, this method will + be called with the value ``operator.add`` + and ``35``. The return value is whatever SQLAlchemy type should + be used for ``35`` for this particular operation. + + """ + return self + def _coerce_compared_value(self, op, value): + return self.coerce_compared_value(op, value) + def copy(self): instance = self.__class__.__new__(self.__class__) instance.__dict__.update(self.__dict__) @@ -384,6 +452,11 @@ class TypeDecorator(AbstractType): def is_mutable(self): return self.impl.is_mutable() + def _adapt_expression(self, op, othertype): + return self.impl._adapt_expression(op, othertype) + + + class MutableType(object): """A mixin that marks a Type as holding a mutable object. @@ -461,7 +534,7 @@ class Concatenable(object): """A mixin that marks a type as supporting 'concatenation', typically strings.""" def _adapt_expression(self, op, othertype): - if op is operators.add and isinstance(othertype, (Concatenable, NullType)): + if op is operators.add and issubclass(othertype._type_affinity, (Concatenable, NullType)): return operators.concat_op, self else: return op, self @@ -1393,6 +1466,13 @@ class Interval(_DateAffinity, TypeDecorator): value is stored as a date which is relative to the "epoch" (Jan. 1, 1970). + Note that the ``Interval`` type does not currently provide + date arithmetic operations on platforms which do not support + interval types natively. Such operations usually require + transformation of both sides of the expression (such as, conversion + of both sides into integer epoch values first) which currently + is a manual procedure (such as via :attr:`~sqlalchemy.sql.expression.func`). + """ impl = DateTime @@ -1421,7 +1501,7 @@ class Interval(_DateAffinity, TypeDecorator): self.native = native self.second_precision = second_precision self.day_precision = day_precision - + def adapt(self, cls): if self.native: return cls._adapt_from_generic_interval(self) @@ -1488,6 +1568,9 @@ class Interval(_DateAffinity, TypeDecorator): def _type_affinity(self): return Interval + def _coerce_compared_value(self, op, value): + return self.impl._coerce_compared_value(op, value) + class FLOAT(Float): """The SQL FLOAT type.""" |
