summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/mods
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
commitbb79e2e871d0a4585164c1a6ed626d96d0231975 (patch)
tree6d457ba6c36c408b45db24ec3c29e147fe7504ff /lib/sqlalchemy/mods
parent4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff)
downloadsqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'lib/sqlalchemy/mods')
-rw-r--r--lib/sqlalchemy/mods/__init__.py7
-rw-r--r--lib/sqlalchemy/mods/legacy_session.py139
-rw-r--r--lib/sqlalchemy/mods/selectresults.py87
-rw-r--r--lib/sqlalchemy/mods/threadlocal.py46
4 files changed, 189 insertions, 90 deletions
diff --git a/lib/sqlalchemy/mods/__init__.py b/lib/sqlalchemy/mods/__init__.py
index 328df3c56..e69de29bb 100644
--- a/lib/sqlalchemy/mods/__init__.py
+++ b/lib/sqlalchemy/mods/__init__.py
@@ -1,7 +0,0 @@
-def install_mods(*mods):
- for mod in mods:
- if isinstance(mod, str):
- _mod = getattr(__import__('sqlalchemy.mods.%s' % mod).mods, mod)
- _mod.install_plugin()
- else:
- mod.install_plugin() \ No newline at end of file
diff --git a/lib/sqlalchemy/mods/legacy_session.py b/lib/sqlalchemy/mods/legacy_session.py
new file mode 100644
index 000000000..7dbeda924
--- /dev/null
+++ b/lib/sqlalchemy/mods/legacy_session.py
@@ -0,0 +1,139 @@
+
+import sqlalchemy.orm.objectstore as objectstore
+import sqlalchemy.orm.unitofwork as unitofwork
+import sqlalchemy.util as util
+import sqlalchemy
+
+import sqlalchemy.mods.threadlocal
+
+class LegacySession(objectstore.Session):
+ def __init__(self, nest_on=None, hash_key=None, **kwargs):
+ super(LegacySession, self).__init__(**kwargs)
+ self.parent_uow = None
+ self.begin_count = 0
+ self.nest_on = util.to_list(nest_on)
+ self.__pushed_count = 0
+ def was_pushed(self):
+ if self.nest_on is None:
+ return
+ self.__pushed_count += 1
+ if self.__pushed_count == 1:
+ for n in self.nest_on:
+ n.push_session()
+ def was_popped(self):
+ if self.nest_on is None or self.__pushed_count == 0:
+ return
+ self.__pushed_count -= 1
+ if self.__pushed_count == 0:
+ for n in self.nest_on:
+ n.pop_session()
+ class SessionTrans(object):
+ """returned by Session.begin(), denotes a transactionalized UnitOfWork instance.
+ call commit() on this to commit the transaction."""
+ def __init__(self, parent, uow, isactive):
+ self.__parent = parent
+ self.__isactive = isactive
+ self.__uow = uow
+ isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.")
+ parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.")
+ uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.")
+ def begin(self):
+ """calls begin() on the underlying Session object, returning a new no-op SessionTrans object."""
+ if self.parent.uow is not self.uow:
+ raise InvalidRequestError("This SessionTrans is no longer valid")
+ return self.parent.begin()
+ def commit(self):
+ """commits the transaction noted by this SessionTrans object."""
+ self.__parent._trans_commit(self)
+ self.__isactive = False
+ def rollback(self):
+ """rolls back the current UnitOfWork transaction, in the case that begin()
+ has been called. The changes logged since the begin() call are discarded."""
+ self.__parent._trans_rollback(self)
+ self.__isactive = False
+ def begin(self):
+ """begins a new UnitOfWork transaction and returns a tranasaction-holding
+ object. commit() or rollback() should be called on the returned object.
+ commit() on the Session will do nothing while a transaction is pending, and further
+ calls to begin() will return no-op transactional objects."""
+ if self.parent_uow is not None:
+ return LegacySession.SessionTrans(self, self.uow, False)
+ self.parent_uow = self.uow
+ self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map)
+ return LegacySession.SessionTrans(self, self.uow, True)
+ def commit(self, *objects):
+ """commits the current UnitOfWork transaction. called with
+ no arguments, this is only used
+ for "implicit" transactions when there was no begin().
+ if individual objects are submitted, then only those objects are committed, and the
+ begin/commit cycle is not affected."""
+ # if an object list is given, commit just those but dont
+ # change begin/commit status
+ if len(objects):
+ self._commit_uow(*objects)
+ self.uow.flush(self, *objects)
+ return
+ if self.parent_uow is None:
+ self._commit_uow()
+ def _trans_commit(self, trans):
+ if trans.uow is self.uow and trans.isactive:
+ try:
+ self._commit_uow()
+ finally:
+ self.uow = self.parent_uow
+ self.parent_uow = None
+ def _trans_rollback(self, trans):
+ if trans.uow is self.uow:
+ self.uow = self.parent_uow
+ self.parent_uow = None
+ def _commit_uow(self, *obj):
+ self.was_pushed()
+ try:
+ self.uow.flush(self, *obj)
+ finally:
+ self.was_popped()
+
+def begin():
+ """deprecated. use s = Session(new_imap=False)."""
+ return objectstore.get_session().begin()
+
+def commit(*obj):
+ """deprecated; use flush(*obj)"""
+ objectstore.get_session().flush(*obj)
+
+def uow():
+ return objectstore.get_session()
+
+def push_session(sess):
+ old = get_session()
+ if getattr(sess, '_previous', None) is not None:
+ raise InvalidRequestError("Given Session is already pushed onto some thread's stack")
+ sess._previous = old
+ session_registry.set(sess)
+ sess.was_pushed()
+
+def pop_session():
+ sess = get_session()
+ old = sess._previous
+ sess._previous = None
+ session_registry.set(old)
+ sess.was_popped()
+ return old
+
+def using_session(sess, func):
+ push_session(sess)
+ try:
+ return func()
+ finally:
+ pop_session()
+
+def install_plugin():
+ objectstore.Session = LegacySession
+ objectstore.session_registry = util.ScopedRegistry(objectstore.Session)
+ objectstore.begin = begin
+ objectstore.commit = commit
+ objectstore.uow = uow
+ objectstore.push_session = push_session
+ objectstore.pop_session = pop_session
+ objectstore.using_session = using_session
+install_plugin()
diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py
index bff436ace..51ed6e4a5 100644
--- a/lib/sqlalchemy/mods/selectresults.py
+++ b/lib/sqlalchemy/mods/selectresults.py
@@ -1,86 +1,7 @@
-import sqlalchemy.sql as sql
+from sqlalchemy.ext.selectresults import *
+from sqlalchemy.orm.mapper import global_extensions
-import sqlalchemy.mapping as mapping
def install_plugin():
- mapping.global_extensions.append(SelectResultsExt)
-
-class SelectResultsExt(mapping.MapperExtension):
- def select_by(self, query, *args, **params):
- return SelectResults(query, query._by_clause(*args, **params))
- def select(self, query, arg=None, **kwargs):
- if arg is not None and isinstance(arg, sql.Selectable):
- return mapping.EXT_PASS
- else:
- return SelectResults(query, arg, ops=kwargs)
-
-MapperExtension = SelectResultsExt
-
-class SelectResults(object):
- def __init__(self, query, clause=None, ops={}):
- self._query = query
- self._clause = clause
- self._ops = {}
- self._ops.update(ops)
-
- def count(self):
- return self._query.count(self._clause)
-
- def min(self, col):
- return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar()
-
- def max(self, col):
- return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar()
-
- def sum(self, col):
- return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar()
-
- def avg(self, col):
- return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar()
-
- def clone(self):
- return SelectResults(self._query, self._clause, self._ops.copy())
-
- def filter(self, clause):
- new = self.clone()
- new._clause = sql.and_(self._clause, clause)
- return new
-
- def order_by(self, order_by):
- new = self.clone()
- new._ops['order_by'] = order_by
- return new
-
- def limit(self, limit):
- return self[:limit]
-
- def offset(self, offset):
- return self[offset:]
-
- def list(self):
- return list(self)
-
- def __getitem__(self, item):
- if isinstance(item, slice):
- start = item.start
- stop = item.stop
- if (isinstance(start, int) and start < 0) or \
- (isinstance(stop, int) and stop < 0):
- return list(self)[item]
- else:
- res = self.clone()
- if start is not None and stop is not None:
- res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start))
- elif start is None and stop is not None:
- res._ops.update(dict(limit=stop))
- elif start is not None and stop is None:
- res._ops.update(dict(offset=self._ops.get('offset', 0)+start))
- if item.step is not None:
- return list(res)[None:None:item.step]
- else:
- return res
- else:
- return list(self[item:item+1])[0]
-
- def __iter__(self):
- return iter(self._query.select_whereclause(self._clause, **self._ops))
+ global_extensions.append(SelectResultsExt)
+install_plugin()
diff --git a/lib/sqlalchemy/mods/threadlocal.py b/lib/sqlalchemy/mods/threadlocal.py
new file mode 100644
index 000000000..b67329612
--- /dev/null
+++ b/lib/sqlalchemy/mods/threadlocal.py
@@ -0,0 +1,46 @@
+from sqlalchemy import util, engine, mapper
+from sqlalchemy.ext.sessioncontext import SessionContext
+import sqlalchemy.ext.assignmapper as assignmapper
+from sqlalchemy.orm.mapper import global_extensions
+from sqlalchemy.orm.session import Session
+import sqlalchemy
+import sys, types
+
+"""this plugin installs thread-local behavior at the Engine and Session level.
+
+The default Engine strategy will be "threadlocal", producing TLocalEngine instances for create_engine by default.
+With this engine, connect() method will return the same connection on the same thread, if it is already checked out
+from the pool. this greatly helps functions that call multiple statements to be able to easily use just one connection
+without explicit "close" statements on result handles.
+
+on the Session side, module-level methods will be installed within the objectstore module, such as flush(), delete(), etc.
+which call this method on the thread-local session.
+
+Note: this mod creates a global, thread-local session context named sqlalchemy.objectstore. All mappers created
+while this mod is installed will reference this global context when creating new mapped object instances.
+"""
+
+class Objectstore(SessionContext):
+ def __getattr__(self, key):
+ return getattr(self.current, key)
+ def get_session(self):
+ return self.current
+
+def assign_mapper(class_, *args, **kwargs):
+ assignmapper.assign_mapper(objectstore, class_, *args, **kwargs)
+
+def _mapper_extension():
+ return SessionContext._get_mapper_extension(objectstore)
+
+objectstore = Objectstore(Session)
+def install_plugin():
+ sqlalchemy.objectstore = objectstore
+ global_extensions.append(_mapper_extension)
+ engine.default_strategy = 'threadlocal'
+ sqlalchemy.assign_mapper = assign_mapper
+
+def uninstall_plugin():
+ engine.default_strategy = 'plain'
+ global_extensions.remove(_mapper_extension)
+
+install_plugin()