diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 88 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 86 | ||||
| -rw-r--r-- | lib/sqlalchemy/types.py | 204 |
4 files changed, 221 insertions, 170 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index cfbef69e8..7d4cbbbd8 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -349,9 +349,16 @@ class PGCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) - affinity = sql_util.determine_date_affinity(extract.expr) - - casts = {sqltypes.Date:'date', sqltypes.DateTime:'timestamp', sqltypes.Interval:'interval', sqltypes.Time:'time'} + if extract.expr.type: + affinity = extract.expr.type._type_affinity + else: + affinity = None + + casts = { + sqltypes.Date:'date', + sqltypes.DateTime:'timestamp', + sqltypes.Interval:'interval', sqltypes.Time:'time' + } cast = casts.get(affinity, None) if isinstance(extract.expr, sql.ColumnElement) and cast is not None: expr = extract.expr.op('::')(sql.literal_column(cast)) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 878b0d826..1ae706999 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -407,14 +407,7 @@ def between(ctest, cleft, cright): """ ctest = _literal_as_binds(ctest) - return _BinaryExpression( - ctest, - ClauseList( - _literal_as_binds(cleft, type_=ctest.type), - _literal_as_binds(cright, type_=ctest.type), - operator=operators.and_, - group=False), - operators.between_op) + return ctest.between(cleft, cright) def case(whens, value=None, else_=None): @@ -1453,19 +1446,35 @@ class _CompareMixin(ColumnOperators): obj = self._check_literal(obj) if reverse: - return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(obj, + self, + op, + type_=sqltypes.BOOLEANTYPE, + negate=negate, modifiers=kwargs) else: - return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(self, + obj, + op, + type_=sqltypes.BOOLEANTYPE, + negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): obj = self._check_literal(obj) - - type_ = self._compare_type(obj) - + if reverse: - return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_) + left, right = obj, self else: - return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_) + left, right = self, obj + + if left.type is None: + op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type) + elif right.type is None: + op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE) + else: + op, result_type = left.type._adapt_expression(op, right.type) + + return _BinaryExpression(left, right, op, type_=result_type) + # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1643,7 +1652,7 @@ class _CompareMixin(ColumnOperators): return lambda other: self.__operate(operator, other) def _bind_param(self, obj): - return _BindParamClause(None, obj, type_=self.type, unique=True) + return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) def _check_literal(self, other): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): @@ -1658,14 +1667,6 @@ class _CompareMixin(ColumnOperators): else: return other - def _compare_type(self, obj): - """Allow subclasses to override the type used in constructing - :class:`_BinaryExpression` objects. - - Default return value is the type of the given object. - - """ - return obj.type class ColumnElement(ClauseElement, _CompareMixin): """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement. @@ -2105,7 +2106,9 @@ class _BindParamClause(ColumnElement): __visit_name__ = 'bindparam' quote = None - def __init__(self, key, value, type_=None, unique=False, isoutparam=False, required=False): + def __init__(self, key, value, type_=None, unique=False, + isoutparam=False, required=False, + _fallback_type=None): """Construct a _BindParamClause. key @@ -2151,12 +2154,12 @@ class _BindParamClause(ColumnElement): self.required = required if type_ is None: - self.type = sqltypes.type_map.get(type(value), sqltypes.NullType)() + self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE) elif isinstance(type_, type): self.type = type_() else: self.type = type_ - + def _clone(self): c = ClauseElement._clone(self) if self.unique: @@ -2171,12 +2174,6 @@ class _BindParamClause(ColumnElement): def bind_processor(self, dialect): return self.type.dialect_impl(dialect).bind_processor(dialect) - def _compare_type(self, obj): - if not isinstance(self.type, sqltypes.NullType): - return self.type - else: - return obj.type - def compare(self, other, **kw): """Compare this :class:`_BindParamClause` to the given clause.""" @@ -2342,7 +2339,14 @@ class ClauseList(ClauseElement): self.clauses = [ _literal_as_text(clause) for clause in clauses if clause is not None] - + + @util.memoized_property + def type(self): + if self.clauses: + return self.clauses[0].type + else: + return sqltypes.NULLTYPE + def __iter__(self): return iter(self.clauses) @@ -2419,7 +2423,7 @@ class _Tuple(ClauseList, ColumnElement): def _bind_param(self, obj): return _Tuple(*[ - _BindParamClause(None, o, type_=self.type, unique=True) + _BindParamClause(None, o, _fallback_type=self.type, unique=True) for o in obj ]).self_group() @@ -2518,11 +2522,8 @@ class FunctionElement(ColumnElement, FromClause): def execute(self): return select([self]).execute() - def _compare_type(self, obj): - return self.type - def _bind_param(self, obj): - return _BindParamClause(None, obj, type_=self.type, unique=True) + return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) class Function(FunctionElement): @@ -2539,7 +2540,7 @@ class Function(FunctionElement): FunctionElement.__init__(self, *clauses, **kw) def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type_=self.type, unique=True) + return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) class _Cast(ColumnElement): @@ -2698,7 +2699,7 @@ class _BinaryExpression(ColumnElement): self.right, self.negate, negate=self.operator, - type_=self.type, + type_=sqltypes.BOOLEANTYPE, modifiers=self.modifiers) else: return super(_BinaryExpression, self)._negate() @@ -3149,7 +3150,7 @@ class ColumnClause(_Immutable, ColumnElement): return [] def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type_=self.type, unique=True) + return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) def _make_proxy(self, selectable, name=None, attach=True): # propagate the "is_literal" flag only if we are keeping our name, @@ -3166,9 +3167,6 @@ class ColumnClause(_Immutable, ColumnElement): selectable.columns[c.name] = c return c - def _compare_type(self, obj): - return self.type - class TableClause(_Immutable, FromClause): """Represents a "table" construct. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 821b3a3d1..43673eaec 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -46,92 +46,6 @@ def find_join_source(clauses, join_to): else: return None, None -_date_affinities = None -def determine_date_affinity(expr): - """Given an expression, determine if it returns 'interval', 'date', or 'datetime'. - - the PG dialect uses this to generate the extract() function. - - It's less than ideal since it basically needs to duplicate PG's - date arithmetic rules. - - Rules are based on http://www.postgresql.org/docs/current/static/functions-datetime.html. - - Returns None if operators other than + or - are detected as well as types - outside of those above. - - """ - - global _date_affinities - if _date_affinities is None: - Date, DateTime, Integer, \ - Numeric, Interval, Time = \ - sqltypes.Date, sqltypes.DateTime,\ - sqltypes.Integer, sqltypes.Numeric,\ - sqltypes.Interval, sqltypes.Time - - _date_affinities = { - operators.add:{ - (Date, Integer):Date, - (Date, Interval):DateTime, - (Date, Time):DateTime, - (Interval, Interval):Interval, - (DateTime, Interval):DateTime, - (Interval, Time):Time, - }, - operators.sub:{ - (Date, Integer):Date, - (Date, Interval):DateTime, - (Time, Time):Interval, - (Time, Interval):Time, - (DateTime, Interval):DateTime, - (Interval, Interval):Interval, - (DateTime, DateTime):Interval, - }, - operators.mul:{ - (Integer, Interval):Interval, - (Interval, Numeric):Interval, - }, - operators.div: { - (Interval, Numeric):Interval - } - } - - if isinstance(expr, expression._BinaryExpression): - if expr.operator not in _date_affinities: - return None - - left_affin, right_affin = \ - determine_date_affinity(expr.left), \ - determine_date_affinity(expr.right) - - if left_affin is None or right_affin is None: - return None - - if operators.is_commutative(expr.operator): - key = tuple(sorted([left_affin, right_affin], key=lambda cls:cls.__name__)) - else: - key = (left_affin, right_affin) - - lookup = _date_affinities[expr.operator] - return lookup.get(key, None) - - # work around the fact that expressions put the wrong type - # on generated bind params when its "datetime + timedelta" - # and similar - if isinstance(expr, expression._BindParamClause): - type_ = sqltypes.type_map.get(type(expr.value), sqltypes.NullType)() - else: - type_ = expr.type - - affinities = set([sqltypes.Date, sqltypes.DateTime, - sqltypes.Interval, sqltypes.Time, sqltypes.Integer]) - - if type_ is not None and type_._type_affinity in affinities: - return type_._type_affinity - else: - return None - def find_tables(clause, check_columns=False, diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index f4d94c918..465454df9 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -26,12 +26,13 @@ from decimal import Decimal as _python_Decimal import codecs from sqlalchemy import exc, schema -from sqlalchemy.sql import expression +from sqlalchemy.sql import expression, operators import sys schema.types = expression.sqltypes =sys.modules['sqlalchemy.types'] from sqlalchemy.util import pickle from sqlalchemy.sql.visitors import Visitable from sqlalchemy import util + NoneType = type(None) if util.jython: import array @@ -95,22 +96,23 @@ class AbstractType(Visitable): """ return None - def adapt_operator(self, op): - """Given an operator from the sqlalchemy.sql.operators package, - translate it to a new operator based on the semantics of this type. - - By default, returns the operator unchanged. - + def _adapt_expression(self, op, othertype): + """evaluate the return type of <self> <op> <othertype>, + and apply any adaptations to the given operator. + """ - return op + return op, self @util.memoized_property def _type_affinity(self): """Return a rudimental 'affinity' value expressing the general class of type.""" - - for i, t in enumerate(self.__class__.__mro__): + + typ = None + for t in self.__class__.__mro__: if t is TypeEngine or t is UserDefinedType: - return self.__class__.__mro__[i - 1] + return typ + elif issubclass(t, TypeEngine): + typ = t else: return self.__class__ @@ -206,6 +208,23 @@ class UserDefinedType(TypeEngine): """ __visit_name__ = "user_defined" + def _adapt_expression(self, op, othertype): + """evaluate the return type of <self> <op> <othertype>, + and apply any adaptations to the given operator. + + """ + return self.adapt_operator(op), self + + def adapt_operator(self, op): + """A hook which allows the given operator to be adapted + to something new. + + See also UserDefinedType._adapt_expression(), an as-yet- + semi-public method with greater capability in this regard. + + """ + return op + class TypeDecorator(AbstractType): """Allows the creation of types which add additional functionality to an existing type. @@ -429,18 +448,41 @@ class NullType(TypeEngine): """ __visit_name__ = 'null' + def _adapt_expression(self, op, othertype): + if othertype is NullType or not operators.is_commutative(op): + return op, self + else: + return othertype._adapt_expression(op, self) + NullTypeEngine = NullType class Concatenable(object): """A mixin that marks a type as supporting 'concatenation', typically strings.""" - def adapt_operator(self, op): - """Converts an add operator to concat.""" - from sqlalchemy.sql import operators - if op is operators.add: - return operators.concat_op + def _adapt_expression(self, op, othertype): + if op is operators.add and isinstance(othertype, (Concatenable, NullType)): + return operators.concat_op, self else: - return op + return op, self + +class _DateAffinity(object): + """Mixin date/time specific expression adaptations. + + Rules are implemented within Date,Time,Interval,DateTime, Numeric, Integer. + Based on http://www.postgresql.org/docs/current/static/functions-datetime.html. + + """ + + @property + def _expression_adaptations(self): + raise NotImplementedError() + + _blank_dict = util.frozendict() + def _adapt_expression(self, op, othertype): + othertype = othertype._type_affinity + return op, \ + self._expression_adaptations.get(op, self._blank_dict).\ + get(othertype, NULLTYPE) class String(Concatenable, TypeEngine): """The base for all string and character types. @@ -673,14 +715,24 @@ class UnicodeText(Text): super(UnicodeText, self).__init__(length=length, **kwargs) -class Integer(TypeEngine): +class Integer(_DateAffinity, TypeEngine): """A type for ``int`` integers.""" __visit_name__ = 'integer' def get_dbapi_type(self, dbapi): return dbapi.NUMBER - + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Date:Date, + }, + operators.mul:{ + Interval:Interval + }, + } class SmallInteger(Integer): """A type for smaller ``int`` integers. @@ -702,7 +754,7 @@ class BigInteger(Integer): __visit_name__ = 'big_integer' -class Numeric(TypeEngine): +class Numeric(_DateAffinity, TypeEngine): """A type for fixed precision numbers. Typically generates DECIMAL or NUMERIC. Returns @@ -776,6 +828,14 @@ class Numeric(TypeEngine): else: return None + @util.memoized_property + def _expression_adaptations(self): + return { + operators.mul:{ + Interval:Interval + }, + } + class Float(Numeric): """A type for ``float`` numbers. @@ -804,7 +864,7 @@ class Float(Numeric): return impltype(precision=self.precision, asdecimal=self.asdecimal) -class DateTime(TypeEngine): +class DateTime(_DateAffinity, TypeEngine): """A type for ``datetime.datetime()`` objects. Date and time types return objects from the Python ``datetime`` @@ -826,8 +886,20 @@ class DateTime(TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.DATETIME + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Interval:DateTime, + }, + operators.sub:{ + Interval:DateTime, + DateTime:Interval, + }, + } + -class Date(TypeEngine): +class Date(_DateAffinity,TypeEngine): """A type for ``datetime.date()`` objects.""" __visit_name__ = 'date' @@ -835,8 +907,32 @@ class Date(TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.DATETIME - -class Time(TypeEngine): + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Integer:Date, + Interval:DateTime, + Time:DateTime, + }, + operators.sub:{ + # date - integer = date + Integer:Date, + + # date - date = integer. + Date:Integer, + + Interval:DateTime, + + # date - datetime = interval, + # this one is not in the PG docs + # but works + DateTime:Interval, + }, + } + + +class Time(_DateAffinity,TypeEngine): """A type for ``datetime.time()`` objects.""" __visit_name__ = 'time' @@ -850,6 +946,20 @@ class Time(TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.DATETIME + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Date:DateTime, + Interval:Time + }, + operators.sub:{ + Time:Interval, + Interval:Time, + }, + } + + class _Binary(TypeEngine): """Define base behavior for binary types.""" @@ -1245,7 +1355,7 @@ class Boolean(TypeEngine, SchemaType): return value and True or False return process -class Interval(TypeDecorator): +class Interval(_DateAffinity, TypeDecorator): """A type for ``datetime.timedelta()`` objects. The Interval type deals with ``datetime.timedelta`` objects. In @@ -1319,10 +1429,31 @@ class Interval(TypeDecorator): return value - epoch return process + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add:{ + Date:DateTime, + Interval:Interval, + DateTime:DateTime, + Time:Time, + }, + operators.sub:{ + Interval:Interval + }, + operators.mul:{ + Numeric:Interval + }, + operators.div: { + Numeric:Interval + } + } + @property def _type_affinity(self): return Interval + class FLOAT(Float): """The SQL FLOAT type.""" @@ -1440,22 +1571,23 @@ class BOOLEAN(Boolean): __visit_name__ = 'BOOLEAN' NULLTYPE = NullType() +BOOLEANTYPE = Boolean() # using VARCHAR/NCHAR so that we dont get the genericized "String" # type which usually resolves to TEXT/CLOB type_map = { - str: String, + str: String(), # Py2K - unicode : String, + unicode : String(), # end Py2K - int : Integer, - float : Numeric, - bool: Boolean, - _python_Decimal : Numeric, - dt.date : Date, - dt.datetime : DateTime, - dt.time : Time, - dt.timedelta : Interval, - NoneType: NullType + int : Integer(), + float : Numeric(), + bool: BOOLEANTYPE, + _python_Decimal : Numeric(), + dt.date : Date(), + dt.datetime : DateTime(), + dt.time : Time(), + dt.timedelta : Interval(), + NoneType: NULLTYPE } |
