summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-10 06:51:58 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-10 06:51:58 +0000
commitbe29010e292739ca3545315eb2e6a9243aa53e1a (patch)
tree58b19350841644bef4a4550e313526dd3b23fef0
parent5cb66ee718ee15e91e6036d573aaec67d4c43fe6 (diff)
downloadsqlalchemy-be29010e292739ca3545315eb2e6a9243aa53e1a.tar.gz
more "column targeting" enhancements..columns have a "depth" from their ultimate source column so that corresponding_column() can find the column that is "closest" (i.e. fewest levels of proxying) to the requested column
-rw-r--r--CHANGES3
-rw-r--r--lib/sqlalchemy/orm/util.py4
-rw-r--r--lib/sqlalchemy/schema.py2
-rw-r--r--lib/sqlalchemy/sql.py26
-rw-r--r--test/orm/eagertest3.py96
-rw-r--r--test/orm/mapper.py2
-rwxr-xr-xtest/sql/selectable.py29
7 files changed, 149 insertions, 13 deletions
diff --git a/CHANGES b/CHANGES
index 606d9baba..b281a8fb6 100644
--- a/CHANGES
+++ b/CHANGES
@@ -39,6 +39,9 @@
- DynamicMetaData has been renamed to ThreadLocalMetaData. the
DynamicMetaData name is deprecated and is an alias for ThreadLocalMetaData
or a regular MetaData if threadlocal=False
+ - some enhancements to "column targeting", the ability to match a column
+ to a "corresponding" column in another selectable. this affects mostly
+ ORM ability to map to complex joins
- MetaData and all SchemaItems are safe to use with pickle. slow
table reflections can be dumped into a pickled file to be reused later.
Just reconnect the engine to the metadata after unpickling. [ticket:619]
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 923dd6797..3b3b9b7ed 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -89,8 +89,8 @@ class TranslatingDict(dict):
def __translate_col(self, col):
ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
- #if col is not ourcol:
- # print "TD TRANSLATING ", col, "TO", ourcol
+# if col is not ourcol and ourcol is not None:
+# print "TD TRANSLATING ", col, "TO", ourcol
if ourcol is None:
return col
else:
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index eb7eb8c1d..d7d728f2b 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -600,11 +600,11 @@ class Column(SchemaItem, sql._ColumnClause):
This is a copy of this ``Column`` referenced by a different parent
(such as an alias or select statement).
"""
-
fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
c.table = selectable
c.orig_set = self.orig_set
+ c._source_column = self
c.__originating_column = self.__originating_column
if not c._is_oid:
selectable.columns.add(c)
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index c86fc561a..2a22a40c1 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -1542,7 +1542,17 @@ class ColumnElement(Selectable, _CompareMixin):
return True
else:
return False
-
+
+ def _distance(self, othercolumn):
+ c = othercolumn
+ count = 0
+ while c is not self:
+ c = c._source_column
+ if c is None:
+ return -1
+ count += 1
+ return count
+
def _make_proxy(self, selectable, name=None):
"""Create a new ``ColumnElement`` representing this
``ColumnElement`` as it appears in the select list of a
@@ -1695,7 +1705,7 @@ class FromClause(Selectable):
"""
if column in self.c:
return column
-
+
if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
if not raiseerr:
return None
@@ -1757,9 +1767,9 @@ class FromClause(Selectable):
for co in self._adjusted_exportable_columns():
cp = self._proxy_column(co)
for ci in cp.orig_set:
- # note that some ambiguity is raised here, whereby a selectable might have more than
- # one column that maps to an "original" column. examples include unions and joins
- self._orig_cols[ci] = cp
+ cx = self._orig_cols.get(ci)
+ if cx is None or ci._distance(cp) < ci._distance(cx):
+ self._orig_cols[ci] = cp
if self.oid_column is not None:
for ci in self.oid_column.orig_set:
self._orig_cols[ci] = self.oid_column
@@ -2078,7 +2088,8 @@ class _Cast(ColumnElement):
self.type = sqltypes.to_instance(totype)
self.clause = clause
self.typeclause = _TypeClause(self.type)
-
+ self._source_column = None
+
def get_children(self, **kwargs):
return self.clause, self.typeclause
def accept_visitor(self, visitor):
@@ -2090,6 +2101,7 @@ class _Cast(ColumnElement):
def _make_proxy(self, selectable, name=None):
if name is not None:
co = _ColumnClause(name, selectable, type=self.type)
+ co._source_column = self
co.orig_set = self.orig_set
selectable.columns[name]= co
return co
@@ -2512,6 +2524,7 @@ class _ColumnClause(ColumnElement):
self.table = selectable
self.type = sqltypes.to_instance(type)
self._is_oid = _is_oid
+ self._source_column = None
self.__label = None
self.case_sensitive = case_sensitive
self.is_literal = is_literal
@@ -2571,6 +2584,7 @@ class _ColumnClause(ColumnElement):
is_literal = self.is_literal and (name is None or name == self.name)
c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal)
c.orig_set = self.orig_set
+ c._source_column = self
if not self._is_oid:
selectable.columns[c.name] = c
return c
diff --git a/test/orm/eagertest3.py b/test/orm/eagertest3.py
index a731581d5..8e7735812 100644
--- a/test/orm/eagertest3.py
+++ b/test/orm/eagertest3.py
@@ -415,7 +415,101 @@ class EagerTest5(testbase.ORMTest):
# object is not in the session; therefore the lazy load cant trigger here,
# eager load had to succeed
assert len([c for c in d2.comments]) == 1
+
+class EagerTest6(testbase.ORMTest):
+ def define_tables(self, metadata):
+ global project_t, task_t, task_status_t, task_type_t, message_t, message_type_t
-
+ project_t = Table('prj', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('created', DateTime , ),
+ Column('title', Unicode(100)),
+ )
+
+ task_t = Table('task', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('status_id', Integer, ForeignKey('task_status.id'), nullable=False),
+ Column('title', Unicode(100)),
+ Column('task_type_id', Integer , ForeignKey('task_type.id'), nullable=False),
+ Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False),
+ )
+
+ task_status_t = Table('task_status', metadata,
+ Column('id', Integer, primary_key=True),
+ )
+
+ task_type_t = Table('task_type', metadata,
+ Column('id', Integer, primary_key=True),
+ )
+
+ message_t = Table('msg', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('posted', DateTime, index=True,),
+ Column('type_id', Integer, ForeignKey('msg_type.id')),
+ Column('task_id', Integer, ForeignKey('task.id')),
+ )
+
+ message_type_t = Table('msg_type', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', Unicode(20)),
+ Column('display_name', Unicode(20)),
+ )
+
+ def setUp(self):
+ testbase.db.execute("INSERT INTO prj (title) values('project 1');")
+ testbase.db.execute("INSERT INTO task_status (id) values(1);")
+ testbase.db.execute("INSERT INTO task_type(id) values(1);")
+ testbase.db.execute("INSERT INTO task (title, task_type_id, status_id, prj_id) values('task 1',1,1,1);")
+
+ def test_nested_joins(self):
+ # this is testing some subtle column resolution stuff,
+ # concerning corresponding_column() being extremely accurate
+ # as well as how mapper sets up its column properties
+
+ class Task(object):pass
+ class Task_Type(object):pass
+ class Message(object):pass
+ class Message_Type(object):pass
+
+ tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id)
+
+ ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')],
+ from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s')
+ j = join(project_t, ss, project_t.c.id == ss.c.prj_id)
+
+ mapper(Task_Type, task_type_t)
+
+ mapper( Task, task_t,
+ properties=dict(type=relation(Task_Type, lazy=False),
+ ))
+
+ mapper(Message_Type, message_type_t)
+
+ mapper(Message, message_t,
+ properties=dict(type=relation(Message_Type, lazy=False, uselist=False),
+ ))
+
+ tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id)
+ ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')],
+ from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s')
+ j = join(project_t, ss, project_t.c.id == ss.c.prj_id)
+
+ j = outerjoin( task_t, message_t, task_t.c.id==message_t.c.task_id)
+ jj = select([ task_t.c.id.label('task_id'),
+ func.count(message_t.c.id).label('props_cnt')],
+ from_obj=[j], group_by=[task_t.c.id]).alias('prop_c_s')
+ jjj = join(task_t, jj, task_t.c.id == jj.c.task_id)
+
+ class cls(object):pass
+
+ props =dict(type=relation(Task_Type, lazy=False))
+ print [c.key for c in jjj.c]
+ cls.mapper = mapper( cls, jjj, properties=props)
+
+ session = create_session()
+
+ for t in session.query(cls.mapper).limit(10).offset(0).list():
+ print t.id, t.title, t.props_cnt
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/mapper.py b/test/orm/mapper.py
index 63af53b96..197f988b7 100644
--- a/test/orm/mapper.py
+++ b/test/orm/mapper.py
@@ -497,7 +497,7 @@ class MapperTest(MapperSuperTest):
class_mapper(User)
except exceptions.ArgumentError, e:
assert str(e) == "Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(f)
- clear_mappers()
+ clear_mappers()
mapper(User, users, properties={
'concat': column_property(f),
diff --git a/test/sql/selectable.py b/test/sql/selectable.py
index 57ad99886..340c55837 100755
--- a/test/sql/selectable.py
+++ b/test/sql/selectable.py
@@ -27,17 +27,42 @@ table2 = Table('table2', db,
)
class SelectableTest(testbase.AssertMixin):
+ def testdistance(self):
+ s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])
+
+ # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far
+ #assert s.corresponding_column(table.c.col1) is s.c.col1
+ assert s.corresponding_column(s.c.col1) is s.c.col1
+ assert s.corresponding_column(s.c.c1) is s.c.c1
+
def testjoinagainstself(self):
jj = select([table.c.col1.label('bar_col1')])
jjj = join(table, jj, table.c.col1==jj.c.bar_col1)
+
+ # test column directly agaisnt itself
assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1
+ assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1
+
+ # test alias of the join, targets the column with the least
+ # "distance" between the requested column and the returned column
+ # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than
+ # there is from j2.c.bar_col1 to table.c.col1)
+ j2 = jjj.alias('foo')
+ assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1
+
+
def testjoinagainstjoin(self):
j = outerjoin(table, table2, table.c.col1==table2.c.col2)
jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')
jjj = join(table, jj, table.c.col1==jj.c.bar_col1)
assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1
+
+ j2 = jjj.alias('foo')
+ print j2.corresponding_column(jjj.c.table1_col1)
+ assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1
+ assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1
def testtablealias(self):
a = table.alias('a')
@@ -110,8 +135,8 @@ class SelectableTest(testbase.AssertMixin):
j = join(a, table2)
criterion = a.c.col1 == table2.c.col2
- print
- print str(j)
+ print criterion
+ print j.onclause
self.assert_(criterion.compare(j.onclause))
def testselectlabels(self):