summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-02-14 00:30:30 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-02-14 00:30:30 +0000
commit791e2f7f7da88bd13a1002540755f920e6703711 (patch)
treed4adbf5923feb0b776705006d257956c274a0f27
parent8a3c00bc5a705132f8c03263e330acbf373a73bf (diff)
downloadsqlalchemy-791e2f7f7da88bd13a1002540755f920e6703711.tar.gz
latest reorgnanization of the objectstore, the Session is a simpler object that just maintains begin/commit state
-rw-r--r--lib/sqlalchemy/mapping/mapper.py10
-rw-r--r--lib/sqlalchemy/mapping/objectstore.py174
-rw-r--r--lib/sqlalchemy/util.py2
-rw-r--r--test/objectstore.py41
-rw-r--r--test/tables.py7
5 files changed, 143 insertions, 91 deletions
diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py
index 3a900569c..24bf11fd8 100644
--- a/lib/sqlalchemy/mapping/mapper.py
+++ b/lib/sqlalchemy/mapping/mapper.py
@@ -204,7 +204,7 @@ class Mapper(object):
oldinit = self.class_.__init__
def init(self, *args, **kwargs):
nohist = kwargs.pop('_mapper_nohistory', False)
- session = kwargs.pop('_sa_session', objectstore.session())
+ session = kwargs.pop('_sa_session', objectstore.get_session())
if oldinit is not None:
try:
oldinit(self, *args, **kwargs)
@@ -244,7 +244,7 @@ class Mapper(object):
# store new stuff in the identity map
for value in imap.values():
- objectstore.session().register_clean(value)
+ objectstore.get_session().register_clean(value)
if len(mappers):
return [result] + otherresults
@@ -261,7 +261,7 @@ class Mapper(object):
def _get(self, key, ident=None):
try:
- return objectstore.session()._get(key)
+ return objectstore.get_session()._get(key)
except KeyError:
if ident is None:
ident = key[2]
@@ -688,8 +688,8 @@ class Mapper(object):
# including modifying any of its related items lists, as its already
# been exposed to being modified by the application.
identitykey = self._identity_key(row)
- if objectstore.session().has_key(identitykey):
- instance = objectstore.session()._get(identitykey)
+ if objectstore.get_session().has_key(identitykey):
+ instance = objectstore.get_session()._get(identitykey)
isnew = False
if populate_existing:
diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py
index 078c5a179..c1549ffb7 100644
--- a/lib/sqlalchemy/mapping/objectstore.py
+++ b/lib/sqlalchemy/mapping/objectstore.py
@@ -32,31 +32,25 @@ class Session(object):
The registry is capable of maintaining object instances on a thread-local,
per-application, or custom user-defined basis."""
- def __init__(self, scope="application", getter=None, hash_key=None, keyfunc=None):
+ def __init__(self, nest_transactions=False, hash_key=None):
"""Initialize the objectstore with a UnitOfWork registry. If called
with no arguments, creates a single UnitOfWork for all operations.
- scope - "application" or "thread", the two default scopes
- getter - a callable that takes this Session as an argument and returns a
- new UnitOfWork.
+ nest_transactions - indicates begin/commit statements can be executed in a
+ "nested", defaults to False which indicates "only commit on the outermost begin/commit"
hash_key - the hash_key used to identify objects against this session, which
defaults to the id of the Session instance.
- keyfunc - allows custom scopes by providing a callable to return the "key"
- identifying the desired UnitOfWork.
"""
- if keyfunc is None:
- if scope=="thread":
- keyfunc = thread.get_ident
- elif scope=="application":
- keyfunc = lambda: True
- if getter is None:
- def createfunc():
- return UnitOfWork(self)
+ self.uow = UnitOfWork()
+ self.parent_uow = None
+ self.begin_count = 0
+ self.nest_transactions = nest_transactions
+ if hash_key is None:
+ self.hash_key = id(self)
else:
- createfunc = lambda: getter(self)
- self.registry = util.ScopedRegistry(createfunc, keyfunc)
- self._hash_key = hash_key
-
+ self.hash_key = hash_key
+ _sessions[self.hash_key] = self
+
def get_id_key(ident, class_, table):
"""returns an identity-map key for use in storing/retrieving an item from the identity
map, given a tuple of the object's primary key values.
@@ -92,29 +86,69 @@ class Session(object):
return (class_, table.hash_key(), tuple([row[column] for column in primary_key]))
get_row_key = staticmethod(get_row_key)
- def _set_uow(self, uow):
- self.registry.set(uow)
- uow = property(lambda s:s.registry(), _set_uow, doc="Returns a scope-specific UnitOfWork object for this session.")
-
- hash_key = property(lambda s:s._hash_key or id(s))
+ def begin(self):
+ """begins a new UnitOfWork transaction. the next commit will affect only
+ objects that are created, modified, or deleted following the begin statement."""
+ self.begin_count += 1
+ if self.parent_uow is not None:
+ return
+ self.parent_uow = self.uow
+ self.uow = UnitOfWork(identity_map = self.uow.identity_map)
+
+ def commit(self, *objects):
+ """commits the current UnitOfWork transaction. if a transaction was begun
+ via begin(), commits only those objects that were created, modified, or deleted
+ since that begin statement. otherwise commits all objects that have been
+ changed.
+ if individual objects are submitted, then only those objects are committed, and the
+ begin/commit cycle is not affected."""
+ # if an object list is given, commit just those but dont
+ # change begin/commit status
+ if len(objects):
+ self.uow.commit(*objects)
+ return
+ if self.parent_uow is not None:
+ self.begin_count -= 1
+ if self.begin_count > 0:
+ return
+ self.uow.commit()
+ if self.parent_uow is not None:
+ self.uow = self.parent_uow
+ self.parent_uow = None
+
+ def rollback(self):
+ """rolls back the current UnitOfWork transaction, in the case that begin()
+ has been called. The changes logged since the begin() call are discarded."""
+ if self.parent_uow is None:
+ raise "UOW transaction is not begun"
+ self.uow = self.parent_uow
+ self.parent_uow = None
+ self.begin_count = 0
+
+ def register_clean(self, obj):
+ self._bind_to(obj)
+ self.uow.register_clean(obj)
+
+ def register_new(self, obj):
+ self._bind_to(obj)
+ self.uow.register_new(obj)
- def bind_to(self, obj):
+ def _bind_to(self, obj):
"""given an object, binds it to this session. changes on the object will affect
the currently scoped UnitOfWork maintained by this session."""
obj._sa_session_id = self.hash_key
def __getattr__(self, key):
"""proxy other methods to our underlying UnitOfWork"""
- return getattr(self.registry(), key)
+ return getattr(self.uow, key)
def clear(self):
- self.registry.clear()
+ self.uow = UnitOfWork()
- def delete(*obj):
+ def delete(self, *obj):
"""registers the given objects as to be deleted upon the next commit"""
- u = registry()
for o in obj:
- u.register_deleted(o)
+ self.uow.register_deleted(o)
def import_instance(self, instance):
"""places the given instance in the current thread's unit of work context,
@@ -130,7 +164,7 @@ class Session(object):
key = getattr(instance, '_instance_key', None)
mapper = object_mapper(instance)
key = (key[0], mapper.table.hash_key(), key[2])
- u = self.registry()
+ u = self.uow
if key is not None:
if u.identity_map.has_key(key):
return u.identity_map[key]
@@ -141,7 +175,6 @@ class Session(object):
else:
u.register_new(instance)
return instance
-
def get_id_key(ident, class_, table):
return Session.get_id_key(ident, class_, table)
@@ -152,53 +185,54 @@ def get_row_key(row, class_, table, primary_key):
def begin():
"""begins a new UnitOfWork transaction. the next commit will affect only
objects that are created, modified, or deleted following the begin statement."""
- session().begin()
+ get_session().begin()
def commit(*obj):
"""commits the current UnitOfWork transaction. if a transaction was begun
via begin(), commits only those objects that were created, modified, or deleted
since that begin statement. otherwise commits all objects that have been
- changed."""
- session().commit(*obj)
+ changed.
+
+ if individual objects are submitted, then only those objects are committed, and the
+ begin/commit cycle is not affected."""
+ get_session().commit(*obj)
def clear():
"""removes all current UnitOfWorks and IdentityMaps for this thread and
establishes a new one. It is probably a good idea to discard all
current mapped object instances, as they are no longer in the Identity Map."""
- session().clear()
+ get_session().clear()
def delete(*obj):
"""registers the given objects as to be deleted upon the next commit"""
- s = session()
- for o in obj:
- s.register_deleted(o)
+ s = get_session().delete(*obj)
def has_key(key):
"""returns True if the current thread-local IdentityMap contains the given instance key"""
- return session().has_key(key)
+ return get_session().has_key(key)
def has_instance(instance):
"""returns True if the current thread-local IdentityMap contains the given instance"""
- return session().has_instance(instance)
+ return get_session().has_instance(instance)
def is_dirty(obj):
"""returns True if the given object is in the current UnitOfWork's new or dirty list,
or if its a modified list attribute on an object."""
- return session().is_dirty(obj)
+ return get_session().is_dirty(obj)
def instance_key(instance):
"""returns the IdentityMap key for the given instance"""
- return session().instance_key(instance)
+ return get_session().instance_key(instance)
def import_instance(instance):
- return session().import_instance(instance)
+ return get_session().import_instance(instance)
class UOWListElement(attributes.ListElement):
def __init__(self, obj, key, data=None, deleteremoved=False, **kwargs):
attributes.ListElement.__init__(self, obj, key, data=data, **kwargs)
self.deleteremoved = deleteremoved
def list_value_changed(self, obj, key, item, listval, isdelete):
- sess = session(obj)
+ sess = get_session(obj)
if not isdelete and sess.deleted.contains(item):
raise "re-inserting a deleted value into a list"
sess.modified_lists.append(self)
@@ -216,23 +250,17 @@ class UOWAttributeManager(attributes.AttributeManager):
def value_changed(self, obj, key, value):
if hasattr(obj, '_instance_key'):
- session(obj).register_dirty(obj)
+ get_session(obj).register_dirty(obj)
else:
- session(obj).register_new(obj)
+ get_session(obj).register_new(obj)
def create_list(self, obj, key, list_, **kwargs):
return UOWListElement(obj, key, list_, **kwargs)
class UnitOfWork(object):
- def __init__(self, session, parent=None, is_begun=False):
- self.session = session
- self.is_begun = is_begun
- if is_begun:
- self.begin_count = 1
- else:
- self.begin_count = 0
- if parent is not None:
- self.identity_map = parent.identity_map
+ def __init__(self, identity_map=None):
+ if identity_map is not None:
+ self.identity_map = identity_map
else:
self.identity_map = weakref.WeakValueDictionary()
@@ -241,7 +269,6 @@ class UnitOfWork(object):
self.dirty = util.HashSet()
self.modified_lists = util.HashSet()
self.deleted = util.HashSet()
- self.parent = parent
def get(self, class_, *id):
"""given a class and a list of primary key values in their table-order, locates the mapper
@@ -305,12 +332,10 @@ class UnitOfWork(object):
if not hasattr(obj, '_instance_key'):
mapper = object_mapper(obj)
obj._instance_key = mapper.instance_key(obj)
- self.session.bind_to(obj)
self._put(obj._instance_key, obj)
self.attributes.commit(obj)
def register_new(self, obj):
- self.session.bind_to(obj)
self.new.append(obj)
def register_dirty(self, obj):
@@ -335,19 +360,7 @@ class UnitOfWork(object):
except KeyError:
pass
- # TODO: tie in register_new/register_dirty with table transaction begins ?
- def begin(self):
- if self.is_begun:
- self.begin_count += 1
- return
- u = UnitOfWork(self.session, parent=self, is_begun=True)
- self.session.registry.set(u)
-
def commit(self, *objects):
- if self.is_begun:
- self.begin_count -= 1
- if self.begin_count > 0:
- return
commit_context = UOWTransaction(self)
if len(objects):
@@ -394,16 +407,12 @@ class UnitOfWork(object):
except:
for e in engines:
e.rollback()
- if self.parent:
- self.session.registry.set(self.parent)
raise
for e in engines:
e.commit()
commit_context.post_exec()
- if self.parent:
- self.session.registry.set(self.parent)
def rollback_object(self, obj):
"""'rolls back' the attributes that have been changed on an object instance."""
@@ -975,13 +984,11 @@ def object_mapper(obj):
global_attributes = UOWAttributeManager()
-global_session = Session(scope="thread", hash_key='thread')
-uow = global_session.registry # Note: this is not a UnitOfWork, it is a ScopedRegistry that manages UnitOfWork objects
-_sessions = weakref.WeakValueDictionary()
-_sessions[global_session.hash_key] = global_session
+session_registry = util.ScopedRegistry(Session) # Default session registry
+_sessions = weakref.WeakValueDictionary() # all referenced sessions (including user-created)
-def session(obj=None):
+def get_session(obj=None):
# object-specific session ?
if obj is not None:
# does it have a hash key ?
@@ -993,12 +1000,9 @@ def session(obj=None):
except KeyError:
raise "Session '%s' referenced by object '%s' no longer exists" % (hashkey, repr(obj))
- try:
- # have a thread-locally defined session (via using_session) ?
- return _sessions[thread.get_ident()]
- except KeyError:
- # nope, return the regular session
- return global_session
+ return session_registry()
+
+uow = get_session # deprecated
def push_session(sess):
old = _sessions.get(thread.get_ident(), None)
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 217210646..633091dd3 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -377,7 +377,7 @@ class ScopedRegistry(object):
def __init__(self, createfunc, scopefunc=None):
self.createfunc = createfunc
if scopefunc is None:
- scopefunc = thread.get_ident
+ self.scopefunc = thread.get_ident
else:
self.scopefunc = scopefunc
self.registry = {}
diff --git a/test/objectstore.py b/test/objectstore.py
index 855160392..bc90ec538 100644
--- a/test/objectstore.py
+++ b/test/objectstore.py
@@ -72,6 +72,47 @@ class HistoryTest(AssertMixin):
u = m.select()[0]
print u.addresses[0].user
+class SessionTest(AssertMixin):
+ def setUpAll(self):
+ db.echo = False
+ users.create()
+ tables.user_data()
+ db.echo = testbase.echo
+ def tearDownAll(self):
+ db.echo = False
+ users.drop()
+ db.echo = testbase.echo
+ def setUp(self):
+ objectstore.get_session().clear()
+ clear_mappers()
+
+ def test_nested_begin_commit(self):
+ """test nested session.begin/commit"""
+ class User(object):pass
+ m = mapper(User, users)
+ def name_of(id):
+ return users.select(users.c.user_id == id).execute().fetchone().user_name
+ name1 = "Oliver Twist"
+ name2 = 'Mr. Bumble'
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ s = objectstore.get_session()
+ s.begin()
+ s.begin()
+ m.get(7).user_name = name1
+ s.begin()
+ m.get(8).user_name = name2
+ s.commit()
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ s.commit()
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ s.commit()
+ self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1)
+ self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2)
+
+
class PKTest(AssertMixin):
def setUpAll(self):
db.echo = False
diff --git a/test/tables.py b/test/tables.py
index 00f946af4..fecd86bc4 100644
--- a/test/tables.py
+++ b/test/tables.py
@@ -71,6 +71,13 @@ def delete():
users.delete().execute()
db.commit()
+def user_data():
+ users.insert().execute(
+ dict(user_id = 7, user_name = 'jack'),
+ dict(user_id = 8, user_name = 'ed'),
+ dict(user_id = 9, user_name = 'fred')
+ )
+
def data():
delete()