summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-08-03 19:31:38 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-08-03 19:31:38 +0000
commite7c83bb37133af7b0deaef2fbc0d0fae8a179dfc (patch)
treeb4aec5209e546128ace3719f874be05950027abe
parentfdc58f4141b094f06f87416def9edffa59ab17d9 (diff)
downloadsqlalchemy-e7c83bb37133af7b0deaef2fbc0d0fae8a179dfc.tar.gz
- removed enhance_classes from scoped_session, replaced with
scoped_session(...).mapper. 'mapper' essentially does the same thing as assign_mapper less verbosely. - adapted assignmapper unit tests into scoped_session tests
-rw-r--r--lib/sqlalchemy/orm/scoping.py54
-rw-r--r--test/orm/session.py86
-rw-r--r--test/orm/unitofwork.py5
3 files changed, 118 insertions, 27 deletions
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 96d9a23fc..5d11a99a4 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -1,4 +1,4 @@
-from sqlalchemy.util import ScopedRegistry, warn_deprecated
+from sqlalchemy.util import ScopedRegistry, warn_deprecated, to_list
from sqlalchemy.orm import MapperExtension, EXT_CONTINUE
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.mapper import global_extensions
@@ -13,16 +13,21 @@ class ScopedSession(object):
Usage::
- Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True)
+ Session = scoped_session(sessionmaker(autoflush=True))
+
+ To map classes so that new instances are saved in the current
+ Session automatically, as well as to provide session-aware
+ class attributes such as "query":
+
+ mapper = Session.mapper
+ mapper(Class, table, ...)
"""
- def __init__(self, session_factory, scopefunc=None, enhance_classes=False):
+ def __init__(self, session_factory, scopefunc=None):
self.session_factory = session_factory
- self.enhance_classes = enhance_classes
self.registry = ScopedRegistry(session_factory, scopefunc)
- if self.enhance_classes:
- global_extensions.append(_ScopedExt(self))
+ self.extension = _ScopedExt(self)
def __call__(self, **kwargs):
if kwargs:
@@ -39,15 +44,28 @@ class ScopedSession(object):
else:
return self.registry()
+ def mapper(self, *args, **kwargs):
+ """return a mapper() function which associates this ScopedSession with the Mapper."""
+
+ from sqlalchemy.orm import mapper
+ validate = kwargs.pop('validate', False)
+ extension = to_list(kwargs.setdefault('extension', []))
+ if validate:
+ extension.append(self.extension.validating())
+ else:
+ extension.append(self.extension)
+ return mapper(*args, **kwargs)
+
def configure(self, **kwargs):
- """reconfigure the sessionmaker used by this SessionContext"""
+ """reconfigure the sessionmaker used by this ScopedSession."""
+
self.session_factory.configure(**kwargs)
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
return do
-for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
+for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete', 'clear'):
setattr(ScopedSession, meth, instrument(meth))
def makeprop(name):
@@ -67,18 +85,22 @@ for prop in ('close_all',):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
- def __init__(self, context):
+ def __init__(self, context, validate=False):
self.context = context
+ self.validate = validate
+
+ def validating(self):
+ return _ScopedExt(self.context, validate=True)
def get_session(self):
return self.context.registry()
def instrument_class(self, mapper, class_):
class query(object):
- def __getattr__(self, key):
- return getattr(registry().query(class_), key)
- def __call__(self):
- return registry().query(class_)
+ def __getattr__(s, key):
+ return getattr(self.context.registry().query(class_), key)
+ def __call__(s):
+ return self.context.registry().query(class_)
if not hasattr(class_, 'query'):
class_.query = query()
@@ -87,9 +109,9 @@ class _ScopedExt(MapperExtension):
session = kwargs.pop('_sa_session', self.context.registry())
if not isinstance(oldinit, types.MethodType):
for key, value in kwargs.items():
- #if validate:
- # if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
- # raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ if self.validate:
+ if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+ raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
setattr(instance, key, value)
session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
return EXT_CONTINUE
diff --git a/test/orm/session.py b/test/orm/session.py
index d3eed5c57..0b56b84d4 100644
--- a/test/orm/session.py
+++ b/test/orm/session.py
@@ -4,7 +4,6 @@ from sqlalchemy.orm import *
from testlib import *
from testlib.tables import *
import testlib.tables as tables
-from sqlalchemy.orm.session import Session
class SessionTest(AssertMixin):
def setUpAll(self):
@@ -98,7 +97,7 @@ class SessionTest(AssertMixin):
conn1 = testbase.db.connect()
conn2 = testbase.db.connect()
- sess = Session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -116,7 +115,7 @@ class SessionTest(AssertMixin):
mapper(User, users)
try:
- sess = Session(transactional=True, autoflush=True)
+ sess = create_session(transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -137,7 +136,7 @@ class SessionTest(AssertMixin):
conn1 = testbase.db.connect()
conn2 = testbase.db.connect()
- sess = Session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
@@ -153,7 +152,7 @@ class SessionTest(AssertMixin):
'addresses':relation(Address)
})
- sess = Session(transactional=True, autoflush=True)
+ sess = create_session(transactional=True, autoflush=True)
u = sess.query(User).get(8)
newad = Address()
newad.email_address == 'something new'
@@ -173,7 +172,7 @@ class SessionTest(AssertMixin):
mapper(User, users)
conn = testbase.db.connect()
trans = conn.begin()
- sess = Session(conn, transactional=True, autoflush=True)
+ sess = create_session(bind=conn, transactional=True, autoflush=True)
sess.begin()
u = User()
sess.save(u)
@@ -189,7 +188,7 @@ class SessionTest(AssertMixin):
try:
conn = testbase.db.connect()
trans = conn.begin()
- sess = Session(conn, transactional=True, autoflush=True)
+ sess = create_session(bind=conn, transactional=True, autoflush=True)
u1 = User()
sess.save(u1)
sess.flush()
@@ -217,7 +216,7 @@ class SessionTest(AssertMixin):
mapper(Address, addresses)
engine2 = create_engine(testbase.db.url)
- sess = Session(transactional=False, autoflush=False, twophase=True)
+ sess = create_session(transactional=False, autoflush=False, twophase=True)
sess.bind_mapper(User, testbase.db)
sess.bind_mapper(Address, engine2)
sess.begin()
@@ -234,7 +233,7 @@ class SessionTest(AssertMixin):
def test_joined_transaction(self):
class User(object):pass
mapper(User, users)
- sess = Session(transactional=True, autoflush=True)
+ sess = create_session(transactional=True, autoflush=True)
sess.begin()
u = User()
sess.save(u)
@@ -440,6 +439,75 @@ class SessionTest(AssertMixin):
key = s.identity_key(User, row=row, entity_name="en")
self._assert_key(key, (User, (1,), "en"))
+class ScopedSessionTest(PersistTest):
+ def setUpAll(self):
+ global metadata, table, table2
+ metadata = MetaData(testbase.db)
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)))
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id'))
+ )
+ metadata.create_all()
+
+ def setUp(self):
+ global SomeObject, SomeOtherObject
+ class SomeObject(object):pass
+ class SomeOtherObject(object):pass
+ global Session
+
+ Session = scoped_session(create_session)
+ Session.mapper(SomeObject, table, properties={
+ 'options':relation(SomeOtherObject)
+ })
+ Session.mapper(SomeOtherObject, table2)
+
+ s = SomeObject()
+ s.id = 1
+ s.data = 'hello'
+ sso = SomeOtherObject()
+ s.options.append(sso)
+ Session.flush()
+ Session.clear()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ def tearDown(self):
+ for table in metadata.table_iterator(reverse=True):
+ table.delete().execute()
+ clear_mappers()
+
+ def test_query(self):
+ sso = SomeOtherObject.query().first()
+ assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+
+ def test_validating_constructor(self):
+ s2 = SomeObject(someid=12)
+ s3 = SomeOtherObject(someid=123, bogus=345)
+
+ class ValidatedOtherObject(object):pass
+ Session.mapper(ValidatedOtherObject, table2, validate=True)
+
+ v1 = ValidatedOtherObject(someid=12)
+ try:
+ v2 = ValidatedOtherObject(someid=12, bogus=345)
+ assert False
+ except exceptions.ArgumentError:
+ pass
+
+ def test_dont_clobber_methods(self):
+ class MyClass(object):
+ def expunge(self):
+ return "an expunge !"
+
+ Session.mapper(MyClass, table2)
+
+ assert MyClass().expunge() == "an expunge !"
+
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py
index f28e428ed..f065267f7 100644
--- a/test/orm/unitofwork.py
+++ b/test/orm/unitofwork.py
@@ -12,8 +12,9 @@ from testlib import tables
class UnitOfWorkTest(AssertMixin):
def setUpAll(self):
- global Session
- Session = scoped_session(sessionmaker(autoflush=True, transactional=True), enhance_classes=True)
+ global Session, mapper
+ Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+ mapper = Session.mapper
def tearDownAll(self):
global_extensions[:] = []
def tearDown(self):