summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-03-05 10:24:15 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2012-03-05 10:24:15 -0500
commit8007834c97e938912dfd54b342d10e4b9c0a6095 (patch)
tree43050b3d5e8d1f7e5f413217c0d70710328820c5 /lib/sqlalchemy/orm
parent66377aaeafee34767c34e14d9e354aa3bd41372f (diff)
downloadsqlalchemy-8007834c97e938912dfd54b342d10e4b9c0a6095.tar.gz
- [bug] Fixed bug whereby objects using
attribute_mapped_collection or column_mapped_collection could not be pickled. [ticket:2409]
Diffstat (limited to 'lib/sqlalchemy/orm')
-rw-r--r--lib/sqlalchemy/orm/collections.py53
1 files changed, 38 insertions, 15 deletions
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index 2eebfbca2..160fac8be 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -112,12 +112,32 @@ from sqlalchemy.sql import expression
from sqlalchemy import schema, util, exc as sa_exc
+
__all__ = ['collection', 'collection_adapter',
'mapped_collection', 'column_mapped_collection',
'attribute_mapped_collection']
__instrumentation_mutex = util.threading.Lock()
+class _SerializableColumnGetter(object):
+ def __init__(self, colkeys):
+ self.colkeys = colkeys
+ self.composite = len(colkeys) > 1
+
+ def __reduce__(self):
+ return _SerializableColumnGetter, (self.colkeys,)
+
+ def __call__(self, value):
+ state = instance_state(value)
+ m = _state_mapper(state)
+ key = [m._get_state_attr_by_column(
+ state, state.dict,
+ m.mapped_table.columns[k])
+ for k in self.colkeys]
+ if self.composite:
+ return tuple(key)
+ else:
+ return key[0]
def column_mapped_collection(mapping_spec):
"""A dictionary-based collection type with column-based keying.
@@ -131,25 +151,27 @@ def column_mapped_collection(mapping_spec):
after a session flush.
"""
+ global _state_mapper, instance_state
from sqlalchemy.orm.util import _state_mapper
from sqlalchemy.orm.attributes import instance_state
- cols = [expression._only_column_elements(q, "mapping_spec")
- for q in util.to_list(mapping_spec)]
- if len(cols) == 1:
- def keyfunc(value):
- state = instance_state(value)
- m = _state_mapper(state)
- return m._get_state_attr_by_column(state, state.dict, cols[0])
- else:
- mapping_spec = tuple(cols)
- def keyfunc(value):
- state = instance_state(value)
- m = _state_mapper(state)
- return tuple(m._get_state_attr_by_column(state, state.dict, c)
- for c in mapping_spec)
+ cols = [c.key for c in [
+ expression._only_column_elements(q, "mapping_spec")
+ for q in util.to_list(mapping_spec)]]
+ keyfunc = _SerializableColumnGetter(cols)
return lambda: MappedCollection(keyfunc)
+class _SerializableAttrGetter(object):
+ def __init__(self, name):
+ self.name = name
+ self.getter = operator.attrgetter(name)
+
+ def __call__(self, target):
+ return self.getter(target)
+
+ def __reduce__(self):
+ return _SerializableAttrGetter, (self.name, )
+
def attribute_mapped_collection(attr_name):
"""A dictionary-based collection type with attribute-based keying.
@@ -163,7 +185,8 @@ def attribute_mapped_collection(attr_name):
after a session flush.
"""
- return lambda: MappedCollection(operator.attrgetter(attr_name))
+ getter = _SerializableAttrGetter(attr_name)
+ return lambda: MappedCollection(getter)
def mapped_collection(keyfunc):