diff options
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 185 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 53 | ||||
| -rw-r--r-- | test/orm/test_mapper.py | 33 |
3 files changed, 122 insertions, 149 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d1b725c8e..f16376ce6 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -162,80 +162,6 @@ class ColumnProperty(StrategizedProperty): log.class_logger(ColumnProperty) -class CompositeProperty(ColumnProperty): - """subclasses ColumnProperty to provide composite type support.""" - - def __init__(self, class_, *columns, **kwargs): - super(CompositeProperty, self).__init__(*columns, **kwargs) - self._col_position_map = util.column_dict( - (c, i) for i, c - in enumerate(columns)) - self.composite_class = class_ - self.strategy_class = strategies.CompositeColumnLoader - - def copy(self): - return CompositeProperty( - deferred=self.deferred, - group=self.group, - composite_class=self.composite_class, - active_history=self.active_history, - *self.columns) - - def do_init(self): - # skip over ColumnProperty's do_init(), - # which issues assertions that do not apply to CompositeColumnProperty - super(ColumnProperty, self).do_init() - - def _getcommitted(self, state, dict_, column, passive=False): - # TODO: no coverage here - obj = state.get_impl(self.key).\ - get_committed_value(state, dict_, passive=passive) - return self.get_col_value(column, obj) - - def set_col_value(self, state, dict_, value, column): - obj = state.get_impl(self.key).get(state, dict_) - if obj is None: - obj = self.composite_class(*[None for c in self.columns]) - state.get_impl(self.key).set(state, state.dict, obj, None) - - if hasattr(obj, '__set_composite_values__'): - values = list(obj.__composite_values__()) - values[self._col_position_map[column]] = value - obj.__set_composite_values__(*values) - else: - setattr(obj, column.key, value) - - def get_col_value(self, column, value): - if value is None: - return None - for a, b in zip(self.columns, value.__composite_values__()): - if a is column: - return b - - class Comparator(PropComparator): - def __clause_element__(self): - if self.adapter: - # TODO: test coverage for adapted composite comparison - return expression.ClauseList( - *[self.adapter(x) for x in self.prop.columns]) - else: - return expression.ClauseList(*self.prop.columns) - - __hash__ = None - - def __eq__(self, other): - if other is None: - values = [None] * len(self.prop.columns) - else: - values = other.__composite_values__() - return sql.and_( - *[a==b for a, b in zip(self.prop.columns, values)]) - - def __ne__(self, other): - return sql.not_(self.__eq__(other)) - - def __str__(self): - return str(self.parent.class_.__name__) + "." + self.key class DescriptorProperty(MapperProperty): """:class:`MapperProperty` which proxies access to a @@ -243,7 +169,9 @@ class DescriptorProperty(MapperProperty): def instrument_class(self, mapper): from sqlalchemy.ext import hybrid - + + prop = self + # hackety hack hack class _ProxyImpl(object): accepts_scalar_loader = False @@ -251,7 +179,11 @@ class DescriptorProperty(MapperProperty): def __init__(self, key): self.key = key - + + if hasattr(prop, 'get_history'): + def get_history(self, state, dict_, **kw): + return prop.get_history(state, dict_, **kw) + if self.descriptor is None: desc = getattr(mapper.class_, self.key, None) if mapper._is_userland_descriptor(desc): @@ -296,7 +228,7 @@ class DescriptorProperty(MapperProperty): descriptor.expr = get_comparator descriptor.impl = _ProxyImpl(self.key) mapper.class_manager.instrument_attribute(self.key, descriptor) - + def setup(self, context, entity, path, adapter, **kwargs): pass @@ -307,6 +239,105 @@ class DescriptorProperty(MapperProperty): dest_state, dest_dict, load, _recursive): pass +class CompositeProperty(DescriptorProperty): + + def __init__(self, class_, *columns, **kwargs): + self.columns = columns + self.composite_class = class_ + self.active_history = kwargs.get('active_history', False) + self.deferred = kwargs.get('deferred', False) + self.group = kwargs.get('group', None) + + prop = self + def fget(instance): + return prop.composite_class( + *[getattr(instance, prop.parent._columntoproperty[col].key) + for col in prop.columns] + ) + def fset(instance, value): + if value is None: + fdel(instance) + else: + for col, value in zip(prop.columns, value.__composite_values__()): + setattr(instance, prop.parent._columntoproperty[col].key, value) + + def fdel(instance): + for col in prop.columns: + setattr(instance, prop.parent._columntoproperty[col].key, None) + self.descriptor = property(fget, fset, fdel) + + def get_history(self, state, dict_, **kw): + """Provided for userland code that uses attributes.get_history().""" + + added = [] + deleted = [] + + has_history = False + for col in self.columns: + key = self.parent._columntoproperty[col].key + hist = state.manager[key].impl.get_history(state, dict_) + if hist.has_changes(): + has_history = True + + added.extend(hist.non_deleted()) + if hist.deleted: + deleted.extend(hist.deleted) + else: + deleted.append(None) + + if has_history: + return attributes.History( + [self.composite_class(*added)], + (), + [self.composite_class(*deleted)] + ) + else: + return attributes.History( + (),[self.composite_class(*added)], () + ) + + def do_init(self): + for col in self.columns: + prop = self.parent._columntoproperty[col] + prop.active_history = self.active_history + if self.deferred: + prop.deferred = self.deferred + prop.strategy_class = strategies.DeferredColumnLoader + prop.group = self.group + # strategies ... + + def _comparator_factory(self, mapper): + return CompositeProperty.Comparator(self) + + class Comparator(PropComparator): + def __init__(self, prop, adapter=None): + self.prop = prop + self.adapter = adapter + + def __clause_element__(self): + if self.adapter: + # TODO: test coverage for adapted composite comparison + return expression.ClauseList( + *[self.adapter(x) for x in self.prop.columns]) + else: + return expression.ClauseList(*self.prop.columns) + + __hash__ = None + + def __eq__(self, other): + if other is None: + values = [None] * len(self.prop.columns) + else: + values = other.__composite_values__() + return sql.and_( + *[a==b for a, b in zip(self.prop.columns, values)]) + + def __ne__(self, other): + return sql.not_(self.__eq__(other)) + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + class ConcreteInheritedProperty(DescriptorProperty): """A 'do nothing' :class:`MapperProperty` that disables an attribute on a concrete subclass that is only present diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c4619d3a7..21f22ef50 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -138,59 +138,6 @@ class ColumnLoader(LoaderStrategy): log.class_logger(ColumnLoader) -class CompositeColumnLoader(ColumnLoader): - """Strategize the loading of a composite column-based MapperProperty.""" - - def init_class_attribute(self, mapper): - self.is_class_level = True - self.logger.info("%s register managed composite attribute", self) - - def copy(obj): - if obj is None: - return None - return self.parent_property.\ - composite_class(*obj.__composite_values__()) - - def compare(a, b): - if a is None or b is None: - return a is b - - for col, aprop, bprop in zip(self.columns, - a.__composite_values__(), - b.__composite_values__()): - if not col.type.compare_values(aprop, bprop): - return False - else: - return True - - _register_attribute(self, mapper, useobject=False, - compare_function=compare, - copy_function=copy, - mutable_scalars=True, - active_history=self.parent_property.active_history, - ) - - def create_row_processor(self, selectcontext, path, mapper, - row, adapter): - key = self.key - columns = self.columns - composite_class = self.parent_property.composite_class - if adapter: - columns = [adapter.columns[c] for c in columns] - - for c in columns: - if c not in row: - def new_execute(state, dict_, row): - state.expire_attribute_pre_commit(dict_, key) - break - else: - def new_execute(state, dict_, row): - dict_[key] = composite_class(*[row[c] for c in columns]) - - return new_execute, None, None - -log.class_logger(CompositeColumnLoader) - class DeferredColumnLoader(LoaderStrategy): """Strategize the loading of a deferred column-based MapperProperty.""" diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index c94ef9b3f..621e5f47c 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import mapper, relationship, backref, \ validates, aliased, Mapper from sqlalchemy.orm import defer, deferred, synonym, attributes, \ column_property, composite, relationship, dynamic_loader, \ - comparable_property, AttributeExtension + comparable_property, AttributeExtension, Session from sqlalchemy.orm.instrumentation import ClassManager from test.lib.testing import eq_, AssertsCompiledSQL from test.orm import _base, _fixtures @@ -2213,39 +2213,34 @@ class CompositeTypesTest(_base.MappedTest): 'end': sa.orm.composite(Point, edges.c.x2, edges.c.y2) }) - sess = create_session() + sess = Session() g = Graph() g.id = 1 g.version_id=1 g.edges.append(Edge(Point(3, 4), Point(5, 6))) g.edges.append(Edge(Point(14, 5), Point(2, 7))) sess.add(g) - sess.flush() + sess.commit() - sess.expunge_all() g2 = sess.query(Graph).get([g.id, g.version_id]) for e1, e2 in zip(g.edges, g2.edges): eq_(e1.start, e2.start) eq_(e1.end, e2.end) g2.edges[1].end = Point(18, 4) - sess.flush() - sess.expunge_all() + sess.commit() + e = sess.query(Edge).get(g2.edges[1].id) eq_(e.end, Point(18, 4)) - - e.end.x = 19 - e.end.y = 5 - sess.flush() - sess.expunge_all() - eq_(sess.query(Edge).get(g2.edges[1].id).end, Point(19, 5)) - - g.edges[1].end = Point(19, 5) - + + e.end = Point(19, 5) + sess.commit() + g.id, g.version_id, g.edges sess.expunge_all() + def go(): - g2 = (sess.query(Graph). - options(sa.orm.joinedload('edges'))).get([g.id, g.version_id]) + g2 = sess.query(Graph).\ + options(sa.orm.joinedload('edges')).get([g.id, g.version_id]) for e1, e2 in zip(g.edges, g2.edges): eq_(e1.start, e2.start) eq_(e1.end, e2.end) @@ -2261,9 +2256,9 @@ class CompositeTypesTest(_base.MappedTest): # query by columns eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)]) - + e = g.edges[1] - e.end.x = e.end.y = None + del e.end sess.flush() eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)]) |
