summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/scoping.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
commit4a6afd469fad170868554bf28578849bf3dfd5dd (patch)
treeb396edc33d567ae19dd244e87137296450467725 /lib/sqlalchemy/orm/scoping.py
parent46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff)
downloadsqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'lib/sqlalchemy/orm/scoping.py')
-rw-r--r--lib/sqlalchemy/orm/scoping.py56
1 files changed, 39 insertions, 17 deletions
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 479b2f737..c1d3db9f1 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -1,8 +1,17 @@
+# scoping.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import inspect
+import types
+
+import sqlalchemy.exceptions as sa_exc
from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs
-from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, class_mapper
+from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, \
+ class_mapper
from sqlalchemy.orm.session import Session
-from sqlalchemy import exceptions
-import types
__all__ = ['ScopedSession']
@@ -33,7 +42,7 @@ class ScopedSession(object):
scope = kwargs.pop('scope', False)
if scope is not None:
if self.registry.has():
- raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
+ raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
else:
sess = self.session_factory(**kwargs)
self.registry.set(sess)
@@ -53,7 +62,7 @@ class ScopedSession(object):
from sqlalchemy.orm import mapper
- extension_args = dict([(arg,kwargs.pop(arg))
+ extension_args = dict([(arg, kwargs.pop(arg))
for arg in get_cls_kwargs(_ScopedExt)
if arg in kwargs])
@@ -110,10 +119,10 @@ for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map'):
setattr(ScopedSession, prop, makeprop(prop))
def clslevel(name):
- def do(cls, *args,**kwargs):
+ def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs)
return classmethod(do)
-for prop in ('close_all','object_session', 'identity_key'):
+for prop in ('close_all', 'object_session', 'identity_key'):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
@@ -121,6 +130,7 @@ class _ScopedExt(MapperExtension):
self.context = context
self.validate = validate
self.save_on_init = save_on_init
+ self.set_kwargs_on_init = None
def validating(self):
return _ScopedExt(self.context, validate=True)
@@ -128,37 +138,49 @@ class _ScopedExt(MapperExtension):
def configure(self, **kwargs):
return _ScopedExt(self.context, **kwargs)
- def get_session(self):
- return self.context.registry()
-
def instrument_class(self, mapper, class_):
class query(object):
def __getattr__(s, key):
return getattr(self.context.registry().query(class_), key)
def __call__(s):
return self.context.registry().query(class_)
-
+ def __get__(self, instance, cls):
+ return self
+
if not 'query' in class_.__dict__:
class_.query = query()
-
+
+ if self.set_kwargs_on_init is None:
+ self.set_kwargs_on_init = class_.__init__ is object.__init__
+ if self.set_kwargs_on_init:
+ def __init__(self, **kwargs):
+ pass
+ class_.__init__ = __init__
+
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
if self.save_on_init:
entity_name = kwargs.pop('_sa_entity_name', None)
session = kwargs.pop('_sa_session', None)
- if not isinstance(oldinit, types.MethodType):
+
+ if self.set_kwargs_on_init:
for key, value in kwargs.items():
if self.validate:
- if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
- raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ if not mapper.get_property(key, resolve_synonyms=False,
+ raiseerr=False):
+ raise sa_exc.ArgumentError(
+ "Invalid __init__ argument: '%s'" % key)
setattr(instance, key, value)
kwargs.clear()
+
if self.save_on_init:
session = session or self.context.registry()
- session._save_impl(instance, entity_name=entity_name)
+ session._save_without_cascade(instance, entity_name=entity_name)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
- object_session(instance).expunge(instance)
+ sess = object_session(instance)
+ if sess:
+ sess.expunge(instance)
return EXT_CONTINUE
def dispose_class(self, mapper, class_):