summaryrefslogtreecommitdiff
path: root/test/orm/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/orm/test_utils.py')
-rw-r--r--test/orm/test_utils.py239
1 files changed, 239 insertions, 0 deletions
diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py
new file mode 100644
index 000000000..06533a243
--- /dev/null
+++ b/test/orm/test_utils.py
@@ -0,0 +1,239 @@
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy.orm import interfaces, util
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import mapper, create_session
+
+
+from sqlalchemy.test import TestBase, testing
+
+from test.orm import _fixtures
+from sqlalchemy.test.testing import eq_
+
+
+class ExtensionCarrierTest(TestBase):
+ def test_basic(self):
+ carrier = util.ExtensionCarrier()
+
+ assert 'translate_row' not in carrier
+ assert carrier.translate_row() is interfaces.EXT_CONTINUE
+ assert 'translate_row' not in carrier
+
+ assert_raises(AttributeError, lambda: carrier.snickysnack)
+
+ class Partial(object):
+ def __init__(self, marker):
+ self.marker = marker
+ def translate_row(self, row):
+ return self.marker
+
+ carrier.append(Partial('end'))
+ assert 'translate_row' in carrier
+ assert carrier.translate_row(None) == 'end'
+
+ carrier.push(Partial('front'))
+ assert carrier.translate_row(None) == 'front'
+
+ assert 'populate_instance' not in carrier
+ carrier.append(interfaces.MapperExtension)
+ assert 'populate_instance' in carrier
+
+ assert carrier.interface
+ for m in carrier.interface:
+ assert getattr(interfaces.MapperExtension, m)
+
+class AliasedClassTest(TestBase):
+ def point_map(self, cls):
+ table = Table('point', MetaData(),
+ Column('id', Integer(), primary_key=True),
+ Column('x', Integer),
+ Column('y', Integer))
+ mapper(cls, table)
+ return table
+
+ def test_simple(self):
+ class Point(object):
+ pass
+ table = self.point_map(Point)
+
+ alias = aliased(Point)
+
+ assert alias.id
+ assert alias.x
+ assert alias.y
+
+ assert Point.id.__clause_element__().table is table
+ assert alias.id.__clause_element__().table is not table
+
+ def test_notcallable(self):
+ class Point(object):
+ pass
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert_raises(TypeError, alias)
+
+ def test_instancemethods(self):
+ class Point(object):
+ def zero(self):
+ self.x, self.y = 0, 0
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.zero
+ assert not getattr(alias, 'zero')
+
+ def test_classmethods(self):
+ class Point(object):
+ @classmethod
+ def max_x(cls):
+ return 100
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.max_x
+ assert alias.max_x
+ assert Point.max_x() == alias.max_x()
+
+ def test_simpleproperties(self):
+ class Point(object):
+ @property
+ def max_x(self):
+ return 100
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.max_x
+ assert Point.max_x != 100
+ assert alias.max_x
+ assert Point.max_x is alias.max_x
+
+ def test_descriptors(self):
+ class descriptor(object):
+ """Tortured..."""
+ def __init__(self, fn):
+ self.fn = fn
+ def __get__(self, obj, owner):
+ if obj is not None:
+ return self.fn(obj, obj)
+ else:
+ return self
+ def method(self):
+ return 'method'
+
+ class Point(object):
+ center = (0, 0)
+ @descriptor
+ def thing(self, arg):
+ return arg.center
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.thing != (0, 0)
+ assert Point().thing == (0, 0)
+ assert Point.thing.method() == 'method'
+
+ assert alias.thing != (0, 0)
+ assert alias.thing.method() == 'method'
+
+ def test_hybrid_descriptors(self):
+ from sqlalchemy import Column # override testlib's override
+ import types
+
+ class MethodDescriptor(object):
+ def __init__(self, func):
+ self.func = func
+ def __get__(self, instance, owner):
+ if instance is None:
+ args = (self.func, owner, owner.__class__)
+ else:
+ args = (self.func, instance, owner)
+ return types.MethodType(*args)
+
+ class PropertyDescriptor(object):
+ def __init__(self, fget, fset, fdel):
+ self.fget = fget
+ self.fset = fset
+ self.fdel = fdel
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.fget(owner)
+ else:
+ return self.fget(instance)
+ def __set__(self, instance, value):
+ self.fset(instance, value)
+ def __delete__(self, instance):
+ self.fdel(instance)
+ hybrid = MethodDescriptor
+ def hybrid_property(fget, fset=None, fdel=None):
+ return PropertyDescriptor(fget, fset, fdel)
+
+ def assert_table(expr, table):
+ for child in expr.get_children():
+ if isinstance(child, Column):
+ assert child.table is table
+
+ class Point(object):
+ def __init__(self, x, y):
+ self.x, self.y = x, y
+ @hybrid
+ def left_of(self, other):
+ return self.x < other.x
+
+ double_x = hybrid_property(lambda self: self.x * 2)
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+ alias_table = alias.x.__clause_element__().table
+ assert table is not alias_table
+
+ p1 = Point(-10, -10)
+ p2 = Point(20, 20)
+
+ assert p1.left_of(p2)
+ assert p1.double_x == -20
+
+ assert_table(Point.double_x, table)
+ assert_table(alias.double_x, alias_table)
+
+ assert_table(Point.left_of(p2), table)
+ assert_table(alias.left_of(p2), alias_table)
+
+class IdentityKeyTest(_fixtures.FixtureTest):
+ run_inserts = None
+
+ @testing.resolve_artifact_names
+ def test_identity_key_1(self):
+ mapper(User, users)
+
+ key = util.identity_key(User, 1)
+ eq_(key, (User, (1,)))
+ key = util.identity_key(User, ident=1)
+ eq_(key, (User, (1,)))
+
+ @testing.resolve_artifact_names
+ def test_identity_key_2(self):
+ mapper(User, users)
+ s = create_session()
+ u = User(name='u1')
+ s.add(u)
+ s.flush()
+ key = util.identity_key(instance=u)
+ eq_(key, (User, (u.id,)))
+
+ @testing.resolve_artifact_names
+ def test_identity_key_3(self):
+ mapper(User, users)
+
+ row = {users.c.id: 1, users.c.name: "Frank"}
+ key = util.identity_key(User, row=row)
+ eq_(key, (User, (1,)))
+
+