diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-06 12:20:15 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-06 12:20:15 -0500 |
| commit | 73f734bf80166c7dfce4892941752d7569a17524 (patch) | |
| tree | 337f48354f72d2c1ef75f0d9724a395b71e7b50c | |
| parent | 2dbeeff50b7ccc6f47b2816a59f99f051fdabc8c (diff) | |
| download | sqlalchemy-73f734bf80166c7dfce4892941752d7569a17524.tar.gz | |
initial annotations approach to join conditions. all tests pass, plus additional tests in #1401 pass.
would now like to reorganize RelationshipProperty more around the annotations concept.
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 40 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 48 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 6 | ||||
| -rw-r--r-- | test/orm/test_joins.py | 12 | ||||
| -rw-r--r-- | test/orm/test_query.py | 8 | ||||
| -rw-r--r-- | test/orm/test_relationships.py | 2 | ||||
| -rw-r--r-- | test/sql/test_selectable.py | 19 |
8 files changed, 110 insertions, 27 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 59c4cb3dc..a590ad706 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -14,7 +14,7 @@ mapped attributes. from sqlalchemy import sql, util, log, exc as sa_exc from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ join_condition, _shallow_annotate -from sqlalchemy.sql import operators, expression +from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm import attributes, dependency, mapper, \ object_mapper, strategies, configure_mappers from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \ @@ -444,6 +444,7 @@ class RelationshipProperty(StrategizedProperty): else: j = _orm_annotate(pj, exclude=self.property.remote_side) + # MARKMARK if criterion is not None and target_adapter: # limit this adapter to annotated only? criterion = target_adapter.traverse(criterion) @@ -1376,6 +1377,34 @@ class RelationshipProperty(StrategizedProperty): "argument to indicate which column lazy join " "condition should bind." % (col, self.mapper)) + count = [0] + def clone(elem): + if set(['local', 'remote']).intersection(elem._annotations): + return None + elif elem in self.local_side and elem in self.remote_side: + # TODO: OK this still sucks. this is basically, + # refuse, refuse, refuse the temptation to guess! + # but crap we really have to guess don't we. we + # might want to traverse here with cloned_traverse + # so we can see the binary exprs and do it at that + # level.... + if count[0] % 2 == 0: + elem = elem._annotate({'local':True}) + else: + elem = elem._annotate({'remote':True}) + count[0] += 1 + elif elem in self.local_side: + elem = elem._annotate({'local':True}) + elif elem in self.remote_side: + elem = elem._annotate({'remote':True}) + else: + elem = None + return elem + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, clone + ) + def _generate_backref(self): if not self.is_primary(): return @@ -1539,17 +1568,20 @@ class RelationshipProperty(StrategizedProperty): secondary_aliasizer.traverse(secondaryjoin) else: primary_aliasizer = ClauseAdapter(dest_selectable, - exclude=self.local_side, + #exclude=self.local_side, + exclude_fn=lambda c: "local" in c._annotations, equivalents=self.mapper._equivalent_columns) if source_selectable is not None: primary_aliasizer.chain( ClauseAdapter(source_selectable, - exclude=self.remote_side, + #exclude=self.remote_side, + exclude_fn=lambda c: "remote" in c._annotations, equivalents=self.parent._equivalent_columns)) secondary_aliasizer = None + primaryjoin = primary_aliasizer.traverse(primaryjoin) target_adapter = secondary_aliasizer or primary_aliasizer - target_adapter.include = target_adapter.exclude = None + target_adapter.include = target_adapter.exclude = target_adapter.exclude_fn = None else: target_adapter = None if source_selectable is None: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b11e5ad42..30e19bc68 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -2184,7 +2184,7 @@ class ColumnElement(ClauseElement, _CompareMixin): for oth in to_compare: if use_proxies and self.shares_lineage(oth): return True - elif oth is self: + elif hash(oth) == hash(self): return True else: return False diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 97975441e..f0509c16f 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -225,7 +225,8 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) -def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): +def join_condition(a, b, ignore_nonexistent_tables=False, + a_subset=None): """create a join condition between two tables or selectables. e.g.:: @@ -535,6 +536,10 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, "'consider_as_foreign_keys' or " "'consider_as_referenced_keys'") + def col_is(a, b): + #return a is b + return a.compare(b) + def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return @@ -544,20 +549,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: @@ -669,11 +674,22 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False): + def __init__(self, selectable, equivalents=None, + include=None, exclude=None, + include_fn=None, exclude_fn=None, + adapt_on_names=False): self.__traverse_options__ = {'stop_on':[selectable]} self.selectable = selectable - self.include = include - self.exclude = exclude + if include: + assert not include_fn + self.include_fn = lambda e: e in include + else: + self.include_fn = include_fn + if exclude: + assert not exclude_fn + self.exclude_fn = lambda e: e in exclude + else: + self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names @@ -693,19 +709,17 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, expression.FromClause): - if self.selectable.is_derived_from(col): + if isinstance(col, expression.FromClause) and \ + self.selectable.is_derived_from(col): return self.selectable - - if not isinstance(col, expression.ColumnElement): + elif not isinstance(col, expression.ColumnElement): return None - - if self.include and col not in self.include: + elif self.include_fn and not self.include_fn(col): return None - elif self.exclude and col in self.exclude: + elif self.exclude_fn and self.exclude_fn(col): return None - - return self._corresponding_column(col, True) + else: + return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index cdcf40aa8..75e099f0d 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -240,16 +240,16 @@ def replacement_traverse(obj, opts, replace): replacement by a given replacement function.""" cloned = util.column_dict() - stop_on = util.column_set(opts.get('stop_on', [])) + stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])]) def clone(elem, **kw): - if elem in stop_on or \ + if id(elem) in stop_on or \ 'no_replacement_traverse' in elem._annotations: return elem else: newelem = replace(elem) if newelem is not None: - stop_on.add(newelem) + stop_on.add(id(newelem)) return newelem else: if elem not in cloned: diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index db7c78cdd..6c43a2f39 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -1700,21 +1700,29 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): sess.flush() sess.close() - def test_join(self): + def test_join_1(self): Node = self.classes.Node - sess = create_session() node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first() assert node.data=='n12' + def test_join_2(self): + Node = self.classes.Node + sess = create_session() ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all() assert ret == [('n12',)] + def test_join_3(self): + Node = self.classes.Node + sess = create_session() node = sess.query(Node).join('children', 'children', aliased=True).filter_by(data='n122').first() assert node.data=='n1' + def test_join_4(self): + Node = self.classes.Node + sess = create_session() node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\ join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first() assert node.data == 'n122' diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 24974ae7e..155f7c68d 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -622,6 +622,14 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self._test(Address.user == None, "addresses.user_id IS NULL") self._test(Address.user != None, "addresses.user_id IS NOT NULL") + + def test_foo(self): + Node = self.classes.Node + nalias = aliased(Node) + self._test( + nalias.parent.has(Node.data=='some data'), + "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.data = :data_1)" + ) def test_selfref_relationship(self): Node = self.classes.Node diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 6781d7104..2049088af 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -249,6 +249,8 @@ class CompositeSelfRefFKTest(fixtures.MappedTest): def _test(self): Employee, Company = self.classes.Employee, self.classes.Company +# employee_t = self.tables.employee_t +# assert Employee.reports_to.property.local_remote_pairs == [(employee_t.c.reports_to_id, employee_t.c.emp_id), (employee_t.c.company_id, employee_t.c.company_id)] sess = create_session() c1 = Company() diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 8f599f1d6..6d85f7c4f 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1023,6 +1023,25 @@ class AnnotationsTest(fixtures.TestBase): annot = obj._annotate({}) eq_(set([obj]), set([annot])) + def test_compare(self): + t = table('t', column('x'), column('y')) + x_a = t.c.x._annotate({}) + assert t.c.x.compare(x_a) + assert x_a.compare(t.c.x) + assert not x_a.compare(t.c.y) + assert not t.c.y.compare(x_a) + assert (t.c.x == 5).compare(x_a == 5) + assert not (t.c.y == 5).compare(x_a == 5) + + s = select([t]) + x_p = s.c.x + assert not x_a.compare(x_p) + assert not t.c.x.compare(x_p) + x_p_a = x_p._annotate({}) + assert x_p_a.compare(x_p) + assert x_p.compare(x_p_a) + assert not x_p_a.compare(x_a) + def test_custom_constructions(self): from sqlalchemy.schema import Column class MyColumn(Column): |
