diff options
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 81 | ||||
| -rw-r--r-- | lib/sqlalchemy/types.py | 143 | ||||
| -rw-r--r-- | test/sql/test_operators.py | 58 | ||||
| -rw-r--r-- | test/sql/test_types.py | 1 |
5 files changed, 117 insertions, 173 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index a9ff988e8..36da14d33 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -708,9 +708,10 @@ class PGCompiler(compiler.SQLCompiler): affinity = None casts = { - sqltypes.Date:'date', - sqltypes.DateTime:'timestamp', - sqltypes.Interval:'interval', sqltypes.Time:'time' + sqltypes.Date: 'date', + sqltypes.DateTime: 'timestamp', + sqltypes.Interval: 'interval', + sqltypes.Time: 'time' } cast = casts.get(affinity, None) if isinstance(extract.expr, sql.ColumnElement) and cast is not None: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 844293c73..63fa23c15 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1875,7 +1875,7 @@ class Immutable(object): return self -class _DefaultColumnComparator(object): +class _DefaultColumnComparator(operators.ColumnOperators): """Defines comparison and math operations. See :class:`.ColumnOperators` and :class:`.Operators` for descriptions @@ -1883,6 +1883,45 @@ class _DefaultColumnComparator(object): """ + @util.memoized_property + def type(self): + return self.expr.type + + def operate(self, op, *other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, self.expr, op, *(other + o[1:]), **kwargs) + + def reverse_operate(self, op, other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, self.expr, op, other, reverse=True, *o[1:], **kwargs) + + def _adapt_expression(self, op, other_comparator): + """evaluate the return type of <self> <op> <othertype>, + and apply any adaptations to the given operator. + + This method determines the type of a resulting binary expression + given two source types and an operator. For example, two + :class:`.Column` objects, both of the type :class:`.Integer`, will + produce a :class:`.BinaryExpression` that also has the type + :class:`.Integer` when compared via the addition (``+``) operator. + However, using the addition operator with an :class:`.Integer` + and a :class:`.Date` object will produce a :class:`.Date`, assuming + "days delta" behavior by the database (in reality, most databases + other than Postgresql don't accept this particular operation). + + The method returns a tuple of the form <operator>, <type>. + The resulting operator and type will be those applied to the + resulting :class:`.BinaryExpression` as the final operator and the + right-hand side of the expression. + + Note that only a subset of operators make usage of + :meth:`._adapt_expression`, + including math operators and user-defined operators, but not + boolean comparison or special SQL keywords like MATCH or BETWEEN. + + """ + return op, other_comparator.type + def _boolean_compare(self, expr, op, obj, negate=None, reverse=False, **kwargs ): @@ -1912,7 +1951,7 @@ class _DefaultColumnComparator(object): type_=sqltypes.BOOLEANTYPE, negate=negate, modifiers=kwargs) - def _binary_operate(self, expr, op, obj, result_type, reverse=False): + def _binary_operate(self, expr, op, obj, reverse=False): obj = self._check_literal(expr, op, obj) if reverse: @@ -1920,6 +1959,8 @@ class _DefaultColumnComparator(object): else: left, right = expr, obj + op, result_type = left.comparator._adapt_expression(op, right.comparator) + return BinaryExpression(left, right, op, type_=result_type) def _scalar(self, expr, op, fn, **kw): @@ -1986,7 +2027,8 @@ class _DefaultColumnComparator(object): expr, operators.like_op, literal_column("'%'", type_=sqltypes.String).__radd__( - self._check_literal(expr, operators.like_op, other) + self._check_literal(expr, + operators.like_op, other) ), escape=escape) @@ -2068,21 +2110,16 @@ class _DefaultColumnComparator(object): "neg": (_neg_impl,), } - def operate(self, expr, op, *other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, expr, op, *(other + o[1:]), **kwargs) - - def reverse_operate(self, expr, op, other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs) def _check_literal(self, expr, operator, other): - if isinstance(other, BindParameter) and \ - isinstance(other.type, sqltypes.NullType): - # TODO: perhaps we should not mutate the incoming bindparam() - # here and instead make a copy of it. this might - # be the only place that we're mutating an incoming construct. - other.type = expr.type + if isinstance(other, (ColumnElement, TextClause)): + if isinstance(other, BindParameter) and \ + isinstance(other.type, sqltypes.NullType): + # TODO: perhaps we should not mutate the incoming + # bindparam() here and instead make a copy of it. + # this might be the only place that we're mutating + # an incoming construct. + other.type = expr.type return other elif hasattr(other, '__clause_element__'): other = other.__clause_element__() @@ -2096,8 +2133,6 @@ class _DefaultColumnComparator(object): else: return other -_DEFAULT_COMPARATOR = _DefaultColumnComparator() - class ColumnElement(ClauseElement, ColumnOperators): """Represent an element that is usable within the "column clause" portion @@ -2155,11 +2190,7 @@ class ColumnElement(ClauseElement, ColumnOperators): def comparator(self): return self.type.comparator_factory(self) - #def _assert_comparator(self): - # assert self.comparator.expr is self - def __getattr__(self, key): - #self._assert_comparator() try: return getattr(self.comparator, key) except AttributeError: @@ -2171,11 +2202,9 @@ class ColumnElement(ClauseElement, ColumnOperators): ) def operate(self, op, *other, **kwargs): - #self._assert_comparator() return op(self.comparator, *other, **kwargs) def reverse_operate(self, op, other, **kwargs): - #self._assert_comparator() return op(other, self.comparator, **kwargs) def _bind_param(self, operator, obj): @@ -3090,6 +3119,10 @@ class TextClause(Executable, ClauseElement): else: return sqltypes.NULLTYPE + @property + def comparator(self): + return self.type.comparator_factory(self) + def self_group(self, against=None): if against is operators.in_op: return Grouping(self) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index d4dbd648c..bbeebf5d3 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -11,21 +11,21 @@ types. For more information see the SQLAlchemy documentation on types. """ -__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', +__all__ = ['TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text', 'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer', 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', 'Unicode', 'Concatenable', - 'UnicodeText','PickleType', 'Interval', 'Enum' ] + 'UnicodeText', 'PickleType', 'Interval', 'Enum'] import datetime as dt import codecs from . import exc, schema, util, processors, events, event from .sql import operators -from .sql.expression import _DEFAULT_COMPARATOR +from .sql.expression import _DefaultColumnComparator from .util import pickle from .util.compat import decimal from .sql.visitors import Visitable @@ -42,7 +42,7 @@ class AbstractType(Visitable): class TypeEngine(AbstractType): """Base for built-in types.""" - class Comparator(operators.ColumnOperators): + class Comparator(_DefaultColumnComparator): """Base class for custom comparison operations defined at the type level. See :attr:`.TypeEngine.comparator_factory`. @@ -54,24 +54,6 @@ class TypeEngine(AbstractType): def __reduce__(self): return _reconstitute_comparator, (self.expr, ) - def operate(self, op, *other, **kwargs): - if len(other) == 1: - obj = other[0] - obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, obj) - op, adapt_type = self.expr.type._adapt_expression(op, - obj.type) - kwargs['result_type'] = adapt_type - - return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs) - - def reverse_operate(self, op, other, **kwargs): - - obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, other) - op, adapt_type = obj.type._adapt_expression(op, self.expr.type) - kwargs['result_type'] = adapt_type - - return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, obj, - **kwargs) comparator_factory = Comparator """A :class:`.TypeEngine.Comparator` class which will apply @@ -143,11 +125,6 @@ class TypeEngine(AbstractType): >>> (c1 == c2).type Boolean() - The propagation of :class:`.TypeEngine.Comparator` throughout an expression - will follow with how the :class:`.TypeEngine` itself is propagated. To - customize the behavior of most operators in this regard, see the - :meth:`._adapt_expression` method. - .. versionadded:: 0.8 The expression system was reworked to support user-defined comparator objects specified at the type level. @@ -247,34 +224,7 @@ class TypeEngine(AbstractType): .. versionadded:: 0.7.2 """ - return Variant(self, {dialect_name:type_}) - - def _adapt_expression(self, op, othertype): - """evaluate the return type of <self> <op> <othertype>, - and apply any adaptations to the given operator. - - This method determines the type of a resulting binary expression - given two source types and an operator. For example, two - :class:`.Column` objects, both of the type :class:`.Integer`, will - produce a :class:`.BinaryExpression` that also has the type - :class:`.Integer` when compared via the addition (``+``) operator. - However, using the addition operator with an :class:`.Integer` - and a :class:`.Date` object will produce a :class:`.Date`, assuming - "days delta" behavior by the database (in reality, most databases - other than Postgresql don't accept this particular operation). - - The method returns a tuple of the form <operator>, <type>. - The resulting operator and type will be those applied to the - resulting :class:`.BinaryExpression` as the final operator and the - right-hand side of the expression. - - Note that only a subset of operators make usage of - :meth:`._adapt_expression`, - including math operators and user-defined operators, but not - boolean comparison or special SQL keywords like MATCH or BETWEEN. - - """ - return op, self + return Variant(self, {dialect_name: type_}) @util.memoized_property def _type_affinity(self): @@ -334,7 +284,7 @@ class TypeEngine(AbstractType): impl = self.adapt(type(self)) # this can't be self, else we create a cycle assert impl is not self - dialect._type_memos[self] = d = {'impl':impl} + dialect._type_memos[self] = d = {'impl': impl} return d def _gen_dialect_impl(self, dialect): @@ -461,22 +411,21 @@ class UserDefinedType(TypeEngine): """ __visit_name__ = "user_defined" - def _adapt_expression(self, op, othertype): - """evaluate the return type of <self> <op> <othertype>, - and apply any adaptations to the given operator. - - """ - return self.adapt_operator(op), self - - def adapt_operator(self, op): - """A hook which allows the given operator to be adapted - to something new. + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if hasattr(self.type, 'adapt_operator'): + util.warn_deprecated( + "UserDefinedType.adapt_operator is deprecated. Create " + "a UserDefinedType.Comparator subclass instead which " + "generates the desired expression constructs, given a " + "particular operator." + ) + return self.type.adapt_operator(op), self.type + else: + return op, self.type - See also UserDefinedType._adapt_expression(), an as-yet- - semi-public method with greater capability in this regard. + comparator_factory = Comparator - """ - return op class TypeDecorator(TypeEngine): """Allows the creation of types which add additional functionality @@ -837,13 +786,6 @@ class TypeDecorator(TypeEngine): """ return self.impl.compare_values(x, y) - def _adapt_expression(self, op, othertype): - op, typ = self.impl._adapt_expression(op, othertype) - typ = to_instance(typ) - if typ._compare_type_affinity(self.impl): - return op, self - else: - return op, typ class Variant(TypeDecorator): """A wrapping type that selects among a variety of @@ -926,8 +868,6 @@ def adapt_type(typeobj, colspecs): return typeobj.adapt(impltype) - - class NullType(TypeEngine): """An unknown type. @@ -943,11 +883,14 @@ class NullType(TypeEngine): """ __visit_name__ = 'null' - def _adapt_expression(self, op, othertype): - if isinstance(othertype, NullType) or not operators.is_commutative(op): - return op, self - else: - return othertype._adapt_expression(op, self) + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if isinstance(other_comparator, NullType.Comparator) or \ + not operators.is_commutative(op): + return op, self.expr.type + else: + return other_comparator._adapt_expression(op, self) + comparator_factory = Comparator NullTypeEngine = NullType @@ -955,12 +898,16 @@ class Concatenable(object): """A mixin that marks a type as supporting 'concatenation', typically strings.""" - def _adapt_expression(self, op, othertype): - if op is operators.add and issubclass(othertype._type_affinity, - (Concatenable, NullType)): - return operators.concat_op, self - else: - return op, self + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if op is operators.add and isinstance(other_comparator, + (Concatenable.Comparator, NullType.Comparator)): + return operators.concat_op, self.expr.type + else: + return op, self.expr.type + + comparator_factory = Comparator + class _DateAffinity(object): """Mixin date/time specific expression adaptations. @@ -975,12 +922,14 @@ class _DateAffinity(object): def _expression_adaptations(self): raise NotImplementedError() - _blank_dict = util.immutabledict() - def _adapt_expression(self, op, othertype): - othertype = othertype._type_affinity - return op, \ - self._expression_adaptations.get(op, self._blank_dict).\ - get(othertype, NULLTYPE) + class Comparator(TypeEngine.Comparator): + _blank_dict = util.immutabledict() + def _adapt_expression(self, op, other_comparator): + othertype = other_comparator.type._type_affinity + return op, \ + self.type._expression_adaptations.get(op, self._blank_dict).\ + get(othertype, NULLTYPE) + comparator_factory = Comparator class String(Concatenable, TypeEngine): """The base for all string and character types. diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index c38f95a01..05de8c9ef 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -12,18 +12,16 @@ from sqlalchemy.types import Integer, TypeEngine, TypeDecorator class DefaultColumnComparatorTest(fixtures.TestBase): def _do_scalar_test(self, operator, compare_to): - cc = _DefaultColumnComparator() left = column('left') - assert cc.operate(left, operator).compare( + assert left.comparator.operate(operator).compare( compare_to(left) ) def _do_operate_test(self, operator): - cc = _DefaultColumnComparator() left = column('left') right = column('right') - assert cc.operate(left, operator, right, result_type=Integer).compare( + assert left.comparator.operate(operator, right).compare( BinaryExpression(left, right, operator) ) @@ -37,9 +35,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase): self._do_operate_test(operators.add) def test_in(self): - cc = _DefaultColumnComparator() left = column('left') - assert cc.operate(left, operators.in_op, [1, 2, 3]).compare( + assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare( BinaryExpression( left, Grouping(ClauseList( @@ -50,10 +47,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase): ) def test_collate(self): - cc = _DefaultColumnComparator() left = column('left') right = "some collation" - cc.operate(left, operators.collate, right).compare( + left.comparator.operate(operators.collate, right).compare( collate(left, right) ) @@ -144,12 +140,8 @@ class _CustomComparatorTests(object): self._assert_add_override(6 - c1) def test_binary_multi_propagate(self): - c1 = Column('foo', self._add_override_factory(True)) - self._assert_add_override((c1 - 6) + 5) - - def test_no_binary_multi_propagate_wo_adapt(self): c1 = Column('foo', self._add_override_factory()) - self._assert_not_add_override((c1 - 6) + 5) + self._assert_add_override((c1 - 6) + 5) def test_no_boolean_propagate(self): c1 = Column('foo', self._add_override_factory()) @@ -166,7 +158,7 @@ class _CustomComparatorTests(object): ) class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): @@ -176,19 +168,12 @@ class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): def __add__(self, other): return self.expr.op("goofy")(other) - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) return MyInteger class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(TypeDecorator): impl = Integer @@ -200,19 +185,12 @@ class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): def __add__(self, other): return self.expr.op("goofy")(other) - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) return MyInteger class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): @@ -221,13 +199,6 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas def __add__(self, other): return self.expr.op("goofy")(other) - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) class MyDecInteger(TypeDecorator): impl = MyInteger @@ -235,7 +206,7 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas return MyDecInteger class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): @@ -243,15 +214,6 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): def foob(self, other): return self.expr.op("foob")(other) - - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) - return MyInteger def _assert_add_override(self, expr): @@ -262,5 +224,3 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): def _assert_not_add_override(self, expr): assert not hasattr(expr, "foob") - def test_no_binary_multi_propagate_wo_adapt(self): - pass
\ No newline at end of file diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 91bf17175..279ae36a0 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1222,6 +1222,7 @@ class ExpressionTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled eq_(expr.right.type.__class__, CHAR) + @testing.uses_deprecated @testing.fails_on('firebird', 'Data type unknown on the parameter') @testing.fails_on('mssql', 'int is unsigned ? not clear') def test_operator_adapt(self): |
