diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-04-04 00:21:28 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-04-04 00:21:28 +0000 |
| commit | 1dbed0b2b4446408f14a87d94f9c0c6b3356fcf2 (patch) | |
| tree | 2dcd585d8472170cdcfd8bdbd658f60a4728a282 /lib/sqlalchemy/orm | |
| parent | 9dd01e52e2e0755bbaf6e08b048e9a48b55879d1 (diff) | |
| download | sqlalchemy-1dbed0b2b4446408f14a87d94f9c0c6b3356fcf2.tar.gz | |
- merged sync_simplify branch
- The methodology behind "primaryjoin"/"secondaryjoin" has
been refactored. Behavior should be slightly more
intelligent, primarily in terms of error messages which
have been pared down to be more readable. In a slight
number of scenarios it can better resolve the correct
foreign key than before.
- moved collections unit test from relationships.py to collection.py
- PropertyLoader now has "synchronize_pairs" and "equated_pairs"
collections which allow easy access to the source/destination
parent/child relation between columns (might change names)
- factored out ClauseSynchronizer (finally)
- added many more tests for priamryjoin/secondaryjoin
error checks
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/dependency.py | 62 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 49 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 169 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 67 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/sync.py | 227 |
6 files changed, 240 insertions, 338 deletions
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 8519d2260..c667460a7 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -11,8 +11,8 @@ """ from sqlalchemy.orm import sync -from sqlalchemy.orm.sync import ONETOMANY,MANYTOONE,MANYTOMANY from sqlalchemy import sql, util, exceptions +from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY def create_dependency_processor(prop): @@ -43,8 +43,8 @@ class DependencyProcessor(object): self.passive_updates = prop.passive_updates self.enable_typechecks = prop.enable_typechecks self.key = prop.key - - self._compile_synchronizers() + if not self.prop.synchronize_pairs: + raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s. No target attributes to populate between parent and child are present" % self.prop) def _get_instrumented_attribute(self): """Return the ``InstrumentedAttribute`` handled by this @@ -121,20 +121,6 @@ class DependencyProcessor(object): raise NotImplementedError() - def _compile_synchronizers(self): - """Assemble a list of *synchronization rules*. - - These are fired to populate attributes from one side - of a relation to another. - """ - - self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction) - if self.direction == sync.MANYTOMANY: - self.syncrules.compile(self.prop.primaryjoin, issecondary=False, foreign_keys=self.foreign_keys) - self.syncrules.compile(self.prop.secondaryjoin, issecondary=True, foreign_keys=self.foreign_keys) - else: - self.syncrules.compile(self.prop.primaryjoin, foreign_keys=self.foreign_keys) - def _conditional_post_update(self, state, uowcommit, related): """Execute a post_update call. @@ -153,11 +139,11 @@ class DependencyProcessor(object): if state is not None and self.post_update: for x in related: if x is not None: - uowcommit.register_object(state, postupdate=True, post_update_cols=self.syncrules.dest_columns()) + uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs]) break def _pks_changed(self, uowcommit, state): - return self.syncrules.source_changes(uowcommit, state) + raise NotImplementedError() def __str__(self): return "%s(%s)" % (self.__class__.__name__, str(self.prop)) @@ -259,7 +245,13 @@ class OneToManyDP(DependencyProcessor): if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): return self._verify_canload(child) - self.syncrules.execute(source, dest, source, child, clearkeys) + if clearkeys: + sync.clear(dest, self.mapper, self.prop.synchronize_pairs) + else: + sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs) + + def _pks_changed(self, uowcommit, state): + return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs) class DetectKeySwitch(DependencyProcessor): """a special DP that works for many-to-one relations, fires off for @@ -298,7 +290,11 @@ class DetectKeySwitch(DependencyProcessor): elem.dict[self.key]._state in switchers ]: uowcommit.register_object(s, listonly=self.passive_updates) - self.syncrules.execute(s.dict[self.key]._state, s, None, None, False) + sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs) + #self.syncrules.execute(s.dict[self.key]._state, s, None, None, False) + + def _pks_changed(self, uowcommit, state): + return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs) class ManyToOneDP(DependencyProcessor): def __init__(self, prop): @@ -368,12 +364,14 @@ class ManyToOneDP(DependencyProcessor): def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): - source = child - dest = state - if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): + if state is None or (not self.post_update and uowcommit.is_deleted(state)): return - self._verify_canload(child) - self.syncrules.execute(source, dest, dest, child, clearkeys) + + if clearkeys or child is None: + sync.clear(state, self.parent, self.prop.synchronize_pairs) + else: + self._verify_canload(child) + sync.populate(child, self.mapper, state, self.parent, self.prop.synchronize_pairs) class ManyToManyDP(DependencyProcessor): def register_dependencies(self, uowcommit): @@ -433,7 +431,10 @@ class ManyToManyDP(DependencyProcessor): if not self.passive_updates and unchanged and self._pks_changed(uowcommit, state): for child in unchanged: associationrow = {} - self.syncrules.update(associationrow, state, child, "old_") + sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs) + sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs) + + #self.syncrules.update(associationrow, state, child, "old_") secondary_update.append(associationrow) if secondary_delete: @@ -470,7 +471,12 @@ class ManyToManyDP(DependencyProcessor): if associationrow is None: return self._verify_canload(child) - self.syncrules.execute(None, associationrow, state, child, clearkeys) + + sync.populate_dict(state, self.parent, associationrow, self.prop.synchronize_pairs) + sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs) + + def _pks_changed(self, uowcommit, state): + return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs) class AssociationDP(OneToManyDP): def __init__(self, *args, **kwargs): diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index f00c42421..d61ebe960 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -27,6 +27,10 @@ __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension', EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE') EXT_STOP = util.symbol('EXT_STOP') +ONETOMANY = util.symbol('ONETOMANY') +MANYTOONE = util.symbol('MANYTOONE') +MANYTOMANY = util.symbol('MANYTOMANY') + class MapperExtension(object): """Base implementation for customizing Mapper behavior. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 12e7d03a9..22f5678d6 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -111,7 +111,8 @@ class Mapper(object): self._dependency_processors = [] self._clause_adapter = None self._requires_row_aliasing = False - + self.__inherits_equated_pairs = None + if not issubclass(class_, object): raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) @@ -171,11 +172,11 @@ class Mapper(object): self.__should_log_info = logging.is_info_enabled(self.logger) self.__should_log_debug = logging.is_debug_enabled(self.logger) - self._compile_class() - self._compile_inheritance() - self._compile_extensions() - self._compile_properties() - self._compile_pks() + self.__compile_class() + self.__compile_inheritance() + self.__compile_extensions() + self.__compile_properties() + self.__compile_pks() global __new_mappers __new_mappers = True self.__log("constructed") @@ -352,17 +353,17 @@ class Mapper(object): to execute once all mappers have been constructed. """ - self.__log("_initialize_properties() started") + self.__log("__initialize_properties() started") l = [(key, prop) for key, prop in self.__props.iteritems()] for key, prop in l: self.__log("initialize prop " + key) if getattr(prop, 'key', None) is None: prop.init(key, self) - self.__log("_initialize_properties() complete") + self.__log("__initialize_properties() complete") self.__props_init = True - def _compile_extensions(self): + def __compile_extensions(self): """Go through the global_extensions list as well as the list of ``MapperExtensions`` specified for this ``Mapper`` and creates a linked list of those extensions. @@ -393,7 +394,7 @@ class Mapper(object): for ext in extlist: self.extension.append(ext) - def _compile_inheritance(self): + def __compile_inheritance(self): """Configure settings related to inherting and/or inherited mappers being present.""" if self.inherits: @@ -412,7 +413,6 @@ class Mapper(object): self.single = True if not self.local_table is self.inherits.local_table: if self.concrete: - self._synchronizer = None self.mapped_table = self.local_table for mapper in self.iterate_to_root(): if mapper.polymorphic_on: @@ -424,17 +424,10 @@ class Mapper(object): # stuff we dont want (allows test/inheritance.InheritTest4 to pass) self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) - # generate sync rules. similarly to creating the on clause, specify a - # stricter set of tables to create "sync rules" by,based on the immediate - # inherited table, rather than all inherited tables - self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - if self.inherit_foreign_keys: - fks = util.Set(self.inherit_foreign_keys) - else: - fks = None - self._synchronizer.compile(self.mapped_table.onclause, foreign_keys=fks) + + fks = util.to_set(self.inherit_foreign_keys) + self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks) else: - self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity is not None: self.inherits.polymorphic_map[self.polymorphic_identity] = self @@ -470,7 +463,6 @@ class Mapper(object): else: self._all_tables = util.Set() self.base_mapper = self - self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity: if self.polymorphic_on is None: @@ -481,7 +473,7 @@ class Mapper(object): if self.mapped_table is None: raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) - def _compile_pks(self): + def __compile_pks(self): self.tables = sqlutil.find_tables(self.mapped_table) @@ -634,7 +626,7 @@ class Mapper(object): return getattr(getattr(cls, clskey), key) - def _compile_properties(self): + def __compile_properties(self): # object attribute names mapped to MapperProperty objects self.__props = util.OrderedDict() @@ -770,7 +762,7 @@ class Mapper(object): for mapper in self._inheriting_mappers: mapper._adapt_inherited_property(key, prop) - def _compile_class(self): + def __compile_class(self): """If this mapper is to be a primary mapper (i.e. the non_primary flag is not set), associate this Mapper with the given class_ and entity name. @@ -1169,8 +1161,8 @@ class Mapper(object): # TODO: this fires off more than needed, try to organize syncrules # per table for m in util.reversed(list(mapper.iterate_to_root())): - if m._synchronizer: - m._synchronizer.execute(state, state) + if m.__inherits_equated_pairs: + m._synchronize_inherited(state) # testlib.pragma exempt:__hash__ inserted_objects.add((state, connection)) @@ -1186,6 +1178,9 @@ class Mapper(object): if 'after_update' in mapper.extension.methods: mapper.extension.after_update(mapper, connection, state.obj()) + def _synchronize_inherited(self, state): + sync.populate(state, self, state, self, self.__inherits_equated_pairs) + def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated values on an instance. For columns which are marked as being generated diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9f8e852f1..fb10357cf 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -11,16 +11,15 @@ invidual ORM-mapped attributes. """ from sqlalchemy import sql, schema, util, exceptions, logging -from sqlalchemy.sql.util import ClauseAdapter +from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns from sqlalchemy.sql import visitors, operators, ColumnElement from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm.mapper import _class_to_mapper from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses -from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty +from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY from sqlalchemy.exceptions import ArgumentError - __all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty', 'ComparableProperty', 'PropertyLoader', 'BackRef') @@ -288,7 +287,7 @@ class PropertyLoader(StrategizedProperty): def __eq__(self, other): if other is None: - if self.prop.direction == sync.ONETOMANY: + if self.prop.direction == ONETOMANY: return ~sql.exists([1], self.prop.primaryjoin) else: return self.prop._optimized_compare(None) @@ -377,7 +376,7 @@ class PropertyLoader(StrategizedProperty): def __ne__(self, other): if other is None: - if self.prop.direction == sync.MANYTOONE: + if self.prop.direction == MANYTOONE: return sql.or_(*[x!=None for x in self.prop.foreign_keys]) elif self.prop.uselist: return self.any() @@ -475,14 +474,14 @@ class PropertyLoader(StrategizedProperty): return self.argument.class_ def do_init(self): - self._determine_targets() - self._determine_joins() - self._determine_fks() - self._determine_direction() - self._determine_remote_side() + self.__determine_targets() + self.__determine_joins() + self.__determine_fks() + self.__determine_direction() + self.__determine_remote_side() self._post_init() - def _determine_targets(self): + def __determine_targets(self): if isinstance(self.argument, type): self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False) elif isinstance(self.argument, mapper.Mapper): @@ -507,10 +506,12 @@ class PropertyLoader(StrategizedProperty): if self.cascade.delete_orphan: if self.parent.class_ is self.mapper.class_: - raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade rule on a self-referential relationship. You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self))) + raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade " + "rule on a self-referential relationship. " + "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self))) self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_)) - def _determine_joins(self): + def __determine_joins(self): if self.secondaryjoin is not None and self.secondary is None: raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") # if join conditions were not specified, figure them out based on foreign keys @@ -535,10 +536,11 @@ class PropertyLoader(StrategizedProperty): if self.primaryjoin is None: self.primaryjoin = _search_for_join(self.parent, self.target).onclause except exceptions.ArgumentError, e: - raise exceptions.ArgumentError("""Error determining primary and/or secondary join for relationship '%s'. If the underlying error cannot be corrected, you should specify the 'primaryjoin' (and 'secondaryjoin', if there is an association table present) keyword arguments to the relation() function (or for backrefs, by specifying the backref using the backref() function with keyword arguments) to explicitly specify the join conditions. Nested error is \"%s\"""" % (str(self), str(e))) + raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s. " + "Specify a 'primaryjoin' expression. If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self)) - def _col_is_part_of_mappings(self, column): + def __col_is_part_of_mappings(self, column): if self.secondary is None: return self.parent.mapped_table.c.contains_column(column) or \ self.target.c.contains_column(column) @@ -547,61 +549,77 @@ class PropertyLoader(StrategizedProperty): self.target.c.contains_column(column) or \ self.secondary.c.contains_column(column) is not None - def _determine_fks(self): + def __determine_fks(self): if self._legacy_foreignkey and not self._refers_to_parent_table(): self.foreign_keys = self._legacy_foreignkey - self._opposite_side = util.Set() + arg_foreign_keys = self.foreign_keys + + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly) + eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] + + if not eq_pairs: + if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): + raise exceptions.ArgumentError("Could not locate any equated column pairs for primaryjoin condition '%s' on relation %s. " + "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.primaryjoin, self) + ) + else: + raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " + "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self)) + + self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs]) + self._opposite_side = util.OrderedSet([l for l, r in eq_pairs]) + self.synchronize_pairs = eq_pairs + + if self.secondaryjoin: + sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys) + sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] + + if not sq_pairs: + if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): + raise exceptions.ArgumentError("Could not locate any equated column pairs for secondaryjoin condition '%s' on relation %s. " + "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.secondaryjoin, self) + ) + else: + raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. " + "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self)) - if self.foreign_keys: - def visit_binary(binary): - if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): - return - if binary.left in self.foreign_keys: - self._opposite_side.add(binary.right) - if binary.right in self.foreign_keys: - self._opposite_side.add(binary.left) + self.foreign_keys.update([r for l, r in sq_pairs]) + self._opposite_side.update([l for l, r in sq_pairs]) + self.secondary_synchronize_pairs = sq_pairs else: - self.foreign_keys = util.Set() - def visit_binary(binary): - if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): - return - - # this check is for when the user put the "view_only" flag on and has tables that have nothing - # to do with the relationship's parent/child mappings in the join conditions. we dont want cols - # or clauses related to those external tables dealt with. see orm.relationships.ViewOnlyTest - if not self._col_is_part_of_mappings(binary.left) or not self._col_is_part_of_mappings(binary.right): - return - - for f in binary.left.foreign_keys: - if f.references(binary.right.table): - self.foreign_keys.add(binary.left) - self._opposite_side.add(binary.right) - for f in binary.right.foreign_keys: - if f.references(binary.left.table): - self.foreign_keys.add(binary.right) - self._opposite_side.add(binary.left) - - visitors.traverse(self.primaryjoin, visit_binary=visit_binary) - - if not self.foreign_keys: - raise exceptions.ArgumentError( - "Can't locate any foreign key columns in primary join " - "condition '%s' for relationship '%s'. Specify " - "'foreign_keys' argument to indicate which columns in " - "the join condition are foreign." %(str(self.primaryjoin), str(self))) - - if self.secondaryjoin is not None: - visitors.traverse(self.secondaryjoin, visit_binary=visit_binary) + self.secondary_synchronize_pairs = None + + def equated_pairs(self): + return zip(self.local_side, self.remote_side) + equated_pairs = property(equated_pairs) + + def __determine_remote_side(self): + if self.remote_side: + if self.direction is MANYTOONE: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True) + else: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True) + if self.secondaryjoin: + sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True) + sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] + eq_pairs += sq_pairs + else: + eq_pairs = zip(self._opposite_side, self.foreign_keys) - def _determine_direction(self): + if self.direction is MANYTOONE: + self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] + else: + self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] + + def __determine_direction(self): """Determine our *direction*, i.e. do we represent one to many, many to many, etc. """ if self.secondaryjoin is not None: - self.direction = sync.MANYTOMANY + self.direction = MANYTOMANY elif self._refers_to_parent_table(): # for a self referential mapper, if the "foreignkey" is a single or composite primary key, # then we are "many to one", since the remote site of the relationship identifies a singular entity. @@ -609,19 +627,19 @@ class PropertyLoader(StrategizedProperty): if self._legacy_foreignkey: for f in self._legacy_foreignkey: if not f.primary_key: - self.direction = sync.ONETOMANY + self.direction = ONETOMANY else: - self.direction = sync.MANYTOONE + self.direction = MANYTOONE elif self.remote_side: for f in self.foreign_keys: if f in self.remote_side: - self.direction = sync.ONETOMANY + self.direction = ONETOMANY return else: - self.direction = sync.MANYTOONE + self.direction = MANYTOONE else: - self.direction = sync.ONETOMANY + self.direction = ONETOMANY else: for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]: onetomany = [c for c in self.foreign_keys if mappedtable.c.contains_column(c)] @@ -635,10 +653,10 @@ class PropertyLoader(StrategizedProperty): elif onetomany and manytoone: continue elif onetomany: - self.direction = sync.ONETOMANY + self.direction = ONETOMANY break elif manytoone: - self.direction = sync.MANYTOONE + self.direction = MANYTOONE break else: raise exceptions.ArgumentError( @@ -647,24 +665,15 @@ class PropertyLoader(StrategizedProperty): "the child's mapped tables. Specify 'foreign_keys' " "argument." % (str(self))) - def _determine_remote_side(self): - if not self.remote_side: - if self.direction is sync.MANYTOONE: - self.remote_side = util.Set(self._opposite_side) - elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: - self.remote_side = util.Set(self.foreign_keys) - - self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side) - def _post_init(self): if logging.is_info_enabled(self.logger): self.logger.info(str(self) + " setup primary join " + str(self.primaryjoin)) self.logger.info(str(self) + " setup secondary join " + str(self.secondaryjoin)) - self.logger.info(str(self) + " foreign keys " + str([str(c) for c in self.foreign_keys])) - self.logger.info(str(self) + " remote columns " + str([str(c) for c in self.remote_side])) - self.logger.info(str(self) + " relation direction " + (self.direction is sync.ONETOMANY and "one-to-many" or (self.direction is sync.MANYTOONE and "many-to-one" or "many-to-many"))) + self.logger.info(str(self) + " synchronize pairs " + ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs])) + self.logger.info(str(self) + " equated pairs " + ",".join(["(%s == %s)" % (l, r) for l, r in self.equated_pairs])) + self.logger.info(str(self) + " relation direction " + (self.direction is ONETOMANY and "one-to-many" or (self.direction is MANYTOONE and "many-to-one" or "many-to-many"))) - if self.uselist is None and self.direction is sync.MANYTOONE: + if self.uselist is None and self.direction is MANYTOONE: self.uselist = False if self.uselist is None: @@ -712,9 +721,9 @@ class PropertyLoader(StrategizedProperty): primaryjoin = self.primaryjoin if fromselectable is not frommapper.local_table: - if self.direction is sync.ONETOMANY: + if self.direction is ONETOMANY: primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) - elif self.direction is sync.MANYTOONE: + elif self.direction is MANYTOONE: primaryjoin = ClauseAdapter(fromselectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) elif self.secondaryjoin: primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 4028eed6a..6bd0d530f 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -255,14 +255,14 @@ NoLoader.logger = logging.class_logger(NoLoader) class LazyLoader(AbstractRelationLoader): def init(self): super(LazyLoader, self).init() - (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self.parent_property) + (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property) - self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere)) + self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere)) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() #from sqlalchemy.orm import query - self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.lazywhere) + self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere) if self.use_get: self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads") @@ -275,10 +275,9 @@ class LazyLoader(AbstractRelationLoader): return self._lazy_none_clause(reverse_direction) if not reverse_direction: - (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns) + (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) else: - (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) - bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) + (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) def visit_bindparam(bindparam): mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent @@ -291,10 +290,9 @@ class LazyLoader(AbstractRelationLoader): def _lazy_none_clause(self, reverse_direction=False): if not reverse_direction: - (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns) + (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) else: - (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) - bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) + (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) def visit_binary(binary): mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent @@ -351,24 +349,20 @@ class LazyLoader(AbstractRelationLoader): instance._state.reset(self.key) return (new_execute, None, None) - def _create_lazy_clause(cls, prop, reverse_direction=False): - (primaryjoin, secondaryjoin, remote_side) = (prop.primaryjoin, prop.secondaryjoin, prop.remote_side) - + def __create_lazy_clause(cls, prop, reverse_direction=False): binds = {} equated_columns = {} + secondaryjoin = prop.secondaryjoin + equated = dict(prop.equated_pairs) + def should_bind(targetcol, othercol): - if not prop._col_is_part_of_mappings(targetcol): - return False - if reverse_direction and not secondaryjoin: - return targetcol in remote_side + return othercol in equated else: - return othercol in remote_side + return targetcol in equated def visit_binary(binary): - if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement): - return leftcol = binary.left rightcol = binary.right @@ -376,31 +370,28 @@ class LazyLoader(AbstractRelationLoader): equated_columns[leftcol] = rightcol if should_bind(leftcol, rightcol): - if leftcol in binds: - binary.left = binds[leftcol] - else: - binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type) + if leftcol not in binds: + binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type) + binary.left = binds[leftcol] + elif should_bind(rightcol, leftcol): + if rightcol not in binds: + binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type) + binary.right = binds[rightcol] - # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1", - # which can happen in rare cases (test/orm/relationships.py RelationTest2) - if leftcol is not rightcol and should_bind(rightcol, leftcol): - if rightcol in binds: - binary.right = binds[rightcol] - else: - binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type) - - - lazywhere = primaryjoin + lazywhere = prop.primaryjoin - if not secondaryjoin or not reverse_direction: + if not prop.secondaryjoin or not reverse_direction: lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary) - if secondaryjoin is not None: + if prop.secondaryjoin is not None: if reverse_direction: secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary) lazywhere = sql.and_(lazywhere, secondaryjoin) - return (lazywhere, binds, equated_columns) - _create_lazy_clause = classmethod(_create_lazy_clause) + + bind_to_col = dict([(binds[col].key, col) for col in binds]) + + return (lazywhere, bind_to_col, equated_columns) + __create_lazy_clause = classmethod(__create_lazy_clause) LazyLoader.logger = logging.class_logger(LazyLoader) @@ -452,7 +443,7 @@ class LoadLazyAttribute(object): ident = [] allnulls = True for primary_key in prop.mapper.primary_key: - val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key]) + val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key]) allnulls = allnulls and val is None ident.append(val) if allnulls: diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index d95009a47..39a7b5044 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -4,186 +4,83 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Contains the ClauseSynchronizer class, which is used to map -attributes between two objects in a manner corresponding to a SQL -clause that compares column values. +"""private module containing functions used for copying data between instances +based on join conditions. """ from sqlalchemy import schema, exceptions, util -from sqlalchemy.sql import visitors, operators +from sqlalchemy.sql import visitors, operators, util as sqlutil from sqlalchemy import logging from sqlalchemy.orm import util as mapperutil +from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY # legacy -ONETOMANY = 0 -MANYTOONE = 1 -MANYTOMANY = 2 - -class ClauseSynchronizer(object): - """Given a SQL clause, usually a series of one or more binary - expressions between columns, and a set of 'source' and - 'destination' mappers, compiles a set of SyncRules corresponding - to that information. - - The ClauseSynchronizer can then be executed given a set of - parent/child objects or destination dictionary, which will iterate - through each of its SyncRules and execute them. Each SyncRule - will copy the value of a single attribute from the parent to the - child, corresponding to the pair of columns in a particular binary - expression, using the source and destination mappers to map those - two columns to object attributes within parent and child. - """ - - def __init__(self, parent_mapper, child_mapper, direction): - self.parent_mapper = parent_mapper - self.child_mapper = child_mapper - self.direction = direction - self.syncrules = [] - - def compile(self, sqlclause, foreign_keys=None, issecondary=None): - def compile_binary(binary): - """Assemble a SyncRule given a single binary condition.""" - - if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): - return - - source_column = None - dest_column = None - - if foreign_keys is None: - if binary.left.table == binary.right.table: - raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync") - - if binary.left in util.Set([f.column for f in binary.right.foreign_keys]): - dest_column = binary.right - source_column = binary.left - elif binary.right in util.Set([f.column for f in binary.left.foreign_keys]): - dest_column = binary.left - source_column = binary.right - else: - if binary.left in foreign_keys: - source_column = binary.right - dest_column = binary.left - elif binary.right in foreign_keys: - source_column = binary.left - dest_column = binary.right +def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column(source, l) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) - if source_column and dest_column: - if self.direction == ONETOMANY: - self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper)) - elif self.direction == MANYTOONE: - self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper)) - else: - if not issecondary: - self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper, issecondary=issecondary)) - else: - self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary)) + try: + dest_mapper._set_state_attr_by_column(dest, r, value) + except exceptions.UnmappedColumnError: + self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r) - rules_added = len(self.syncrules) - visitors.traverse(sqlclause, visit_binary=compile_binary) - if len(self.syncrules) == rules_added: - raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause)) +def clear(dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + if r.primary_key: + raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest))) + try: + dest_mapper._set_state_attr_by_column(dest, r, None) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(True, None, l, dest_mapper, r) - def dest_columns(self): - return [r.dest_column for r in self.syncrules if r.dest_column is not None] +def update(source, source_mapper, dest, old_prefix, synchronize_pairs): + for l, r in synchronize_pairs: + try: + oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l) + value = source_mapper._get_state_attr_by_column(source, l) + except exceptions.UnmappedColumnError: + self._raise_col_to_prop(False, source_mapper, l, None, r) + dest[r.key] = value + dest[old_prefix + r.key] = oldvalue - def update(self, dest, parent, child, old_prefix): - for rule in self.syncrules: - rule.update(dest, parent, child, old_prefix) - - def execute(self, source, dest, obj=None, child=None, clearkeys=None): - for rule in self.syncrules: - rule.execute(source, dest, obj, child, clearkeys) - - def source_changes(self, uowcommit, source): - for rule in self.syncrules: - if rule.source_changes(uowcommit, source): - return True - else: - return False +def populate_dict(source, source_mapper, dict_, synchronize_pairs): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column(source, l) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) -class SyncRule(object): - """An instruction indicating how to populate the objects on each - side of a relationship. - - E.g. if table1 column A is joined against table2 column - B, and we are a one-to-many from table1 to table2, a syncrule - would say *take the A attribute from object1 and assign it to the - B attribute on object2*. - """ + dict_[r.key] = value - def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None): - self.source_mapper = source_mapper - self.source_column = source_column - self.issecondary = issecondary - self.dest_mapper = dest_mapper - self.dest_column = dest_column - - #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper - - def dest_primary_key(self): - # late-evaluating boolean since some syncs are created - # before the mapper has assembled pks - try: - return self._dest_primary_key - except AttributeError: - self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks - return self._dest_primary_key - - def _raise_col_to_prop(self, isdest): - if isdest: - raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (self.dest_column, self.dest_mapper)) - else: - raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (self.source_column, self.source_mapper, self.dest_column)) - - def source_changes(self, uowcommit, source): +def source_changes(uowcommit, source, source_mapper, synchronize_pairs): + for l, r in synchronize_pairs: try: - prop = self.source_mapper._get_col_to_prop(self.source_column) + prop = source_mapper._get_col_to_prop(l) except exceptions.UnmappedColumnError: - self._raise_col_to_prop(False) + _raise_col_to_prop(False, source_mapper, l, None, r) (added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True) - return bool(added and deleted) - - def update(self, dest, parent, child, old_prefix): - if self.issecondary is False: - source = parent - elif self.issecondary is True: - source = child + if added and deleted: + return True + else: + return False + +def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: try: - oldvalue = self.source_mapper._get_committed_attr_by_column(source.obj(), self.source_column) - value = self.source_mapper._get_state_attr_by_column(source, self.source_column) + prop = dest_mapper._get_col_to_prop(r) except exceptions.UnmappedColumnError: - self._raise_col_to_prop(False) - dest[self.dest_column.key] = value - dest[old_prefix + self.dest_column.key] = oldvalue + _raise_col_to_prop(True, None, l, dest_mapper, r) + (added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True) + if added and deleted: + return True + else: + return False + +def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column): + if isdest: + raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper)) + else: + raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column)) - def execute(self, source, dest, parent, child, clearkeys): - # TODO: break the "dictionary" case into a separate method like 'update' above, - # reduce conditionals - if source is None: - if self.issecondary is False: - source = parent - elif self.issecondary is True: - source = child - if clearkeys or source is None: - value = None - clearkeys = True - else: - try: - value = self.source_mapper._get_state_attr_by_column(source, self.source_column) - except exceptions.UnmappedColumnError: - self._raise_col_to_prop(False) - if isinstance(dest, dict): - dest[self.dest_column.key] = value - else: - if clearkeys and self.dest_primary_key(): - raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.state_str(dest))) - - if logging.is_debug_enabled(self.logger): - self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.state_str(source), str(self.source_column), mapperutil.state_str(dest), str(self.dest_column), value)) - try: - self.dest_mapper._set_state_attr_by_column(dest, self.dest_column, value) - except exceptions.UnmappedColumnError: - self._raise_col_to_prop(True) - -SyncRule.logger = logging.class_logger(SyncRule) - |
