diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-08 10:14:36 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-08 10:14:36 -0500 |
| commit | d1414ad20524c421aa78272c03dce5f839a0aab6 (patch) | |
| tree | 3e1ce8014fe934a5cab201073076ca8b302623b8 /lib | |
| parent | f39c43083a612fdb77dbb3eb2c297cde9662fe81 (diff) | |
| download | sqlalchemy-d1414ad20524c421aa78272c03dce5f839a0aab6.tar.gz | |
simplify remote annotation significantly, and also
catch the actual remote columns more accurately.
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 237 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 5 |
4 files changed, 148 insertions, 125 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9bab0c2f4..953430162 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -949,6 +949,7 @@ class RelationshipProperty(StrategizedProperty): assert self.jc.direction is self.direction assert self.jc.remote_side == self.remote_side assert self.jc.local_remote_pairs == self.local_remote_pairs + pass def _check_conflicts(self): """Test that this relationship is legal, warn about @@ -1510,6 +1511,7 @@ class RelationshipProperty(StrategizedProperty): return strategy.use_get def _refers_to_parent_table(self): + alt = self._alt_refers_to_parent_table() pt = self.parent.mapped_table mt = self.mapper.mapped_table for c, f in self.synchronize_pairs: @@ -1519,10 +1521,35 @@ class RelationshipProperty(StrategizedProperty): mt.is_derived_from(c.table) and \ mt.is_derived_from(f.table) ): + assert alt return True else: + assert not alt return False + def _alt_refers_to_parent_table(self): + pt = self.parent.mapped_table + mt = self.mapper.mapped_table + result = [False] + def visit_binary(binary): + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) and \ + isinstance(f, expression.ColumnClause) and \ + pt.is_derived_from(c.table) and \ + pt.is_derived_from(f.table) and \ + mt.is_derived_from(c.table) and \ + mt.is_derived_from(f.table) + ): + result[0] = True + + visitors.traverse( + self.primaryjoin, + {}, + {"binary":visit_binary} + ) + return result[0] + @util.memoized_property def _is_self_referential(self): return self.mapper.common_parent(self.parent) diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 02eab9c2d..cb07f234a 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -19,6 +19,27 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY +def remote(expr): + return _annotate_columns(expr, {"remote":True}) + +def foreign(expr): + return _annotate_columns(expr, {"foreign":True}) + +def remote_foreign(expr): + return _annotate_columns(expr, {"foreign":True, + "remote":True}) + +def _annotate_columns(element, annotations): + def clone(elem): + if isinstance(elem, expression.ColumnClause): + elem = elem._annotate(annotations.copy()) + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + class JoinCondition(object): def __init__(self, parent_selectable, @@ -55,7 +76,8 @@ class JoinCondition(object): self.support_sync = support_sync self.can_be_synced_fn = can_be_synced_fn self._determine_joins() - self._parse_joins() + self._annotate_fks() + self._annotate_remote() self._determine_direction() def _determine_joins(self): @@ -106,13 +128,7 @@ class JoinCondition(object): "'secondaryjoin' is needed as well." % self.prop) - def _parse_joins(self): - """Apply 'remote', 'local' and 'foreign' annotations - to the primary and secondary join conditions. - - """ - parentcols = util.column_set(self.parent_selectable.c) - targetcols = util.column_set(self.child_selectable.c) + def _annotate_fks(self): if self.secondary is not None: secondarycols = util.column_set(self.secondary.c) else: @@ -121,20 +137,6 @@ class JoinCondition(object): def col_is(a, b): return a.compare(b) - def refers_to_parent_table(binary): - pt = self.parent_selectable - mt = self.child_selectable - c, f = binary.left, binary.right - if ( - pt.is_derived_from(c.table) and \ - pt.is_derived_from(f.table) and \ - mt.is_derived_from(c.table) and \ - mt.is_derived_from(f.table) - ): - return True - else: - return False - def is_foreign(a, b): if self.consider_as_foreign_keys: if a in self.consider_as_foreign_keys and ( @@ -161,32 +163,19 @@ class JoinCondition(object): elif b in secondarycols and a not in secondarycols: return b - def _run_w_switch(binary, fn): - binary.left, binary.right = fn(binary, binary.left, binary.right) - binary.right, binary.left = fn(binary, binary.right, binary.left) - def _annotate_fk(binary, left, right): can_be_synced = self.can_be_synced_fn(left) left = left._annotate({ - "equated":binary.operator is operators.eq, + #"equated":binary.operator is operators.eq, "can_be_synced":can_be_synced and \ binary.operator is operators.eq }) right = right._annotate({ - "equated":binary.operator is operators.eq, + #"equated":binary.operator is operators.eq, "referent":True }) return left, right - def _annotate_remote(binary, left, right): - left = left._annotate( - {"remote":True}) - if right in parentcols or \ - right in targetcols: - right = right._annotate( - {"local":True}) - return left, right - def visit_binary(binary): if not isinstance(binary.left, sql.ColumnElement) or \ not isinstance(binary.right, sql.ColumnElement): @@ -204,41 +193,12 @@ class JoinCondition(object): {"foreign":True}) # TODO: when the two cols are the same. - has_foreign = False if "foreign" in binary.left._annotations: binary.left, binary.right = _annotate_fk( binary, binary.left, binary.right) - has_foreign = True if "foreign" in binary.right._annotations: binary.right, binary.left = _annotate_fk( binary, binary.right, binary.left) - has_foreign = True - - if "remote" not in binary.left._annotations and \ - "remote" not in binary.right._annotations: - - def go(binary, left, right): - if self._local_remote_pairs: - raise NotImplementedError() - elif self._remote_side: - if left in self._remote_side: - return _annotate_remote(binary, left, right) - elif refers_to_parent_table(binary): - # assume one to many - FKs are "remote" - if "foreign" in left._annotations: - return _annotate_remote(binary, left, right) - elif secondarycols: - if left in secondarycols: - return _annotate_remote(binary, left, right) - else: - # TODO: to support the X->Y->Z case - # we might need to look at parentcols - # and annotate "local" separately... - if left in targetcols and has_foreign \ - and right in parentcols or right in secondarycols: - return _annotate_remote(binary, left, right) - return left, right - _run_w_switch(binary, go) self.primaryjoin = visitors.cloned_traverse( self.primaryjoin, @@ -257,11 +217,63 @@ class JoinCondition(object): self._check_foreign_cols( self.secondaryjoin, False) + def _refers_to_parent_table(self): + pt = self.parent_selectable + mt = self.child_selectable + result = [False] + def visit_binary(binary): + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) and \ + isinstance(f, expression.ColumnClause) and \ + pt.is_derived_from(c.table) and \ + pt.is_derived_from(f.table) and \ + mt.is_derived_from(c.table) and \ + mt.is_derived_from(f.table) + ): + result[0] = True + + visitors.traverse( + self.primaryjoin, + {}, + {"binary":visit_binary} + ) + return result[0] + + def _annotate_remote(self): + for col in visitors.iterate(self.primaryjoin, {}): + if "remote" in col._annotations: + return + + if self._local_remote_pairs: + raise NotImplementedError() + elif self._remote_side: + def repl(element): + if element in self._remote_side: + return element._annotate({"remote":True}) + elif self.secondary is not None: + def repl(element): + if self.secondary.c.contains_column(element): + return element._annotate({"remote":True}) + elif self._refers_to_parent_table(): + def repl(element): + # assume one to many - FKs are "remote" + if "foreign" in element._annotations: + return element._annotate({"remote":True}) + else: + def repl(element): + if self.child_selectable.c.contains_column(element): + return element._annotate({"remote":True}) + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl) + def _check_foreign_cols(self, join_condition, primary): """Check the foreign key columns collected and emit error messages.""" - # TODO: don't worry, we can simplify this once we - # encourage configuration via direct annotation can_sync = False @@ -284,66 +296,30 @@ class JoinCondition(object): # to report. Check for a join condition using any operator # (not just ==), perhaps they need to turn on "viewonly=True". if self.support_sync and has_foreign and not can_sync: - - err = "Could not locate any "\ - "foreign-key-equated, locally mapped column "\ - "pairs for %s "\ - "condition '%s' on relationship %s." % ( + err = "Could not locate any simple equality expressions "\ + "involving foreign key columns for %s join condition "\ + "'%s' on relationship %s." % ( primary and 'primaryjoin' or 'secondaryjoin', join_condition, self.prop ) - - # TODO: this needs to be changed to detect that - # annotations were present and whatnot. the future - # foreignkey(col) annotation will cover establishing - # the col as foreign to it's mate - if not self.consider_as_foreign_keys: - err += " Ensure that the "\ - "referencing Column objects have a "\ - "ForeignKey present, or are otherwise part "\ - "of a ForeignKeyConstraint on their parent "\ - "Table, or specify the foreign_keys parameter "\ - "to this relationship." - - err += " For more "\ - "relaxed rules on join conditions, the "\ - "relationship may be marked as viewonly=True." + err += " Ensure that referencing columns are associated with a "\ + "ForeignKey or ForeignKeyConstraint, or are annotated "\ + "in the join condition with the foreign() annotation. "\ + "To allow comparison operators other than '==', "\ + "the relationship can be marked as viewonly=True." raise sa_exc.ArgumentError(err) else: - if self.consider_as_foreign_keys: - raise sa_exc.ArgumentError("Could not determine " - "relationship direction for %s condition " - "'%s', on relationship %s, using manual " - "'foreign_keys' setting. Do the columns " - "in 'foreign_keys' represent all, and " - "only, the 'foreign' columns in this join " - "condition? Does the %s Table already " - "have adequate ForeignKey and/or " - "ForeignKeyConstraint objects established " - "(in which case 'foreign_keys' is usually " - "unnecessary)?" - % ( - primary and 'primaryjoin' or 'secondaryjoin', - join_condition, - self.prop, - primary and 'mapped' or 'secondary' - )) - else: - raise sa_exc.ArgumentError("Could not determine " - "relationship direction for %s condition " - "'%s', on relationship %s. Ensure that the " - "referencing Column objects have a " - "ForeignKey present, or are otherwise part " - "of a ForeignKeyConstraint on their parent " - "Table, or specify the foreign_keys parameter " - "to this relationship." - % ( - primary and 'primaryjoin' or 'secondaryjoin', - join_condition, - self.prop - )) + err = "Could not locate any relevant foreign key columns "\ + "for %s join condition '%s' on relationship %s." % ( + primary and 'primaryjoin' or 'secondaryjoin', + join_condition, + self.prop + ) + err += "Ensure that referencing columns are associated with a "\ + "a ForeignKey or ForeignKeyConstraint, or are annotated "\ + "in the join condition with the foreign() annotation." def _determine_direction(self): """Determine if this relationship is one to many, many to one, @@ -399,14 +375,21 @@ class JoinCondition(object): "nor the child's mapped tables" % self.prop) @util.memoized_property - def remote_columns(self): + def liberal_remote_columns(self): + # this is temporary until we figure out + # which version of "remote" to use return self._gather_join_annotations("remote") + @util.memoized_property + def remote_columns(self): + return set([r for l, r in self.local_remote_pairs]) + #return self._gather_join_annotations("remote") + remote_side = remote_columns @util.memoized_property def local_columns(self): - return self._gather_join_annotations("local") + return set([l for l, r in self.local_remote_pairs]) @util.memoized_property def foreign_key_columns(self): @@ -440,10 +423,14 @@ class JoinCondition(object): lrp = util.OrderedSet() def visit_binary(binary): if "remote" in binary.right._annotations and \ - "local" in binary.left._annotations: + "remote" not in binary.left._annotations and \ + isinstance(binary.left, expression.ColumnClause) and \ + self.can_be_synced_fn(binary.left): lrp.add((binary.left, binary.right)) elif "remote" in binary.left._annotations and \ - "local" in binary.right._annotations: + "remote" not in binary.right._annotations and \ + isinstance(binary.right, expression.ColumnClause) and \ + self.can_be_synced_fn(binary.right): lrp.add((binary.right, binary.left)) visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary}) if self.secondaryjoin is not None: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 30e19bc68..72099a5f5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3385,6 +3385,10 @@ class _BinaryExpression(ColumnElement): raise TypeError("Boolean value of this clause is not defined") @property + def is_comparison(self): + return operators.is_comparison(self.operator) + + @property def _from_objects(self): return self.left._from_objects + self.right._from_objects diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 89f0aaee1..b86b50db4 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -521,6 +521,11 @@ def nullslast_op(a): _commutative = set([eq, ne, add, mul]) +_comparison = set([eq, ne, lt, gt, ge, le]) + +def is_comparison(op): + return op in _comparison + def is_commutative(op): return op in _commutative |
