summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/scoping.py8
-rw-r--r--lib/sqlalchemy/util.py11
2 files changed, 8 insertions, 11 deletions
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index f00b30849..a8ed9c910 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -5,7 +5,8 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc
-from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs, deprecated
+from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \
+ to_list, get_cls_kwargs, deprecated
from sqlalchemy.orm import (
EXT_CONTINUE, MapperExtension, class_mapper, object_session
)
@@ -29,7 +30,10 @@ class ScopedSession(object):
def __init__(self, session_factory, scopefunc=None):
self.session_factory = session_factory
- self.registry = ScopedRegistry(session_factory, scopefunc)
+ if scopefunc:
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+ else:
+ self.registry = ThreadLocalRegistry(session_factory)
self.extension = _ScopedExt(self)
def __call__(self, **kwargs):
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 8f0b5583d..da426cbd8 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -1180,14 +1180,7 @@ class ScopedRegistry(object):
scopefunc
a callable that will return a key to store/retrieve an object.
- If None, ScopedRegistry uses a threading.local object instead.
-
"""
- def __new__(cls, createfunc, scopefunc=None):
- if not scopefunc:
- return object.__new__(_TLocalRegistry)
- else:
- return object.__new__(cls)
def __init__(self, createfunc, scopefunc):
self.createfunc = createfunc
@@ -1213,8 +1206,8 @@ class ScopedRegistry(object):
except KeyError:
pass
-class _TLocalRegistry(ScopedRegistry):
- def __init__(self, createfunc, scopefunc=None):
+class ThreadLocalRegistry(ScopedRegistry):
+ def __init__(self, createfunc):
self.createfunc = createfunc
self.registry = threading.local()