diff options
-rw-r--r-- | lib/sqlalchemy/ext/index.py | 185 | ||||
-rw-r--r-- | test/ext/test_index.py | 313 |
2 files changed, 498 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/index.py b/lib/sqlalchemy/ext/index.py new file mode 100644 index 000000000..d99f315ff --- /dev/null +++ b/lib/sqlalchemy/ext/index.py @@ -0,0 +1,185 @@ +# ext/index.py +# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Property interface implementations for Indexable columns. + +This is a private module. + +""" +from __future__ import absolute_import + +from ..sql.sqltypes import TypeEngine, JSON +from ..dialects import postgresql +from ..orm.attributes import flag_modified +from ..util.langhelpers import public_factory + + +__all__ = ['index_property', 'json_property'] + + +class IndexPropertyInterface(object): + """Describes an object attribute that corresponds to an indexable column. + + Public constructors are the :func:`.orm.index_property` and + :func:`.orm.json_property` function + + """ + + __slots__ = ( + 'attr_name', 'index', 'default', 'use_column_default_for_none', + 'cast_type') + + class IndexPropertyDefault(object): + def __init__(self, arg): + self.arg = arg + + column_index_mappers = { + JSON: lambda key: str(key), + postgresql.JSON: lambda key: str(key), + postgresql.ARRAY: lambda index: index + 1, # 1-based index in pg + postgresql.HSTORE: lambda key: str(key), + } + + def __init__(self, attr_name, index, **kwargs): + """Provide a sub-column property ingredients for Indexable typed columns. + + An index property subscribe an index of a column with Indexable type. + Use this function to concentrate more on each index of indexable columns. + + See `index_property` or `json_property` for actual properties. + + :param attr_name: + An attritube name of a `Indexable` typed column. + + :param index: + An index with matching type for column's type. + + :param default: + When given, accessing will returns the value if IndexError or + KeyError is raised while accessing by the index. + + :param use_column_default_for_none: + When ``True``, the subscribing column will be automatically set to + its default value. + """ + if 'default' in kwargs: + default = self.IndexPropertyDefault(kwargs.pop('default')) + else: + default = None + use_column_default_for_none = kwargs.pop('use_column_default_for_none', + False) + + if kwargs: + raise TypeError('Unknown parameter(s) for index property: %s' + % kwargs.keys()) + + self.attr_name = attr_name + self.index = index + self.default = default + self.use_column_default_for_none = use_column_default_for_none + + def fget(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + if self.use_column_default_for_none and\ + (self.use_column_default_for_none is True or + 'getter' in self.use_column_default_for_none): + column = getattr(instance.__class__, attr_name) + column_value = column.default.arg + elif self.default: + return self.default.arg + try: + return column_value[self.index] + except (KeyError, IndexError): + if self.default: + return self.default.arg + raise + + def fset(self, instance, value): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + if self.use_column_default_for_none and\ + (self.use_column_default_for_none is True or + 'setter' in self.use_column_default_for_none): + column = getattr(instance.__class__, attr_name) + column_value = column.default.arg + column_value[self.index] = value + setattr(instance, attr_name, column_value) + flag_modified(instance, attr_name) + + def fdel(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + if self.use_column_default_for_none and \ + (self.use_column_default_for_none is True or + 'deleter' in self.use_column_default_for_none): + column = getattr(self.__class__, attr_name) + column_value = column.default.arg + del column_value[self.index] + setattr(instance, attr_name, column_value) + flag_modified(instance, attr_name) + + def expr(self, model): + column = getattr(model, self.attr_name) + + index = self.index + column_type = type(column.type) + column_index_mapper = self.column_index_mappers.get(column_type, None) + if column_index_mapper: + index = column_index_mapper(index) + expr = column[index] + return expr + + def json_expr(self, model): + expr = self.expr(model) + if self.cast_type is not None: + expr = expr.astext.cast(self.cast_type) + return expr + + @classmethod + def property(cls, column, index, **kwargs): + from sqlalchemy.ext.hybrid import hybrid_property + + mutable = kwargs.pop('mutable', False) + interface = cls(column, index, **kwargs) + if mutable: + property = hybrid_property(interface.fget, interface.fset, + interface.fdel, interface.expr) + else: + property = hybrid_property(interface.fget, None, None, + interface.expr) + return property + + @classmethod + def json_property(cls, column, index, **kwargs): + from sqlalchemy.ext.hybrid import hybrid_property + + cast_type = kwargs.pop('cast_type', None) + if cast_type is not None: + if not (isinstance(cast_type, TypeEngine) or callable(cast_type)): + raise TypeError("'cast_type' must be a schema type but '%s' " + "found" % cast_type) + + mutable = kwargs.pop('mutable', False) + interface = cls(column, index, **kwargs) + interface.cast_type = cast_type + if mutable: + property = hybrid_property(interface.fget, interface.fset, + interface.fdel, interface.json_expr) + else: + property = hybrid_property(interface.fget, None, None, + interface.json_expr) + return property + + +index_property = public_factory(IndexPropertyInterface.property, + ".ext.index.index_property") +json_property = public_factory(IndexPropertyInterface.json_property, + ".ext.index.json_property") diff --git a/test/ext/test_index.py b/test/ext/test_index.py new file mode 100644 index 000000000..557542283 --- /dev/null +++ b/test/ext/test_index.py @@ -0,0 +1,313 @@ +from sqlalchemy.testing import assert_raises +import sqlalchemy as sa +from sqlalchemy import testing +from sqlalchemy import Integer, Text +from sqlalchemy.sql.sqltypes import ARRAY, JSON +from sqlalchemy.testing.schema import Column +from sqlalchemy.orm import Session +from sqlalchemy.testing import fixtures +from sqlalchemy.ext.index import index_property, json_property + + +class IndexPropertyTest(fixtures.TestBase): + + def test_array(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + id = Column('id', Integer, primary_key=True) + array = Column('_array', ARRAY(Integer), + default=[]) + first = index_property('array', 0, mutable=True) + + a = A(array=[1, 2, 3]) + assert a.first == 1 + a.first = 100 + assert a.first == 100 + assert a.array == [100, 2, 3] + + def test_json(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class J(Base): + __tablename__ = 'j' + id = Column('id', Integer, primary_key=True) + json = Column('_json', JSON, default={}) + field = index_property('json', 'field', default=None, mutable=True) + + j = J(json={'a': 1, 'b': 2}) + assert j.field is None + j.field = 'test' + assert j.field == 'test' + assert j.json == {'a': 1, 'b': 2, 'field': 'test'} + + def test_column_key(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + id = Column('id', Integer, primary_key=True) + array = Column('_array', ARRAY(Integer), + default=[]) + first = index_property('array', 0, mutable=True) + + a = A(array=[1, 2, 3]) + assert a.first == 1 + a.first = 100 + assert a.first == 100 + assert a.array == [100, 2, 3] + + def test_get_no_column_default(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + id = Column('id', Integer, primary_key=True) + array = Column('_array', ARRAY(Integer)) + first = index_property('array', 0) + + a = A() + assert_raises(TypeError, lambda: a.first) + + def test_get_no_index_default(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + id = Column('id', Integer, primary_key=True) + array = Column('_array', ARRAY(Integer)) + first = index_property('array', 0) + + a = A(array=[]) + assert_raises(IndexError, lambda: a.first) + + def test_get_index_default(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + array = Column(ARRAY(Integer)) + first = index_property('array', 0, default=5) + + a = A() + assert a.first == 5 + + def test_set_immutable(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + array = Column(ARRAY(Integer)) + first = index_property('array', 0) + + a = A() + + def set(): + a.first = 10 + assert_raises(AttributeError, set) + + def test_set_mutable_dict(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class J(Base): + __tablename__ = 'j' + id = Column(Integer, primary_key=True) + json = Column(JSON, default={}) + field = index_property('json', 'field', default=None, mutable=True) + + j = J() + + def set(): + j.field = 10 + + assert_raises(TypeError, set) + + j.json = {} + assert j.field is None + set() + assert j.field == 10 + + def test_set_without_column_default(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + array = Column(ARRAY(Integer)) + first = index_property('array', 0, mutable=True) + + a = A() + + def set(): + a.first = 10 + assert_raises(TypeError, set) + + a.array = [] + assert_raises(IndexError, set) + + a.array = [42] + assert a.first == 42 + set() + assert a.first == 10 + + +class IndexPropertyPostgresqlTest(fixtures.DeclarativeMappedTest): + + __only_on__ = 'postgresql' + __backend__ = True + + @classmethod + def setup_classes(cls): + from sqlalchemy.dialects.postgresql import ARRAY, JSON + + Base = cls.DeclarativeBasic + + class Array(fixtures.ComparableEntity, Base): + __tablename__ = "array" + + id = Column(sa.Integer, primary_key=True, + test_needs_autoincrement=True) + array = Column(ARRAY(Integer), default=[]) + first = index_property('array', 0, default=None) + mutable = index_property('array', 0, default=None, mutable=True) + + class Json(fixtures.ComparableEntity, Base): + __tablename__ = "json" + + id = Column(sa.Integer, primary_key=True, + test_needs_autoincrement=True) + json = Column(JSON, default={}) + field = index_property('json', 'field', default=None) + json_field = json_property('json', 'field') + int_field = json_property('json', 'field', cast_type=Integer) + text_field = json_property('json', 'field', cast_type=Text) + other = index_property('json', 'other', mutable=True, + use_column_default_for_none=True) + + Base.metadata.drop_all() + Base.metadata.create_all() + + def test_query_array(self): + Array = self.classes.Array + s = Session(testing.db) + + s.add_all([ + Array(), + Array(array=[1, 2, 3]), + Array(array=[4, 5, 6])]) + s.commit() + + a1 = s.query(Array).filter(Array.array == [1, 2, 3]).one() + a2 = s.query(Array).filter(Array.first == 1).one() + assert a1.id == a2.id + a3 = s.query(Array).filter(Array.first == 4).one() + assert a1.id != a3.id + + def test_query_json(self): + Json = self.classes.Json + s = Session(testing.db) + + s.add_all([ + Json(), + Json(json={'field': 10}), + Json(json={'field': 20})]) + s.commit() + + a1 = s.query(Json).filter(Json.json['field'].astext.cast(Integer) == 10)\ + .one() + a2 = s.query(Json).filter(Json.field.astext == '10').one() + assert a1.id == a2.id + a3 = s.query(Json).filter(Json.field.astext == '20').one() + assert a1.id != a3.id + + def test_mutable_array(self): + Array = self.classes.Array + s = Session(testing.db) + + a = Array(array=[1, 2, 3]) + s.add(a) + s.commit() + + a.mutable = 42 + assert a.first == 42 + s.commit() + assert a.first == 42 + + def test_mutable_json(self): + Json = self.classes.Json + s = Session(testing.db) + + j = Json(json={}) + s.add(j) + s.commit() + + j.other = 42 + assert j.other == 42 + s.commit() + assert j.other == 42 + + def test_set_column_default(self): + Json = self.classes.Json + j = Json() + + assert_raises(KeyError, lambda: j.other) + j.other = 42 + assert j.other == 42 + assert j.json == {'other': 42} + + def test_modified(self): + from sqlalchemy import inspect + + Array = self.classes.Array + s = Session(testing.db) + + a = Array(array=[1, 2, 3]) + s.add(a) + s.commit() + + i = inspect(a) + assert not i.modified + assert 'array' in i.unmodified + + a.mutable = 10 + + assert i.modified + assert 'array' not in i.unmodified + + def test_json_type(self): + Json = self.classes.Json + s = Session(testing.db) + + j = Json(json={'field': 10}) + s.add(j) + s.commit() + + jq = s.query(Json).filter(Json.int_field == 10).one() + assert j.id == jq.id + + jq = s.query(Json).filter(Json.text_field == '10').one() + assert j.id == jq.id + + jq = s.query(Json).filter(Json.json_field.astext == '10').one() + assert j.id == jq.id + + jq = s.query(Json).filter(Json.text_field == 'wrong').first() + assert jq is None + + j.json = {'field': True} + s.commit() + + jq = s.query(Json).filter(Json.text_field == 'true').one() + assert j.id == jq.id |