summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-07-09 19:48:02 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-07-09 19:48:02 +0000
commit5c7993080a93dc142057a4ac05501bcb3291731d (patch)
tree01658bf66ef61bd8426d97e7759de1e3c85b8c5b
parent1b0ac7e43b715b9050cb5f45d4a180bfc203c8a9 (diff)
downloadsqlalchemy-5c7993080a93dc142057a4ac05501bcb3291731d.tar.gz
some refactorings to activemapper, made relationship() class have some polymorphic behavior for initializing its real relation, added support + unittest for self-referential relationship
-rw-r--r--CHANGES3
-rw-r--r--lib/sqlalchemy/ext/activemapper.py98
-rw-r--r--test/ext/activemapper.py38
3 files changed, 92 insertions, 47 deletions
diff --git a/CHANGES b/CHANGES
index 63f905236..f6e4e63ab 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,6 @@
+0.2.6
+- tweaks to ActiveMapper, supports self-referential relationships
+
0.2.5
- fixed endless loop bug in select_by(), if the traversal hit
two mappers that referenced each other
diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py
index 32efc5a05..a0984f46c 100644
--- a/lib/sqlalchemy/ext/activemapper.py
+++ b/lib/sqlalchemy/ext/activemapper.py
@@ -9,7 +9,6 @@ from sqlalchemy import backref as create_backref
import inspect
import sys
-import sets
#
# the "proxy" to the database engine... this can be swapped out at runtime
@@ -59,7 +58,31 @@ class relationship(object):
self.uselist = uselist
self.secondary = secondary
self.order_by = order_by
-
+ def process(self, klass, propname, relations):
+ relclass = ActiveMapperMeta.classes[self.classname]
+ if isinstance(self.order_by, str):
+ self.order_by = [ self.order_by ]
+ if isinstance(self.order_by, list):
+ for itemno in range(len(self.order_by)):
+ if isinstance(self.order_by[itemno], str):
+ self.order_by[itemno] = \
+ getattr(relclass.c, self.order_by[itemno])
+ backref = self.create_backref(klass)
+ relations[propname] = relation(relclass.mapper,
+ secondary=self.secondary,
+ backref=backref,
+ private=self.private,
+ lazy=self.lazy,
+ uselist=self.uselist,
+ order_by=self.order_by)
+ def create_backref(self, klass):
+ relclass = ActiveMapperMeta.classes[self.classname]
+ if klass.__name__ == self.classname:
+ br_fkey = getattr(relclass.c, self.colname)
+ else:
+ br_fkey = None
+ return create_backref(self.backref, foreignkey=br_fkey)
+
class one_to_many(relationship):
def __init__(self, classname, colname=None, backref=None, private=False,
lazy=True, order_by=False):
@@ -69,10 +92,15 @@ class one_to_many(relationship):
class one_to_one(relationship):
def __init__(self, classname, colname=None, backref=None, private=False,
lazy=True, order_by=False):
- if backref is not None:
- backref = create_backref(backref, uselist=False)
relationship.__init__(self, classname, colname, backref, private,
lazy, uselist=False, order_by=order_by)
+ def create_backref(self, klass):
+ relclass = ActiveMapperMeta.classes[self.classname]
+ if klass.__name__ == self.classname:
+ br_fkey = getattr(relclass.c, self.colname)
+ else:
+ br_fkey = None
+ return create_backref(self.backref, foreignkey=br_fkey, uselist=False)
class many_to_many(relationship):
def __init__(self, classname, secondary, backref=None, lazy=True,
@@ -81,7 +109,6 @@ class many_to_many(relationship):
uselist=True, secondary=secondary,
order_by=order_by)
-
#
# SQLAlchemy metaclass and superclass that can be used to do SQLAlchemy
# mapping in a declarative way, along with a function to process the
@@ -89,22 +116,16 @@ class many_to_many(relationship):
# up if the classes aren't specified in a proper order
#
-__deferred_classes__ = set()
-__processed_classes__ = set()
+__deferred_classes__ = {}
+__processed_classes__ = {}
def process_relationships(klass, was_deferred=False):
# first, we loop through all of the relationships defined on the
# class, and make sure that the related class already has been
# completely processed and defer processing if it has not
defer = False
for propname, reldesc in klass.relations.items():
- found = False
- for other_klass in __processed_classes__:
- if reldesc.classname == other_klass.__name__:
- found = True
- break
-
+ found = (reldesc.classname == klass.__name__ or reldesc.classname in __processed_classes__)
if not found:
- if not was_deferred: __deferred_classes__.add(klass)
defer = True
break
@@ -112,44 +133,33 @@ def process_relationships(klass, was_deferred=False):
# and make sure that we can find the related tables (they do not
# have to be processed yet, just defined), and we defer if we are
# not able to find any of the related tables
- for col in klass.columns:
- if col.foreign_key is not None:
- found = False
- for other_klass in ActiveMapperMeta.classes.values():
+ if not defer:
+ for col in klass.columns:
+ if col.foreign_key is not None:
+ found = False
table_name = col.foreign_key._colspec.rsplit('.', 1)[0]
- if other_klass.table.fullname.lower() == table_name.lower():
- found = True
+ for other_klass in ActiveMapperMeta.classes.values():
+ if other_klass.table.fullname.lower() == table_name.lower():
+ found = True
- if not found:
- if not was_deferred: __deferred_classes__.add(klass)
- defer = True
- break
-
+ if not found:
+ defer = True
+ break
+
+ if defer and not was_deferred:
+ __deferred_classes__[klass.__name__] = klass
+
# if we are able to find all related and referred to tables, then
# we can go ahead and assign the relationships to the class
if not defer:
relations = {}
for propname, reldesc in klass.relations.items():
- relclass = ActiveMapperMeta.classes[reldesc.classname]
- if isinstance(reldesc.order_by, str):
- reldesc.order_by = [ reldesc.order_by ]
- if isinstance(reldesc.order_by, list):
- for itemno in range(len(reldesc.order_by)):
- if isinstance(reldesc.order_by[itemno], str):
- reldesc.order_by[itemno] = \
- getattr(relclass.c, reldesc.order_by[itemno])
- relations[propname] = relation(relclass.mapper,
- secondary=reldesc.secondary,
- backref=reldesc.backref,
- private=reldesc.private,
- lazy=reldesc.lazy,
- uselist=reldesc.uselist,
- order_by=reldesc.order_by)
+ reldesc.process(klass, propname, relations)
class_mapper(klass).add_properties(relations)
- if klass in __deferred_classes__:
- __deferred_classes__.remove(klass)
- __processed_classes__.add(klass)
+ if klass.__name__ in __deferred_classes__:
+ del __deferred_classes__[klass.__name__]
+ __processed_classes__[klass.__name__] = klass
# finally, loop through the deferred classes and attempt to process
# relationships for them
@@ -160,7 +170,7 @@ def process_relationships(klass, was_deferred=False):
while last_count > len(__deferred_classes__):
last_count = len(__deferred_classes__)
deferred = __deferred_classes__.copy()
- for deferred_class in deferred:
+ for deferred_class in deferred.values():
process_relationships(deferred_class, was_deferred=True)
diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py
index 1bb93dd63..2a44f8e5b 100644
--- a/test/ext/activemapper.py
+++ b/test/ext/activemapper.py
@@ -1,6 +1,6 @@
import testbase
from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, objectstore
-from sqlalchemy import and_, or_, clear_mappers
+from sqlalchemy import and_, or_, clear_mappers, backref
from sqlalchemy import ForeignKey, String, Integer, DateTime
from datetime import datetime
@@ -218,6 +218,38 @@ class testcase(testbase.PersistTest):
)
self.assertEquals(len(results), 1)
-
+class testselfreferential(testbase.PersistTest):
+ def setUpAll(self):
+ global TreeNode
+ class TreeNode(activemapper.ActiveMapper):
+ class mapping:
+ id = column(Integer, primary_key=True)
+ name = column(String(30))
+ parent_id = column(Integer, foreign_key=ForeignKey('treenode.id'))
+ children = one_to_many('TreeNode', colname='id', backref='parent')
+
+ activemapper.metadata.connect(testbase.db)
+ activemapper.create_tables()
+ def tearDownAll(self):
+ clear_mappers()
+ activemapper.drop_tables()
+
+ def testbasic(self):
+ t = TreeNode(name='node1')
+ t.children.append(TreeNode(name='node2'))
+ t.children.append(TreeNode(name='node3'))
+ objectstore.flush()
+ objectstore.clear()
+
+ t = TreeNode.get_by(name='node1')
+ assert (t.name == 'node1')
+ assert (t.children[0].name == 'node2')
+ assert (t.children[1].name == 'node3')
+ assert (t.children[1].parent is t)
+
+ objectstore.clear()
+ t = TreeNode.get_by(name='node3')
+ assert (t.parent is TreeNode.get_by(name='node1'))
+
if __name__ == '__main__':
- unittest.main()
+ testbase.main()