diff options
author | Jeong YunWon <jeong@youknowone.org> | 2016-02-13 19:20:12 +0900 |
---|---|---|
committer | Jeong YunWon <jeong@youknowone.org> | 2016-02-13 21:28:50 +0900 |
commit | f7354b43e4c993e9070ad104fcaab424ad2df276 (patch) | |
tree | cca962a32be0245cb30baf6536d64fbeaf18a7b1 | |
parent | 1b6422a603a131c4a400853e472e5958a760f99e (diff) | |
download | sqlalchemy-f7354b43e4c993e9070ad104fcaab424ad2df276.tar.gz |
Add `sqlalchemy.ext.mutable.MutableSet`pr/236
from https://bitbucket.org/zzzeek/sqlalchemy/issues/3297
-rw-r--r-- | lib/sqlalchemy/ext/mutable.py | 65 | ||||
-rw-r--r-- | test/ext/test_mutable.py | 196 |
2 files changed, 260 insertions, 1 deletions
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 0081cf720..aa5be57ff 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -778,3 +778,68 @@ class MutableList(Mutable, list): def __setstate__(self, state): self[:] = state + + +class MutableSet(Mutable, set): + """A set type that implements :class:`.Mutable`. + + The :class:`.MutableSet` object implements a list that will + emit change events to the underlying mapping when the contents of + the set are altered, including when values are added or removed. + """ + + def update(self, *arg): + set.update(self, *arg) + self.changed() + + def intersection_update(self, *arg): + set.intersection_update(self, *arg) + self.changed() + + def difference_update(self, *arg): + set.difference_update(self, *arg) + self.changed() + + def symmetric_difference_update(self, *arg): + set.symmetric_difference_update(self, *arg) + self.changed() + + def add(self, elem): + set.add(self, elem) + self.changed() + + def remove(self, elem): + set.remove(self, elem) + self.changed() + + def discard(self, elem): + set.discard(self, elem) + self.changed() + + def pop(self, *arg): + result = set.pop(self, *arg) + self.changed() + return result + + def clear(self): + set.clear(self) + self.changed() + + @classmethod + def coerce(cls, index, value): + """Convert plain set to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, set): + return cls(value) + return Mutable.coerce(index, value) + else: + return value + + def __getstate__(self): + return set(self) + + def __setstate__(self, state): + self.update(state) + + def __reduce_ex__(self, proto): + return (self.__class__, (list(self), )) diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index 7cdf9f12b..1e1a75e7e 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -8,7 +8,7 @@ from sqlalchemy.testing import eq_, assert_raises_message, assert_raises from sqlalchemy.testing.util import picklers from sqlalchemy.testing import fixtures from sqlalchemy.ext.mutable import MutableComposite -from sqlalchemy.ext.mutable import MutableDict, MutableList +from sqlalchemy.ext.mutable import MutableDict, MutableList, MutableSet class Foo(fixtures.BasicEntity): @@ -461,6 +461,183 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data[0], 3) +class _MutableSetTestFixture(object): + @classmethod + def _type_fixture(cls): + return MutableSet + + def teardown(self): + # clear out mapper events + Mapper.dispatch._clear() + ClassManager.dispatch._clear() + super(_MutableSetTestFixture, self).teardown() + + +class _MutableSetTestBase(_MutableSetTestFixture): + run_define_tables = 'each' + + def setup_mappers(cls): + foo = cls.tables.foo + + mapper(Foo, foo) + + def test_coerce_none(self): + sess = Session() + f1 = Foo(data=None) + sess.add(f1) + sess.commit() + eq_(f1.data, None) + + def test_coerce_raise(self): + assert_raises_message( + ValueError, + "Attribute 'data' does not accept objects of type", + Foo, data=[1, 2, 3] + ) + + def test_clear(self): + sess = Session() + + f1 = Foo(data=set([1, 2])) + sess.add(f1) + sess.commit() + + f1.data.clear() + sess.commit() + + eq_(f1.data, set()) + + def test_pop(self): + sess = Session() + + f1 = Foo(data=set([1])) + sess.add(f1) + sess.commit() + + eq_(f1.data.pop(), 1) + sess.commit() + + assert_raises(KeyError, f1.data.pop) + + eq_(f1.data, set()) + + def test_add(self): + sess = Session() + + f1 = Foo(data=set([1, 2])) + sess.add(f1) + sess.commit() + + f1.data.add(5) + sess.commit() + + eq_(f1.data, set([1, 2, 5])) + + def test_update(self): + sess = Session() + + f1 = Foo(data=set([1, 2])) + sess.add(f1) + sess.commit() + + f1.data.update(set([2, 5])) + sess.commit() + + eq_(f1.data, set([1, 2, 5])) + + def test_intersection_update(self): + sess = Session() + + f1 = Foo(data=set([1, 2])) + sess.add(f1) + sess.commit() + + f1.data.intersection_update(set([2, 5])) + sess.commit() + + eq_(f1.data, set([2])) + + def test_difference_update(self): + sess = Session() + + f1 = Foo(data=set([1, 2])) + sess.add(f1) + sess.commit() + + f1.data.difference_update(set([2, 5])) + sess.commit() + + eq_(f1.data, set([1])) + + def test_symmetric_difference_update(self): + sess = Session() + + f1 = Foo(data=set([1, 2])) + sess.add(f1) + sess.commit() + + f1.data.symmetric_difference_update(set([2, 5])) + sess.commit() + + eq_(f1.data, set([1, 5])) + + def test_remove(self): + sess = Session() + + f1 = Foo(data=set([1, 2, 3])) + sess.add(f1) + sess.commit() + + f1.data.remove(2) + sess.commit() + + eq_(f1.data, set([1, 3])) + + def test_discard(self): + sess = Session() + + f1 = Foo(data=set([1, 2, 3])) + sess.add(f1) + sess.commit() + + f1.data.discard(2) + sess.commit() + + eq_(f1.data, set([1, 3])) + + f1.data.discard(2) + sess.commit() + + eq_(f1.data, set([1, 3])) + + def test_pickle_parent(self): + sess = Session() + + f1 = Foo(data=set([1, 2])) + sess.add(f1) + sess.commit() + f1.data + sess.close() + + for loads, dumps in picklers(): + sess = Session() + f2 = loads(dumps(f1)) + sess.add(f2) + f2.data.add(3) + assert f2 in sess.dirty + + def test_unrelated_flush(self): + sess = Session() + f1 = Foo(data=set([1, 2]), unrelated_data="unrelated") + sess.add(f1) + sess.flush() + f1.unrelated_data = "unrelated 2" + sess.flush() + f1.data.add(3) + sess.commit() + eq_(f1.data, set([1, 2, 3])) + + class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest): @classmethod def define_tables(cls, metadata): @@ -566,6 +743,23 @@ class MutableListWithScalarPickleTest(_MutableListTestBase, fixtures.MappedTest) ) +class MutableSetWithScalarPickleTest(_MutableSetTestBase, fixtures.MappedTest): + + @classmethod + def define_tables(cls, metadata): + MutableSet = cls._type_fixture() + + mutable_pickle = MutableSet.as_mutable(PickleType) + Table('foo', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('skip', mutable_pickle), + Column('data', mutable_pickle), + Column('non_mutable_data', PickleType), + Column('unrelated_data', String(50)) + ) + + class MutableAssocWithAttrInheritTest(_MutableDictTestBase, fixtures.MappedTest): |