summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
commit4a6afd469fad170868554bf28578849bf3dfd5dd (patch)
treeb396edc33d567ae19dd244e87137296450467725 /lib/sqlalchemy/ext
parent46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff)
downloadsqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/activemapper.py298
-rw-r--r--lib/sqlalchemy/ext/assignmapper.py72
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py77
-rw-r--r--lib/sqlalchemy/ext/declarative.py4
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py22
-rw-r--r--lib/sqlalchemy/ext/selectresults.py28
-rw-r--r--lib/sqlalchemy/ext/sessioncontext.py50
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py34
8 files changed, 96 insertions, 489 deletions
diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py
deleted file mode 100644
index 02f4b5b35..000000000
--- a/lib/sqlalchemy/ext/activemapper.py
+++ /dev/null
@@ -1,298 +0,0 @@
-from sqlalchemy import ThreadLocalMetaData, util, Integer
-from sqlalchemy import Table, Column, ForeignKey
-from sqlalchemy.orm import class_mapper, relation, scoped_session
-from sqlalchemy.orm import sessionmaker
-
-from sqlalchemy.orm import backref as create_backref
-
-import inspect
-import sys
-
-#
-# the "proxy" to the database engine... this can be swapped out at runtime
-#
-metadata = ThreadLocalMetaData()
-Objectstore = scoped_session
-objectstore = scoped_session(sessionmaker(autoflush=True, transactional=False))
-
-#
-# declarative column declaration - this is so that we can infer the colname
-#
-class column(object):
- def __init__(self, coltype, colname=None, foreign_key=None,
- primary_key=False, *args, **kwargs):
- if isinstance(foreign_key, basestring):
- foreign_key = ForeignKey(foreign_key)
-
- self.coltype = coltype
- self.colname = colname
- self.foreign_key = foreign_key
- self.primary_key = primary_key
- self.kwargs = kwargs
- self.args = args
-
-#
-# declarative relationship declaration
-#
-class relationship(object):
- def __init__(self, classname, colname=None, backref=None, private=False,
- lazy=True, uselist=True, secondary=None, order_by=False, viewonly=False):
- self.classname = classname
- self.colname = colname
- self.backref = backref
- self.private = private
- self.lazy = lazy
- self.uselist = uselist
- self.secondary = secondary
- self.order_by = order_by
- self.viewonly = viewonly
-
- def process(self, klass, propname, relations):
- relclass = ActiveMapperMeta.classes[self.classname]
-
- if isinstance(self.order_by, str):
- self.order_by = [ self.order_by ]
-
- if isinstance(self.order_by, list):
- for itemno in range(len(self.order_by)):
- if isinstance(self.order_by[itemno], str):
- self.order_by[itemno] = \
- getattr(relclass.c, self.order_by[itemno])
-
- backref = self.create_backref(klass)
- relations[propname] = relation(relclass.mapper,
- secondary=self.secondary,
- backref=backref,
- private=self.private,
- lazy=self.lazy,
- uselist=self.uselist,
- order_by=self.order_by,
- viewonly=self.viewonly)
-
- def create_backref(self, klass):
- if self.backref is None:
- return None
-
- relclass = ActiveMapperMeta.classes[self.classname]
-
- if klass.__name__ == self.classname:
- class_mapper(relclass).compile()
- br_fkey = relclass.c[self.colname]
- else:
- br_fkey = None
-
- return create_backref(self.backref, remote_side=br_fkey)
-
-
-class one_to_many(relationship):
- def __init__(self, *args, **kwargs):
- kwargs['uselist'] = True
- relationship.__init__(self, *args, **kwargs)
-
-class one_to_one(relationship):
- def __init__(self, *args, **kwargs):
- kwargs['uselist'] = False
- relationship.__init__(self, *args, **kwargs)
-
- def create_backref(self, klass):
- if self.backref is None:
- return None
-
- relclass = ActiveMapperMeta.classes[self.classname]
-
- if klass.__name__ == self.classname:
- br_fkey = getattr(relclass.c, self.colname)
- else:
- br_fkey = None
-
- return create_backref(self.backref, foreignkey=br_fkey, uselist=False)
-
-
-class many_to_many(relationship):
- def __init__(self, classname, secondary, backref=None, lazy=True,
- order_by=False):
- relationship.__init__(self, classname, None, backref, False, lazy,
- uselist=True, secondary=secondary,
- order_by=order_by)
-
-
-#
-# SQLAlchemy metaclass and superclass that can be used to do SQLAlchemy
-# mapping in a declarative way, along with a function to process the
-# relationships between dependent objects as they come in, without blowing
-# up if the classes aren't specified in a proper order
-#
-
-__deferred_classes__ = {}
-__processed_classes__ = {}
-def process_relationships(klass, was_deferred=False):
- # first, we loop through all of the relationships defined on the
- # class, and make sure that the related class already has been
- # completely processed and defer processing if it has not
- defer = False
- for propname, reldesc in klass.relations.items():
- found = (reldesc.classname == klass.__name__ or reldesc.classname in __processed_classes__)
- if not found:
- defer = True
- break
-
- # next, we loop through all the columns looking for foreign keys
- # and make sure that we can find the related tables (they do not
- # have to be processed yet, just defined), and we defer if we are
- # not able to find any of the related tables
- if not defer:
- for col in klass.columns:
- if col.foreign_keys:
- found = False
- cn = col.foreign_keys[0]._colspec
- table_name = cn[:cn.rindex('.')]
- for other_klass in ActiveMapperMeta.classes.values():
- if other_klass.table.fullname.lower() == table_name.lower():
- found = True
-
- if not found:
- defer = True
- break
-
- if defer and not was_deferred:
- __deferred_classes__[klass.__name__] = klass
-
- # if we are able to find all related and referred to tables, then
- # we can go ahead and assign the relationships to the class
- if not defer:
- relations = {}
- for propname, reldesc in klass.relations.items():
- reldesc.process(klass, propname, relations)
-
- class_mapper(klass).add_properties(relations)
- if klass.__name__ in __deferred_classes__:
- del __deferred_classes__[klass.__name__]
- __processed_classes__[klass.__name__] = klass
-
- # finally, loop through the deferred classes and attempt to process
- # relationships for them
- if not was_deferred:
- # loop through the list of deferred classes, processing the
- # relationships, until we can make no more progress
- last_count = len(__deferred_classes__) + 1
- while last_count > len(__deferred_classes__):
- last_count = len(__deferred_classes__)
- deferred = __deferred_classes__.copy()
- for deferred_class in deferred.values():
- process_relationships(deferred_class, was_deferred=True)
-
-
-class ActiveMapperMeta(type):
- classes = {}
- metadatas = util.Set()
- def __init__(cls, clsname, bases, dict):
- table_name = clsname.lower()
- columns = []
- relations = {}
- autoload = False
- _metadata = getattr(sys.modules[cls.__module__],
- "__metadata__", metadata)
- version_id_col = None
- version_id_col_object = None
- table_opts = {}
-
- if 'mapping' in dict:
- found_pk = False
-
- members = inspect.getmembers(dict.get('mapping'))
- for name, value in members:
- if name == '__table__':
- table_name = value
- continue
-
- if '__metadata__' == name:
- _metadata= value
- continue
-
- if '__autoload__' == name:
- autoload = True
- continue
-
- if '__version_id_col__' == name:
- version_id_col = value
-
- if '__table_opts__' == name:
- table_opts = value
-
- if name.startswith('__'): continue
-
- if isinstance(value, column):
- if value.primary_key == True: found_pk = True
-
- if value.foreign_key:
- col = Column(value.colname or name,
- value.coltype,
- value.foreign_key,
- primary_key=value.primary_key,
- *value.args, **value.kwargs)
- else:
- col = Column(value.colname or name,
- value.coltype,
- primary_key=value.primary_key,
- *value.args, **value.kwargs)
- columns.append(col)
- continue
-
- if isinstance(value, relationship):
- relations[name] = value
-
- if not found_pk and not autoload:
- col = Column('id', Integer, primary_key=True)
- cls.mapping.id = col
- columns.append(col)
-
- assert _metadata is not None, "No MetaData specified"
-
- ActiveMapperMeta.metadatas.add(_metadata)
-
- if not autoload:
- cls.table = Table(table_name, _metadata, *columns, **table_opts)
- cls.columns = columns
- else:
- cls.table = Table(table_name, _metadata, autoload=True, **table_opts)
- cls.columns = cls.table._columns
-
- if version_id_col is not None:
- version_id_col_object = getattr(cls.table.c, version_id_col, None)
- assert(version_id_col_object is not None, "version_id_col (%s) does not exist." % version_id_col)
-
- # check for inheritence
- if hasattr(bases[0], "mapping"):
- cls._base_mapper= bases[0].mapper
- cls.mapper = objectstore.mapper(cls, cls.table,
- inherits=cls._base_mapper, version_id_col=version_id_col_object)
- else:
- cls.mapper = objectstore.mapper(cls, cls.table, version_id_col=version_id_col_object)
- cls.relations = relations
- ActiveMapperMeta.classes[clsname] = cls
-
- process_relationships(cls)
-
- super(ActiveMapperMeta, cls).__init__(clsname, bases, dict)
-
-
-
-class ActiveMapper(object):
- __metaclass__ = ActiveMapperMeta
-
- def set(self, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
-
-
-#
-# a utility function to create all tables for all ActiveMapper classes
-#
-
-def create_tables():
- for metadata in ActiveMapperMeta.metadatas:
- metadata.create_all()
-
-def drop_tables():
- for metadata in ActiveMapperMeta.metadatas:
- metadata.drop_all()
diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py
deleted file mode 100644
index 5a28fbe68..000000000
--- a/lib/sqlalchemy/ext/assignmapper.py
+++ /dev/null
@@ -1,72 +0,0 @@
-from sqlalchemy import util, exceptions
-import types
-from sqlalchemy.orm import mapper, Query
-
-def _monkeypatch_query_method(name, ctx, class_):
- def do(self, *args, **kwargs):
- query = Query(class_, session=ctx.current)
- util.warn_deprecated('Query methods on the class are deprecated; use %s.query.%s instead' % (class_.__name__, name))
- return getattr(query, name)(*args, **kwargs)
- try:
- do.__name__ = name
- except:
- pass
- if not hasattr(class_, name):
- setattr(class_, name, classmethod(do))
-
-def _monkeypatch_session_method(name, ctx, class_):
- def do(self, *args, **kwargs):
- session = ctx.current
- return getattr(session, name)(self, *args, **kwargs)
- try:
- do.__name__ = name
- except:
- pass
- if not hasattr(class_, name):
- setattr(class_, name, do)
-
-def assign_mapper(ctx, class_, *args, **kwargs):
- extension = kwargs.pop('extension', None)
- if extension is not None:
- extension = util.to_list(extension)
- extension.append(ctx.mapper_extension)
- else:
- extension = ctx.mapper_extension
-
- validate = kwargs.pop('validate', False)
-
- if not isinstance(getattr(class_, '__init__'), types.MethodType):
- def __init__(self, **kwargs):
- for key, value in kwargs.items():
- if validate:
- if not self.mapper.get_property(key,
- resolve_synonyms=False,
- raiseerr=False):
- raise exceptions.ArgumentError(
- "Invalid __init__ argument: '%s'" % key)
- setattr(self, key, value)
- class_.__init__ = __init__
-
- class query(object):
- def __getattr__(self, key):
- return getattr(ctx.current.query(class_), key)
- def __call__(self):
- return ctx.current.query(class_)
-
- if not hasattr(class_, 'query'):
- class_.query = query()
-
- for name in ('get', 'filter', 'filter_by', 'select', 'select_by',
- 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by',
- 'get_by', 'join_to', 'join_via', 'count', 'count_by',
- 'options', 'instances'):
- _monkeypatch_query_method(name, ctx, class_)
- for name in ('refresh', 'expire', 'delete', 'expunge', 'update'):
- _monkeypatch_session_method(name, ctx, class_)
-
- m = mapper(class_, extension=extension, *args, **kwargs)
- class_.mapper = m
- return m
-
-assign_mapper = util.deprecated(
- "assign_mapper is deprecated. Use scoped_session() instead.")(assign_mapper)
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index d878f7b9b..4d54f6072 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -406,13 +406,26 @@ class _AssociationList(object):
def clear(self):
del self.col[0:len(self.col)]
- def __eq__(self, other): return list(self) == other
- def __ne__(self, other): return list(self) != other
- def __lt__(self, other): return list(self) < other
- def __le__(self, other): return list(self) <= other
- def __gt__(self, other): return list(self) > other
- def __ge__(self, other): return list(self) >= other
- def __cmp__(self, other): return cmp(list(self), other)
+ def __eq__(self, other):
+ return list(self) == other
+
+ def __ne__(self, other):
+ return list(self) != other
+
+ def __lt__(self, other):
+ return list(self) < other
+
+ def __le__(self, other):
+ return list(self) <= other
+
+ def __gt__(self, other):
+ return list(self) > other
+
+ def __ge__(self, other):
+ return list(self) >= other
+
+ def __cmp__(self, other):
+ return cmp(list(self), other)
def __add__(self, iterable):
try:
@@ -534,13 +547,26 @@ class _AssociationDict(object):
def clear(self):
self.col.clear()
- def __eq__(self, other): return dict(self) == other
- def __ne__(self, other): return dict(self) != other
- def __lt__(self, other): return dict(self) < other
- def __le__(self, other): return dict(self) <= other
- def __gt__(self, other): return dict(self) > other
- def __ge__(self, other): return dict(self) >= other
- def __cmp__(self, other): return cmp(dict(self), other)
+ def __eq__(self, other):
+ return dict(self) == other
+
+ def __ne__(self, other):
+ return dict(self) != other
+
+ def __lt__(self, other):
+ return dict(self) < other
+
+ def __le__(self, other):
+ return dict(self) <= other
+
+ def __gt__(self, other):
+ return dict(self) > other
+
+ def __ge__(self, other):
+ return dict(self) >= other
+
+ def __cmp__(self, other):
+ return cmp(dict(self), other)
def __repr__(self):
return repr(dict(self.items()))
@@ -802,12 +828,23 @@ class _AssociationSet(object):
def copy(self):
return util.Set(self)
- def __eq__(self, other): return util.Set(self) == other
- def __ne__(self, other): return util.Set(self) != other
- def __lt__(self, other): return util.Set(self) < other
- def __le__(self, other): return util.Set(self) <= other
- def __gt__(self, other): return util.Set(self) > other
- def __ge__(self, other): return util.Set(self) >= other
+ def __eq__(self, other):
+ return util.Set(self) == other
+
+ def __ne__(self, other):
+ return util.Set(self) != other
+
+ def __lt__(self, other):
+ return util.Set(self) < other
+
+ def __le__(self, other):
+ return util.Set(self) <= other
+
+ def __gt__(self, other):
+ return util.Set(self) > other
+
+ def __ge__(self, other):
+ return util.Set(self) >= other
def __repr__(self):
return repr(util.Set(self))
diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py
index d736736e9..f06f16059 100644
--- a/lib/sqlalchemy/ext/declarative.py
+++ b/lib/sqlalchemy/ext/declarative.py
@@ -213,6 +213,9 @@ class DeclarativeMeta(type):
continue
prop = _deferred_relation(cls, value)
our_stuff[k] = prop
+
+ # set up attributes in the order they were created
+ our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order, our_stuff[y]._creation_order))
table = None
if '__table__' not in cls.__dict__:
@@ -254,6 +257,7 @@ class DeclarativeMeta(type):
mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
else:
mapper_cls = mapper
+
cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
return type.__init__(cls, classname, bases, dict_)
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
index e7464b0bd..21adc85a8 100644
--- a/lib/sqlalchemy/ext/orderinglist.py
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -34,7 +34,7 @@ which have a user-defined, serialized order::
u = User()
u.topten.append(Blurb('Number one!'))
u.topten.append(Blurb('Number two!'))
-
+
# Like magic.
assert [blurb.position for blurb in u.topten] == [0, 1]
@@ -60,7 +60,7 @@ __all__ = [ 'ordering_list' ]
def ordering_list(attr, count_from=None, **kw):
"""Prepares an OrderingList factory for use in mapper definitions.
-
+
Returns an object suitable for use as an argument to a Mapper relation's
``collection_class`` option. Arguments are:
@@ -73,7 +73,7 @@ def ordering_list(attr, count_from=None, **kw):
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
list in SQL, storing the value in the 'pos' column. Ignored if
``ordering_func`` is supplied.
-
+
Passes along any keyword arguments to ``OrderingList`` constructor.
"""
@@ -108,7 +108,7 @@ def _unsugar_count_from(**kw):
Keyword argument filter, prepares a simple ``ordering_func`` from a
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
"""
-
+
count_from = kw.pop('count_from', None)
if kw.get('ordering_func', None) is None and count_from is not None:
if count_from == 0:
@@ -126,11 +126,11 @@ class OrderingList(list):
``ordering_list`` function is used to configure ``OrderingList``
collections in ``mapper`` relation definitions.
"""
-
+
def __init__(self, ordering_attr=None, ordering_func=None,
reorder_on_append=False):
"""A custom list that manages position information for its children.
-
+
``OrderingList`` is a ``collection_class`` list implementation that
syncs position in a Python list with a position attribute on the
mapped objects.
@@ -148,7 +148,7 @@ class OrderingList(list):
An ``ordering_func`` is called with two positional parameters: the
index of the element in the list, and the list itself.
-
+
If omitted, Python list indexes are used for the attribute values.
Two basic pre-built numbering functions are provided in this module:
``count_from_0`` and ``count_from_1``. For more exotic examples
@@ -194,7 +194,7 @@ class OrderingList(list):
def _reorder(self):
"""Sweep through the list and ensure that each object has accurate
ordering information set."""
-
+
for index, entity in enumerate(self):
self._order_entity(index, entity, True)
@@ -206,7 +206,7 @@ class OrderingList(list):
return
should_be = self.ordering_func(index, self)
- if have <> should_be:
+ if have != should_be:
self._set_order_value(entity, should_be)
def append(self, entity):
@@ -229,7 +229,7 @@ class OrderingList(list):
entity = super(OrderingList, self).pop(index)
self._reorder()
return entity
-
+
def __setitem__(self, index, entity):
if isinstance(index, slice):
for i in range(index.start or 0, index.stop or 0, index.step or 1):
@@ -237,7 +237,7 @@ class OrderingList(list):
else:
self._order_entity(index, entity, True)
super(OrderingList, self).__setitem__(index, entity)
-
+
def __delitem__(self, index):
super(OrderingList, self).__delitem__(index)
self._reorder()
diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py
deleted file mode 100644
index 446228254..000000000
--- a/lib/sqlalchemy/ext/selectresults.py
+++ /dev/null
@@ -1,28 +0,0 @@
-"""SelectResults has been rolled into Query. This class is now just a placeholder."""
-
-import sqlalchemy.sql as sql
-import sqlalchemy.orm as orm
-
-class SelectResultsExt(orm.MapperExtension):
- """a MapperExtension that provides SelectResults functionality for the
- results of query.select_by() and query.select()"""
-
- def select_by(self, query, *args, **params):
- q = query
- for a in args:
- q = q.filter(a)
- return q.filter_by(**params)
-
- def select(self, query, arg=None, **kwargs):
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- return orm.EXT_CONTINUE
- else:
- if arg is not None:
- query = query.filter(arg)
- return query._legacy_select_kwargs(**kwargs)
-
-def SelectResults(query, clause=None, ops={}):
- if clause is not None:
- query = query.filter(clause)
- query = query.options(orm.extension(SelectResultsExt()))
- return query._legacy_select_kwargs(**ops)
diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py
deleted file mode 100644
index 5ac8acb40..000000000
--- a/lib/sqlalchemy/ext/sessioncontext.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from sqlalchemy.orm.scoping import ScopedSession, _ScopedExt
-from sqlalchemy.util import warn_deprecated
-from sqlalchemy.orm import create_session
-
-__all__ = ['SessionContext', 'SessionContextExt']
-
-
-class SessionContext(ScopedSession):
- """Provides thread-local management of Sessions.
-
- Usage::
-
- context = SessionContext(sessionmaker(autoflush=True))
-
- """
-
- def __init__(self, session_factory=None, scopefunc=None):
- warn_deprecated("SessionContext is deprecated. Use scoped_session().")
- if session_factory is None:
- session_factory=create_session
- super(SessionContext, self).__init__(session_factory, scopefunc=scopefunc)
-
- def get_current(self):
- return self.registry()
-
- def set_current(self, session):
- self.registry.set(session)
-
- def del_current(self):
- self.registry.clear()
-
- current = property(get_current, set_current, del_current,
- """Property used to get/set/del the session in the current scope.""")
-
- def _get_mapper_extension(self):
- try:
- return self._extension
- except AttributeError:
- self._extension = ext = SessionContextExt(self)
- return ext
-
- mapper_extension = property(_get_mapper_extension,
- doc="""Get a mapper extension that implements `get_session` using this context. Deprecated.""")
-
-
-class SessionContextExt(_ScopedExt):
- def __init__(self, *args, **kwargs):
- warn_deprecated("SessionContextExt is deprecated. Use ScopedSession(enhance_classes=True)")
- super(SessionContextExt, self).__init__(*args, **kwargs)
-
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index bad9ba5a8..95971f787 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -210,7 +210,7 @@ Advanced Use
Accessing the Session
---------------------
-SqlSoup uses a SessionContext to provide thread-local sessions. You
+SqlSoup uses a ScopedSession to provide thread-local sessions. You
can get a reference to the current one like this::
>>> from sqlalchemy.ext.sqlsoup import objectstore
@@ -325,7 +325,7 @@ Boring tests here. Nothing of real expository value.
from sqlalchemy import *
from sqlalchemy import schema, sql
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.scoping import ScopedSession
from sqlalchemy.exceptions import *
from sqlalchemy.sql import expression
@@ -379,15 +379,24 @@ __all__ = ['PKNotFoundError', 'SqlSoup']
#
# thread local SessionContext
#
-class Objectstore(SessionContext):
+class Objectstore(ScopedSession):
def __getattr__(self, key):
- return getattr(self.current, key)
+ if key.startswith('__'): # dont trip the registry for module-level sweeps of things
+ # like '__bases__'. the session gets bound to the
+ # module which is interfered with by other unit tests.
+ # (removal of mapper.get_session() revealed the issue)
+ raise AttributeError()
+ return getattr(self.registry(), key)
+ def current(self):
+ return self.registry()
+ current = property(current)
def get_session(self):
- return self.current
+ return self.registry()
objectstore = Objectstore(create_session)
-class PKNotFoundError(SQLAlchemyError): pass
+class PKNotFoundError(SQLAlchemyError):
+ pass
def _ddl_error(cls):
msg = 'SQLSoup can only modify mapped Tables (found: %s)' \
@@ -439,7 +448,7 @@ def _is_outer_join(selectable):
def _selectable_name(selectable):
if isinstance(selectable, sql.Alias):
- return _selectable_name(selectable.selectable)
+ return _selectable_name(selectable.element)
elif isinstance(selectable, sql.Select):
return ''.join([_selectable_name(s) for s in selectable.froms])
elif isinstance(selectable, schema.Table):
@@ -457,7 +466,7 @@ def class_for_table(selectable, **mapper_kwargs):
klass = TableClassType(mapname, (object,), {})
else:
klass = SelectableClassType(mapname, (object,), {})
-
+
def __cmp__(self, o):
L = self.__class__.c.keys()
L.sort()
@@ -482,12 +491,17 @@ def class_for_table(selectable, **mapper_kwargs):
for m in ['__cmp__', '__repr__']:
setattr(klass, m, eval(m))
klass._table = selectable
+ klass.c = expression.ColumnCollection()
mappr = mapper(klass,
selectable,
- extension=objectstore.mapper_extension,
+ extension=objectstore.extension,
allow_null_pks=_is_outer_join(selectable),
**mapper_kwargs)
- klass._query = Query(mappr)
+
+ for k in mappr.iterate_properties:
+ klass.c[k.key] = k.columns[0]
+
+ klass._query = objectstore.query_property()
return klass
class SqlSoup: