summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/properties.py2
-rw-r--r--lib/sqlalchemy/sql/expression.py65
-rw-r--r--lib/sqlalchemy/types.py109
3 files changed, 135 insertions, 41 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 56a03fc7c..e900b0cab 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -134,7 +134,7 @@ class ColumnProperty(StrategizedProperty):
def reverse_operate(self, op, other, **kwargs):
col = self.__clause_element__()
- return op(col._bind_param(other), col, **kwargs)
+ return op(col._bind_param(op, other), col, **kwargs)
# TODO: legacy..do we need this ? (0.5)
ColumnComparator = Comparator
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 1c3961f1f..c559f3850 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1443,7 +1443,7 @@ class _CompareMixin(ColumnOperators):
else:
raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL")
else:
- obj = self._check_literal(obj)
+ obj = self._check_literal(op, obj)
if reverse:
return _BinaryExpression(obj,
@@ -1459,7 +1459,7 @@ class _CompareMixin(ColumnOperators):
negate=negate, modifiers=kwargs)
def __operate(self, op, obj, reverse=False):
- obj = self._check_literal(obj)
+ obj = self._check_literal(op, obj)
if reverse:
left, right = obj, self
@@ -1532,7 +1532,7 @@ class _CompareMixin(ColumnOperators):
"in() function accepts either a list of non-selectable values, "
"or a selectable: %r" % o)
else:
- o = self._bind_param(o)
+ o = self._bind_param(op, o)
args.append(o)
if len(args) == 0:
@@ -1558,7 +1558,9 @@ class _CompareMixin(ColumnOperators):
# use __radd__ to force string concat behavior
return self.__compare(
operators.like_op,
- literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)),
+ literal_column("'%'", type_=sqltypes.String).__radd__(
+ self._check_literal(operators.like_op, other)
+ ),
escape=escape)
def endswith(self, other, escape=None):
@@ -1566,7 +1568,8 @@ class _CompareMixin(ColumnOperators):
return self.__compare(
operators.like_op,
- literal_column("'%'", type_=sqltypes.String) + self._check_literal(other),
+ literal_column("'%'", type_=sqltypes.String) +
+ self._check_literal(operators.like_op, other),
escape=escape)
def contains(self, other, escape=None):
@@ -1575,7 +1578,7 @@ class _CompareMixin(ColumnOperators):
return self.__compare(
operators.like_op,
literal_column("'%'", type_=sqltypes.String) +
- self._check_literal(other) +
+ self._check_literal(operators.like_op, other) +
literal_column("'%'", type_=sqltypes.String),
escape=escape)
@@ -1585,7 +1588,7 @@ class _CompareMixin(ColumnOperators):
The allowed contents of ``other`` are database backend specific.
"""
- return self.__compare(operators.match_op, self._check_literal(other))
+ return self.__compare(operators.match_op, self._check_literal(operators.match_op, other))
def label(self, name):
"""Produce a column label, i.e. ``<columnname> AS <name>``.
@@ -1615,8 +1618,8 @@ class _CompareMixin(ColumnOperators):
return _BinaryExpression(
self,
ClauseList(
- self._check_literal(cleft),
- self._check_literal(cright),
+ self._check_literal(operators.and_, cleft),
+ self._check_literal(operators.and_, cright),
operator=operators.and_,
group=False),
operators.between_op)
@@ -1651,17 +1654,23 @@ class _CompareMixin(ColumnOperators):
"""
return lambda other: self.__operate(operator, other)
- def _bind_param(self, obj):
- return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
-
- def _check_literal(self, other):
- if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(None, obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type, unique=True)
+
+ def _check_literal(self, operator, other):
+ if isinstance(other, _BindParamClause) and \
+ isinstance(other.type, sqltypes.NullType):
+ # TODO: perhaps we should not mutate the incoming bindparam()
+ # here and instead make a copy of it. this might
+ # be the only place that we're mutating an incoming construct.
other.type = self.type
return other
elif hasattr(other, '__clause_element__'):
return other.__clause_element__()
elif not isinstance(other, ClauseElement):
- return self._bind_param(other)
+ return self._bind_param(operator, other)
elif isinstance(other, (_SelectBaseMixin, Alias)):
return other.as_scalar()
else:
@@ -2108,7 +2117,8 @@ class _BindParamClause(ColumnElement):
def __init__(self, key, value, type_=None, unique=False,
isoutparam=False, required=False,
- _fallback_type=None):
+ _compared_to_operator=None,
+ _compared_to_type=None):
"""Construct a _BindParamClause.
key
@@ -2154,9 +2164,10 @@ class _BindParamClause(ColumnElement):
self.required = required
if type_ is None:
- self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE)
- if _fallback_type and _fallback_type._type_affinity == self.type._type_affinity:
- self.type = _fallback_type
+ if _compared_to_type is not None:
+ self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value)
+ else:
+ self.type = sqltypes.type_map.get(type(value), sqltypes.NULLTYPE)
elif isinstance(type_, type):
self.type = type_()
else:
@@ -2434,9 +2445,9 @@ class _Tuple(ClauseList, ColumnElement):
def _select_iterable(self):
return (self, )
- def _bind_param(self, obj):
+ def _bind_param(self, operator, obj):
return _Tuple(*[
- _BindParamClause(None, o, _fallback_type=self.type, unique=True)
+ _BindParamClause(None, o, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
for o in obj
]).self_group()
@@ -2538,8 +2549,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def execute(self):
return self.select().execute()
- def _bind_param(self, obj):
- return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
class Function(FunctionElement):
@@ -2555,8 +2566,8 @@ class Function(FunctionElement):
FunctionElement.__init__(self, *clauses, **kw)
- def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
class _Cast(ColumnElement):
@@ -3165,8 +3176,8 @@ class ColumnClause(_Immutable, ColumnElement):
else:
return []
- def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
def _make_proxy(self, selectable, name=None, attach=True):
# propagate the "is_literal" flag only if we are keeping our name,
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index cdbf7927e..5c4e2ca3f 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -116,6 +116,13 @@ class AbstractType(Visitable):
typ = t
else:
return self.__class__
+
+ def _coerce_compared_value(self, op, value):
+ _coerced_type = type_map.get(type(value), NULLTYPE)
+ if _coerced_type is NULLTYPE or _coerced_type._type_affinity is self._type_affinity:
+ return self
+ else:
+ return _coerced_type
def _compare_type_affinity(self, other):
return self._type_affinity is other._type_affinity
@@ -229,17 +236,22 @@ class UserDefinedType(TypeEngine):
class TypeDecorator(AbstractType):
"""Allows the creation of types which add additional functionality
to an existing type.
+
+ This method is preferred to direct subclassing of SQLAlchemy's
+ built-in types as it ensures that all required functionality of
+ the underlying type is kept in place.
Typical usage::
import sqlalchemy.types as types
class MyType(types.TypeDecorator):
- # Prefixes Unicode values with "PREFIX:" on the way in and
- # strips it off on the way out.
+ '''Prefixes Unicode values with "PREFIX:" on the way in and
+ strips it off on the way out.
+ '''
impl = types.Unicode
-
+
def process_bind_param(self, value, dialect):
return "PREFIX:" + value
@@ -255,14 +267,49 @@ class TypeDecorator(AbstractType):
given; in this case, the "impl" variable can reference
``TypeEngine`` as a placeholder.
- The reason that type behavior is modified using class decoration
- instead of subclassing is due to the way dialect specific types
- are used. Such as with the example above, when using the mysql
- dialect, the actual type in use will be a
- ``sqlalchemy.databases.mysql.MSString`` instance.
- ``TypeDecorator`` handles the mechanics of passing the values
- between user-defined ``process_`` methods and the current
- dialect-specific type in use.
+ Types that receive a Python type that isn't similar to the
+ ultimate type used may want to define the :meth:`TypeDecorator.coerce_compared_value`
+ method. This is used to give the expression system a hint
+ when coercing Python objects into bind parameters within expressions.
+ Consider this expression::
+
+ mytable.c.somecol + datetime.date(2009, 5, 15)
+
+ Above, if "somecol" is an ``Integer`` variant, it makes sense that
+ we're doing date arithmetic, where above is usually interpreted
+ by databases as adding a number of days to the given date.
+ The expression system does the right thing by not attempting to
+ coerce the "date()" value into an integer-oriented bind parameter.
+
+ However, in the case of ``TypeDecorator``, we are usually changing
+ an incoming Python type to something new - ``TypeDecorator`` by
+ default will "coerce" the non-typed side to be the same type as itself.
+ Such as below, we define an "epoch" type that stores a date value as an integer::
+
+ class MyEpochType(types.TypeDecorator):
+ impl = types.Integer
+
+ epoch = datetime.date(1970, 1, 1)
+
+ def process_bind_param(self, value, dialect):
+ return (value - self.epoch).days
+
+ def process_result_value(self, value, dialect):
+ return self.epoch + timedelta(days=value)
+
+ Our expression of ``somecol + date`` with the above type will coerce the
+ "date" on the right side to also be treated as ``MyEpochType``.
+
+ This behavior can be overridden via the :meth:`~TypeDecorator.coerce_compared_value`
+ method, which returns a type that should be used for the value of the expression.
+ Below we set it such that an integer value will be treated as an ``Integer``,
+ and any other value is assumed to be a date and will be treated as a ``MyEpochType``::
+
+ def coerce_compared_value(self, op, value):
+ if isinstance(value, int):
+ return Integer()
+ else:
+ return self
"""
@@ -365,7 +412,28 @@ class TypeDecorator(AbstractType):
return process
else:
return self.impl.result_processor(dialect, coltype)
+
+ def coerce_compared_value(self, op, value):
+ """Suggest a type for a 'coerced' Python value in an expression.
+
+ By default, returns self. This method is called by
+ the expression system when an object using this type is
+ on the left or right side of an expression against a plain Python
+ object which does not yet have a SQLAlchemy type assigned::
+
+ expr = table.c.somecolumn + 35
+
+ Where above, if ``somecolumn`` uses this type, this method will
+ be called with the value ``operator.add``
+ and ``35``. The return value is whatever SQLAlchemy type should
+ be used for ``35`` for this particular operation.
+
+ """
+ return self
+ def _coerce_compared_value(self, op, value):
+ return self.coerce_compared_value(op, value)
+
def copy(self):
instance = self.__class__.__new__(self.__class__)
instance.__dict__.update(self.__dict__)
@@ -384,6 +452,11 @@ class TypeDecorator(AbstractType):
def is_mutable(self):
return self.impl.is_mutable()
+ def _adapt_expression(self, op, othertype):
+ return self.impl._adapt_expression(op, othertype)
+
+
+
class MutableType(object):
"""A mixin that marks a Type as holding a mutable object.
@@ -461,7 +534,7 @@ class Concatenable(object):
"""A mixin that marks a type as supporting 'concatenation', typically strings."""
def _adapt_expression(self, op, othertype):
- if op is operators.add and isinstance(othertype, (Concatenable, NullType)):
+ if op is operators.add and issubclass(othertype._type_affinity, (Concatenable, NullType)):
return operators.concat_op, self
else:
return op, self
@@ -1393,6 +1466,13 @@ class Interval(_DateAffinity, TypeDecorator):
value is stored as a date which is relative to the "epoch"
(Jan. 1, 1970).
+ Note that the ``Interval`` type does not currently provide
+ date arithmetic operations on platforms which do not support
+ interval types natively. Such operations usually require
+ transformation of both sides of the expression (such as, conversion
+ of both sides into integer epoch values first) which currently
+ is a manual procedure (such as via :attr:`~sqlalchemy.sql.expression.func`).
+
"""
impl = DateTime
@@ -1421,7 +1501,7 @@ class Interval(_DateAffinity, TypeDecorator):
self.native = native
self.second_precision = second_precision
self.day_precision = day_precision
-
+
def adapt(self, cls):
if self.native:
return cls._adapt_from_generic_interval(self)
@@ -1488,6 +1568,9 @@ class Interval(_DateAffinity, TypeDecorator):
def _type_affinity(self):
return Interval
+ def _coerce_compared_value(self, op, value):
+ return self.impl._coerce_compared_value(op, value)
+
class FLOAT(Float):
"""The SQL FLOAT type."""