summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES4
-rw-r--r--lib/sqlalchemy/orm/mapper.py4
-rw-r--r--lib/sqlalchemy/orm/properties.py9
-rw-r--r--test/orm/inheritance/basic.py70
4 files changed, 81 insertions, 6 deletions
diff --git a/CHANGES b/CHANGES
index 6154bc3d1..5a9893368 100644
--- a/CHANGES
+++ b/CHANGES
@@ -87,6 +87,10 @@ CHANGES
- also with dynamic, implemented correct count() behavior as well
as other helper methods.
+ - fix to cascades on polymorphic relations, such that cascades
+ from an object to a polymorphic collection continue cascading
+ along the set of attributes specific to each element in the collection.
+
- query.get() and query.load() do not take existing filter or other
criterion into account; these methods *always* look up the given id
in the database or return the current instance from the identity map,
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 0f5dbaaf5..e9fe41fdc 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1538,8 +1538,8 @@ def has_mapper(object):
return hasattr(object, '_entity_name')
-def _state_mapper(state):
- return state.class_._class_state.mappers[state.dict.get('_entity_name', None)]
+def _state_mapper(state, entity_name=None):
+ return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
def object_mapper(object, entity_name=None, raiseerror=True):
"""Given an object, return the primary Mapper associated with the object instance.
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 9394e9aea..4d41556a0 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -13,7 +13,7 @@ to handle flush-time dependency sorting and processing.
from sqlalchemy import sql, schema, util, exceptions, logging
from sqlalchemy.sql import util as sql_util, visitors, operators, ColumnElement
-from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
+from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
@@ -365,8 +365,11 @@ class PropertyLoader(StrategizedProperty):
if not isinstance(c, self.mapper.class_):
raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
recursive.add(c)
- yield (c, mapper)
- for (c2, m) in mapper.cascade_iterator(type, c._state, recursive):
+
+ # cascade using the mapper local to this object, so that its individual properties are located
+ instance_mapper = object_mapper(c, entity_name=mapper.entity_name)
+ yield (c, instance_mapper)
+ for (c2, m) in instance_mapper.cascade_iterator(type, c._state, recursive):
yield (c2, m)
def _get_target_class(self):
diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py
index 05603ac86..2ef76b6d8 100644
--- a/test/orm/inheritance/basic.py
+++ b/test/orm/inheritance/basic.py
@@ -9,7 +9,6 @@ class O2MTest(ORMTest):
"""deals with inheritance and one-to-many relationships"""
def define_tables(self, metadata):
global foo, bar, blub
- # the 'data' columns are to appease SQLite which cant handle a blank INSERT
foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq', optional=True),
primary_key=True),
@@ -65,7 +64,76 @@ class O2MTest(ORMTest):
self.assert_(compare == result)
self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
+class CascadeTest(ORMTest):
+ """that cascades on polymorphic relations continue
+ cascading along the path of the instance's mapper, not
+ the base mapper."""
+
+ def define_tables(self, metadata):
+ global t1, t2, t3, t4
+ t1= Table('t1', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30))
+ )
+
+ t2 = Table('t2', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('t1id', Integer, ForeignKey('t1.id')),
+ Column('type', String(30)),
+ Column('data', String(30))
+ )
+ t3 = Table('t3', metadata,
+ Column('id', Integer, ForeignKey('t2.id'), primary_key=True),
+ Column('moredata', String(30)))
+
+ t4 = Table('t4', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('t3id', Integer, ForeignKey('t3.id')),
+ Column('data', String(30)))
+
+ def test_cascade(self):
+ class T1(fixtures.Base):
+ pass
+ class T2(fixtures.Base):
+ pass
+ class T3(T2):
+ pass
+ class T4(fixtures.Base):
+ pass
+
+ mapper(T1, t1, properties={
+ 't2s':relation(T2, cascade="all")
+ })
+ mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2')
+ mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={
+ 't4s':relation(T4, cascade="all")
+ })
+ mapper(T4, t4)
+
+ sess = create_session()
+ t1_1 = T1(data='t1')
+
+ t3_1 = T3(data ='t3', moredata='t3')
+ t2_1 = T2(data='t2')
+
+ t1_1.t2s.append(t2_1)
+ t1_1.t2s.append(t3_1)
+
+ t4_1 = T4(data='t4')
+ t3_1.t4s.append(t4_1)
+
+ sess.save(t1_1)
+
+ assert t4_1 in sess.new
+ sess.flush()
+
+ sess.delete(t1_1)
+ assert t4_1 in sess.deleted
+ sess.flush()
+
+
+
class GetTest(ORMTest):
def define_tables(self, metadata):
global foo, bar, blub