diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-09 16:34:10 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-09 16:34:10 +0000 |
| commit | 4a6afd469fad170868554bf28578849bf3dfd5dd (patch) | |
| tree | b396edc33d567ae19dd244e87137296450467725 /lib/sqlalchemy/orm/scoping.py | |
| parent | 46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff) | |
| download | sqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz | |
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'lib/sqlalchemy/orm/scoping.py')
| -rw-r--r-- | lib/sqlalchemy/orm/scoping.py | 56 |
1 files changed, 39 insertions, 17 deletions
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 479b2f737..c1d3db9f1 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -1,8 +1,17 @@ +# scoping.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import inspect +import types + +import sqlalchemy.exceptions as sa_exc from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs -from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, class_mapper +from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, \ + class_mapper from sqlalchemy.orm.session import Session -from sqlalchemy import exceptions -import types __all__ = ['ScopedSession'] @@ -33,7 +42,7 @@ class ScopedSession(object): scope = kwargs.pop('scope', False) if scope is not None: if self.registry.has(): - raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") + raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") else: sess = self.session_factory(**kwargs) self.registry.set(sess) @@ -53,7 +62,7 @@ class ScopedSession(object): from sqlalchemy.orm import mapper - extension_args = dict([(arg,kwargs.pop(arg)) + extension_args = dict([(arg, kwargs.pop(arg)) for arg in get_cls_kwargs(_ScopedExt) if arg in kwargs]) @@ -110,10 +119,10 @@ for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map'): setattr(ScopedSession, prop, makeprop(prop)) def clslevel(name): - def do(cls, *args,**kwargs): + def do(cls, *args, **kwargs): return getattr(Session, name)(*args, **kwargs) return classmethod(do) -for prop in ('close_all','object_session', 'identity_key'): +for prop in ('close_all', 'object_session', 'identity_key'): setattr(ScopedSession, prop, clslevel(prop)) class _ScopedExt(MapperExtension): @@ -121,6 +130,7 @@ class _ScopedExt(MapperExtension): self.context = context self.validate = validate self.save_on_init = save_on_init + self.set_kwargs_on_init = None def validating(self): return _ScopedExt(self.context, validate=True) @@ -128,37 +138,49 @@ class _ScopedExt(MapperExtension): def configure(self, **kwargs): return _ScopedExt(self.context, **kwargs) - def get_session(self): - return self.context.registry() - def instrument_class(self, mapper, class_): class query(object): def __getattr__(s, key): return getattr(self.context.registry().query(class_), key) def __call__(s): return self.context.registry().query(class_) - + def __get__(self, instance, cls): + return self + if not 'query' in class_.__dict__: class_.query = query() - + + if self.set_kwargs_on_init is None: + self.set_kwargs_on_init = class_.__init__ is object.__init__ + if self.set_kwargs_on_init: + def __init__(self, **kwargs): + pass + class_.__init__ = __init__ + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): if self.save_on_init: entity_name = kwargs.pop('_sa_entity_name', None) session = kwargs.pop('_sa_session', None) - if not isinstance(oldinit, types.MethodType): + + if self.set_kwargs_on_init: for key, value in kwargs.items(): if self.validate: - if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False): - raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) + if not mapper.get_property(key, resolve_synonyms=False, + raiseerr=False): + raise sa_exc.ArgumentError( + "Invalid __init__ argument: '%s'" % key) setattr(instance, key, value) kwargs.clear() + if self.save_on_init: session = session or self.context.registry() - session._save_impl(instance, entity_name=entity_name) + session._save_without_cascade(instance, entity_name=entity_name) return EXT_CONTINUE def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): - object_session(instance).expunge(instance) + sess = object_session(instance) + if sess: + sess.expunge(instance) return EXT_CONTINUE def dispose_class(self, mapper, class_): |
