summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/activemapper.py132
-rw-r--r--lib/sqlalchemy/ext/assignmapper.py34
-rw-r--r--lib/sqlalchemy/ext/proxy.py98
-rw-r--r--lib/sqlalchemy/ext/selectresults.py82
-rw-r--r--lib/sqlalchemy/ext/sessioncontext.py55
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py254
6 files changed, 343 insertions, 312 deletions
diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py
index f875b30b3..74f4df349 100644
--- a/lib/sqlalchemy/ext/activemapper.py
+++ b/lib/sqlalchemy/ext/activemapper.py
@@ -1,8 +1,30 @@
-from sqlalchemy import assign_mapper, relation, exceptions
+from sqlalchemy import create_session, relation, mapper, join, DynamicMetaData, class_mapper
+from sqlalchemy import and_, or_
from sqlalchemy import Table, Column, ForeignKey
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.ext.assignmapper import assign_mapper
+from sqlalchemy import backref as create_backref
import inspect
import sys
+import sets
+
+#
+# the "proxy" to the database engine... this can be swapped out at runtime
+#
+metadata = DynamicMetaData("activemapper")
+
+#
+# thread local SessionContext
+#
+class Objectstore(SessionContext):
+ def __getattr__(self, key):
+ return getattr(self.current, key)
+ def get_session(self):
+ return self.current
+
+objectstore = Objectstore(create_session)
+
#
# declarative column declaration - this is so that we can infer the colname
@@ -40,7 +62,7 @@ class one_to_many(relationship):
class one_to_one(relationship):
def __init__(self, classname, colname=None, backref=None, private=False, lazy=True):
- relationship.__init__(self, classname, colname, backref, private, lazy, uselist=False)
+ relationship.__init__(self, classname, colname, create_backref(backref, uselist=False), private, lazy, uselist=False)
class many_to_many(relationship):
def __init__(self, classname, secondary, backref=None, lazy=True):
@@ -56,43 +78,15 @@ class many_to_many(relationship):
#
__deferred_classes__ = []
-__processed_classes__ = []
-
-def check_relationships(klass):
- #Check the class for foreign_keys recursively. If some foreign table is not found, the processing of the table
- #must be defered.
- for keyname in klass.table._foreign_keys:
- xtable = keyname._colspec[:keyname._colspec.find('.')]
- tablefound = False
- for xclass in ActiveMapperMeta.classes:
- if ActiveMapperMeta.classes[xclass].table.from_name == xtable:
- tablefound = True
- break
- if tablefound==False:
- #The refered table has not yet been created.
- return False
-
- return True
-
-
-def process_relationships(klass):
+def process_relationships(klass, was_deferred=False):
defer = False
for propname, reldesc in klass.relations.items():
- # we require that every related table has been processed first
- if not reldesc.classname in __processed_classes__:
- if not klass._classname in __deferred_classes__: __deferred_classes__.append(klass._classname)
- defer = True
-
- # check every column item to see if it points to an existing table
- # if it does not, defer...
- if not defer:
- if not check_relationships(klass):
- if not klass._classname in __deferred_classes__: __deferred_classes__.append(klass._classname)
+ if not reldesc.classname in ActiveMapperMeta.classes:
+ if not was_deferred: __deferred_classes__.append(klass)
defer = True
if not defer:
relations = {}
-
for propname, reldesc in klass.relations.items():
relclass = ActiveMapperMeta.classes[reldesc.classname]
relations[propname] = relation(relclass.mapper,
@@ -101,40 +95,39 @@ def process_relationships(klass):
private=reldesc.private,
lazy=reldesc.lazy,
uselist=reldesc.uselist)
- if len(relations) > 0:
- assign_ok = True
- try:
- assign_mapper(klass, klass.table, properties=relations)
- except exceptions.ArgumentError:
- assign_ok = False
-
- if assign_ok:
- __processed_classes__.append(klass._classname)
- if klass._classname in __deferred_classes__: __deferred_classes__.remove(klass._classname)
- else:
- __processed_classes__.append(klass._classname)
-
+ class_mapper(klass).add_properties(relations)
+ #assign_mapper(objectstore, klass, klass.table, properties=relations,
+ # inherits=getattr(klass, "_base_mapper", None))
+ if was_deferred: __deferred_classes__.remove(klass)
+
+ if not was_deferred:
for deferred_class in __deferred_classes__:
- process_relationships(ActiveMapperMeta.classes[deferred_class])
+ process_relationships(deferred_class, was_deferred=True)
+
class ActiveMapperMeta(type):
classes = {}
-
+ metadatas = sets.Set()
def __init__(cls, clsname, bases, dict):
table_name = clsname.lower()
columns = []
relations = {}
-
+ _metadata = getattr( sys.modules[cls.__module__], "__metadata__", metadata )
+
if 'mapping' in dict:
members = inspect.getmembers(dict.get('mapping'))
for name, value in members:
if name == '__table__':
table_name = value
continue
-
+
+ if '__metadata__' == name:
+ _metadata= value
+ continue
+
if name.startswith('__'): continue
-
+
if isinstance(value, column):
if value.foreign_key:
col = Column(value.colname or name,
@@ -149,29 +142,29 @@ class ActiveMapperMeta(type):
*value.args, **value.kwargs)
columns.append(col)
continue
-
+
if isinstance(value, relationship):
relations[name] = value
-
- cls.table = Table(table_name, redefine=True, *columns)
-
+ assert _metadata is not None, "No MetaData specified"
+ ActiveMapperMeta.metadatas.add(_metadata)
+ cls.table = Table(table_name, _metadata, *columns)
# check for inheritence
- if hasattr(bases[0], "mapping"):
- cls._base_mapper = bases[0].mapper
- assign_mapper(cls, cls.table, inherits=cls._base_mapper)
- elif len(relations) == 0:
- assign_mapper(cls, cls.table)
+ if hasattr( bases[0], "mapping" ):
+ cls._base_mapper= bases[0].mapper
+ assign_mapper(objectstore, cls, cls.table, inherits=cls._base_mapper)
+ else:
+ assign_mapper(objectstore, cls, cls.table)
cls.relations = relations
- cls._classname = clsname
ActiveMapperMeta.classes[clsname] = cls
+
process_relationships(cls)
-
+
super(ActiveMapperMeta, cls).__init__(clsname, bases, dict)
class ActiveMapper(object):
__metaclass__ = ActiveMapperMeta
-
+
def set(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@@ -182,12 +175,9 @@ class ActiveMapper(object):
#
def create_tables():
- for klass in ActiveMapperMeta.classes.values():
- klass.table.create()
-
-#
-# a utility function to drop all tables for all ActiveMapper classes
-#
+ for metadata in ActiveMapperMeta.metadatas:
+ metadata.create_all()
def drop_tables():
- for klass in ActiveMapperMeta.classes.values():
- klass.table.drop() \ No newline at end of file
+ for metadata in ActiveMapperMeta.metadatas:
+ metadata.drop_all()
+
diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py
new file mode 100644
index 000000000..b8a676b75
--- /dev/null
+++ b/lib/sqlalchemy/ext/assignmapper.py
@@ -0,0 +1,34 @@
+from sqlalchemy import mapper, util
+import types
+
+def monkeypatch_query_method(ctx, class_, name):
+ def do(self, *args, **kwargs):
+ query = class_.mapper.query(session=ctx.current)
+ return getattr(query, name)(*args, **kwargs)
+ setattr(class_, name, classmethod(do))
+
+def monkeypatch_objectstore_method(ctx, class_, name):
+ def do(self, *args, **kwargs):
+ session = ctx.current
+ return getattr(session, name)(self, *args, **kwargs)
+ setattr(class_, name, do)
+
+def assign_mapper(ctx, class_, *args, **kwargs):
+ kwargs.setdefault("is_primary", True)
+ if not isinstance(getattr(class_, '__init__'), types.MethodType):
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+ class_.__init__ = __init__
+ extension = kwargs.pop('extension', None)
+ if extension is not None:
+ extension = util.to_list(extension)
+ extension.append(ctx.mapper_extension)
+ else:
+ extension = ctx.mapper_extension
+ m = mapper(class_, extension=extension, *args, **kwargs)
+ class_.mapper = m
+ for name in ['get', 'select', 'select_by', 'selectone', 'get_by', 'join_to', 'join_via']:
+ monkeypatch_query_method(ctx, class_, name)
+ for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'update', 'save_or_update']:
+ monkeypatch_objectstore_method(ctx, class_, name)
diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py
index a24f089e9..deced55b4 100644
--- a/lib/sqlalchemy/ext/proxy.py
+++ b/lib/sqlalchemy/ext/proxy.py
@@ -5,14 +5,11 @@ except ImportError:
from sqlalchemy import sql
from sqlalchemy.engine import create_engine
-from sqlalchemy.types import TypeEngine
-import sqlalchemy.schema as schema
-import thread, weakref
-class BaseProxyEngine(schema.SchemaEngine):
- '''
- Basis for all proxy engines
- '''
+__all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine']
+
+class BaseProxyEngine(sql.Engine):
+ """Basis for all proxy engines."""
def get_engine(self):
raise NotImplementedError
@@ -21,66 +18,50 @@ class BaseProxyEngine(schema.SchemaEngine):
raise NotImplementedError
engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e))
-
- def reflecttable(self, table):
- return self.get_engine().reflecttable(table)
+
def execute_compiled(self, *args, **kwargs):
- return self.get_engine().execute_compiled(*args, **kwargs)
- def compiler(self, *args, **kwargs):
- return self.get_engine().compiler(*args, **kwargs)
- def schemagenerator(self, *args, **kwargs):
- return self.get_engine().schemagenerator(*args, **kwargs)
- def schemadropper(self, *args, **kwargs):
- return self.get_engine().schemadropper(*args, **kwargs)
-
- def hash_key(self):
- return "%s(%s)" % (self.__class__.__name__, id(self))
+ """this method is required to be present as it overrides the execute_compiled present in sql.Engine"""
+ return self.get_engine().execute_compiled(*args, **kwargs)
+ def compiler(self, *args, **kwargs):
+ """this method is required to be present as it overrides the compiler method present in sql.Engine"""
+ return self.get_engine().compiler(*args, **kwargs)
- def oid_column_name(self):
- # oid_column should not be requested before the engine is connected.
- # it should ideally only be called at query compilation time.
- e= self.get_engine()
- if e is None:
- return None
- return e.oid_column_name()
-
def __getattr__(self, attr):
+ """provides proxying for methods that are not otherwise present on this BaseProxyEngine. Note
+ that methods which are present on the base class sql.Engine will *not* be proxied through this,
+ and must be explicit on this class."""
# call get_engine() to give subclasses a chance to change
# connection establishment behavior
- e= self.get_engine()
+ e = self.get_engine()
if e is not None:
return getattr(e, attr)
- raise AttributeError('No connection established in ProxyEngine: '
- ' no access to %s' % attr)
+ raise AttributeError("No connection established in ProxyEngine: "
+ " no access to %s" % attr)
+
class AutoConnectEngine(BaseProxyEngine):
- '''
- An SQLEngine proxy that automatically connects when necessary.
- '''
+ """An SQLEngine proxy that automatically connects when necessary."""
- def __init__(self, dburi, opts=None, **kwargs):
+ def __init__(self, dburi, **kwargs):
BaseProxyEngine.__init__(self)
- self.dburi= dburi
- self.opts= opts
- self.kwargs= kwargs
- self._engine= None
+ self.dburi = dburi
+ self.kwargs = kwargs
+ self._engine = None
def get_engine(self):
if self._engine is None:
if callable(self.dburi):
- dburi= self.dburi()
+ dburi = self.dburi()
else:
- dburi= self.dburi
- self._engine= create_engine( dburi, self.opts, **self.kwargs )
+ dburi = self.dburi
+ self._engine = create_engine(dburi, **self.kwargs)
return self._engine
-
class ProxyEngine(BaseProxyEngine):
- """
- SQLEngine proxy. Supports lazy and late initialization by
- delegating to a real engine (set with connect()), and using proxy
- classes for TypeEngine.
+ """Engine proxy for lazy and late initialization.
+
+ This engine will delegate access to a real engine set with connect().
"""
def __init__(self, **kwargs):
@@ -90,14 +71,15 @@ class ProxyEngine(BaseProxyEngine):
self.storage.connection = {}
self.storage.engine = None
self.kwargs = kwargs
-
- def connect(self, uri, opts=None, **kwargs):
- """Establish connection to a real engine.
- """
- kw = self.kwargs.copy()
- kw.update(kwargs)
- kwargs = kw
- key = "%s(%s,%s)" % (uri, repr(opts), repr(kwargs))
+
+ def connect(self, *args, **kwargs):
+ """Establish connection to a real engine."""
+
+ kwargs.update(self.kwargs)
+ if not kwargs:
+ key = repr(args)
+ else:
+ key = "%s, %s" % (repr(args), repr(sorted(kwargs.items())))
try:
map = self.storage.connection
except AttributeError:
@@ -107,15 +89,13 @@ class ProxyEngine(BaseProxyEngine):
try:
self.engine = map[key]
except KeyError:
- map[key] = create_engine(uri, opts, **kwargs)
+ map[key] = create_engine(*args, **kwargs)
self.storage.engine = map[key]
def get_engine(self):
if self.storage.engine is None:
- raise AttributeError('No connection established')
+ raise AttributeError("No connection established")
return self.storage.engine
def set_engine(self, engine):
self.storage.engine = engine
-
-
diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py
new file mode 100644
index 000000000..5ba9153dd
--- /dev/null
+++ b/lib/sqlalchemy/ext/selectresults.py
@@ -0,0 +1,82 @@
+import sqlalchemy.sql as sql
+
+import sqlalchemy.orm as orm
+
+
+class SelectResultsExt(orm.MapperExtension):
+ def select_by(self, query, *args, **params):
+ return SelectResults(query, query.join_by(*args, **params))
+ def select(self, query, arg=None, **kwargs):
+ if arg is not None and isinstance(arg, sql.Selectable):
+ return orm.EXT_PASS
+ else:
+ return SelectResults(query, arg, ops=kwargs)
+
+class SelectResults(object):
+ def __init__(self, query, clause=None, ops={}):
+ self._query = query
+ self._clause = clause
+ self._ops = {}
+ self._ops.update(ops)
+
+ def count(self):
+ return self._query.count(self._clause)
+
+ def min(self, col):
+ return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar()
+
+ def max(self, col):
+ return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar()
+
+ def sum(self, col):
+ return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar()
+
+ def avg(self, col):
+ return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar()
+
+ def clone(self):
+ return SelectResults(self._query, self._clause, self._ops.copy())
+
+ def filter(self, clause):
+ new = self.clone()
+ new._clause = sql.and_(self._clause, clause)
+ return new
+
+ def order_by(self, order_by):
+ new = self.clone()
+ new._ops['order_by'] = order_by
+ return new
+
+ def limit(self, limit):
+ return self[:limit]
+
+ def offset(self, offset):
+ return self[offset:]
+
+ def list(self):
+ return list(self)
+
+ def __getitem__(self, item):
+ if isinstance(item, slice):
+ start = item.start
+ stop = item.stop
+ if (isinstance(start, int) and start < 0) or \
+ (isinstance(stop, int) and stop < 0):
+ return list(self)[item]
+ else:
+ res = self.clone()
+ if start is not None and stop is not None:
+ res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start))
+ elif start is None and stop is not None:
+ res._ops.update(dict(limit=stop))
+ elif start is not None and stop is None:
+ res._ops.update(dict(offset=self._ops.get('offset', 0)+start))
+ if item.step is not None:
+ return list(res)[None:None:item.step]
+ else:
+ return res
+ else:
+ return list(self[item:item+1])[0]
+
+ def __iter__(self):
+ return iter(self._query.select_whereclause(self._clause, **self._ops))
diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py
new file mode 100644
index 000000000..f431f87c7
--- /dev/null
+++ b/lib/sqlalchemy/ext/sessioncontext.py
@@ -0,0 +1,55 @@
+from sqlalchemy.util import ScopedRegistry
+from sqlalchemy.orm.mapper import MapperExtension
+
+__all__ = ['SessionContext', 'SessionContextExt']
+
+class SessionContext(object):
+ """A simple wrapper for ScopedRegistry that provides a "current" property
+ which can be used to get, set, or remove the session in the current scope.
+
+ By default this object provides thread-local scoping, which is the default
+ scope provided by sqlalchemy.util.ScopedRegistry.
+
+ Usage:
+ engine = create_engine(...)
+ def session_factory():
+ return Session(bind_to=engine)
+ context = SessionContext(session_factory)
+
+ s = context.current # get thread-local session
+ context.current = Session(bind_to=other_engine) # set current session
+ del context.current # discard the thread-local session (a new one will
+ # be created on the next call to context.current)
+ """
+ def __init__(self, session_factory, scopefunc=None):
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+ super(SessionContext, self).__init__()
+
+ def get_current(self):
+ return self.registry()
+ def set_current(self, session):
+ self.registry.set(session)
+ def del_current(self):
+ self.registry.clear()
+ current = property(get_current, set_current, del_current,
+ """Property used to get/set/del the session in the current scope""")
+
+ def _get_mapper_extension(self):
+ try:
+ return self._extension
+ except AttributeError:
+ self._extension = ext = SessionContextExt(self)
+ return ext
+ mapper_extension = property(_get_mapper_extension,
+ doc="""get a mapper extension that implements get_session using this context""")
+
+
+class SessionContextExt(MapperExtension):
+ """a mapper extionsion that provides sessions to a mapper using SessionContext"""
+
+ def __init__(self, context):
+ MapperExtension.__init__(self)
+ self.context = context
+
+ def get_session(self):
+ return self.context.current
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index b1fb0b889..043abc38b 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -1,182 +1,72 @@
-from sqlalchemy import *
-
-"""
-SqlSoup provides a convenient way to access database tables without having
-to declare table or mapper classes ahead of time.
-
-Suppose we have a database with users, books, and loans tables
-(corresponding to the PyWebOff dataset, if you're curious).
-For testing purposes, we can create this db as follows:
-
->>> from sqlalchemy import create_engine
->>> e = create_engine('sqlite://filename=:memory:')
->>> for sql in _testsql: e.execute(sql)
-...
-
-Creating a SqlSoup gateway is just like creating an SqlAlchemy engine:
->>> from sqlalchemy.ext.sqlsoup import SqlSoup
->>> soup = SqlSoup('sqlite://filename=:memory:')
-
-or, you can re-use an existing engine:
->>> soup = SqlSoup(e)
-
-Loading objects is as easy as this:
->>> users = soup.users.select()
->>> users.sort()
->>> users
-[Class_Users(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1), Class_Users(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)]
-
-Of course, letting the database do the sort is better (".c" is short for ".columns"):
->>> soup.users.select(order_by=[soup.users.c.name])
-[Class_Users(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1),
- Class_Users(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)]
-
-Field access is intuitive:
->>> users[0].email
-u'basepair@example.edu'
-
-Of course, you don't want to load all users very often. The common case is to
-select by a key or other field:
->>> soup.users.selectone_by(name='Bhargan Basepair')
-Class_Users(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1)
-
-All the SqlAlchemy mapper select variants (select, select_by, selectone, selectone_by, selectfirst, selectfirst_by)
-are available. See the SqlAlchemy documentation for details:
-http://www.sqlalchemy.org/docs/sqlconstruction.myt
-
-Modifying objects is intuitive:
->>> user = _
->>> user.email = 'basepair+nospam@example.edu'
->>> soup.commit()
-
-(SqlSoup leverages the sophisticated SqlAlchemy unit-of-work code, so
-multiple updates to a single object will be turned into a single UPDATE
-statement when you commit.)
-
-Finally, insert and delete. Let's insert a new loan, then delete it:
->>> soup.loans.insert(book_id=soup.books.selectfirst().id, user_name=user.name)
-Class_Loans(book_id=1,user_name='Bhargan Basepair',loan_date=None)
->>> soup.commit()
-
->>> loan = soup.loans.selectone_by(book_id=1, user_name='Bhargan Basepair')
->>> soup.delete(loan)
->>> soup.commit()
-"""
-
-_testsql = """
-CREATE TABLE books (
- id integer PRIMARY KEY, -- auto-SERIAL in sqlite
- title text NOT NULL,
- published_year char(4) NOT NULL,
- authors text NOT NULL
-);
-
-CREATE TABLE users (
- name varchar(32) PRIMARY KEY,
- email varchar(128) NOT NULL,
- password varchar(128) NOT NULL,
- classname text,
- admin int NOT NULL -- 0 = false
-);
-
-CREATE TABLE loans (
- book_id int PRIMARY KEY REFERENCES books(id),
- user_name varchar(32) references users(name)
- ON DELETE SET NULL ON UPDATE CASCADE,
- loan_date date NOT NULL DEFAULT current_timestamp
-);
-
-insert into users(name, email, password, admin)
-values('Bhargan Basepair', 'basepair@example.edu', 'basepair', 1);
-insert into users(name, email, password, admin)
-values('Joe Student', 'student@example.edu', 'student', 0);
-
-insert into books(title, published_year, authors)
-values('Mustards I Have Known', '1989', 'Jones');
-insert into books(title, published_year, authors)
-values('Regional Variation in Moss', '1971', 'Flim and Flam');
-
-insert into loans(book_id, user_name)
-values (
- (select min(id) from books),
- (select name from users where name like 'Joe%'))
-;
-""".split(';')
-
-__all__ = ['NoSuchTableError', 'SqlSoup']
-
-class NoSuchTableError(SQLAlchemyError): pass
-
-# metaclass is necessary to expose class methods with getattr, e.g.
-# we want to pass db.users.select through to users._mapper.select
-class TableClassType(type):
- def insert(cls, **kwargs):
- o = cls()
- o.__dict__.update(kwargs)
- return o
- def __getattr__(cls, attr):
- if attr == '_mapper':
- # called during mapper init
- raise AttributeError()
- return getattr(cls._mapper, attr)
-
-def class_for_table(table):
- klass = TableClassType('Class_' + table.name.capitalize(), (object,), {})
- def __repr__(self):
- import locale
- encoding = locale.getdefaultlocale()[1]
- L = []
- for k in self.__class__.c.keys():
- value = getattr(self, k, '')
- if isinstance(value, unicode):
- value = value.encode(encoding)
- L.append("%s=%r" % (k, value))
- return '%s(%s)' % (self.__class__.__name__, ','.join(L))
- klass.__repr__ = __repr__
- klass._mapper = mapper(klass, table)
- return klass
-
-class SqlSoup:
- def __init__(self, *args, **kwargs):
- """
- args may either be an SQLEngine or a set of arguments suitable
- for passing to create_engine
- """
- from sqlalchemy.engine import SQLEngine
- # meh, sometimes having method overloading instead of kwargs would be easier
- if isinstance(args[0], SQLEngine):
- args = list(args)
- engine = args.pop(0)
- if args or kwargs:
- raise ArgumentError('Extra arguments not allowed when engine is given')
- else:
- engine = create_engine(*args, **kwargs)
- self._engine = engine
- self._cache = {}
- def delete(self, *args, **kwargs):
- objectstore.delete(*args, **kwargs)
- def commit(self):
- objectstore.get_session().commit()
- def rollback(self):
- objectstore.clear()
- def _reset(self):
- # for debugging
- self._cache = {}
- self.rollback()
- def __getattr__(self, attr):
- try:
- t = self._cache[attr]
- except KeyError:
- table = Table(attr, self._engine, autoload=True)
- if table.columns:
- t = class_for_table(table)
- else:
- t = None
- self._cache[attr] = t
- if not t:
- raise NoSuchTableError('%s does not exist' % attr)
- return t
-
-if __name__ == '__main__':
- import doctest
- doctest.testmod()
+from sqlalchemy import *
+
+class NoSuchTableError(SQLAlchemyError): pass
+
+# metaclass is necessary to expose class methods with getattr, e.g.
+# we want to pass db.users.select through to users._mapper.select
+class TableClassType(type):
+ def insert(cls, **kwargs):
+ o = cls()
+ o.__dict__.update(kwargs)
+ return o
+ def __getattr__(cls, attr):
+ if attr == '_mapper':
+ # called during mapper init
+ raise AttributeError()
+ return getattr(cls._mapper, attr)
+
+def class_for_table(table):
+ klass = TableClassType('Class_' + table.name.capitalize(), (object,), {})
+ def __repr__(self):
+ import locale
+ encoding = locale.getdefaultlocale()[1]
+ L = []
+ for k in self.__class__.c.keys():
+ value = getattr(self, k, '')
+ if isinstance(value, unicode):
+ value = value.encode(encoding)
+ L.append("%s=%r" % (k, value))
+ return '%s(%s)' % (self.__class__.__name__, ','.join(L))
+ klass.__repr__ = __repr__
+ klass._mapper = mapper(klass, table)
+ return klass
+
+class SqlSoup:
+ def __init__(self, *args, **kwargs):
+ """
+ args may either be an SQLEngine or a set of arguments suitable
+ for passing to create_engine
+ """
+ from sqlalchemy.sql import Engine
+ # meh, sometimes having method overloading instead of kwargs would be easier
+ if isinstance(args[0], Engine):
+ engine = args.pop(0)
+ if args or kwargs:
+ raise ArgumentError('Extra arguments not allowed when engine is given')
+ else:
+ engine = create_engine(*args, **kwargs)
+ self._engine = engine
+ self._cache = {}
+ def delete(self, *args, **kwargs):
+ objectstore.delete(*args, **kwargs)
+ def commit(self):
+ objectstore.get_session().commit()
+ def rollback(self):
+ objectstore.clear()
+ def _reset(self):
+ # for debugging
+ self._cache = {}
+ self.rollback()
+ def __getattr__(self, attr):
+ try:
+ t = self._cache[attr]
+ except KeyError:
+ table = Table(attr, self._engine, autoload=True)
+ if table.columns:
+ t = class_for_table(table)
+ else:
+ t = None
+ self._cache[attr] = t
+ if not t:
+ raise NoSuchTableError('%s does not exist' % attr)
+ return t