summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeong YunWon <jeong@youknowone.org>2016-02-13 19:20:12 +0900
committerJeong YunWon <jeong@youknowone.org>2016-02-13 21:28:50 +0900
commitf7354b43e4c993e9070ad104fcaab424ad2df276 (patch)
treecca962a32be0245cb30baf6536d64fbeaf18a7b1
parent1b6422a603a131c4a400853e472e5958a760f99e (diff)
downloadsqlalchemy-f7354b43e4c993e9070ad104fcaab424ad2df276.tar.gz
Add `sqlalchemy.ext.mutable.MutableSet`pr/236
from https://bitbucket.org/zzzeek/sqlalchemy/issues/3297
-rw-r--r--lib/sqlalchemy/ext/mutable.py65
-rw-r--r--test/ext/test_mutable.py196
2 files changed, 260 insertions, 1 deletions
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py
index 0081cf720..aa5be57ff 100644
--- a/lib/sqlalchemy/ext/mutable.py
+++ b/lib/sqlalchemy/ext/mutable.py
@@ -778,3 +778,68 @@ class MutableList(Mutable, list):
def __setstate__(self, state):
self[:] = state
+
+
+class MutableSet(Mutable, set):
+ """A set type that implements :class:`.Mutable`.
+
+ The :class:`.MutableSet` object implements a list that will
+ emit change events to the underlying mapping when the contents of
+ the set are altered, including when values are added or removed.
+ """
+
+ def update(self, *arg):
+ set.update(self, *arg)
+ self.changed()
+
+ def intersection_update(self, *arg):
+ set.intersection_update(self, *arg)
+ self.changed()
+
+ def difference_update(self, *arg):
+ set.difference_update(self, *arg)
+ self.changed()
+
+ def symmetric_difference_update(self, *arg):
+ set.symmetric_difference_update(self, *arg)
+ self.changed()
+
+ def add(self, elem):
+ set.add(self, elem)
+ self.changed()
+
+ def remove(self, elem):
+ set.remove(self, elem)
+ self.changed()
+
+ def discard(self, elem):
+ set.discard(self, elem)
+ self.changed()
+
+ def pop(self, *arg):
+ result = set.pop(self, *arg)
+ self.changed()
+ return result
+
+ def clear(self):
+ set.clear(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, index, value):
+ """Convert plain set to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, set):
+ return cls(value)
+ return Mutable.coerce(index, value)
+ else:
+ return value
+
+ def __getstate__(self):
+ return set(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+ def __reduce_ex__(self, proto):
+ return (self.__class__, (list(self), ))
diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py
index 7cdf9f12b..1e1a75e7e 100644
--- a/test/ext/test_mutable.py
+++ b/test/ext/test_mutable.py
@@ -8,7 +8,7 @@ from sqlalchemy.testing import eq_, assert_raises_message, assert_raises
from sqlalchemy.testing.util import picklers
from sqlalchemy.testing import fixtures
from sqlalchemy.ext.mutable import MutableComposite
-from sqlalchemy.ext.mutable import MutableDict, MutableList
+from sqlalchemy.ext.mutable import MutableDict, MutableList, MutableSet
class Foo(fixtures.BasicEntity):
@@ -461,6 +461,183 @@ class _MutableListTestBase(_MutableListTestFixture):
eq_(f1.data[0], 3)
+class _MutableSetTestFixture(object):
+ @classmethod
+ def _type_fixture(cls):
+ return MutableSet
+
+ def teardown(self):
+ # clear out mapper events
+ Mapper.dispatch._clear()
+ ClassManager.dispatch._clear()
+ super(_MutableSetTestFixture, self).teardown()
+
+
+class _MutableSetTestBase(_MutableSetTestFixture):
+ run_define_tables = 'each'
+
+ def setup_mappers(cls):
+ foo = cls.tables.foo
+
+ mapper(Foo, foo)
+
+ def test_coerce_none(self):
+ sess = Session()
+ f1 = Foo(data=None)
+ sess.add(f1)
+ sess.commit()
+ eq_(f1.data, None)
+
+ def test_coerce_raise(self):
+ assert_raises_message(
+ ValueError,
+ "Attribute 'data' does not accept objects of type",
+ Foo, data=[1, 2, 3]
+ )
+
+ def test_clear(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.clear()
+ sess.commit()
+
+ eq_(f1.data, set())
+
+ def test_pop(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1]))
+ sess.add(f1)
+ sess.commit()
+
+ eq_(f1.data.pop(), 1)
+ sess.commit()
+
+ assert_raises(KeyError, f1.data.pop)
+
+ eq_(f1.data, set())
+
+ def test_add(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.add(5)
+ sess.commit()
+
+ eq_(f1.data, set([1, 2, 5]))
+
+ def test_update(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.update(set([2, 5]))
+ sess.commit()
+
+ eq_(f1.data, set([1, 2, 5]))
+
+ def test_intersection_update(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.intersection_update(set([2, 5]))
+ sess.commit()
+
+ eq_(f1.data, set([2]))
+
+ def test_difference_update(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.difference_update(set([2, 5]))
+ sess.commit()
+
+ eq_(f1.data, set([1]))
+
+ def test_symmetric_difference_update(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.symmetric_difference_update(set([2, 5]))
+ sess.commit()
+
+ eq_(f1.data, set([1, 5]))
+
+ def test_remove(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2, 3]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.remove(2)
+ sess.commit()
+
+ eq_(f1.data, set([1, 3]))
+
+ def test_discard(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2, 3]))
+ sess.add(f1)
+ sess.commit()
+
+ f1.data.discard(2)
+ sess.commit()
+
+ eq_(f1.data, set([1, 3]))
+
+ f1.data.discard(2)
+ sess.commit()
+
+ eq_(f1.data, set([1, 3]))
+
+ def test_pickle_parent(self):
+ sess = Session()
+
+ f1 = Foo(data=set([1, 2]))
+ sess.add(f1)
+ sess.commit()
+ f1.data
+ sess.close()
+
+ for loads, dumps in picklers():
+ sess = Session()
+ f2 = loads(dumps(f1))
+ sess.add(f2)
+ f2.data.add(3)
+ assert f2 in sess.dirty
+
+ def test_unrelated_flush(self):
+ sess = Session()
+ f1 = Foo(data=set([1, 2]), unrelated_data="unrelated")
+ sess.add(f1)
+ sess.flush()
+ f1.unrelated_data = "unrelated 2"
+ sess.flush()
+ f1.data.add(3)
+ sess.commit()
+ eq_(f1.data, set([1, 2, 3]))
+
+
class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):
@@ -566,6 +743,23 @@ class MutableListWithScalarPickleTest(_MutableListTestBase, fixtures.MappedTest)
)
+class MutableSetWithScalarPickleTest(_MutableSetTestBase, fixtures.MappedTest):
+
+ @classmethod
+ def define_tables(cls, metadata):
+ MutableSet = cls._type_fixture()
+
+ mutable_pickle = MutableSet.as_mutable(PickleType)
+ Table('foo', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('skip', mutable_pickle),
+ Column('data', mutable_pickle),
+ Column('non_mutable_data', PickleType),
+ Column('unrelated_data', String(50))
+ )
+
+
class MutableAssocWithAttrInheritTest(_MutableDictTestBase,
fixtures.MappedTest):