summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/ext/index.py185
-rw-r--r--test/ext/test_index.py313
2 files changed, 498 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/index.py b/lib/sqlalchemy/ext/index.py
new file mode 100644
index 000000000..d99f315ff
--- /dev/null
+++ b/lib/sqlalchemy/ext/index.py
@@ -0,0 +1,185 @@
+# ext/index.py
+# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Property interface implementations for Indexable columns.
+
+This is a private module.
+
+"""
+from __future__ import absolute_import
+
+from ..sql.sqltypes import TypeEngine, JSON
+from ..dialects import postgresql
+from ..orm.attributes import flag_modified
+from ..util.langhelpers import public_factory
+
+
+__all__ = ['index_property', 'json_property']
+
+
+class IndexPropertyInterface(object):
+ """Describes an object attribute that corresponds to an indexable column.
+
+ Public constructors are the :func:`.orm.index_property` and
+ :func:`.orm.json_property` function
+
+ """
+
+ __slots__ = (
+ 'attr_name', 'index', 'default', 'use_column_default_for_none',
+ 'cast_type')
+
+ class IndexPropertyDefault(object):
+ def __init__(self, arg):
+ self.arg = arg
+
+ column_index_mappers = {
+ JSON: lambda key: str(key),
+ postgresql.JSON: lambda key: str(key),
+ postgresql.ARRAY: lambda index: index + 1, # 1-based index in pg
+ postgresql.HSTORE: lambda key: str(key),
+ }
+
+ def __init__(self, attr_name, index, **kwargs):
+ """Provide a sub-column property ingredients for Indexable typed columns.
+
+ An index property subscribe an index of a column with Indexable type.
+ Use this function to concentrate more on each index of indexable columns.
+
+ See `index_property` or `json_property` for actual properties.
+
+ :param attr_name:
+ An attritube name of a `Indexable` typed column.
+
+ :param index:
+ An index with matching type for column's type.
+
+ :param default:
+ When given, accessing will returns the value if IndexError or
+ KeyError is raised while accessing by the index.
+
+ :param use_column_default_for_none:
+ When ``True``, the subscribing column will be automatically set to
+ its default value.
+ """
+ if 'default' in kwargs:
+ default = self.IndexPropertyDefault(kwargs.pop('default'))
+ else:
+ default = None
+ use_column_default_for_none = kwargs.pop('use_column_default_for_none',
+ False)
+
+ if kwargs:
+ raise TypeError('Unknown parameter(s) for index property: %s'
+ % kwargs.keys())
+
+ self.attr_name = attr_name
+ self.index = index
+ self.default = default
+ self.use_column_default_for_none = use_column_default_for_none
+
+ def fget(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ if self.use_column_default_for_none and\
+ (self.use_column_default_for_none is True or
+ 'getter' in self.use_column_default_for_none):
+ column = getattr(instance.__class__, attr_name)
+ column_value = column.default.arg
+ elif self.default:
+ return self.default.arg
+ try:
+ return column_value[self.index]
+ except (KeyError, IndexError):
+ if self.default:
+ return self.default.arg
+ raise
+
+ def fset(self, instance, value):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ if self.use_column_default_for_none and\
+ (self.use_column_default_for_none is True or
+ 'setter' in self.use_column_default_for_none):
+ column = getattr(instance.__class__, attr_name)
+ column_value = column.default.arg
+ column_value[self.index] = value
+ setattr(instance, attr_name, column_value)
+ flag_modified(instance, attr_name)
+
+ def fdel(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ if self.use_column_default_for_none and \
+ (self.use_column_default_for_none is True or
+ 'deleter' in self.use_column_default_for_none):
+ column = getattr(self.__class__, attr_name)
+ column_value = column.default.arg
+ del column_value[self.index]
+ setattr(instance, attr_name, column_value)
+ flag_modified(instance, attr_name)
+
+ def expr(self, model):
+ column = getattr(model, self.attr_name)
+
+ index = self.index
+ column_type = type(column.type)
+ column_index_mapper = self.column_index_mappers.get(column_type, None)
+ if column_index_mapper:
+ index = column_index_mapper(index)
+ expr = column[index]
+ return expr
+
+ def json_expr(self, model):
+ expr = self.expr(model)
+ if self.cast_type is not None:
+ expr = expr.astext.cast(self.cast_type)
+ return expr
+
+ @classmethod
+ def property(cls, column, index, **kwargs):
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ mutable = kwargs.pop('mutable', False)
+ interface = cls(column, index, **kwargs)
+ if mutable:
+ property = hybrid_property(interface.fget, interface.fset,
+ interface.fdel, interface.expr)
+ else:
+ property = hybrid_property(interface.fget, None, None,
+ interface.expr)
+ return property
+
+ @classmethod
+ def json_property(cls, column, index, **kwargs):
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ cast_type = kwargs.pop('cast_type', None)
+ if cast_type is not None:
+ if not (isinstance(cast_type, TypeEngine) or callable(cast_type)):
+ raise TypeError("'cast_type' must be a schema type but '%s' "
+ "found" % cast_type)
+
+ mutable = kwargs.pop('mutable', False)
+ interface = cls(column, index, **kwargs)
+ interface.cast_type = cast_type
+ if mutable:
+ property = hybrid_property(interface.fget, interface.fset,
+ interface.fdel, interface.json_expr)
+ else:
+ property = hybrid_property(interface.fget, None, None,
+ interface.json_expr)
+ return property
+
+
+index_property = public_factory(IndexPropertyInterface.property,
+ ".ext.index.index_property")
+json_property = public_factory(IndexPropertyInterface.json_property,
+ ".ext.index.json_property")
diff --git a/test/ext/test_index.py b/test/ext/test_index.py
new file mode 100644
index 000000000..557542283
--- /dev/null
+++ b/test/ext/test_index.py
@@ -0,0 +1,313 @@
+from sqlalchemy.testing import assert_raises
+import sqlalchemy as sa
+from sqlalchemy import testing
+from sqlalchemy import Integer, Text
+from sqlalchemy.sql.sqltypes import ARRAY, JSON
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.orm import Session
+from sqlalchemy.testing import fixtures
+from sqlalchemy.ext.index import index_property, json_property
+
+
+class IndexPropertyTest(fixtures.TestBase):
+
+ def test_array(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column('id', Integer, primary_key=True)
+ array = Column('_array', ARRAY(Integer),
+ default=[])
+ first = index_property('array', 0, mutable=True)
+
+ a = A(array=[1, 2, 3])
+ assert a.first == 1
+ a.first = 100
+ assert a.first == 100
+ assert a.array == [100, 2, 3]
+
+ def test_json(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class J(Base):
+ __tablename__ = 'j'
+ id = Column('id', Integer, primary_key=True)
+ json = Column('_json', JSON, default={})
+ field = index_property('json', 'field', default=None, mutable=True)
+
+ j = J(json={'a': 1, 'b': 2})
+ assert j.field is None
+ j.field = 'test'
+ assert j.field == 'test'
+ assert j.json == {'a': 1, 'b': 2, 'field': 'test'}
+
+ def test_column_key(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column('id', Integer, primary_key=True)
+ array = Column('_array', ARRAY(Integer),
+ default=[])
+ first = index_property('array', 0, mutable=True)
+
+ a = A(array=[1, 2, 3])
+ assert a.first == 1
+ a.first = 100
+ assert a.first == 100
+ assert a.array == [100, 2, 3]
+
+ def test_get_no_column_default(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column('id', Integer, primary_key=True)
+ array = Column('_array', ARRAY(Integer))
+ first = index_property('array', 0)
+
+ a = A()
+ assert_raises(TypeError, lambda: a.first)
+
+ def test_get_no_index_default(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column('id', Integer, primary_key=True)
+ array = Column('_array', ARRAY(Integer))
+ first = index_property('array', 0)
+
+ a = A(array=[])
+ assert_raises(IndexError, lambda: a.first)
+
+ def test_get_index_default(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column(Integer, primary_key=True)
+ array = Column(ARRAY(Integer))
+ first = index_property('array', 0, default=5)
+
+ a = A()
+ assert a.first == 5
+
+ def test_set_immutable(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column(Integer, primary_key=True)
+ array = Column(ARRAY(Integer))
+ first = index_property('array', 0)
+
+ a = A()
+
+ def set():
+ a.first = 10
+ assert_raises(AttributeError, set)
+
+ def test_set_mutable_dict(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class J(Base):
+ __tablename__ = 'j'
+ id = Column(Integer, primary_key=True)
+ json = Column(JSON, default={})
+ field = index_property('json', 'field', default=None, mutable=True)
+
+ j = J()
+
+ def set():
+ j.field = 10
+
+ assert_raises(TypeError, set)
+
+ j.json = {}
+ assert j.field is None
+ set()
+ assert j.field == 10
+
+ def test_set_without_column_default(self):
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class A(Base):
+ __tablename__ = 'a'
+ id = Column(Integer, primary_key=True)
+ array = Column(ARRAY(Integer))
+ first = index_property('array', 0, mutable=True)
+
+ a = A()
+
+ def set():
+ a.first = 10
+ assert_raises(TypeError, set)
+
+ a.array = []
+ assert_raises(IndexError, set)
+
+ a.array = [42]
+ assert a.first == 42
+ set()
+ assert a.first == 10
+
+
+class IndexPropertyPostgresqlTest(fixtures.DeclarativeMappedTest):
+
+ __only_on__ = 'postgresql'
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ from sqlalchemy.dialects.postgresql import ARRAY, JSON
+
+ Base = cls.DeclarativeBasic
+
+ class Array(fixtures.ComparableEntity, Base):
+ __tablename__ = "array"
+
+ id = Column(sa.Integer, primary_key=True,
+ test_needs_autoincrement=True)
+ array = Column(ARRAY(Integer), default=[])
+ first = index_property('array', 0, default=None)
+ mutable = index_property('array', 0, default=None, mutable=True)
+
+ class Json(fixtures.ComparableEntity, Base):
+ __tablename__ = "json"
+
+ id = Column(sa.Integer, primary_key=True,
+ test_needs_autoincrement=True)
+ json = Column(JSON, default={})
+ field = index_property('json', 'field', default=None)
+ json_field = json_property('json', 'field')
+ int_field = json_property('json', 'field', cast_type=Integer)
+ text_field = json_property('json', 'field', cast_type=Text)
+ other = index_property('json', 'other', mutable=True,
+ use_column_default_for_none=True)
+
+ Base.metadata.drop_all()
+ Base.metadata.create_all()
+
+ def test_query_array(self):
+ Array = self.classes.Array
+ s = Session(testing.db)
+
+ s.add_all([
+ Array(),
+ Array(array=[1, 2, 3]),
+ Array(array=[4, 5, 6])])
+ s.commit()
+
+ a1 = s.query(Array).filter(Array.array == [1, 2, 3]).one()
+ a2 = s.query(Array).filter(Array.first == 1).one()
+ assert a1.id == a2.id
+ a3 = s.query(Array).filter(Array.first == 4).one()
+ assert a1.id != a3.id
+
+ def test_query_json(self):
+ Json = self.classes.Json
+ s = Session(testing.db)
+
+ s.add_all([
+ Json(),
+ Json(json={'field': 10}),
+ Json(json={'field': 20})])
+ s.commit()
+
+ a1 = s.query(Json).filter(Json.json['field'].astext.cast(Integer) == 10)\
+ .one()
+ a2 = s.query(Json).filter(Json.field.astext == '10').one()
+ assert a1.id == a2.id
+ a3 = s.query(Json).filter(Json.field.astext == '20').one()
+ assert a1.id != a3.id
+
+ def test_mutable_array(self):
+ Array = self.classes.Array
+ s = Session(testing.db)
+
+ a = Array(array=[1, 2, 3])
+ s.add(a)
+ s.commit()
+
+ a.mutable = 42
+ assert a.first == 42
+ s.commit()
+ assert a.first == 42
+
+ def test_mutable_json(self):
+ Json = self.classes.Json
+ s = Session(testing.db)
+
+ j = Json(json={})
+ s.add(j)
+ s.commit()
+
+ j.other = 42
+ assert j.other == 42
+ s.commit()
+ assert j.other == 42
+
+ def test_set_column_default(self):
+ Json = self.classes.Json
+ j = Json()
+
+ assert_raises(KeyError, lambda: j.other)
+ j.other = 42
+ assert j.other == 42
+ assert j.json == {'other': 42}
+
+ def test_modified(self):
+ from sqlalchemy import inspect
+
+ Array = self.classes.Array
+ s = Session(testing.db)
+
+ a = Array(array=[1, 2, 3])
+ s.add(a)
+ s.commit()
+
+ i = inspect(a)
+ assert not i.modified
+ assert 'array' in i.unmodified
+
+ a.mutable = 10
+
+ assert i.modified
+ assert 'array' not in i.unmodified
+
+ def test_json_type(self):
+ Json = self.classes.Json
+ s = Session(testing.db)
+
+ j = Json(json={'field': 10})
+ s.add(j)
+ s.commit()
+
+ jq = s.query(Json).filter(Json.int_field == 10).one()
+ assert j.id == jq.id
+
+ jq = s.query(Json).filter(Json.text_field == '10').one()
+ assert j.id == jq.id
+
+ jq = s.query(Json).filter(Json.json_field.astext == '10').one()
+ assert j.id == jq.id
+
+ jq = s.query(Json).filter(Json.text_field == 'wrong').first()
+ assert jq is None
+
+ j.json = {'field': True}
+ s.commit()
+
+ jq = s.query(Json).filter(Json.text_field == 'true').one()
+ assert j.id == jq.id