diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
| commit | bb79e2e871d0a4585164c1a6ed626d96d0231975 (patch) | |
| tree | 6d457ba6c36c408b45db24ec3c29e147fe7504ff /lib | |
| parent | 4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff) | |
| download | sqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz | |
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'lib')
46 files changed, 4685 insertions, 3451 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 94a0fcb6d..acbacafa4 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -4,20 +4,14 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from engine import * from types import * from sql import * from schema import * -from exceptions import * -import sqlalchemy.sql -import sqlalchemy.mapping as mapping -from sqlalchemy.mapping import * -import sqlalchemy.schema -import sqlalchemy.ext.proxy -sqlalchemy.schema.default_engine = sqlalchemy.ext.proxy.ProxyEngine() +from sqlalchemy.orm import * -from sqlalchemy.mods import install_mods +from sqlalchemy.engine import create_engine +from sqlalchemy.schema import default_metadata def global_connect(*args, **kwargs): - sqlalchemy.schema.default_engine.connect(*args, **kwargs) + default_metadata.connect(*args, **kwargs)
\ No newline at end of file diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index df3f8fa59..6956c5379 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -4,18 +4,14 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""defines ANSI SQL operations.""" +"""defines ANSI SQL operations. Contains default implementations for the abstract objects +in the sql module.""" -import sqlalchemy.schema as schema - -from sqlalchemy.schema import * -import sqlalchemy.sql as sql -import sqlalchemy.engine -from sqlalchemy.sql import * -from sqlalchemy.util import * +from sqlalchemy import schema, sql, engine, util +import sqlalchemy.engine.default as default import string, re -ANSI_FUNCS = HashSet([ +ANSI_FUNCS = util.HashSet([ 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_DATE', @@ -27,32 +23,32 @@ ANSI_FUNCS = HashSet([ ]) -def engine(**params): - return ANSISQLEngine(**params) - -class ANSISQLEngine(sqlalchemy.engine.SQLEngine): - - def schemagenerator(self, **params): - return ANSISchemaGenerator(self, **params) - - def schemadropper(self, **params): - return ANSISchemaDropper(self, **params) - - def compiler(self, statement, parameters, **kwargs): - return ANSICompiler(statement, parameters, engine=self, **kwargs) +def create_engine(): + return engine.ComposedSQLEngine(None, ANSIDialect()) +class ANSIDialect(default.DefaultDialect): def connect_args(self): return ([],{}) def dbapi(self): return None + def schemagenerator(self, *args, **params): + return ANSISchemaGenerator(*args, **params) + + def schemadropper(self, *args, **params): + return ANSISchemaDropper(*args, **params) + + def compiler(self, statement, parameters, **kwargs): + return ANSICompiler(self, statement, parameters, **kwargs) + + class ANSICompiler(sql.Compiled): """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings.""" - def __init__(self, statement, parameters=None, typemap=None, engine=None, positional=None, paramstyle=None, **kwargs): + def __init__(self, dialect, statement, parameters=None, **kwargs): """constructs a new ANSICompiler object. - engine - SQLEngine to compile against + dialect - Dialect to be used statement - ClauseElement to be compiled @@ -61,22 +57,18 @@ class ANSICompiler(sql.Compiled): key/value pairs when the Compiled is executed, and also may affect the actual compilation, as in the case of an INSERT where the actual columns inserted will correspond to the keys present in the parameters.""" - sql.Compiled.__init__(self, statement, parameters, engine=engine) + sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs) self.binds = {} self.froms = {} self.wheres = {} self.strings = {} self.select_stack = [] - self.typemap = typemap or {} + self.typemap = {} self.isinsert = False self.isupdate = False self.bindtemplate = ":%s" - if engine is not None: - self.paramstyle = engine.paramstyle - self.positional = engine.positional - else: - self.positional = False - self.paramstyle = 'named' + self.paramstyle = dialect.paramstyle + self.positional = dialect.positional def after_compile(self): # this re will search for params like :param @@ -130,7 +122,7 @@ class ANSICompiler(sql.Compiled): bindparams = {} bindparams.update(params) - d = sql.ClauseParameters(self.engine) + d = sql.ClauseParameters(self.dialect) if self.positional: for k in self.positiontup: b = self.binds[k] @@ -177,10 +169,19 @@ class ANSICompiler(sql.Compiled): # if we are within a visit to a Select, set up the "typemap" # for this column which is used to translate result set values self.typemap.setdefault(column.key.lower(), column.type) - if column.table is None or column.table.name is None: + if column.table is None or not column.table.named_with_column(): self.strings[column] = column.name else: - self.strings[column] = "%s.%s" % (column.table.name, column.name) + if column.table.oid_column is column: + n = self.dialect.oid_column_name() + if n is not None: + self.strings[column] = "%s.%s" % (column.table.name, n) + elif len(column.table.primary_key) != 0: + self.strings[column] = "%s.%s" % (column.table.name, column.table.primary_key[0].name) + else: + self.strings[column] = None + else: + self.strings[column] = "%s.%s" % (column.table.name, column.name) def visit_fromclause(self, fromclause): @@ -190,7 +191,7 @@ class ANSICompiler(sql.Compiled): self.strings[index] = index.name def visit_typeclause(self, typeclause): - self.strings[typeclause] = typeclause.type.engine_impl(self.engine).get_col_spec() + self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec() def visit_textclause(self, textclause): if textclause.parens and len(textclause.text): @@ -218,9 +219,9 @@ class ANSICompiler(sql.Compiled): def visit_clauselist(self, list): if list.parens: - self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ', ') + ")" + self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + ")" else: - self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') + self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 @@ -294,7 +295,7 @@ class ANSICompiler(sql.Compiled): # the actual list of columns to print in the SELECT column list. # its an ordered dictionary to insure that the actual labeled column name # is unique. - inner_columns = OrderedDict() + inner_columns = util.OrderedDict() self.select_stack.append(select) for c in select._raw_columns: @@ -314,7 +315,7 @@ class ANSICompiler(sql.Compiled): # SQLite doesnt like selecting from a subquery where the column # names look like table.colname, so add a label synonomous with # the column name - l = co.label(co.text) + l = co.label(co.name) l.accept_visitor(self) inner_columns[self.get_str(l.obj)] = l else: @@ -385,7 +386,7 @@ class ANSICompiler(sql.Compiled): order_by = self.get_str(select.order_by_clause) if order_by: text += " ORDER BY " + order_by - + text += self.visit_select_postclauses(select) if select.for_update: @@ -545,7 +546,7 @@ class ANSICompiler(sql.Compiled): # case one: no parameters in the statement, no parameters in the # compiled params - just return binds for all the table columns if self.parameters is None and stmt.parameters is None: - return [(c, bindparam(c.name, type=c.type)) for c in stmt.table.columns] + return [(c, sql.bindparam(c.name, type=c.type)) for c in stmt.table.columns] # if we have statement parameters - set defaults in the # compiled params @@ -578,7 +579,7 @@ class ANSICompiler(sql.Compiled): if d.has_key(c): value = d[c] if sql._is_literal(value): - value = bindparam(c.name, value, type=c.type) + value = sql.bindparam(c.name, value, type=c.type) values.append((c, value)) return values @@ -594,7 +595,7 @@ class ANSICompiler(sql.Compiled): return self.get_str(self.statement) -class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): +class ANSISchemaGenerator(engine.SchemaIterator): def get_column_specification(self, column, override_pk=False, first_pk=False): raise NotImplementedError() @@ -631,10 +632,15 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): if isinstance(column.default.arg, str): return repr(column.default.arg) else: - return str(column.default.arg.compile(self.engine)) + return str(self._compile(column.default.arg, None)) else: return None + def _compile(self, tocompile, parameters): + compiler = self.engine.dialect.compiler(tocompile, parameters) + compiler.compile() + return compiler + def visit_column(self, column): pass @@ -648,7 +654,7 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): self.execute() -class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator): +class ANSISchemaDropper(engine.SchemaIterator): def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() @@ -660,5 +666,5 @@ class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator): self.execute() -class ANSIDefaultRunner(sqlalchemy.engine.DefaultRunner): +class ANSIDefaultRunner(engine.DefaultRunner): pass diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index 627cac4b6..c285ea50c 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -37,11 +37,12 @@ class SmartProperty(object): create_prop method on AttributeManger, which can be overridden to provide subclasses of SmartProperty. """ - def __init__(self, manager, key, uselist, callable_, **kwargs): + def __init__(self, manager, key, uselist, callable_, typecallable, **kwargs): self.manager = manager self.key = key self.uselist = uselist self.callable_ = callable_ + self.typecallable= typecallable self.kwargs = kwargs def init(self, obj, attrhist=None): """creates an appropriate ManagedAttribute for the given object and establishes @@ -50,7 +51,7 @@ class SmartProperty(object): func = self.callable_(obj) else: func = None - return self.manager.create_managed_attribute(obj, self.key, self.uselist, callable_=func, attrdict=attrhist, **self.kwargs) + return self.manager.create_managed_attribute(obj, self.key, self.uselist, callable_=func, attrdict=attrhist, typecallable=self.typecallable, **self.kwargs) def __set__(self, obj, value): self.manager.set_attribute(obj, self.key, value) def __delete__(self, obj): @@ -86,11 +87,19 @@ class ManagedAttribute(object): self.key = d['key'] self.__obj = weakref.ref(d['obj']) obj = property(lambda s:s.__obj()) + def value_changed(self, *args, **kwargs): + self.obj._managed_value_changed = True + self.do_value_changed(*args, **kwargs) def history(self, **kwargs): return self def plain_init(self, *args, **kwargs): pass - + def hasparent(self, item): + return item.__class__._attribute_manager.attribute_history(item).get('_hasparent_' + self.key) + def sethasparent(self, item, value): + if item is not None: + item.__class__._attribute_manager.attribute_history(item)['_hasparent_' + self.key] = value + class ScalarAttribute(ManagedAttribute): """Used by AttributeManager to track the history of a scalar attribute on an object instance. This is the "scalar history container" object. @@ -98,10 +107,11 @@ class ScalarAttribute(ManagedAttribute): so that the two objects can be called upon largely interchangeably.""" # make our own NONE to distinguish from "None" NONE = object() - def __init__(self, obj, key, extension=None, **kwargs): + def __init__(self, obj, key, extension=None, trackparent=False, **kwargs): ManagedAttribute.__init__(self, obj, key) self.orig = ScalarAttribute.NONE self.extension = extension + self.trackparent = trackparent def clear(self): del self.obj.__dict__[self.key] def history_contains(self, obj): @@ -121,15 +131,24 @@ class ScalarAttribute(ManagedAttribute): if self.orig is ScalarAttribute.NONE: self.orig = orig self.obj.__dict__[self.key] = value + if self.trackparent: + if value is not None: + self.sethasparent(value, True) + if orig is not None: + self.sethasparent(orig, False) if self.extension is not None: self.extension.set(self.obj, value, orig) + self.value_changed(orig, value) def delattr(self, **kwargs): orig = self.obj.__dict__.get(self.key, None) if self.orig is ScalarAttribute.NONE: self.orig = orig self.obj.__dict__[self.key] = None + if self.trackparent: + self.sethasparent(orig, False) if self.extension is not None: self.extension.set(self.obj, None, orig) + self.value_changed(orig, None) def append(self, obj): self.setattr(obj) def remove(self, obj): @@ -140,9 +159,11 @@ class ScalarAttribute(ManagedAttribute): self.orig = ScalarAttribute.NONE def commit(self): self.orig = ScalarAttribute.NONE + def do_value_changed(self, oldvalue, newvalue): + pass def added_items(self): if self.orig is not ScalarAttribute.NONE: - return [self.obj.__dict__[self.key]] + return [self.obj.__dict__.get(self.key)] else: return [] def deleted_items(self): @@ -152,7 +173,7 @@ class ScalarAttribute(ManagedAttribute): return [] def unchanged_items(self): if self.orig is ScalarAttribute.NONE: - return [self.obj.__dict__[self.key]] + return [self.obj.__dict__.get(self.key)] else: return [] @@ -161,9 +182,10 @@ class ListAttribute(util.HistoryArraySet, ManagedAttribute): This is the "list history container" object. Subclasses util.HistoryArraySet to provide "onchange" event handling as well as a plugin point for BackrefExtension objects.""" - def __init__(self, obj, key, data=None, extension=None, **kwargs): + def __init__(self, obj, key, data=None, extension=None, trackparent=False, typecallable=None, **kwargs): ManagedAttribute.__init__(self, obj, key) self.extension = extension + self.trackparent = trackparent # if we are given a list, try to behave nicely with an existing # list that might be set on the object already try: @@ -176,36 +198,32 @@ class ListAttribute(util.HistoryArraySet, ManagedAttribute): except KeyError: if data is not None: list_ = data + elif typecallable is not None: + list_ = typecallable() else: list_ = [] - obj.__dict__[key] = [] - + obj.__dict__[key] = list_ util.HistoryArraySet.__init__(self, list_, readonly=kwargs.get('readonly', False)) - def list_value_changed(self, obj, key, item, listval, isdelete): + def do_value_changed(self, obj, key, item, listval, isdelete): pass def setattr(self, value, **kwargs): self.obj.__dict__[self.key] = value self.set_data(value) def delattr(self, value, **kwargs): pass - def _setrecord(self, item): - res = util.HistoryArraySet._setrecord(self, item) - if res: - self.list_value_changed(self.obj, self.key, item, self, False) - if self.extension is not None: - self.extension.append(self.obj, item) - return res - def _delrecord(self, item): - res = util.HistoryArraySet._delrecord(self, item) - if res: - self.list_value_changed(self.obj, self.key, item, self, True) - if self.extension is not None: - self.extension.delete(self.obj, item) - return res + def do_value_appended(self, item): + if self.trackparent: + self.sethasparent(item, True) + self.value_changed(self.obj, self.key, item, self, False) + if self.extension is not None: + self.extension.append(self.obj, item) + def do_value_deleted(self, item): + if self.trackparent: + self.sethasparent(item, False) + self.value_changed(self.obj, self.key, item, self, True) + if self.extension is not None: + self.extension.delete(self.obj, item) -# deprecated -class ListElement(ListAttribute):pass - class TriggeredAttribute(ManagedAttribute): """Used by AttributeManager to allow the attaching of a callable item, representing the future value of a particular attribute on a particular object instance, as the current attribute on an object. @@ -225,7 +243,7 @@ class TriggeredAttribute(ManagedAttribute): def plain_init(self, attrhist): if not self.uselist: - p = ScalarAttribute(self.obj, self.key, **self.kwargs) + p = self.manager.create_scalar(self.obj, self.key, **self.kwargs) self.obj.__dict__[self.key] = None else: p = self.manager.create_list(self.obj, self.key, None, **self.kwargs) @@ -251,7 +269,7 @@ class TriggeredAttribute(ManagedAttribute): raise AssertionError("AttributeError caught in callable prop:" + str(e.args)) self.obj.__dict__[self.key] = value - p = ScalarAttribute(self.obj, self.key, **self.kwargs) + p = self.manager.create_scalar(self.obj, self.key, **self.kwargs) else: if not self.obj.__dict__.has_key(self.key) or len(self.obj.__dict__[self.key]) == 0: if passive: @@ -315,20 +333,21 @@ class AttributeManager(object): def __init__(self): pass - def value_changed(self, obj, key, value): + def do_value_changed(self, obj, key, value): """subclasses override this method to provide functionality that is triggered upon an attribute change of value.""" pass - def create_prop(self, class_, key, uselist, callable_, **kwargs): + def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs): """creates a scalar property object, defaulting to SmartProperty, which will communicate change events back to this AttributeManager.""" - return SmartProperty(self, key, uselist, callable_, **kwargs) - - def create_list(self, obj, key, list_, **kwargs): + return SmartProperty(self, key, uselist, callable_, typecallable, **kwargs) + def create_scalar(self, obj, key, **kwargs): + return ScalarAttribute(obj, key, **kwargs) + def create_list(self, obj, key, list_, typecallable=None, **kwargs): """creates a history-aware list property, defaulting to a ListAttribute which is a subclass of HistoryArrayList.""" - return ListAttribute(obj, key, list_, **kwargs) + return ListAttribute(obj, key, list_, typecallable=typecallable, **kwargs) def create_callable(self, obj, key, func, uselist, **kwargs): """creates a callable container that will invoke a function the first time an object property is accessed. The return value of the function @@ -352,12 +371,10 @@ class AttributeManager(object): def set_attribute(self, obj, key, value, **kwargs): """sets the value of an object's attribute.""" self.get_unexec_history(obj, key).setattr(value, **kwargs) - self.value_changed(obj, key, value) def delete_attribute(self, obj, key, **kwargs): """deletes the value from an object's attribute.""" self.get_unexec_history(obj, key).delattr(**kwargs) - self.value_changed(obj, key, None) def rollback(self, *obj): """rolls back all attribute changes on the given list of objects, @@ -366,9 +383,11 @@ class AttributeManager(object): try: attributes = self.attribute_history(o) for hist in attributes.values(): - hist.rollback() + if isinstance(hist, ManagedAttribute): + hist.rollback() except KeyError: pass + o._managed_value_changed = False def commit(self, *obj): """commits all attribute changes on the given list of objects, @@ -377,10 +396,15 @@ class AttributeManager(object): try: attributes = self.attribute_history(o) for hist in attributes.values(): - hist.commit() + if isinstance(hist, ManagedAttribute): + hist.commit() except KeyError: pass - + o._managed_value_changed = False + + def is_modified(self, object): + return getattr(object, '_managed_value_changed', False) + def remove(self, obj): """called when an object is totally being removed from memory""" # currently a no-op since the state of the object is attached to the object itself @@ -471,15 +495,15 @@ class AttributeManager(object): def is_class_managed(self, class_, key): return hasattr(class_, key) and isinstance(getattr(class_, key), SmartProperty) - def create_managed_attribute(self, obj, key, uselist, callable_=None, attrdict=None, **kwargs): + def create_managed_attribute(self, obj, key, uselist, callable_=None, attrdict=None, typecallable=None, **kwargs): """creates a new ManagedAttribute corresponding to the given attribute key on the given object instance, and installs it in the attribute dictionary attached to the object.""" if callable_ is not None: - prop = self.create_callable(obj, key, callable_, uselist=uselist, **kwargs) + prop = self.create_callable(obj, key, callable_, uselist=uselist, typecallable=typecallable, **kwargs) elif not uselist: - prop = ScalarAttribute(obj, key, **kwargs) + prop = self.create_scalar(obj, key, **kwargs) else: - prop = self.create_list(obj, key, None, **kwargs) + prop = self.create_list(obj, key, None, typecallable=typecallable, **kwargs) if attrdict is None: attrdict = self.attribute_history(obj) attrdict[key] = prop @@ -500,4 +524,9 @@ class AttributeManager(object): will be passed along to newly created ManagedAttribute.""" if not hasattr(class_, '_attribute_manager'): class_._attribute_manager = self - setattr(class_, key, self.create_prop(class_, key, uselist, callable_, **kwargs)) + typecallable = getattr(class_, key, None) + # TODO: look at existing properties on the class, and adapt them to the SmartProperty + if isinstance(typecallable, SmartProperty): + typecallable = None + setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs)) + diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 99ef9eb9f..7dc48a54a 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, datetime +import sys, StringIO, string import sqlalchemy.sql as sql import sqlalchemy.schema as schema @@ -30,16 +30,6 @@ class FBSmallInteger(sqltypes.Smallinteger): class FBDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATE" - def convert_bind_param(self, value, engine): - if value is not None: - if isinstance(value, datetime.datetime): - seconds = float(str(value.second) + "." - + str(value.microsecond)) - return kinterbasdb.date_conv_out((value.year, value.month, value.day, - value.hour, value.minute, seconds)) - return kinterbasdb.timestamp_conv_in(value) - else: - return None class FBText(sqltypes.TEXT): def get_col_spec(self): return "BLOB SUB_TYPE 2" @@ -84,12 +74,13 @@ def descriptor(): ]} class FBSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, module=None, use_oids=False, **params): + def __init__(self, opts, use_ansi = True, module = None, **params): + self._use_ansi = use_ansi + self.opts = opts or {} if module is None: self.module = kinterbasdb else: self.module = module - self.opts = self._translate_connect_args(('host', 'database', 'user', 'password'), opts) ansisql.ANSISQLEngine.__init__(self, **params) def do_commit(self, connection): @@ -111,7 +102,7 @@ class FBSQLEngine(ansisql.ANSISQLEngine): return self.context.last_inserted_ids def compiler(self, statement, bindparams, **kwargs): - return FBCompiler(statement, bindparams, engine=self, **kwargs) + return FBCompiler(statement, bindparams, engine=self, use_ansi=self._use_ansi, **kwargs) def schemagenerator(self, **params): return FBSchemaGenerator(self, **params) @@ -197,6 +188,21 @@ class FBSQLEngine(ansisql.ANSISQLEngine): class FBCompiler(ansisql.ANSICompiler): """firebird compiler modifies the lexical structure of Select statements to work under non-ANSI configured Firebird databases, if the use_ansi flag is False.""" + + def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs): + self._outertable = None + self._use_ansi = use_ansi + ansisql.ANSICompiler.__init__(self, engine, statement, parameters, **kwargs) + + def visit_column(self, column): + if self._use_ansi: + return ansisql.ANSICompiler.visit_column(self, column) + + if column.table is self._outertable: + self.strings[column] = "%s.%s(+)" % (column.table.name, column.name) + else: + self.strings[column] = "%s.%s" % (column.table.name, column.name) + def visit_function(self, func): if len(func.clauses): super(FBCompiler, self).visit_function(func) @@ -217,11 +223,10 @@ class FBCompiler(ansisql.ANSICompiler): """ called when building a SELECT statment, position is just before column list Firebird puts the limit and offset right after the select...thanks for adding the visit_select_precolumns!!!""" - result = '' if select.offset: - result +=" FIRST %s " % select.offset + result +=" FIRST " + select.offset if select.limit: - result += " SKIP %s " % select.limit + result += " SKIP " + select.limit if select.distinct: result += " DISTINCT " return result @@ -229,8 +234,6 @@ class FBCompiler(ansisql.ANSICompiler): def limit_clause(self, select): """Already taken care of in the visit_select_precolumns method.""" return "" - def default_from(self): - return ' from RDB$DATABASE ' class FBSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 468e9a548..55c522558 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -7,22 +7,22 @@ from sqlalchemy.exceptions import * from sqlalchemy import * from sqlalchemy.ansisql import * -generic_engine = ansisql.engine() +ischema = MetaData() -gen_schemata = schema.Table("schemata", generic_engine, +schemata = schema.Table("schemata", ischema, Column("catalog_name", String), Column("schema_name", String), Column("schema_owner", String), schema="information_schema") -gen_tables = schema.Table("tables", generic_engine, +tables = schema.Table("tables", ischema, Column("table_catalog", String), Column("table_schema", String), Column("table_name", String), Column("table_type", String), schema="information_schema") -gen_columns = schema.Table("columns", generic_engine, +columns = schema.Table("columns", ischema, Column("table_schema", String), Column("table_name", String), Column("column_name", String), @@ -35,28 +35,40 @@ gen_columns = schema.Table("columns", generic_engine, Column("column_default", Integer), schema="information_schema") -gen_constraints = schema.Table("table_constraints", generic_engine, +constraints = schema.Table("table_constraints", ischema, Column("table_schema", String), Column("table_name", String), Column("constraint_name", String), Column("constraint_type", String), schema="information_schema") -gen_column_constraints = schema.Table("constraint_column_usage", generic_engine, +column_constraints = schema.Table("constraint_column_usage", ischema, Column("table_schema", String), Column("table_name", String), Column("column_name", String), Column("constraint_name", String), schema="information_schema") -gen_key_constraints = schema.Table("key_column_usage", generic_engine, +pg_key_constraints = schema.Table("key_column_usage", ischema, Column("table_schema", String), Column("table_name", String), Column("column_name", String), Column("constraint_name", String), schema="information_schema") -gen_ref_constraints = schema.Table("referential_constraints", generic_engine, +mysql_key_constraints = schema.Table("key_column_usage", ischema, + Column("table_schema", String), + Column("table_name", String), + Column("column_name", String), + Column("constraint_name", String), + Column("referenced_table_schema", String), + Column("referenced_table_name", String), + Column("referenced_column_name", String), + schema="information_schema") + +key_constraints = pg_key_constraints + +ref_constraints = schema.Table("referential_constraints", ischema, Column("constraint_catalog", String), Column("constraint_schema", String), Column("constraint_name", String), @@ -88,37 +100,25 @@ class ISchema(object): return self.cache[name] -def reflecttable(engine, table, ischema_names, use_mysql=False): - columns = gen_columns.toengine(engine) - constraints = gen_constraints.toengine(engine) +def reflecttable(connection, table, ischema_names, use_mysql=False): if use_mysql: # no idea which INFORMATION_SCHEMA spec is correct, mysql or postgres - key_constraints = schema.Table("key_column_usage", engine, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - Column("referenced_table_schema", String), - Column("referenced_table_name", String), - Column("referenced_column_name", String), - schema="information_schema", useexisting=True) + key_constraints = mysql_key_constraints else: - column_constraints = gen_column_constraints.toengine(engine) - key_constraints = gen_key_constraints.toengine(engine) - - + key_constraints = pg_key_constraints + if table.schema is not None: current_schema = table.schema else: - current_schema = engine.get_default_schema_name() + current_schema = connection.default_schema_name() s = select([columns], sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema), order_by=[columns.c.ordinal_position]) - c = s.execute() + c = connection.execute(s) while True: row = c.fetchone() if row is None: @@ -160,7 +160,7 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): s.append_whereclause(constraints.c.table_name==table.name) s.append_whereclause(constraints.c.table_schema==current_schema) colmap = [constraints.c.constraint_type, key_constraints.c.column_name, key_constraints.c.referenced_table_schema, key_constraints.c.referenced_table_name, key_constraints.c.referenced_column_name] - c = s.execute() + c = connection.execute(s) while True: row = c.fetchone() @@ -178,6 +178,8 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): if type=='PRIMARY KEY': table.c[constrained_column]._set_primary_key() elif type=='FOREIGN KEY': - remotetable = Table(referred_table, engine, autoload = True, schema=referred_schema) + if current_schema == referred_schema: + referred_schema = table.schema + remotetable = Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema) table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column])) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 6a7ef91b3..a8124537a 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -455,7 +455,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): super(MSSQLCompiler, self).visit_column(column) if column.table is not None and self.tablealiases.has_key(column.table): self.strings[column] = \ - self.strings[self.tablealiases[column.table]._get_col_by_original(column.original)] + self.strings[self.tablealiases[column.table].corresponding_column(column.original)] class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 60435f220..0a480ec11 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -6,14 +6,11 @@ import sys, StringIO, string, types, re, datetime -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +from sqlalchemy import sql,engine,schema,ansisql +from sqlalchemy.engine import default import sqlalchemy.types as sqltypes -from sqlalchemy import * import sqlalchemy.databases.information_schema as ischema -from sqlalchemy.exceptions import * +import sqlalchemy.exceptions as exceptions try: import MySQLdb as mysql @@ -26,7 +23,7 @@ class MSNumeric(sqltypes.Numeric): class MSDouble(sqltypes.Numeric): def __init__(self, precision = None, length = None): if (precision is None and length is not None) or (precision is not None and length is None): - raise ArgumentError("You must specify both precision and length or omit both altogether.") + raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.") super(MSDouble, self).__init__(precision, length) def get_col_spec(self): if self.precision is not None and self.length is not None: @@ -56,7 +53,7 @@ class MSDate(sqltypes.Date): class MSTime(sqltypes.Time): def get_col_spec(self): return "TIME" - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # convert from a timedelta value if value is not None: return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) @@ -129,72 +126,66 @@ def descriptor(): return {'name':'mysql', 'description':'MySQL', 'arguments':[ - ('user',"Database Username",None), - ('passwd',"Database Password",None), - ('db',"Database Name",None), + ('username',"Database Username",None), + ('password',"Database Password",None), + ('database',"Database Name",None), ('host',"Hostname", None), ]} -class MySQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, module = None, **params): + +class MySQLExecutionContext(default.DefaultExecutionContext): + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + if getattr(compiled, "isinsert", False): + self._last_inserted_ids = [proxy().lastrowid] + +class MySQLDialect(ansisql.ANSIDialect): + def __init__(self, module = None, **kwargs): if module is None: self.module = mysql - self.opts = self._translate_connect_args(('host', 'db', 'user', 'passwd'), opts) - ansisql.ANSISQLEngine.__init__(self, **params) + ansisql.ANSIDialect.__init__(self, **kwargs) + + def create_connect_args(self, url): + opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port']) + return [[], opts] - def connect_args(self): - return [[], self.opts] + def create_execution_context(self): + return MySQLExecutionContext(self) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def last_inserted_ids(self): - return self.context.last_inserted_ids def supports_sane_rowcount(self): return False def compiler(self, statement, bindparams, **kwargs): - return MySQLCompiler(statement, bindparams, engine=self, **kwargs) + return MySQLCompiler(self, statement, bindparams, **kwargs) - def schemagenerator(self, **params): - return MySQLSchemaGenerator(self, **params) + def schemagenerator(self, *args, **kwargs): + return MySQLSchemaGenerator(*args, **kwargs) - def schemadropper(self, **params): - return MySQLSchemaDropper(self, **params) + def schemadropper(self, *args, **kwargs): + return MySQLSchemaDropper(*args, **kwargs) def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): self._default_schema_name = text("select database()", self).scalar() return self._default_schema_name - - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def post_exec(self, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self.context.last_inserted_ids = [proxy().lastrowid] - # executemany just runs normally, since we arent using rowcount at all with mysql -# def _executemany(self, c, statement, parameters): - # """we need accurate rowcounts for updates, inserts and deletes. mysql is *also* is not nice enough - # to produce this correctly for an executemany, so we do our own executemany here.""" - # rowcount = 0 - # for param in parameters: - # c.execute(statement, param) - # rowcount += c.rowcount - # self.context.rowcount = rowcount - def dbapi(self): return self.module - def reflecttable(self, table): + def has_table(self, connection, table_name): + cursor = connection.execute("show table status like '" + table_name + "'") + return bool( not not cursor.rowcount ) + + def reflecttable(self, connection, table): # to use information_schema: #ischema.reflecttable(self, table, ischema_names, use_mysql=True) - tabletype, foreignkeyD = self.moretableinfo(table=table) + tabletype, foreignkeyD = self.moretableinfo(connection, table=table) table.kwargs['mysql_engine'] = tabletype - c = self.execute("describe " + table.name, {}) + c = connection.execute("describe " + table.name, {}) while True: row = c.fetchone() if row is None: @@ -224,7 +215,7 @@ class MySQLEngine(ansisql.ANSISQLEngine): default=default ))) - def moretableinfo(self, table): + def moretableinfo(self, connection, table): """Return (tabletype, {colname:foreignkey,...}) execute(SHOW CREATE TABLE child) => CREATE TABLE `child` ( @@ -233,7 +224,7 @@ class MySQLEngine(ansisql.ANSISQLEngine): KEY `par_ind` (`parent_id`), CONSTRAINT `child_ibfk_1` FOREIGN KEY (`parent_id`) REFERENCES `parent` (`id`) ON DELETE CASCADE\n) TYPE=InnoDB """ - c = self.execute("SHOW CREATE TABLE " + table.name, {}) + c = connection.execute("SHOW CREATE TABLE " + table.name, {}) desc = c.fetchone()[1].strip() tabletype = '' lastparen = re.search(r'\)[^\)]*\Z', desc) @@ -277,7 +268,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if column.primary_key: if not override_pk: colspec += " PRIMARY KEY" - if not column.foreign_key and first_pk and isinstance(column.type, types.Integer): + if not column.foreign_key and first_pk and isinstance(column.type, sqltypes.Integer): colspec += " AUTO_INCREMENT" if column.foreign_key: colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.foreign_key.column.table.name, column.foreign_key.column.name) @@ -294,3 +285,5 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX " + index.name + " ON " + index.table.name) self.execute() + +dialect = MySQLDialect
\ No newline at end of file diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 16c6cb218..b27f87dd0 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -7,11 +7,14 @@ import sys, StringIO, string +import sqlalchemy.util as util import sqlalchemy.sql as sql +import sqlalchemy.engine as engine +import sqlalchemy.engine.default as default import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql -from sqlalchemy import * import sqlalchemy.types as sqltypes +import sqlalchemy.exceptions as exceptions try: import cx_Oracle @@ -93,8 +96,6 @@ AND ac.r_constraint_name = rem.constraint_name(+) -- order multiple primary keys correctly ORDER BY ac.constraint_name, loc.position""" -def engine(*args, **params): - return OracleSQLEngine(*args, **params) def descriptor(): return {'name':'oracle', @@ -104,45 +105,53 @@ def descriptor(): ('user', 'Username', None), ('password', 'Password', None) ]} + +class OracleExecutionContext(default.DefaultExecutionContext): + pass -class OracleSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, use_ansi = True, module = None, threaded=False, **params): - self._use_ansi = use_ansi - self.opts = self._translate_connect_args((None, 'dsn', 'user', 'password'), opts) - self.opts['threaded'] = threaded +class OracleDialect(ansisql.ANSIDialect): + def __init__(self, use_ansi=True, module=None, threaded=True, **kwargs): + self.use_ansi = use_ansi + self.threaded = threaded if module is None: self.module = cx_Oracle else: self.module = module - ansisql.ANSISQLEngine.__init__(self, **params) + ansisql.ANSIDialect.__init__(self, **kwargs) def dbapi(self): return self.module - def connect_args(self): - return [[], self.opts] + def create_connect_args(self, url): + opts = url.translate_connect_args([None, 'dsn', 'user', 'password']) + opts['threaded'] = self.threaded + return ([], opts) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def last_inserted_ids(self): - return self.context.last_inserted_ids - def oid_column_name(self): return "rowid" + def create_execution_context(self): + return OracleExecutionContext(self) + def compiler(self, statement, bindparams, **kwargs): - return OracleCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs) - - def schemagenerator(self, **params): - return OracleSchemaGenerator(self, **params) - def schemadropper(self, **params): - return OracleSchemaDropper(self, **params) - def defaultrunner(self, proxy): - return OracleDefaultRunner(self, proxy) + return OracleCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): + return OracleSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return OracleSchemaDropper(*args, **kwargs) + def defaultrunner(self, engine, proxy): + return OracleDefaultRunner(engine, proxy) + + + def has_table(self, connection, table_name): + cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()}) + return bool( cursor.fetchone() is not None ) - def reflecttable(self, table): - c = self.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS where TABLE_NAME = :table_name", {'table_name':table.name.upper()}) + def reflecttable(self, connection, table): + c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS where TABLE_NAME = :table_name", {'table_name':table.name.upper()}) while True: row = c.fetchone() @@ -171,14 +180,14 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): colargs = [] if default is not None: - colargs.append(PassiveDefault(sql.text(default))) + colargs.append(schema.PassiveDefault(sql.text(default))) name = name.lower() table.append_item (schema.Column(name, coltype, nullable=nullable, *colargs)) - c = self.execute(constraintSQL, {'table_name' : table.name.upper()}) + c = connection.execute(constraintSQL, {'table_name' : table.name.upper()}) while True: row = c.fetchone() if row is None: @@ -189,34 +198,24 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): table.c[local_column]._set_primary_key() elif cons_type == 'R': table.c[local_column].append_item( - schema.ForeignKey(Table(remote_table, - self, + schema.ForeignKey(schema.Table(remote_table, + table.metadata, autoload=True).c[remote_column] ) ) - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def pre_exec(self, proxy, compiled, parameters, **kwargs): - pass - - def _executemany(self, c, statement, parameters): + def do_executemany(self, c, statement, parameters, context=None): rowcount = 0 for param in parameters: c.execute(statement, param) rowcount += c.rowcount - self.context.rowcount = rowcount + if context is not None: + context._rowcount = rowcount class OracleCompiler(ansisql.ANSICompiler): """oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs): - self._outertable = None - self._use_ansi = use_ansi - ansisql.ANSICompiler.__init__(self, statement, parameters, engine=engine, **kwargs) - def default_from(self): """called when a SELECT statement has no froms, and no FROM clause is to be appended. gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ @@ -226,7 +225,7 @@ class OracleCompiler(ansisql.ANSICompiler): return len(func.clauses) > 0 def visit_join(self, join): - if self._use_ansi: + if self.dialect.use_ansi: return ansisql.ANSICompiler.visit_join(self, join) self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right) @@ -251,7 +250,7 @@ class OracleCompiler(ansisql.ANSICompiler): def visit_column(self, column): ansisql.ANSICompiler.visit_column(self, column) - if not self._use_ansi and self._outertable is not None and column.table is self._outertable: + if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable: self.strings[column] = self.strings[column] + "(+)" def visit_insert(self, insert): @@ -275,12 +274,15 @@ class OracleCompiler(ansisql.ANSICompiler): self.strings[select.order_by_clause] = "" ansisql.ANSICompiler.visit_select(self, select) return + if select.limit is not None or select.offset is not None: select._oracle_visit = True # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.strings[select.order_by_clause] if not orderby: orderby = select.oid_column + orderby.accept_visitor(self) + orderby = self.strings[orderby] select.append_column(sql.ColumnClause("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) if select.offset is not None: @@ -330,3 +332,5 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def visit_sequence(self, seq): return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0] + +dialect = OracleDialect diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a92cb340d..b6917c035 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -9,11 +9,11 @@ import datetime, sys, StringIO, string, types, re import sqlalchemy.util as util import sqlalchemy.sql as sql import sqlalchemy.engine as engine +import sqlalchemy.engine.default as default import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes -from sqlalchemy.exceptions import * -from sqlalchemy import * +import sqlalchemy.exceptions as exceptions import information_schema as ischema try: @@ -47,7 +47,7 @@ class PG2DateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" class PG1DateTime(sqltypes.DateTime): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): if value is not None: if isinstance(value, datetime.datetime): seconds = float(str(value.second) + "." @@ -59,7 +59,7 @@ class PG1DateTime(sqltypes.DateTime): return psycopg.TimestampFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): if value is None: return None second_parts = str(value.second).split(".") @@ -68,21 +68,20 @@ class PG1DateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day, value.hour, value.minute, seconds, microseconds) - def get_col_spec(self): return "TIMESTAMP" class PG2Date(sqltypes.Date): def get_col_spec(self): return "DATE" class PG1Date(sqltypes.Date): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.DateFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): @@ -91,14 +90,14 @@ class PG2Time(sqltypes.Time): def get_col_spec(self): return "TIME" class PG1Time(sqltypes.Time): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.TimeFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): @@ -175,18 +174,35 @@ def descriptor(): return {'name':'postgres', 'description':'PostGres', 'arguments':[ - ('user',"Database Username",None), + ('username',"Database Username",None), ('password',"Database Password",None), ('database',"Database Name",None), ('host',"Hostname", None), ]} -class PGSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, module=None, use_oids=False, **params): +class PGExecutionContext(default.DefaultExecutionContext): + + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: + if not engine.dialect.use_oids: + pass + # will raise invalid error when they go to get them + else: + table = compiled.statement.table + cursor = proxy() + if cursor.lastrowid is not None and table is not None and len(table.primary_key): + s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) + c = s.compile(engine=engine) + cursor = proxy(str(c), c.get_params()) + row = cursor.fetchone() + self._last_inserted_ids = [v for v in row] + +class PGDialect(ansisql.ANSIDialect): + def __init__(self, module=None, use_oids=False, **params): self.use_oids = use_oids if module is None: - if psycopg is None: - raise ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") + #if psycopg is None: + # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") self.module = psycopg else: self.module = module @@ -198,17 +214,19 @@ class PGSQLEngine(ansisql.ANSISQLEngine): self.version = 1 except: self.version = 1 - self.opts = self._translate_connect_args(('host', 'database', 'user', 'password'), opts) - if self.opts.has_key('port'): + ansisql.ANSIDialect.__init__(self, **params) + + def create_connect_args(self, url): + opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) + if opts.has_key('port'): if self.version == 2: - self.opts['port'] = int(self.opts['port']) + opts['port'] = int(opts['port']) else: - self.opts['port'] = str(self.opts['port']) - - ansisql.ANSISQLEngine.__init__(self, **params) - - def connect_args(self): - return [[], self.opts] + opts['port'] = str(opts['port']) + return ([], opts) + + def create_execution_context(self): + return PGExecutionContext(self) def type_descriptor(self, typeobj): if self.version == 2: @@ -217,25 +235,22 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return sqltypes.adapt_type(typeobj, pg1_colspecs) def compiler(self, statement, bindparams, **kwargs): - return PGCompiler(statement, bindparams, engine=self, **kwargs) - - def schemagenerator(self, **params): - return PGSchemaGenerator(self, **params) - - def schemadropper(self, **params): - return PGSchemaDropper(self, **params) - - def defaultrunner(self, proxy=None): - return PGDefaultRunner(self, proxy) + return PGCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): + return PGSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return PGSchemaDropper(*args, **kwargs) + def defaultrunner(self, engine, proxy): + return PGDefaultRunner(engine, proxy) - def get_default_schema_name(self): + def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): - self._default_schema_name = text("select current_schema()", self).scalar() + self._default_schema_name = connection.scalar("select current_schema()", None) return self._default_schema_name def last_inserted_ids(self): if self.context.last_inserted_ids is None: - raise InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") + raise exceptions.InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") else: return self.context.last_inserted_ids @@ -245,51 +260,32 @@ class PGSQLEngine(ansisql.ANSISQLEngine): else: return None - def pre_exec(self, proxy, statement, parameters, **kwargs): - return - - def post_exec(self, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None: - if not self.use_oids: - pass - # will raise invalid error when they go to get them - else: - table = compiled.statement.table - cursor = proxy() - if cursor.lastrowid is not None and table is not None and len(table.primary_key): - s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) - c = s.compile() - cursor = proxy(str(c), c.get_params()) - row = cursor.fetchone() - self.context.last_inserted_ids = [v for v in row] - - def _executemany(self, c, statement, parameters): + def do_executemany(self, c, statement, parameters, context=None): """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough to produce this correctly for an executemany, so we do our own executemany here.""" rowcount = 0 for param in parameters: - try: - c.execute(statement, param) - except Exception, e: - raise exceptions.SQLError(statement, param, e) + c.execute(statement, param) rowcount += c.rowcount - self.context.rowcount = rowcount + if context is not None: + context._rowcount = rowcount def dbapi(self): return self.module - def reflecttable(self, table): + def has_table(self, connection, table_name): + cursor = connection.execute("""select relname from pg_class where lower(relname) = %(name)s""", {'name':table_name.lower()}) + return bool( not not cursor.rowcount ) + + def reflecttable(self, connection, table): if self.version == 2: ischema_names = pg2_ischema_names else: ischema_names = pg1_ischema_names - # give ischema the given table's engine with which to look up - # other tables, not 'self', since it could be a ProxyEngine - ischema.reflecttable(table.engine, table, ischema_names) + ischema.reflecttable(connection, table, ischema_names) class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): # Postgres advises against OID usage and turns it off in 8.1, @@ -322,7 +318,7 @@ class PGCompiler(ansisql.ANSICompiler): return "DISTINCT ON (" + str(select.distinct) + ") " else: return "" - + def binary_operator_string(self, binary): if isinstance(binary.type, sqltypes.String) and binary.operator == '+': return '||' @@ -333,7 +329,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if column.primary_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.engine_impl(self.engine).get_col_spec() @@ -367,7 +363,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if isinstance(column.default, schema.PassiveDefault): c = self.proxy("select %s" % column.default.arg) return c.fetchone()[0] - elif isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + elif isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema if sch is not None: exc = "select nextval('%s.%s_%s_seq')" % (sch, column.table.name, column.name) @@ -386,3 +382,5 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): return c.fetchone()[0] else: return None + +dialect = PGDialect diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index a7536ee4e..4d9f562ae 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -7,13 +7,9 @@ import sys, StringIO, string, types, re -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +from sqlalchemy import sql, engine, schema, ansisql, exceptions, pool +import sqlalchemy.engine.default as default import sqlalchemy.types as sqltypes -from sqlalchemy.exceptions import * -from sqlalchemy.ansisql import * import datetime,time pysqlite2_timesupport = False # Change this if the init.d guys ever get around to supporting time cols @@ -38,12 +34,12 @@ class SLSmallInteger(sqltypes.Smallinteger): class SLDateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): if value is not None: return str(value) else: return None - def _cvt(self, value, engine, fmt): + def _cvt(self, value, dialect, fmt): if value is None: return None parts = value.split('.') @@ -53,20 +49,20 @@ class SLDateTime(sqltypes.DateTime): except ValueError: (value, microsecond) = (value, 0) return time.strptime(value, fmt)[0:6] + (microsecond,) - def convert_result_value(self, value, engine): - tup = self._cvt(value, engine, "%Y-%m-%d %H:%M:%S") + def convert_result_value(self, value, dialect): + tup = self._cvt(value, dialect, "%Y-%m-%d %H:%M:%S") return tup and datetime.datetime(*tup) class SLDate(SLDateTime): def get_col_spec(self): return "DATE" - def convert_result_value(self, value, engine): - tup = self._cvt(value, engine, "%Y-%m-%d") + def convert_result_value(self, value, dialect): + tup = self._cvt(value, dialect, "%Y-%m-%d") return tup and datetime.date(*tup[0:3]) class SLTime(SLDateTime): def get_col_spec(self): return "TIME" - def convert_result_value(self, value, engine): - tup = self._cvt(value, engine, "%H:%M:%S") + def convert_result_value(self, value, dialect): + tup = self._cvt(value, dialect, "%H:%M:%S") return tup and datetime.time(*tup[4:7]) class SLText(sqltypes.TEXT): def get_col_spec(self): @@ -115,33 +111,32 @@ pragma_names = { if pysqlite2_timesupport: colspecs.update({sqltypes.Time : SLTime}) pragma_names.update({'TIME' : SLTime}) - -def engine(opts, **params): - return SQLiteSQLEngine(opts, **params) def descriptor(): return {'name':'sqlite', 'description':'SQLite', 'arguments':[ - ('filename', "Database Filename",None) + ('database', "Database Filename",None) ]} - -class SQLiteSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, **params): - if sqlite is None: - raise ArgumentError("Couldn't import sqlite or pysqlite2") - self.filename = opts.pop('filename', ':memory:') - self.opts = opts or {} - params['poolclass'] = sqlalchemy.pool.SingletonThreadPool - ansisql.ANSISQLEngine.__init__(self, **params) - def post_exec(self, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self.context.last_inserted_ids = [proxy().lastrowid] +class SQLiteExecutionContext(default.DefaultExecutionContext): + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + if getattr(compiled, "isinsert", False): + self._last_inserted_ids = [proxy().lastrowid] + +class SQLiteDialect(ansisql.ANSIDialect): + def compiler(self, statement, bindparams, **kwargs): + return SQLiteCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): + return SQLiteSchemaGenerator(*args, **kwargs) + def create_connect_args(self, url): + filename = url.database or ':memory:' + return ([filename], {}) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - + def create_execution_context(self): + return SQLiteExecutionContext(self) def last_inserted_ids(self): return self.context.last_inserted_ids @@ -151,20 +146,21 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): def connect_args(self): return ([self.filename], self.opts) - def compiler(self, statement, bindparams, **kwargs): - return SQLiteCompiler(statement, bindparams, engine=self, **kwargs) - def dbapi(self): + if sqlite is None: + raise ArgumentError("Couldn't import sqlite or pysqlite2") return sqlite def push_session(self): raise InvalidRequestError("SQLite doesn't support nested sessions") - def schemagenerator(self, **params): - return SQLiteSchemaGenerator(self, **params) + def has_table(self, connection, table_name): + cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {}) + row = cursor.fetchone() + return (row is not None) - def reflecttable(self, table): - c = self.execute("PRAGMA table_info(" + table.name + ")", {}) + def reflecttable(self, connection, table): + c = connection.execute("PRAGMA table_info(" + table.name + ")", {}) while True: row = c.fetchone() if row is None: @@ -183,7 +179,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): #print "args! " +repr(args) coltype = coltype(*[int(a) for a in args]) table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable)) - c = self.execute("PRAGMA foreign_key_list(" + table.name + ")", {}) + c = connection.execute("PRAGMA foreign_key_list(" + table.name + ")", {}) while True: row = c.fetchone() if row is None: @@ -192,10 +188,10 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): #print "row! " + repr(row) # look up the table based on the given table's engine, not 'self', # since it could be a ProxyEngine - remotetable = Table(tablename, table.engine, autoload = True) + remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection) table.c[localcol].append_item(schema.ForeignKey(remotetable.c[remotecol])) # check for UNIQUE indexes - c = self.execute("PRAGMA index_list(" + table.name + ")", {}) + c = connection.execute("PRAGMA index_list(" + table.name + ")", {}) unique_indexes = [] while True: row = c.fetchone() @@ -205,7 +201,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): unique_indexes.append(row[1]) # loop thru unique indexes for one that includes the primary key for idx in unique_indexes: - c = self.execute("PRAGMA index_info(" + idx + ")", {}) + c = connection.execute("PRAGMA index_info(" + idx + ")", {}) cols = [] while True: row = c.fetchone() @@ -219,9 +215,6 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): table.columns[col]._set_primary_key() class SQLiteCompiler(ansisql.ANSICompiler): - def __init__(self, *args, **params): - params.setdefault('paramstyle', 'named') - ansisql.ANSICompiler.__init__(self, *args, **params) def limit_clause(self, select): text = "" if select.limit is not None: @@ -238,7 +231,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): return '||' else: return ansisql.ANSICompiler.binary_operator_string(self, binary) - + class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() @@ -277,4 +270,5 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): for index in table.indexes: self.visit_index(index) - +dialect = SQLiteDialect +poolclass = pool.SingletonThreadPool diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py deleted file mode 100644 index f1bb76057..000000000 --- a/lib/sqlalchemy/engine.py +++ /dev/null @@ -1,878 +0,0 @@ -# engine.py -# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Defines the SQLEngine class, which serves as the primary "database" object -used throughout the sql construction and object-relational mapper packages. -A SQLEngine is a facade around a single connection pool corresponding to a -particular set of connection parameters, and provides thread-local transactional -methods and statement execution methods for Connection objects. It also provides -a facade around a Cursor object to allow richer column selection for result rows -as well as type conversion operations, known as a ResultProxy. - -A SQLEngine is provided to an application as a subclass that is specific to a particular type -of DBAPI, and is the central switching point for abstracting different kinds of database -behavior into a consistent set of behaviors. It provides a variety of factory methods -to produce everything specific to a certain kind of database, including a Compiler, -schema creation/dropping objects. - -The term "database-specific" will be used to describe any object or function that has behavior -corresponding to a particular vendor, such as mysql-specific, sqlite-specific, etc. -""" - -import sqlalchemy.pool -import schema -import exceptions -import util -import sql -import sqlalchemy.databases -import sqlalchemy.types as types -import StringIO, sys, re -from cgi import parse_qsl - -__all__ = ['create_engine', 'engine_descriptors'] - -def create_engine(name, opts=None,**kwargs): - """creates a new SQLEngine instance. There are two forms of calling this method. - - In the first, the "name" argument is the type of engine to load, i.e. 'sqlite', 'postgres', - 'oracle', 'mysql'. "opts" is a dictionary of options to be sent to the underlying DBAPI module - to create a connection, usually including a hostname, username, password, etc. - - In the second, the "name" argument is a URL in the form <enginename>://opt1=val1&opt2=val2. - Where <enginename> is the name as above, and the contents of the option dictionary are - spelled out as a URL encoded string. The "opts" argument is not used. - - In both cases, **kwargs represents options to be sent to the SQLEngine itself. A possibly - partial listing of those options is as follows: - - pool=None : an instance of sqlalchemy.pool.DBProxy or sqlalchemy.pool.Pool to be used as the - underlying source for connections (DBProxy/Pool is described in the previous section). If None, - a default DBProxy will be created using the engine's own database module with the given - arguments. - - echo=False : if True, the SQLEngine will log all statements as well as a repr() of their - parameter lists to the engines logger, which defaults to sys.stdout. A SQLEngine instances' - "echo" data member can be modified at any time to turn logging on and off. If set to the string - 'debug', result rows will be printed to the standard output as well. - - logger=None : a file-like object where logging output can be sent, if echo is set to True. - This defaults to sys.stdout. - - module=None : used by Oracle and Postgres, this is a reference to a DBAPI2 module to be used - instead of the engine's default module. For Postgres, the default is psycopg2, or psycopg1 if - 2 cannot be found. For Oracle, its cx_Oracle. For mysql, MySQLdb. - - use_ansi=True : used only by Oracle; when False, the Oracle driver attempts to support a - particular "quirk" of some Oracle databases, that the LEFT OUTER JOIN SQL syntax is not - supported, and the "Oracle join" syntax of using <column1>(+)=<column2> must be used - in order to achieve a LEFT OUTER JOIN. Its advised that the Oracle database be configured to - have full ANSI support instead of using this feature. - - """ - m = re.match(r'(\w+)://(.*)', name) - if m is not None: - (name, args) = m.group(1, 2) - opts = dict( parse_qsl( args ) ) - module = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) - return module.engine(opts, **kwargs) - -def engine_descriptors(): - """provides a listing of all the database implementations supported. this data - is provided as a list of dictionaries, where each dictionary contains the following - key/value pairs: - - name : the name of the engine, suitable for use in the create_engine function - - description: a plain description of the engine. - - arguments : a dictionary describing the name and description of each parameter - used to connect to this engine's underlying DBAPI. - - This function is meant for usage in automated configuration tools that wish to - query the user for database and connection information. - """ - result = [] - for module in sqlalchemy.databases.__all__: - module = getattr(__import__('sqlalchemy.databases.%s' % module).databases, module) - result.append(module.descriptor()) - return result - -class SchemaIterator(schema.SchemaVisitor): - """a visitor that can gather text into a buffer and execute the contents of the buffer.""" - def __init__(self, engine, **params): - """initializes this SchemaIterator and initializes its buffer. - - sqlproxy - a callable function returned by SQLEngine.proxy(), which executes a - statement plus optional parameters. - """ - self.engine = engine - self.buffer = StringIO.StringIO() - - def append(self, s): - """appends content to the SchemaIterator's query buffer.""" - self.buffer.write(s) - - def execute(self): - """executes the contents of the SchemaIterator's buffer using its sql proxy and - clears out the buffer.""" - try: - return self.engine.execute(self.buffer.getvalue(), None) - finally: - self.buffer.truncate(0) - -class DefaultRunner(schema.SchemaVisitor): - def __init__(self, engine, proxy): - self.proxy = proxy - self.engine = engine - - def get_column_default(self, column): - if column.default is not None: - return column.default.accept_schema_visitor(self) - else: - return None - - def get_column_onupdate(self, column): - if column.onupdate is not None: - return column.onupdate.accept_schema_visitor(self) - else: - return None - - def visit_passive_default(self, default): - """passive defaults by definition return None on the app side, - and are post-fetched to get the DB-side value""" - return None - - def visit_sequence(self, seq): - """sequences are not supported by default""" - return None - - def exec_default_sql(self, default): - c = sql.select([default.arg], engine=self.engine).compile() - return self.proxy(str(c), c.get_params()).fetchone()[0] - - def visit_column_onupdate(self, onupdate): - if isinstance(onupdate.arg, sql.ClauseElement): - return self.exec_default_sql(onupdate) - elif callable(onupdate.arg): - return onupdate.arg() - else: - return onupdate.arg - - def visit_column_default(self, default): - if isinstance(default.arg, sql.ClauseElement): - return self.exec_default_sql(default) - elif callable(default.arg): - return default.arg() - else: - return default.arg - -class SQLSession(object): - """represents a a handle to the SQLEngine's connection pool. the default SQLSession maintains a distinct connection during transactions, otherwise returns connections newly retrieved from the pool each time. the Pool is usually configured to have use_threadlocal=True so if a particular connection is already checked out, youll get that same connection in the same thread. There can also be a "unique" SQLSession pushed onto the engine, which returns a connection via the unique_connection() method on Pool; this allows nested transactions to take place, or other operations upon more than one connection at a time.`""" - def __init__(self, engine, parent=None): - self.engine = engine - self.parent = parent - # if we have a parent SQLSession, then use a unique connection. - # else we use the default connection returned by the pool. - if parent is not None: - self.__connection = self.engine._pool.unique_connection() - self.__tcount = 0 - def pop(self): - self.engine.pop_session(self) - def _connection(self): - try: - return self.__transaction - except AttributeError: - try: - return self.__connection - except AttributeError: - return self.engine._pool.connect() - connection = property(_connection, doc="the connection represented by this SQLSession. The connection is late-connecting, meaning the call to the connection pool only occurs when it is first called (and the pool will typically only connect the first time it is called as well)") - - def begin(self): - """begins a transaction on this SQLSession's connection. repeated calls to begin() will increment a counter that must be decreased by corresponding commit() statements before an actual commit occurs. this is to provide "nested" behavior of transactions so that different functions in a particular call stack can call begin()/commit() independently of each other without knowledge of an existing transaction. """ - if self.__tcount == 0: - self.__transaction = self.connection - self.engine.do_begin(self.connection) - self.__tcount += 1 - def rollback(self): - """rolls back the transaction on this SQLSession's connection. this can be called regardless of the "begin" counter value, i.e. can be called from anywhere inside a callstack. the "begin" counter is cleared.""" - if self.__tcount > 0: - try: - self.engine.do_rollback(self.connection) - finally: - del self.__transaction - self.__tcount = 0 - def commit(self): - """commits the transaction started by begin(). If begin() was called multiple times, a counter will be decreased for each call to commit(), with the actual commit operation occuring when the counter reaches zero. this is to provide "nested" behavior of transactions so that different functions in a particular call stack can call begin()/commit() independently of each other without knowledge of an existing transaction.""" - if self.__tcount == 1: - try: - self.engine.do_commit(self.connection) - finally: - del self.__transaction - self.__tcount = 0 - elif self.__tcount > 1: - self.__tcount -= 1 - def is_begun(self): - return self.__tcount > 0 - -class SQLEngine(schema.SchemaEngine): - """ - The central "database" object used by an application. Subclasses of this object is used - by the schema and SQL construction packages to provide database-specific behaviors, - as well as an execution and thread-local transaction context. - - SQLEngines are constructed via the create_engine() function inside this package. - """ - - def __init__(self, pool=None, echo=False, logger=None, default_ordering=False, echo_pool=False, echo_uow=False, convert_unicode=False, encoding='utf-8', **params): - """constructs a new SQLEngine. SQLEngines should be constructed via the create_engine() - function which will construct the appropriate subclass of SQLEngine.""" - # get a handle on the connection pool via the connect arguments - # this insures the SQLEngine instance integrates with the pool referenced - # by direct usage of pool.manager(<module>).connect(*args, **params) - schema.SchemaEngine.__init__(self) - (cargs, cparams) = self.connect_args() - if pool is None: - params['echo'] = echo_pool - params['use_threadlocal'] = True - self._pool = sqlalchemy.pool.manage(self.dbapi(), **params).get_pool(*cargs, **cparams) - elif isinstance(pool, sqlalchemy.pool.DBProxy): - self._pool = pool.get_pool(*cargs, **cparams) - else: - self._pool = pool - self.default_ordering=default_ordering - self.echo = echo - self.echo_uow = echo_uow - self.convert_unicode = convert_unicode - self.encoding = encoding - self.context = util.ThreadLocal() - self._ischema = None - self._figure_paramstyle() - self.logger = logger or util.Logger(origin='engine') - - def _translate_connect_args(self, names, args): - """translates a dictionary of connection arguments to those used by a specific dbapi. - the names parameter is a tuple of argument names in the form ('host', 'database', 'user', 'password') - where the given strings match the corresponding argument names for the dbapi. Will return a dictionary - with the dbapi-specific parameters, the generic ones removed, and any additional parameters still remaining, - from the dictionary represented by args. Will return a blank dictionary if args is null.""" - if args is None: - return {} - a = args.copy() - standard_names = [('host','hostname'), ('database', 'dbname'), ('user', 'username'), ('password', 'passwd', 'pw')] - for n in names: - sname = standard_names.pop(0) - if n is None: - continue - for sn in sname: - if sn != n and a.has_key(sn): - a[n] = a[sn] - del a[sn] - return a - def _get_ischema(self): - # We use a property for ischema so that the accessor - # creation only happens as needed, since otherwise we - # have a circularity problem with the generic - # ansisql.engine() - if self._ischema is None: - import sqlalchemy.databases.information_schema as ischema - self._ischema = ischema.ISchema(self) - return self._ischema - ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") - - def hash_key(self): - return "%s(%s)" % (self.__class__.__name__, repr(self.connect_args())) - - def _get_name(self): - return sys.modules[self.__module__].descriptor()['name'] - name = property(_get_name) - - def dispose(self): - """disposes of the underlying pool manager for this SQLEngine.""" - (cargs, cparams) = self.connect_args() - sqlalchemy.pool.manage(self.dbapi()).dispose(*cargs, **cparams) - self._pool = None - - def _set_paramstyle(self, style): - self._paramstyle = style - self._figure_paramstyle(style) - paramstyle = property(lambda s:s._paramstyle, _set_paramstyle) - - def _figure_paramstyle(self, paramstyle=None): - db = self.dbapi() - if paramstyle is not None: - self._paramstyle = paramstyle - elif db is not None: - self._paramstyle = db.paramstyle - else: - self._paramstyle = 'named' - - if self._paramstyle == 'named': - self.positional=False - elif self._paramstyle == 'pyformat': - self.positional=False - elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric': - # for positional, use pyformat internally, ANSICompiler will convert - # to appropriate character upon compilation - self.positional = True - else: - raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle) - - def type_descriptor(self, typeobj): - """provides a database-specific TypeEngine object, given the generic object - which comes from the types module. Subclasses will usually use the adapt_type() - method in the types module to make this job easy.""" - if type(typeobj) is type: - typeobj = typeobj() - return typeobj - - def _func(self): - return sql.FunctionGenerator(self) - func = property(_func) - - def text(self, text, *args, **kwargs): - """returns a sql.text() object for performing literal queries.""" - return sql.text(text, engine=self, *args, **kwargs) - - def schemagenerator(self, **params): - """returns a schema.SchemaVisitor instance that can generate schemas, when it is - invoked to traverse a set of schema objects. - - schemagenerator is called via the create() method. - """ - raise NotImplementedError() - - def schemadropper(self, **params): - """returns a schema.SchemaVisitor instance that can drop schemas, when it is - invoked to traverse a set of schema objects. - - schemagenerator is called via the drop() method. - """ - raise NotImplementedError() - - def defaultrunner(self, proxy=None): - """Returns a schema.SchemaVisitor instance that can execute the default values on a column. - The base class for this visitor is the DefaultRunner class inside this module. - This visitor will typically only receive schema.DefaultGenerator schema objects. The given - proxy is a callable that takes a string statement and a dictionary of bind parameters - to be executed. For engines that require positional arguments, the dictionary should - be an instance of OrderedDict which returns its bind parameters in the proper order. - - defaultrunner is called within the context of the execute_compiled() method.""" - return DefaultRunner(self, proxy) - - def compiler(self, statement, parameters): - """returns a sql.ClauseVisitor which will produce a string representation of the given - ClauseElement and parameter dictionary. This object is usually a subclass of - ansisql.ANSICompiler. - - compiler is called within the context of the compile() method.""" - raise NotImplementedError() - - def oid_column_name(self): - """returns the oid column name for this engine, or None if the engine cant/wont support OID/ROWID.""" - return None - - def supports_sane_rowcount(self): - """Provided to indicate when MySQL is being used, which does not have standard behavior - for the "rowcount" function on a statement handle. """ - return True - - def create(self, entity, **params): - """creates a table or index within this engine's database connection given a schema.Table object.""" - entity.accept_schema_visitor(self.schemagenerator(**params)) - return entity - - def drop(self, entity, **params): - """drops a table or index within this engine's database connection given a schema.Table object.""" - entity.accept_schema_visitor(self.schemadropper(**params)) - - def compile(self, statement, parameters, **kwargs): - """given a sql.ClauseElement statement plus optional bind parameters, creates a new - instance of this engine's SQLCompiler, compiles the ClauseElement, and returns the - newly compiled object.""" - compiler = self.compiler(statement, parameters, **kwargs) - compiler.compile() - return compiler - - def reflecttable(self, table): - """given a Table object, reflects its columns and properties from the database.""" - raise NotImplementedError() - - def get_default_schema_name(self): - """returns the currently selected schema in the current connection.""" - return None - - def last_inserted_ids(self): - """returns a thread-local list of the primary key values for the last insert statement executed. - This does not apply to straight textual clauses; only to sql.Insert objects compiled against - a schema.Table object, which are executed via statement.execute(). The order of items in the - list is the same as that of the Table's 'primary_key' attribute. - - In some cases, this method may invoke a query back to the database to retrieve the data, based on - the "lastrowid" value in the cursor.""" - raise NotImplementedError() - - def connect_args(self): - """subclasses override this method to provide a two-item tuple containing the *args - and **kwargs used to establish a connection.""" - raise NotImplementedError() - - def dbapi(self): - """subclasses override this method to provide the DBAPI module used to establish - connections.""" - raise NotImplementedError() - - def do_begin(self, connection): - """implementations might want to put logic here for turning autocommit on/off, - etc.""" - pass - def do_rollback(self, connection): - """implementations might want to put logic here for turning autocommit on/off, - etc.""" - #print "ENGINE ROLLBACK ON ", connection.connection - connection.rollback() - def do_commit(self, connection): - """implementations might want to put logic here for turning autocommit on/off, etc.""" - #print "ENGINE COMMIT ON ", connection.connection - connection.commit() - - def _session(self): - if not hasattr(self.context, 'session'): - self.context.session = SQLSession(self) - return self.context.session - session = property(_session, doc="returns the current thread's SQLSession") - - def push_session(self): - """pushes a new SQLSession onto this engine, temporarily replacing the previous one for the current thread. The previous session can be restored by calling pop_session(). this allows the usage of a new connection and possibly transaction within a particular block, superceding the existing one, including any transactions that are in progress. Returns the new SQLSession object.""" - sess = SQLSession(self, self.context.session) - self.context.session = sess - return sess - def pop_session(self, s = None): - """restores the current thread's SQLSession to that before the last push_session. Returns the restored SQLSession object. Raises an exception if there is no SQLSession pushed onto the stack.""" - sess = self.context.session.parent - if sess is None: - raise exceptions.InvalidRequestError("No SQLSession is pushed onto the stack.") - elif s is not None and s is not self.context.session: - raise exceptions.InvalidRequestError("Given SQLSession is not the current session on the stack") - self.context.session = sess - return sess - - def connection(self): - """returns a managed DBAPI connection from this SQLEngine's connection pool.""" - return self.session.connection - - def unique_connection(self): - """returns a DBAPI connection from this SQLEngine's connection pool that is distinct from the current thread's connection.""" - return self._pool.unique_connection() - - def multi_transaction(self, tables, func): - """provides a transaction boundary across tables which may be in multiple databases. - If you have three tables, and a function that operates upon them, providing the tables as a - list and the function will result in a begin()/commit() pair invoked for each distinct engine - represented within those tables, and the function executed within the context of that transaction. - any exceptions will result in a rollback(). - - clearly, this approach only goes so far, such as if database A commits, then database B commits - and fails, A is already committed. Any failure conditions have to be raised before anyone - commits for this to be useful.""" - engines = util.HashSet() - for table in tables: - engines.append(table.engine) - for engine in engines: - engine.begin() - try: - func() - except: - for engine in engines: - engine.rollback() - raise - for engine in engines: - engine.commit() - - def transaction(self, func, *args, **kwargs): - """executes the given function within a transaction boundary. this is a shortcut for - explicitly calling begin() and commit() and optionally rollback() when execptions are raised. - The given *args and **kwargs will be passed to the function as well, which could be handy - in constructing decorators.""" - self.begin() - try: - func(*args, **kwargs) - except: - self.rollback() - raise - self.commit() - - def begin(self): - """ begins a transaction on the current thread SQLSession. """ - self.session.begin() - - def rollback(self): - """rolls back the transaction on the current thread's SQLSession.""" - self.session.rollback() - - def commit(self): - self.session.commit() - - def _process_defaults(self, proxy, compiled, parameters, **kwargs): - """INSERT and UPDATE statements, when compiled, may have additional columns added to their - VALUES and SET lists corresponding to column defaults/onupdates that are present on the - Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those - DefaultGenerator objects that require pre-execution and sets their values within the - parameter list, and flags the thread-local state about - PassiveDefault objects that may require post-fetching the row after it is inserted/updated. - This method relies upon logic within the ANSISQLCompiler in its visit_insert and - visit_update methods that add the appropriate column clauses to the statement when its - being compiled, so that these parameters can be bound to the statement.""" - if compiled is None: return - if getattr(compiled, "isinsert", False): - if isinstance(parameters, list): - plist = parameters - else: - plist = [parameters] - drunner = self.defaultrunner(proxy) - self.context.lastrow_has_defaults = False - for param in plist: - last_inserted_ids = [] - need_lastrowid=False - for c in compiled.statement.table.c: - if not param.has_key(c.name) or param[c.name] is None: - if isinstance(c.default, schema.PassiveDefault): - self.context.lastrow_has_defaults = True - newid = drunner.get_column_default(c) - if newid is not None: - param[c.name] = newid - if c.primary_key: - last_inserted_ids.append(param[c.name]) - elif c.primary_key: - need_lastrowid = True - elif c.primary_key: - last_inserted_ids.append(param[c.name]) - if need_lastrowid: - self.context.last_inserted_ids = None - else: - self.context.last_inserted_ids = last_inserted_ids - self.context.last_inserted_params = param - elif getattr(compiled, 'isupdate', False): - if isinstance(parameters, list): - plist = parameters - else: - plist = [parameters] - drunner = self.defaultrunner(proxy) - self.context.lastrow_has_defaults = False - for param in plist: - for c in compiled.statement.table.c: - if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None): - value = drunner.get_column_onupdate(c) - if value is not None: - param[c.name] = value - self.context.last_updated_params = param - - def last_inserted_params(self): - """returns a dictionary of the full parameter dictionary for the last compiled INSERT statement, - including any ColumnDefaults or Sequences that were pre-executed. this value is thread-local.""" - return self.context.last_inserted_params - def last_updated_params(self): - """returns a dictionary of the full parameter dictionary for the last compiled UPDATE statement, - including any ColumnDefaults that were pre-executed. this value is thread-local.""" - return self.context.last_updated_params - def lastrow_has_defaults(self): - """returns True if the last row INSERTED via a compiled insert statement contained PassiveDefaults, - indicating that the database inserted data beyond that which we gave it. this value is thread-local.""" - return self.context.lastrow_has_defaults - - def pre_exec(self, proxy, compiled, parameters, **kwargs): - """called by execute_compiled before the compiled statement is executed.""" - pass - - def post_exec(self, proxy, compiled, parameters, **kwargs): - """called by execute_compiled after the compiled statement is executed.""" - pass - - def execute_compiled(self, compiled, parameters, connection=None, cursor=None, echo=None, **kwargs): - """executes the given compiled statement object with the given parameters. - - The parameters can be a dictionary of key/value pairs, or a list of dictionaries for an - executemany() style of execution. Engines that use positional parameters will convert - the parameters to a list before execution. - - If the current thread has specified a transaction begin() for this engine, the - statement will be executed in the context of the current transactional connection. - Otherwise, a commit() will be performed immediately after execution, since the local - pooled connection is returned to the pool after execution without a transaction set - up. - - In all error cases, a rollback() is immediately performed on the connection before - propigating the exception outwards. - - Other options include: - - connection - a DBAPI connection to use for the execute. If None, a connection is - pulled from this engine's connection pool. - - echo - enables echo for this execution, which causes all SQL and parameters - to be dumped to the engine's logging output before execution. - - typemap - a map of column names mapped to sqlalchemy.types.TypeEngine objects. - These will be passed to the created ResultProxy to perform - post-processing on result-set values. - - commit - if True, will automatically commit the statement after completion. """ - - if connection is None: - connection = self.connection() - - if cursor is None: - cursor = connection.cursor() - - executemany = parameters is not None and (isinstance(parameters, list) or isinstance(parameters, tuple)) - if executemany: - parameters = [compiled.get_params(**m) for m in parameters] - else: - parameters = compiled.get_params(**parameters) - def proxy(statement=None, parameters=None): - if statement is None: - return cursor - - parameters = self._convert_compiled_params(parameters) - self.execute(statement, parameters, connection=connection, cursor=cursor, return_raw=True) - return cursor - - self.pre_exec(proxy, compiled, parameters, **kwargs) - self._process_defaults(proxy, compiled, parameters, **kwargs) - proxy(str(compiled), parameters) - self.post_exec(proxy, compiled, parameters, **kwargs) - return ResultProxy(cursor, self, typemap=compiled.typemap) - - def execute(self, statement, parameters=None, connection=None, cursor=None, echo=None, typemap=None, commit=False, return_raw=False, **kwargs): - """ executes the given string-based SQL statement with the given parameters. - - The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending - on the paramstyle of the DBAPI. - - If the current thread has specified a transaction begin() for this engine, the - statement will be executed in the context of the current transactional connection. - Otherwise, a commit() will be performed immediately after execution, since the local - pooled connection is returned to the pool after execution without a transaction set - up. - - In all error cases, a rollback() is immediately performed on the connection before - propagating the exception outwards. - - Other options include: - - connection - a DBAPI connection to use for the execute. If None, a connection is - pulled from this engine's connection pool. - - echo - enables echo for this execution, which causes all SQL and parameters - to be dumped to the engine's logging output before execution. - - typemap - a map of column names mapped to sqlalchemy.types.TypeEngine objects. - These will be passed to the created ResultProxy to perform - post-processing on result-set values. - - commit - if True, will automatically commit the statement after completion. """ - - if connection is None: - connection = self.connection() - - if cursor is None: - cursor = connection.cursor() - - try: - if echo is True or self.echo is not False: - self.log(statement) - self.log(repr(parameters)) - if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): - self._executemany(cursor, statement, parameters) - else: - self._execute(cursor, statement, parameters) - if not self.session.is_begun(): - self.do_commit(connection) - except: - self.do_rollback(connection) - raise - if return_raw: - return cursor - else: - return ResultProxy(cursor, self, typemap=typemap) - - def _execute(self, c, statement, parameters): - if parameters is None: - if self.positional: - parameters = () - else: - parameters = {} - try: - c.execute(statement, parameters) - except Exception, e: - raise exceptions.SQLError(statement, parameters, e) - self.context.rowcount = c.rowcount - def _executemany(self, c, statement, parameters): - c.executemany(statement, parameters) - self.context.rowcount = c.rowcount - - def _convert_compiled_params(self, parameters): - executemany = parameters is not None and isinstance(parameters, list) - # the bind params are a CompiledParams object. but all the DBAPI's hate - # that object (or similar). so convert it to a clean - # dictionary/list/tuple of dictionary/tuple of list - if parameters is not None: - if self.positional: - if executemany: - parameters = [p.values() for p in parameters] - else: - parameters = parameters.values() - else: - if executemany: - parameters = [p.get_raw_dict() for p in parameters] - else: - parameters = parameters.get_raw_dict() - return parameters - - def proxy(self, statement=None, parameters=None): - """returns a callable which will execute the given statement string and parameter object. - the parameter object is expected to be the result of a call to compiled.get_params(). - This callable is a generic version of a connection/cursor-specific callable that - is produced within the execute_compiled method, and is used for objects that require - this style of proxy when outside of an execute_compiled method, primarily the DefaultRunner.""" - parameters = self._convert_compiled_params(parameters) - return self.execute(statement, parameters) - - def log(self, msg): - """logs a message using this SQLEngine's logger stream.""" - self.logger.write(msg) - - -class ResultProxy: - """wraps a DBAPI cursor object to provide access to row columns based on integer - position, case-insensitive column name, or by schema.Column object. e.g.: - - row = fetchone() - - col1 = row[0] # access via integer position - - col2 = row['col2'] # access via name - - col3 = row[mytable.c.mycol] # access via Column object. - - ResultProxy also contains a map of TypeEngine objects and will invoke the appropriate - convert_result_value() method before returning columns. - """ - class AmbiguousColumn(object): - def __init__(self, key): - self.key = key - def convert_result_value(self, arg, engine): - raise InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) - - def __init__(self, cursor, engine, typemap = None): - """ResultProxy objects are constructed via the execute() method on SQLEngine.""" - self.cursor = cursor - self.engine = engine - self.echo = engine.echo=="debug" - self.rowcount = engine.context.rowcount - metadata = cursor.description - self.props = {} - self.keys = [] - i = 0 - if metadata is not None: - for item in metadata: - # sqlite possibly prepending table name to colnames so strip - colname = item[0].split('.')[-1].lower() - if typemap is not None: - rec = (typemap.get(colname, types.NULLTYPE), i) - else: - rec = (types.NULLTYPE, i) - if rec[0] is None: - raise DBAPIError("None for metadata " + colname) - if self.props.setdefault(colname, rec) is not rec: - self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0) - self.keys.append(colname) - #print "COLNAME", colname - self.props[i] = rec - i+=1 - - def _get_col(self, row, key): - if isinstance(key, schema.Column) or isinstance(key, sql.ColumnElement): - try: - rec = self.props[key._label.lower()] - #print "GOT IT FROM LABEL FOR ", key._label - except KeyError: - try: - rec = self.props[key.key.lower()] - except KeyError: - rec = self.props[key.name.lower()] - elif isinstance(key, str): - rec = self.props[key.lower()] - else: - rec = self.props[key] - return rec[0].engine_impl(self.engine).convert_result_value(row[rec[1]], self.engine) - - def __iter__(self): - while True: - row = self.fetchone() - if row is None: - raise StopIteration - else: - yield row - - def last_inserted_ids(self): - return self.engine.last_inserted_ids() - def last_updated_params(self): - return self.engine.last_updated_params() - def last_inserted_params(self): - return self.engine.last_inserted_params() - def lastrow_has_defaults(self): - return self.engine.lastrow_has_defaults() - def supports_sane_rowcount(self): - return self.engine.supports_sane_rowcount() - - def fetchall(self): - """fetches all rows, just like DBAPI cursor.fetchall().""" - l = [] - while True: - v = self.fetchone() - if v is None: - return l - l.append(v) - - def fetchone(self): - """fetches one row, just like DBAPI cursor.fetchone().""" - row = self.cursor.fetchone() - if row is not None: - if self.echo: self.engine.log(repr(row)) - return RowProxy(self, row) - else: - return None - -class RowProxy: - """proxies a single cursor row for a parent ResultProxy.""" - def __init__(self, parent, row): - """RowProxy objects are constructed by ResultProxy objects.""" - self.__parent = parent - self.__row = row - def __iter__(self): - for i in range(0, len(self.__row)): - yield self.__parent._get_col(self.__row, i) - def __eq__(self, other): - return (other is self) or (other == tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))])) - def __repr__(self): - return repr(tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))])) - def __getitem__(self, key): - return self.__parent._get_col(self.__row, key) - def __getattr__(self, name): - try: - return self.__parent._get_col(self.__row, name) - except KeyError: - raise AttributeError - def items(self): - return [(key, getattr(self, key)) for key in self.keys()] - def keys(self): - return self.__parent.keys - def values(self): - return list(self) - def __len__(self): - return len(self.__row) diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py new file mode 100644 index 000000000..2cb94a90d --- /dev/null +++ b/lib/sqlalchemy/engine/__init__.py @@ -0,0 +1,92 @@ +# engine/__init__.py +# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import sqlalchemy.databases + +from base import * +import strategies +import re + +def engine_descriptors(): + """provides a listing of all the database implementations supported. this data + is provided as a list of dictionaries, where each dictionary contains the following + key/value pairs: + + name : the name of the engine, suitable for use in the create_engine function + + description: a plain description of the engine. + + arguments : a dictionary describing the name and description of each parameter + used to connect to this engine's underlying DBAPI. + + This function is meant for usage in automated configuration tools that wish to + query the user for database and connection information. + """ + result = [] + #for module in sqlalchemy.databases.__all__: + for module in ['sqlite', 'postgres', 'mysql']: + module = getattr(__import__('sqlalchemy.databases.%s' % module).databases, module) + result.append(module.descriptor()) + return result + +default_strategy = 'plain' +def create_engine(*args, **kwargs): + """creates a new Engine instance. Using the given strategy name, + locates that strategy and invokes its create() method to produce the Engine. + The strategies themselves are instances of EngineStrategy, and the built in + ones are present in the sqlalchemy.engine.strategies module. Current implementations + include "plain" and "threadlocal". The default used by this function is "threadlocal". + + "plain" provides support for a Connection object which can be used to execute SQL queries + with a specific underlying DBAPI connection. + + "threadlocal" is similar to "plain" except that it adds support for a thread-local connection and + transaction context, which allows a group of engine operations to participate using the same + connection and transaction without the need for explicit passing of a Connection object. + + The standard method of specifying the engine is via URL as the first positional + argument, to indicate the appropriate database dialect and connection arguments, with additional + keyword arguments sent as options to the dialect and resulting Engine. + + The URL is in the form <dialect>://opt1=val1&opt2=val2. + Where <dialect> is a name such as "mysql", "oracle", "postgres", and the options indicate + username, password, database, etc. Supported keynames include "username", "user", "password", + "pw", "db", "database", "host", "filename". + + **kwargs represents options to be sent to the Engine itself as well as the components of the Engine, + including the Dialect, the ConnectionProvider, and the Pool. A list of common options is as follows: + + pool=None : an instance of sqlalchemy.pool.DBProxy or sqlalchemy.pool.Pool to be used as the + underlying source for connections (DBProxy/Pool is described in the previous section). If None, + a default DBProxy will be created using the engine's own database module with the given + arguments. + + echo=False : if True, the Engine will log all statements as well as a repr() of their + parameter lists to the engines logger, which defaults to sys.stdout. A Engine instances' + "echo" data member can be modified at any time to turn logging on and off. If set to the string + 'debug', result rows will be printed to the standard output as well. + + logger=None : a file-like object where logging output can be sent, if echo is set to True. + This defaults to sys.stdout. + + encoding='utf-8' : the encoding to be used when encoding/decoding Unicode strings + + convert_unicode=False : True if unicode conversion should be applied to all str types + + module=None : used by Oracle and Postgres, this is a reference to a DBAPI2 module to be used + instead of the engine's default module. For Postgres, the default is psycopg2, or psycopg1 if + 2 cannot be found. For Oracle, its cx_Oracle. For mysql, MySQLdb. + + use_ansi=True : used only by Oracle; when False, the Oracle driver attempts to support a + particular "quirk" of some Oracle databases, that the LEFT OUTER JOIN SQL syntax is not + supported, and the "Oracle join" syntax of using <column1>(+)=<column2> must be used + in order to achieve a LEFT OUTER JOIN. Its advised that the Oracle database be configured to + have full ANSI support instead of using this feature. + + """ + strategy = kwargs.pop('strategy', default_strategy) + strategy = strategies.strategies[strategy] + return strategy.create(*args, **kwargs) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py new file mode 100644 index 000000000..bf7b1c20d --- /dev/null +++ b/lib/sqlalchemy/engine/base.py @@ -0,0 +1,687 @@ +from sqlalchemy import exceptions, sql, schema, util, types +import StringIO, sys, re + +class ConnectionProvider(object): + """defines an interface that returns raw Connection objects (or compatible).""" + def get_connection(self): + """this method should return a Connection or compatible object from a DBAPI which + also contains a close() method. + It is not defined what context this connection belongs to. It may be newly connected, + returned from a pool, part of some other kind of context such as thread-local, + or can be a fixed member of this object.""" + raise NotImplementedError() + def dispose(self): + """releases all resources corresponding to this ConnectionProvider, such + as any underlying connection pools.""" + raise NotImplementedError() + +class Dialect(sql.AbstractDialect): + """Adds behavior to the execution of queries to provide + support for column defaults, differences between paramstyles, quirks between post-execution behavior, + and a general consistentization of the behavior of various DBAPIs. + + The Dialect should also implement the following two attributes: + + positional - True if the paramstyle for this Dialect is positional + + paramstyle - the paramstyle to be used (some DBAPIs support multiple paramstyles) + + supports_autoclose_results - usually True; if False, indicates that rows returned by fetchone() + might not be just plain tuples, and may be "live" proxy objects which still require the cursor + to be open in order to be read (such as pyPgSQL which has active filehandles for BLOBs). in that + case, an auto-closing ResultProxy cannot automatically close itself after results are consumed. + + convert_unicode - True if unicode conversion should be applied to all str types + + encoding - type of encoding to use for unicode, usually defaults to 'utf-8' + """ + def create_connect_args(self, opts): + """given a dictionary of key-valued connect parameters, returns a tuple + consisting of a *args/**kwargs suitable to send directly to the dbapi's connect function. + The connect args will have any number of the following keynames: host, hostname, database, dbanme, + user,username, password, pw, passwd, filename.""" + raise NotImplementedError() + def convert_compiled_params(self, parameters): + """given a sql.ClauseParameters object, returns an array or dictionary suitable to pass + directly to this Dialect's DBAPI's execute method.""" + def type_descriptor(self, typeobj): + """provides a database-specific TypeEngine object, given the generic object + which comes from the types module. Subclasses will usually use the adapt_type() + method in the types module to make this job easy.""" + raise NotImplementedError() + def oid_column_name(self): + """returns the oid column name for this dialect, or None if the dialect cant/wont support OID/ROWID.""" + raise NotImplementedError() + def supports_sane_rowcount(self): + """Provided to indicate when MySQL is being used, which does not have standard behavior + for the "rowcount" function on a statement handle. """ + raise NotImplementedError() + def schemagenerator(self, engine, proxy, **params): + """returns a schema.SchemaVisitor instance that can generate schemas, when it is + invoked to traverse a set of schema objects. + + schemagenerator is called via the create() method on Table, Index, and others. + """ + raise NotImplementedError() + def schemadropper(self, engine, proxy, **params): + """returns a schema.SchemaVisitor instance that can drop schemas, when it is + invoked to traverse a set of schema objects. + + schemagenerator is called via the drop() method on Table, Index, and others. + """ + raise NotImplementedError() + def defaultrunner(self, engine, proxy, **params): + """returns a schema.SchemaVisitor instances that can execute defaults.""" + raise NotImplementedError() + def compiler(self, statement, parameters): + """returns a sql.ClauseVisitor which will produce a string representation of the given + ClauseElement and parameter dictionary. This object is usually a subclass of + ansisql.ANSICompiler. + + compiler is called within the context of the compile() method.""" + raise NotImplementedError() + def reflecttable(self, connection, table): + """given an Connection and a Table object, reflects its columns and properties from the database.""" + raise NotImplementedError() + def has_table(self, connection, table_name): + raise NotImplementedError() + def dbapi(self): + """subclasses override this method to provide the DBAPI module used to establish + connections.""" + raise NotImplementedError() + def get_default_schema_name(self, connection): + """returns the currently selected schema given an connection""" + raise NotImplementedError() + def execution_context(self): + """returns a new ExecutionContext object.""" + raise NotImplementedError() + def do_begin(self, connection): + """provides an implementation of connection.begin()""" + raise NotImplementedError() + def do_rollback(self, connection): + """provides an implementation of connection.rollback()""" + raise NotImplementedError() + def do_commit(self, connection): + """provides an implementation of connection.commit()""" + raise NotImplementedError() + def do_executemany(self, cursor, statement, parameters): + raise NotImplementedError() + def do_execute(self, cursor, statement, parameters): + raise NotImplementedError() + +class ExecutionContext(object): + """a messenger object for a Dialect that corresponds to a single execution. The Dialect + should provide an ExecutionContext via the create_execution_context() method. + The pre_exec and post_exec methods will be called for compiled statements, afterwhich + it is expected that the various methods last_inserted_ids, last_inserted_params, etc. + will contain appropriate values, if applicable.""" + def pre_exec(self, engine, proxy, compiled, parameters): + """called before an execution of a compiled statement. proxy is a callable that + takes a string statement and a bind parameter list/dictionary.""" + raise NotImplementedError() + def post_exec(self, engine, proxy, compiled, parameters): + """called after the execution of a compiled statement. proxy is a callable that + takes a string statement and a bind parameter list/dictionary.""" + raise NotImplementedError() + def get_rowcount(self, cursor): + """returns the count of rows updated/deleted for an UPDATE/DELETE statement""" + raise NotImplementedError() + def supports_sane_rowcount(self): + """Provided to indicate when MySQL is being used, which does not have standard behavior + for the "rowcount" function on a statement handle. """ + raise NotImplementedError() + def last_inserted_ids(self): + """returns the list of the primary key values for the last insert statement executed. + This does not apply to straight textual clauses; only to sql.Insert objects compiled against + a schema.Table object, which are executed via statement.execute(). The order of items in the + list is the same as that of the Table's 'primary_key' attribute. + + In some cases, this method may invoke a query back to the database to retrieve the data, based on + the "lastrowid" value in the cursor.""" + raise NotImplementedError() + def last_inserted_params(self): + """returns a dictionary of the full parameter dictionary for the last compiled INSERT statement, + including any ColumnDefaults or Sequences that were pre-executed. this value is thread-local.""" + raise NotImplementedError() + def last_updated_params(self): + """returns a dictionary of the full parameter dictionary for the last compiled UPDATE statement, + including any ColumnDefaults that were pre-executed. this value is thread-local.""" + raise NotImplementedError() + def lastrow_has_defaults(self): + """returns True if the last row INSERTED via a compiled insert statement contained PassiveDefaults, + indicating that the database inserted data beyond that which we gave it. this value is thread-local.""" + raise NotImplementedError() + +class Connectable(object): + """interface for an object that can provide an Engine and a Connection object which correponds to that Engine.""" + def contextual_connect(self): + """returns a Connection object which may be part of an ongoing context.""" + raise NotImplementedError() + def create(self, entity, **kwargs): + """creates a table or index given an appropriate schema object.""" + raise NotImplementedError() + def drop(self, entity, **kwargs): + raise NotImplementedError() + def execute(self, object, *multiparams, **params): + raise NotImplementedError() + def _not_impl(self): + raise NotImplementedError() + engine = property(_not_impl, doc="returns the Engine which this Connectable is associated with.") + +class Connection(Connectable): + """represents a single DBAPI connection returned from the underlying connection pool. Provides + execution support for string-based SQL statements as well as ClauseElement, Compiled and DefaultGenerator objects. + provides a begin method to return Transaction objects.""" + def __init__(self, engine, connection=None, close_with_result=False): + self.__engine = engine + self.__connection = connection or engine.raw_connection() + self.__transaction = None + self.__close_with_result = close_with_result + engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated (read only)") + connection = property(lambda s:s.__connection, doc="The underlying DBAPI connection managed by this Connection.") + should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") + def _create_transaction(self, parent): + return Transaction(self, parent) + def connect(self): + """connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly.""" + return self + def contextual_connect(self, **kwargs): + """contextual_connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly.""" + return self + def begin(self): + if self.__transaction is None: + self.__transaction = self._create_transaction(None) + return self.__transaction + else: + return self._create_transaction(self.__transaction) + def _begin_impl(self): + if self.__engine.echo: + self.__engine.log("BEGIN") + self.__engine.dialect.do_begin(self.__connection) + def _rollback_impl(self): + if self.__engine.echo: + self.__engine.log("ROLLBACK") + self.__engine.dialect.do_rollback(self.__connection) + def _commit_impl(self): + if self.__engine.echo: + self.__engine.log("COMMIT") + self.__engine.dialect.do_commit(self.__connection) + def _autocommit(self, statement): + """when no Transaction is present, this is called after executions to provide "autocommit" behavior.""" + # TODO: have the dialect determine if autocommit can be set on the connection directly without this + # extra step + if self.__transaction is None and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP', statement.lstrip().upper()): + self._commit_impl() + def close(self): + if self.__connection is not None: + self.__connection.close() + self.__connection = None + def scalar(self, object, parameters, **kwargs): + row = self.execute(object, parameters, **kwargs).fetchone() + if row is not None: + return row[0] + else: + return None + def execute(self, object, *multiparams, **params): + return Connection.executors[type(object).__mro__[-2]](self, object, *multiparams, **params) + def execute_default(self, default, **kwargs): + return default.accept_schema_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs)) + def execute_text(self, statement, parameters=None): + cursor = self._execute_raw(statement, parameters) + return ResultProxy(self.__engine, self, cursor) + def _params_to_listofdicts(self, *multiparams, **params): + if len(multiparams) == 0: + return [params] + elif len(multiparams) == 1: + if multiparams[0] == None: + return [{}] + elif isinstance (multiparams[0], list) or isinstance (multiparams[0], tuple): + return multiparams[0] + else: + return [multiparams[0]] + else: + return multiparams + def execute_clauseelement(self, elem, *multiparams, **params): + executemany = len(multiparams) > 0 + if executemany: + param = multiparams[0] + else: + param = params + return self.execute_compiled(elem.compile(engine=self.__engine, parameters=param), *multiparams, **params) + def execute_compiled(self, compiled, *multiparams, **params): + """executes a sql.Compiled object.""" + cursor = self.__connection.cursor() + parameters = [compiled.get_params(**m) for m in self._params_to_listofdicts(*multiparams, **params)] + if len(parameters) == 1: + parameters = parameters[0] + def proxy(statement=None, parameters=None): + if statement is None: + return cursor + + parameters = self.__engine.dialect.convert_compiled_params(parameters) + self._execute_raw(statement, parameters, cursor=cursor, context=context) + return cursor + context = self.__engine.dialect.create_execution_context() + context.pre_exec(self.__engine, proxy, compiled, parameters) + proxy(str(compiled), parameters) + context.post_exec(self.__engine, proxy, compiled, parameters) + return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap) + + # poor man's multimethod/generic function thingy + executors = { + sql.ClauseElement : execute_clauseelement, + sql.Compiled : execute_compiled, + schema.SchemaItem:execute_default, + str.__mro__[-2] : execute_text + } + + def create(self, entity, **kwargs): + """creates a table or index given an appropriate schema object.""" + return self.__engine.create(entity, connection=self, **kwargs) + def drop(self, entity, **kwargs): + """drops a table or index given an appropriate schema object.""" + return self.__engine.drop(entity, connection=self, **kwargs) + def reflecttable(self, table, **kwargs): + """reflects the columns in the given table from the database.""" + return self.__engine.reflecttable(table, connection=self, **kwargs) + def default_schema_name(self): + return self.__engine.dialect.get_default_schema_name(self) + def run_callable(self, callable_): + callable_(self) + def _execute_raw(self, statement, parameters=None, cursor=None, echo=None, context=None, **kwargs): + if cursor is None: + cursor = self.__connection.cursor() + try: + if echo is True or self.__engine.echo is not False: + self.__engine.log(statement) + self.__engine.log(repr(parameters)) + if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): + self._executemany(cursor, statement, parameters, context=context) + else: + self._execute(cursor, statement, parameters, context=context) + self._autocommit(statement) + except: + raise + return cursor + + def _execute(self, c, statement, parameters, context=None): + if parameters is None: + if self.__engine.dialect.positional: + parameters = () + else: + parameters = {} + try: + self.__engine.dialect.do_execute(c, statement, parameters, context=context) + except Exception, e: + self._rollback_impl() + if self.__close_with_result: + self.close() + raise exceptions.SQLError(statement, parameters, e) + def _executemany(self, c, statement, parameters, context=None): + try: + self.__engine.dialect.do_executemany(c, statement, parameters, context=context) + except Exception, e: + self._rollback_impl() + if self.__close_with_result: + self.close() + raise exceptions.SQLError(statement, parameters, e) + def proxy(self, statement=None, parameters=None): + """executes the given statement string and parameter object. + the parameter object is expected to be the result of a call to compiled.get_params(). + This callable is a generic version of a connection/cursor-specific callable that + is produced within the execute_compiled method, and is used for objects that require + this style of proxy when outside of an execute_compiled method, primarily the DefaultRunner.""" + parameters = self.__engine.dialect.convert_compiled_params(parameters) + return self._execute_raw(statement, parameters) + +class Transaction(object): + """represents a Transaction in progress""" + def __init__(self, connection, parent): + self.__connection = connection + self.__parent = parent or self + self.__is_active = True + if self.__parent is self: + self.__connection._begin_impl() + connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction") + def rollback(self): + if not self.__parent.__is_active: + raise exceptions.InvalidRequestError("This transaction is inactive") + if self.__parent is self: + self.__connection._rollback_impl() + self.__is_active = False + else: + self.__parent.rollback() + def commit(self): + if not self.__parent.__is_active: + raise exceptions.InvalidRequestError("This transaction is inactive") + if self.__parent is self: + self.__connection._commit_impl() + self.__is_active = False + +class ComposedSQLEngine(sql.Engine, Connectable): + """ + Connects a ConnectionProvider, a Dialect and a CompilerFactory together to + provide a default implementation of SchemaEngine. + """ + def __init__(self, connection_provider, dialect, echo=False, logger=None, **kwargs): + self.connection_provider = connection_provider + self.dialect=dialect + self.echo = echo + self.logger = logger or util.Logger(origin='engine') + + name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name']) + engine = property(lambda s:s) + + def dispose(self): + self.connection_provider.dispose() + def create(self, entity, connection=None, **kwargs): + """creates a table or index within this engine's database connection given a schema.Table object.""" + self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs) + def drop(self, entity, connection=None, **kwargs): + """drops a table or index within this engine's database connection given a schema.Table object.""" + self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs) + def execute_default(self, default, **kwargs): + connection = self.contextual_connect() + try: + return connection.execute_default(default, **kwargs) + finally: + connection.close() + + def _func(self): + return sql.FunctionGenerator(self) + func = property(_func) + def text(self, text, *args, **kwargs): + """returns a sql.text() object for performing literal queries.""" + return sql.text(text, engine=self, *args, **kwargs) + + def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): + if connection is None: + conn = self.contextual_connect() + else: + conn = connection + try: + element.accept_schema_visitor(visitorcallable(self, conn.proxy, **kwargs)) + finally: + if connection is None: + conn.close() + + def transaction(self, callable_, connection=None, *args, **kwargs): + if connection is None: + conn = self.contextual_connect() + else: + conn = connection + try: + trans = conn.begin() + try: + ret = callable_(conn, *args, **kwargs) + trans.commit() + return ret + except: + trans.rollback() + raise + finally: + if connection is None: + conn.close() + + def run_callable(self, callable_, connection=None, *args, **kwargs): + if connection is None: + conn = self.contextual_connect() + else: + conn = connection + try: + return callable_(conn, *args, **kwargs) + finally: + if connection is None: + conn.close() + + def execute(self, statement, *multiparams, **params): + connection = self.contextual_connect(close_with_result=True) + return connection.execute(statement, *multiparams, **params) + + def execute_compiled(self, compiled, *multiparams, **params): + connection = self.contextual_connect(close_with_result=True) + return connection.execute_compiled(compiled, *multiparams, **params) + + def compiler(self, statement, parameters, **kwargs): + return self.dialect.compiler(statement, parameters, engine=self, **kwargs) + + def connect(self, **kwargs): + """returns a newly allocated Connection object.""" + return Connection(self, **kwargs) + + def contextual_connect(self, close_with_result=False, **kwargs): + """returns a Connection object which may be newly allocated, or may be part of some + ongoing context. This Connection is meant to be used by the various "auto-connecting" operations.""" + return Connection(self, close_with_result=close_with_result, **kwargs) + + def reflecttable(self, table, connection=None): + """given a Table object, reflects its columns and properties from the database.""" + if connection is None: + conn = self.contextual_connect() + else: + conn = connection + try: + self.dialect.reflecttable(conn, table) + finally: + if connection is None: + conn.close() + def has_table(self, table_name): + return self.run_callable(lambda c: self.dialect.has_table(c, table_name)) + + def raw_connection(self): + """returns a DBAPI connection.""" + return self.connection_provider.get_connection() + + def log(self, msg): + """logs a message using this SQLEngine's logger stream.""" + self.logger.write(msg) + +class ResultProxy: + """wraps a DBAPI cursor object to provide access to row columns based on integer + position, case-insensitive column name, or by schema.Column object. e.g.: + + row = fetchone() + + col1 = row[0] # access via integer position + + col2 = row['col2'] # access via name + + col3 = row[mytable.c.mycol] # access via Column object. + + ResultProxy also contains a map of TypeEngine objects and will invoke the appropriate + convert_result_value() method before returning columns. + """ + class AmbiguousColumn(object): + def __init__(self, key): + self.key = key + def convert_result_value(self, arg, engine): + raise InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) + + def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None): + """ResultProxy objects are constructed via the execute() method on SQLEngine.""" + self.connection = connection + self.dialect = engine.dialect + self.cursor = cursor + self.engine = engine + self.closed = False + self.executioncontext = executioncontext + self.echo = engine.echo=="debug" + if executioncontext: + self.rowcount = executioncontext.get_rowcount(cursor) + else: + self.rowcount = cursor.rowcount + metadata = cursor.description + self.props = {} + self.keys = [] + i = 0 + if metadata is not None: + for item in metadata: + # sqlite possibly prepending table name to colnames so strip + colname = item[0].split('.')[-1].lower() + if typemap is not None: + rec = (typemap.get(colname, types.NULLTYPE), i) + else: + rec = (types.NULLTYPE, i) + if rec[0] is None: + raise DBAPIError("None for metadata " + colname) + if self.props.setdefault(colname, rec) is not rec: + self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0) + self.keys.append(colname) + self.props[i] = rec + i+=1 + def close(self): + if not self.closed: + self.closed = True + if self.connection.should_close_with_result and self.dialect.supports_autoclose_results: + self.connection.close() + def _get_col(self, row, key): + if isinstance(key, sql.ColumnElement): + try: + rec = self.props[key._label.lower()] + except KeyError: + try: + rec = self.props[key.key.lower()] + except KeyError: + rec = self.props[key.name.lower()] + elif isinstance(key, str): + rec = self.props[key.lower()] + else: + rec = self.props[key] + return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect) + + def __iter__(self): + while True: + row = self.fetchone() + if row is None: + raise StopIteration + else: + yield row + + def last_inserted_ids(self): + return self.executioncontext.last_inserted_ids() + def last_updated_params(self): + return self.executioncontext.last_updated_params() + def last_inserted_params(self): + return self.executioncontext.last_inserted_params() + def lastrow_has_defaults(self): + return self.executioncontext.lastrow_has_defaults() + def supports_sane_rowcount(self): + return self.executioncontext.supports_sane_rowcount() + + def fetchall(self): + """fetches all rows, just like DBAPI cursor.fetchall().""" + l = [] + while True: + v = self.fetchone() + if v is None: + return l + l.append(v) + + def fetchone(self): + """fetches one row, just like DBAPI cursor.fetchone().""" + row = self.cursor.fetchone() + if row is not None: + if self.echo: self.engine.log(repr(row)) + return RowProxy(self, row) + else: + # controversy! can we auto-close the cursor after results are consumed ? + # what if the returned rows are still hanging around, and are "live" objects + # and not just plain tuples ? + self.close() + return None + +class RowProxy: + """proxies a single cursor row for a parent ResultProxy.""" + def __init__(self, parent, row): + """RowProxy objects are constructed by ResultProxy objects.""" + self.__parent = parent + self.__row = row + def close(self): + self.__parent.close() + def __iter__(self): + for i in range(0, len(self.__row)): + yield self.__parent._get_col(self.__row, i) + def __eq__(self, other): + return (other is self) or (other == tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))])) + def __repr__(self): + return repr(tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))])) + def __getitem__(self, key): + return self.__parent._get_col(self.__row, key) + def __getattr__(self, name): + try: + return self.__parent._get_col(self.__row, name) + except KeyError: + raise AttributeError + def items(self): + return [(key, getattr(self, key)) for key in self.keys()] + def keys(self): + return self.__parent.keys + def values(self): + return list(self) + def __len__(self): + return len(self.__row) + +class SchemaIterator(schema.SchemaVisitor): + """a visitor that can gather text into a buffer and execute the contents of the buffer.""" + def __init__(self, engine, proxy, **params): + self.proxy = proxy + self.engine = engine + self.buffer = StringIO.StringIO() + + def append(self, s): + """appends content to the SchemaIterator's query buffer.""" + self.buffer.write(s) + + def execute(self): + """executes the contents of the SchemaIterator's buffer using its sql proxy and + clears out the buffer.""" + try: + return self.proxy(self.buffer.getvalue(), None) + finally: + self.buffer.truncate(0) + +class DefaultRunner(schema.SchemaVisitor): + def __init__(self, engine, proxy): + self.proxy = proxy + self.engine = engine + + def get_column_default(self, column): + if column.default is not None: + return column.default.accept_schema_visitor(self) + else: + return None + + def get_column_onupdate(self, column): + if column.onupdate is not None: + return column.onupdate.accept_schema_visitor(self) + else: + return None + + def visit_passive_default(self, default): + """passive defaults by definition return None on the app side, + and are post-fetched to get the DB-side value""" + return None + + def visit_sequence(self, seq): + """sequences are not supported by default""" + return None + + def exec_default_sql(self, default): + c = sql.select([default.arg], engine=self.engine).compile() + return self.proxy(str(c), c.get_params()).fetchone()[0] + + def visit_column_onupdate(self, onupdate): + if isinstance(onupdate.arg, sql.ClauseElement): + return self.exec_default_sql(onupdate) + elif callable(onupdate.arg): + return onupdate.arg() + else: + return onupdate.arg + + def visit_column_default(self, default): + if isinstance(default.arg, sql.ClauseElement): + return self.exec_default_sql(default) + elif callable(default.arg): + return default.arg() + else: + return default.arg diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py new file mode 100644 index 000000000..40978204a --- /dev/null +++ b/lib/sqlalchemy/engine/default.py @@ -0,0 +1,213 @@ +# engine/default.py +# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +from sqlalchemy import schema, exceptions, util, sql, types +import sqlalchemy.pool +import StringIO, sys, re +import base + +"""provides default implementations of the engine interfaces""" + + +class PoolConnectionProvider(base.ConnectionProvider): + def __init__(self, dialect, url, poolclass=None, pool=None, **kwargs): + (cargs, cparams) = dialect.create_connect_args(url) + if pool is None: + kwargs.setdefault('echo', False) + kwargs.setdefault('use_threadlocal',True) + if poolclass is None: + poolclass = sqlalchemy.pool.QueuePool + dbapi = dialect.dbapi() + if dbapi is None: + raise exceptions.InvalidRequestException("Cant get DBAPI module for dialect '%s'" % dialect) + self._pool = poolclass(lambda: dbapi.connect(*cargs, **cparams), **kwargs) + else: + if isinstance(pool, sqlalchemy.pool.DBProxy): + self._pool = pool.get_pool(*cargs, **cparams) + else: + self._pool = pool + def get_connection(self): + return self._pool.connect() + def dispose(self): + self._pool.dispose() + if hasattr(self, '_dbproxy'): + self._dbproxy.dispose() + +class DefaultDialect(base.Dialect): + """default implementation of Dialect""" + def __init__(self, convert_unicode=False, encoding='utf-8', **kwargs): + self.convert_unicode = convert_unicode + self.supports_autoclose_results = True + self.encoding = encoding + self.positional = False + self.paramstyle = 'named' + self._ischema = None + self._figure_paramstyle() + def create_execution_context(self): + return DefaultExecutionContext(self) + def type_descriptor(self, typeobj): + """provides a database-specific TypeEngine object, given the generic object + which comes from the types module. Subclasses will usually use the adapt_type() + method in the types module to make this job easy.""" + if type(typeobj) is type: + typeobj = typeobj() + return typeobj + def oid_column_name(self): + return None + def supports_sane_rowcount(self): + return True + def do_begin(self, connection): + """implementations might want to put logic here for turning autocommit on/off, + etc.""" + pass + def do_rollback(self, connection): + """implementations might want to put logic here for turning autocommit on/off, + etc.""" + #print "ENGINE ROLLBACK ON ", connection.connection + connection.rollback() + def do_commit(self, connection): + """implementations might want to put logic here for turning autocommit on/off, etc.""" + #print "ENGINE COMMIT ON ", connection.connection + connection.commit() + def do_executemany(self, cursor, statement, parameters, **kwargs): + cursor.executemany(statement, parameters) + def do_execute(self, cursor, statement, parameters, **kwargs): + cursor.execute(statement, parameters) + def defaultrunner(self, engine, proxy): + return base.DefaultRunner(engine, proxy) + + def _set_paramstyle(self, style): + self._paramstyle = style + self._figure_paramstyle(style) + paramstyle = property(lambda s:s._paramstyle, _set_paramstyle) + + def convert_compiled_params(self, parameters): + executemany = parameters is not None and isinstance(parameters, list) + # the bind params are a CompiledParams object. but all the DBAPI's hate + # that object (or similar). so convert it to a clean + # dictionary/list/tuple of dictionary/tuple of list + if parameters is not None: + if self.positional: + if executemany: + parameters = [p.values() for p in parameters] + else: + parameters = parameters.values() + else: + if executemany: + parameters = [p.get_raw_dict() for p in parameters] + else: + parameters = parameters.get_raw_dict() + return parameters + + def _figure_paramstyle(self, paramstyle=None): + db = self.dbapi() + if paramstyle is not None: + self._paramstyle = paramstyle + elif db is not None: + self._paramstyle = db.paramstyle + else: + self._paramstyle = 'named' + + if self._paramstyle == 'named': + self.positional=False + elif self._paramstyle == 'pyformat': + self.positional=False + elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric': + # for positional, use pyformat internally, ANSICompiler will convert + # to appropriate character upon compilation + self.positional = True + else: + raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle) + + def _get_ischema(self): + # We use a property for ischema so that the accessor + # creation only happens as needed, since otherwise we + # have a circularity problem with the generic + # ansisql.engine() + if self._ischema is None: + import sqlalchemy.databases.information_schema as ischema + self._ischema = ischema.ISchema(self) + return self._ischema + ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") + +class DefaultExecutionContext(base.ExecutionContext): + def __init__(self, dialect): + self.dialect = dialect + def pre_exec(self, engine, proxy, compiled, parameters): + self._process_defaults(engine, proxy, compiled, parameters) + def post_exec(self, engine, proxy, compiled, parameters): + pass + def get_rowcount(self, cursor): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return cursor.rowcount + def supports_sane_rowcount(self): + return self.dialect.supports_sane_rowcount() + def last_inserted_ids(self): + return self._last_inserted_ids + def last_inserted_params(self): + return self._last_inserted_params + def last_updated_params(self): + return self._last_updated_params + def lastrow_has_defaults(self): + return self._lastrow_has_defaults + def _process_defaults(self, engine, proxy, compiled, parameters): + """INSERT and UPDATE statements, when compiled, may have additional columns added to their + VALUES and SET lists corresponding to column defaults/onupdates that are present on the + Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those + DefaultGenerator objects that require pre-execution and sets their values within the + parameter list, and flags the thread-local state about + PassiveDefault objects that may require post-fetching the row after it is inserted/updated. + This method relies upon logic within the ANSISQLCompiler in its visit_insert and + visit_update methods that add the appropriate column clauses to the statement when its + being compiled, so that these parameters can be bound to the statement.""" + if compiled is None: return + if getattr(compiled, "isinsert", False): + if isinstance(parameters, list): + plist = parameters + else: + plist = [parameters] + drunner = self.dialect.defaultrunner(engine, proxy) + self._lastrow_has_defaults = False + for param in plist: + last_inserted_ids = [] + need_lastrowid=False + for c in compiled.statement.table.c: + if not param.has_key(c.name) or param[c.name] is None: + if isinstance(c.default, schema.PassiveDefault): + self._lastrow_has_defaults = True + newid = drunner.get_column_default(c) + if newid is not None: + param[c.name] = newid + if c.primary_key: + last_inserted_ids.append(param[c.name]) + elif c.primary_key: + need_lastrowid = True + elif c.primary_key: + last_inserted_ids.append(param[c.name]) + if need_lastrowid: + self._last_inserted_ids = None + else: + self._last_inserted_ids = last_inserted_ids + self._last_inserted_params = param + elif getattr(compiled, 'isupdate', False): + if isinstance(parameters, list): + plist = parameters + else: + plist = [parameters] + drunner = self.dialect.defaultrunner(engine, proxy) + self._lastrow_has_defaults = False + for param in plist: + for c in compiled.statement.table.c: + if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None): + value = drunner.get_column_onupdate(c) + if value is not None: + param[c.name] = value + self._last_updated_params = param + + diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py new file mode 100644 index 000000000..a4f406502 --- /dev/null +++ b/lib/sqlalchemy/engine/strategies.py @@ -0,0 +1,70 @@ +"""defines different strategies for creating new instances of sql.Engine. +by default there are two, one which is the "thread-local" strategy, one which is the "plain" strategy. +new strategies can be added via constructing a new EngineStrategy object which will add itself to the +list of available strategies here, or replace one of the existing name. +this can be accomplished via a mod; see the sqlalchemy/mods package for details.""" + + +from sqlalchemy.engine import base, default, threadlocal, url + +strategies = {} + +class EngineStrategy(object): + """defines a function that receives input arguments and produces an instance of sql.Engine, typically + an instance sqlalchemy.engine.base.ComposedSQLEngine or a subclass.""" + def __init__(self, name): + """constructs a new EngineStrategy object and sets it in the list of available strategies + under this name.""" + self.name = name + strategies[self.name] = self + def create(self, *args, **kwargs): + """given arguments, returns a new sql.Engine instance.""" + raise NotImplementedError() + + +class PlainEngineStrategy(EngineStrategy): + def __init__(self): + EngineStrategy.__init__(self, 'plain') + def create(self, name_or_url, **kwargs): + u = url.make_url(name_or_url) + module = u.get_module() + + dialect = module.dialect(**kwargs) + + poolargs = {} + for key in (('echo', 'echo_pool'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout')): + if kwargs.has_key(key[0]): + poolargs[key[1]] = kwargs[key[0]] + poolclass = getattr(module, 'poolclass', None) + if poolclass is not None: + poolargs.setdefault('poolclass', poolclass) + poolargs['use_threadlocal'] = False + provider = default.PoolConnectionProvider(dialect, u, **poolargs) + + return base.ComposedSQLEngine(provider, dialect, **kwargs) +PlainEngineStrategy() + +class ThreadLocalEngineStrategy(EngineStrategy): + def __init__(self): + EngineStrategy.__init__(self, 'threadlocal') + def create(self, name_or_url, **kwargs): + u = url.make_url(name_or_url) + module = u.get_module() + + dialect = module.dialect(**kwargs) + + poolargs = {} + for key in (('echo', 'echo_pool'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout')): + if kwargs.has_key(key[0]): + poolargs[key[1]] = kwargs[key[0]] + poolclass = getattr(module, 'poolclass', None) + if poolclass is not None: + poolargs.setdefault('poolclass', poolclass) + poolargs['use_threadlocal'] = True + provider = threadlocal.TLocalConnectionProvider(dialect, u, **poolargs) + + return threadlocal.TLEngine(provider, dialect, **kwargs) +ThreadLocalEngineStrategy() + + + diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py new file mode 100644 index 000000000..85628c208 --- /dev/null +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -0,0 +1,84 @@ +from sqlalchemy import schema, exceptions, util, sql, types +import StringIO, sys, re +import base, default + +"""provides a thread-local transactional wrapper around the basic ComposedSQLEngine. multiple calls to engine.connect() +will return the same connection for the same thread. also provides begin/commit methods on the engine itself +which correspond to a thread-local transaction.""" + +class TLTransaction(base.Transaction): + def rollback(self): + try: + base.Transaction.rollback(self) + finally: + try: + del self.connection.engine.context.transaction + except AttributeError: + pass + def commit(self): + try: + base.Transaction.commit(self) + stack = self.connection.engine.context.transaction + stack.pop() + if len(stack) == 0: + del self.connection.engine.context.transaction + except: + try: + del self.connection.engine.context.transaction + except AttributeError: + pass + raise + +class TLConnection(base.Connection): + def _create_transaction(self, parent): + return TLTransaction(self, parent) + def begin(self): + t = base.Connection.begin(self) + if not hasattr(self.engine.context, 'transaction'): + self.engine.context.transaction = [] + self.engine.context.transaction.append(t) + return t + +class TLEngine(base.ComposedSQLEngine): + """a ComposedSQLEngine that includes support for thread-local managed transactions. This engine + is better suited to be used with threadlocal Pool object.""" + def __init__(self, *args, **kwargs): + """the TLEngine relies upon the ConnectionProvider having "threadlocal" behavior, + so that once a connection is checked out for the current thread, you get that same connection + repeatedly.""" + base.ComposedSQLEngine.__init__(self, *args, **kwargs) + self.context = util.ThreadLocal() + def raw_connection(self): + """returns a DBAPI connection.""" + return self.connection_provider.get_connection() + def connect(self, **kwargs): + """returns a Connection that is not thread-locally scoped. this is the equilvalent to calling + "connect()" on a ComposedSQLEngine.""" + return base.Connection(self, self.connection_provider.unique_connection()) + def contextual_connect(self, **kwargs): + """returns a TLConnection which is thread-locally scoped.""" + return TLConnection(self, **kwargs) + def begin(self): + return self.connect().begin() + def commit(self): + if hasattr(self.context, 'transaction'): + self.context.transaction[-1].commit() + def rollback(self): + if hasattr(self.context, 'transaction'): + self.context.transaction[-1].rollback() + def transaction(self, func, *args, **kwargs): + """executes the given function within a transaction boundary. this is a shortcut for + explicitly calling begin() and commit() and optionally rollback() when execptions are raised. + The given *args and **kwargs will be passed to the function as well, which could be handy + in constructing decorators.""" + trans = self.begin() + try: + func(*args, **kwargs) + except: + trans.rollback() + raise + trans.commit() + +class TLocalConnectionProvider(default.PoolConnectionProvider): + def unique_connection(self): + return self._pool.unique_connection() diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py new file mode 100644 index 000000000..d79213c68 --- /dev/null +++ b/lib/sqlalchemy/engine/url.py @@ -0,0 +1,81 @@ +import re +import cgi + +class URL(object): + def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None): + self.drivername = drivername + self.username = username + self.password = password + self.host = host + self.port = port + self.database= database + def __str__(self): + s = self.drivername + "://" + if self.username is not None: + s += self.username + if self.password is not None: + s += ':' + self.password + s += "@" + if self.host is not None: + s += self.host + if self.port is not None: + s += ':' + self.port + if self.database is not None: + s += '/' + self.database + return s + def get_module(self): + return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) + def translate_connect_args(self, names): + """translates this URL's attributes into a dictionary of connection arguments used by a specific dbapi. + the names parameter is a list of argument names in the form ('host', 'database', 'user', 'password', 'port') + where the given strings match the corresponding argument names for the dbapi. Will return a dictionary + with the dbapi-specific parameters.""" + a = {} + attribute_names = ['host', 'database', 'username', 'password', 'port'] + for n in names: + sname = attribute_names.pop(0) + if n is None: + continue + if getattr(self, sname, None) is not None: + a[n] = getattr(self, sname) + return a + + +def make_url(name_or_url): + if isinstance(name_or_url, str): + return _parse_rfc1738_args(name_or_url) + else: + return name_or_url + +def _parse_rfc1738_args(name): + pattern = re.compile(r''' + (\w+):// + (?: + ([^:]*) + (?::(.*))? + @)? + (?: + ([^/:]*) + (?::([^/]*))? + )? + (?:/(.*))? + ''' + , re.X) + + m = pattern.match(name) + if m is not None: + (name, username, password, host, port, database) = m.group(1, 2, 3, 4, 5, 6) + opts = {'username':username,'password':password,'host':host,'port':port,'database':database} + return URL(name, **opts) + else: + return None + +def _parse_keyvalue_args(name): + m = re.match( r'(\w+)://(.*)', name) + if m is not None: + (name, args) = m.group(1, 2) + opts = dict( cgi.parse_qsl( args ) ) + return URL(name, *opts) + else: + return None + diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exceptions.py index e270225d8..c942ab4c2 100644 --- a/lib/sqlalchemy/exceptions.py +++ b/lib/sqlalchemy/exceptions.py @@ -25,8 +25,8 @@ class ArgumentError(SQLAlchemyError): objects. This error generally corresponds to construction time state errors.""" pass -class CommitError(SQLAlchemyError): - """raised when an invalid condition is detected upon a commit()""" +class FlushError(SQLAlchemyError): + """raised when an invalid condition is detected upon a flush()""" pass class InvalidRequestError(SQLAlchemyError): diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py index f875b30b3..74f4df349 100644 --- a/lib/sqlalchemy/ext/activemapper.py +++ b/lib/sqlalchemy/ext/activemapper.py @@ -1,8 +1,30 @@ -from sqlalchemy import assign_mapper, relation, exceptions +from sqlalchemy import create_session, relation, mapper, join, DynamicMetaData, class_mapper +from sqlalchemy import and_, or_ from sqlalchemy import Table, Column, ForeignKey +from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy.ext.assignmapper import assign_mapper +from sqlalchemy import backref as create_backref import inspect import sys +import sets + +# +# the "proxy" to the database engine... this can be swapped out at runtime +# +metadata = DynamicMetaData("activemapper") + +# +# thread local SessionContext +# +class Objectstore(SessionContext): + def __getattr__(self, key): + return getattr(self.current, key) + def get_session(self): + return self.current + +objectstore = Objectstore(create_session) + # # declarative column declaration - this is so that we can infer the colname @@ -40,7 +62,7 @@ class one_to_many(relationship): class one_to_one(relationship): def __init__(self, classname, colname=None, backref=None, private=False, lazy=True): - relationship.__init__(self, classname, colname, backref, private, lazy, uselist=False) + relationship.__init__(self, classname, colname, create_backref(backref, uselist=False), private, lazy, uselist=False) class many_to_many(relationship): def __init__(self, classname, secondary, backref=None, lazy=True): @@ -56,43 +78,15 @@ class many_to_many(relationship): # __deferred_classes__ = [] -__processed_classes__ = [] - -def check_relationships(klass): - #Check the class for foreign_keys recursively. If some foreign table is not found, the processing of the table - #must be defered. - for keyname in klass.table._foreign_keys: - xtable = keyname._colspec[:keyname._colspec.find('.')] - tablefound = False - for xclass in ActiveMapperMeta.classes: - if ActiveMapperMeta.classes[xclass].table.from_name == xtable: - tablefound = True - break - if tablefound==False: - #The refered table has not yet been created. - return False - - return True - - -def process_relationships(klass): +def process_relationships(klass, was_deferred=False): defer = False for propname, reldesc in klass.relations.items(): - # we require that every related table has been processed first - if not reldesc.classname in __processed_classes__: - if not klass._classname in __deferred_classes__: __deferred_classes__.append(klass._classname) - defer = True - - # check every column item to see if it points to an existing table - # if it does not, defer... - if not defer: - if not check_relationships(klass): - if not klass._classname in __deferred_classes__: __deferred_classes__.append(klass._classname) + if not reldesc.classname in ActiveMapperMeta.classes: + if not was_deferred: __deferred_classes__.append(klass) defer = True if not defer: relations = {} - for propname, reldesc in klass.relations.items(): relclass = ActiveMapperMeta.classes[reldesc.classname] relations[propname] = relation(relclass.mapper, @@ -101,40 +95,39 @@ def process_relationships(klass): private=reldesc.private, lazy=reldesc.lazy, uselist=reldesc.uselist) - if len(relations) > 0: - assign_ok = True - try: - assign_mapper(klass, klass.table, properties=relations) - except exceptions.ArgumentError: - assign_ok = False - - if assign_ok: - __processed_classes__.append(klass._classname) - if klass._classname in __deferred_classes__: __deferred_classes__.remove(klass._classname) - else: - __processed_classes__.append(klass._classname) - + class_mapper(klass).add_properties(relations) + #assign_mapper(objectstore, klass, klass.table, properties=relations, + # inherits=getattr(klass, "_base_mapper", None)) + if was_deferred: __deferred_classes__.remove(klass) + + if not was_deferred: for deferred_class in __deferred_classes__: - process_relationships(ActiveMapperMeta.classes[deferred_class]) + process_relationships(deferred_class, was_deferred=True) + class ActiveMapperMeta(type): classes = {} - + metadatas = sets.Set() def __init__(cls, clsname, bases, dict): table_name = clsname.lower() columns = [] relations = {} - + _metadata = getattr( sys.modules[cls.__module__], "__metadata__", metadata ) + if 'mapping' in dict: members = inspect.getmembers(dict.get('mapping')) for name, value in members: if name == '__table__': table_name = value continue - + + if '__metadata__' == name: + _metadata= value + continue + if name.startswith('__'): continue - + if isinstance(value, column): if value.foreign_key: col = Column(value.colname or name, @@ -149,29 +142,29 @@ class ActiveMapperMeta(type): *value.args, **value.kwargs) columns.append(col) continue - + if isinstance(value, relationship): relations[name] = value - - cls.table = Table(table_name, redefine=True, *columns) - + assert _metadata is not None, "No MetaData specified" + ActiveMapperMeta.metadatas.add(_metadata) + cls.table = Table(table_name, _metadata, *columns) # check for inheritence - if hasattr(bases[0], "mapping"): - cls._base_mapper = bases[0].mapper - assign_mapper(cls, cls.table, inherits=cls._base_mapper) - elif len(relations) == 0: - assign_mapper(cls, cls.table) + if hasattr( bases[0], "mapping" ): + cls._base_mapper= bases[0].mapper + assign_mapper(objectstore, cls, cls.table, inherits=cls._base_mapper) + else: + assign_mapper(objectstore, cls, cls.table) cls.relations = relations - cls._classname = clsname ActiveMapperMeta.classes[clsname] = cls + process_relationships(cls) - + super(ActiveMapperMeta, cls).__init__(clsname, bases, dict) class ActiveMapper(object): __metaclass__ = ActiveMapperMeta - + def set(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @@ -182,12 +175,9 @@ class ActiveMapper(object): # def create_tables(): - for klass in ActiveMapperMeta.classes.values(): - klass.table.create() - -# -# a utility function to drop all tables for all ActiveMapper classes -# + for metadata in ActiveMapperMeta.metadatas: + metadata.create_all() def drop_tables(): - for klass in ActiveMapperMeta.classes.values(): - klass.table.drop()
\ No newline at end of file + for metadata in ActiveMapperMeta.metadatas: + metadata.drop_all() + diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py new file mode 100644 index 000000000..b8a676b75 --- /dev/null +++ b/lib/sqlalchemy/ext/assignmapper.py @@ -0,0 +1,34 @@ +from sqlalchemy import mapper, util +import types + +def monkeypatch_query_method(ctx, class_, name): + def do(self, *args, **kwargs): + query = class_.mapper.query(session=ctx.current) + return getattr(query, name)(*args, **kwargs) + setattr(class_, name, classmethod(do)) + +def monkeypatch_objectstore_method(ctx, class_, name): + def do(self, *args, **kwargs): + session = ctx.current + return getattr(session, name)(self, *args, **kwargs) + setattr(class_, name, do) + +def assign_mapper(ctx, class_, *args, **kwargs): + kwargs.setdefault("is_primary", True) + if not isinstance(getattr(class_, '__init__'), types.MethodType): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + class_.__init__ = __init__ + extension = kwargs.pop('extension', None) + if extension is not None: + extension = util.to_list(extension) + extension.append(ctx.mapper_extension) + else: + extension = ctx.mapper_extension + m = mapper(class_, extension=extension, *args, **kwargs) + class_.mapper = m + for name in ['get', 'select', 'select_by', 'selectone', 'get_by', 'join_to', 'join_via']: + monkeypatch_query_method(ctx, class_, name) + for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'update', 'save_or_update']: + monkeypatch_objectstore_method(ctx, class_, name) diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py index a24f089e9..deced55b4 100644 --- a/lib/sqlalchemy/ext/proxy.py +++ b/lib/sqlalchemy/ext/proxy.py @@ -5,14 +5,11 @@ except ImportError: from sqlalchemy import sql from sqlalchemy.engine import create_engine -from sqlalchemy.types import TypeEngine -import sqlalchemy.schema as schema -import thread, weakref -class BaseProxyEngine(schema.SchemaEngine): - ''' - Basis for all proxy engines - ''' +__all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine'] + +class BaseProxyEngine(sql.Engine): + """Basis for all proxy engines.""" def get_engine(self): raise NotImplementedError @@ -21,66 +18,50 @@ class BaseProxyEngine(schema.SchemaEngine): raise NotImplementedError engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e)) - - def reflecttable(self, table): - return self.get_engine().reflecttable(table) + def execute_compiled(self, *args, **kwargs): - return self.get_engine().execute_compiled(*args, **kwargs) - def compiler(self, *args, **kwargs): - return self.get_engine().compiler(*args, **kwargs) - def schemagenerator(self, *args, **kwargs): - return self.get_engine().schemagenerator(*args, **kwargs) - def schemadropper(self, *args, **kwargs): - return self.get_engine().schemadropper(*args, **kwargs) - - def hash_key(self): - return "%s(%s)" % (self.__class__.__name__, id(self)) + """this method is required to be present as it overrides the execute_compiled present in sql.Engine""" + return self.get_engine().execute_compiled(*args, **kwargs) + def compiler(self, *args, **kwargs): + """this method is required to be present as it overrides the compiler method present in sql.Engine""" + return self.get_engine().compiler(*args, **kwargs) - def oid_column_name(self): - # oid_column should not be requested before the engine is connected. - # it should ideally only be called at query compilation time. - e= self.get_engine() - if e is None: - return None - return e.oid_column_name() - def __getattr__(self, attr): + """provides proxying for methods that are not otherwise present on this BaseProxyEngine. Note + that methods which are present on the base class sql.Engine will *not* be proxied through this, + and must be explicit on this class.""" # call get_engine() to give subclasses a chance to change # connection establishment behavior - e= self.get_engine() + e = self.get_engine() if e is not None: return getattr(e, attr) - raise AttributeError('No connection established in ProxyEngine: ' - ' no access to %s' % attr) + raise AttributeError("No connection established in ProxyEngine: " + " no access to %s" % attr) + class AutoConnectEngine(BaseProxyEngine): - ''' - An SQLEngine proxy that automatically connects when necessary. - ''' + """An SQLEngine proxy that automatically connects when necessary.""" - def __init__(self, dburi, opts=None, **kwargs): + def __init__(self, dburi, **kwargs): BaseProxyEngine.__init__(self) - self.dburi= dburi - self.opts= opts - self.kwargs= kwargs - self._engine= None + self.dburi = dburi + self.kwargs = kwargs + self._engine = None def get_engine(self): if self._engine is None: if callable(self.dburi): - dburi= self.dburi() + dburi = self.dburi() else: - dburi= self.dburi - self._engine= create_engine( dburi, self.opts, **self.kwargs ) + dburi = self.dburi + self._engine = create_engine(dburi, **self.kwargs) return self._engine - class ProxyEngine(BaseProxyEngine): - """ - SQLEngine proxy. Supports lazy and late initialization by - delegating to a real engine (set with connect()), and using proxy - classes for TypeEngine. + """Engine proxy for lazy and late initialization. + + This engine will delegate access to a real engine set with connect(). """ def __init__(self, **kwargs): @@ -90,14 +71,15 @@ class ProxyEngine(BaseProxyEngine): self.storage.connection = {} self.storage.engine = None self.kwargs = kwargs - - def connect(self, uri, opts=None, **kwargs): - """Establish connection to a real engine. - """ - kw = self.kwargs.copy() - kw.update(kwargs) - kwargs = kw - key = "%s(%s,%s)" % (uri, repr(opts), repr(kwargs)) + + def connect(self, *args, **kwargs): + """Establish connection to a real engine.""" + + kwargs.update(self.kwargs) + if not kwargs: + key = repr(args) + else: + key = "%s, %s" % (repr(args), repr(sorted(kwargs.items()))) try: map = self.storage.connection except AttributeError: @@ -107,15 +89,13 @@ class ProxyEngine(BaseProxyEngine): try: self.engine = map[key] except KeyError: - map[key] = create_engine(uri, opts, **kwargs) + map[key] = create_engine(*args, **kwargs) self.storage.engine = map[key] def get_engine(self): if self.storage.engine is None: - raise AttributeError('No connection established') + raise AttributeError("No connection established") return self.storage.engine def set_engine(self, engine): self.storage.engine = engine - - diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py new file mode 100644 index 000000000..5ba9153dd --- /dev/null +++ b/lib/sqlalchemy/ext/selectresults.py @@ -0,0 +1,82 @@ +import sqlalchemy.sql as sql + +import sqlalchemy.orm as orm + + +class SelectResultsExt(orm.MapperExtension): + def select_by(self, query, *args, **params): + return SelectResults(query, query.join_by(*args, **params)) + def select(self, query, arg=None, **kwargs): + if arg is not None and isinstance(arg, sql.Selectable): + return orm.EXT_PASS + else: + return SelectResults(query, arg, ops=kwargs) + +class SelectResults(object): + def __init__(self, query, clause=None, ops={}): + self._query = query + self._clause = clause + self._ops = {} + self._ops.update(ops) + + def count(self): + return self._query.count(self._clause) + + def min(self, col): + return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar() + + def max(self, col): + return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar() + + def sum(self, col): + return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar() + + def avg(self, col): + return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar() + + def clone(self): + return SelectResults(self._query, self._clause, self._ops.copy()) + + def filter(self, clause): + new = self.clone() + new._clause = sql.and_(self._clause, clause) + return new + + def order_by(self, order_by): + new = self.clone() + new._ops['order_by'] = order_by + return new + + def limit(self, limit): + return self[:limit] + + def offset(self, offset): + return self[offset:] + + def list(self): + return list(self) + + def __getitem__(self, item): + if isinstance(item, slice): + start = item.start + stop = item.stop + if (isinstance(start, int) and start < 0) or \ + (isinstance(stop, int) and stop < 0): + return list(self)[item] + else: + res = self.clone() + if start is not None and stop is not None: + res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start)) + elif start is None and stop is not None: + res._ops.update(dict(limit=stop)) + elif start is not None and stop is None: + res._ops.update(dict(offset=self._ops.get('offset', 0)+start)) + if item.step is not None: + return list(res)[None:None:item.step] + else: + return res + else: + return list(self[item:item+1])[0] + + def __iter__(self): + return iter(self._query.select_whereclause(self._clause, **self._ops)) diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py new file mode 100644 index 000000000..f431f87c7 --- /dev/null +++ b/lib/sqlalchemy/ext/sessioncontext.py @@ -0,0 +1,55 @@ +from sqlalchemy.util import ScopedRegistry +from sqlalchemy.orm.mapper import MapperExtension + +__all__ = ['SessionContext', 'SessionContextExt'] + +class SessionContext(object): + """A simple wrapper for ScopedRegistry that provides a "current" property + which can be used to get, set, or remove the session in the current scope. + + By default this object provides thread-local scoping, which is the default + scope provided by sqlalchemy.util.ScopedRegistry. + + Usage: + engine = create_engine(...) + def session_factory(): + return Session(bind_to=engine) + context = SessionContext(session_factory) + + s = context.current # get thread-local session + context.current = Session(bind_to=other_engine) # set current session + del context.current # discard the thread-local session (a new one will + # be created on the next call to context.current) + """ + def __init__(self, session_factory, scopefunc=None): + self.registry = ScopedRegistry(session_factory, scopefunc) + super(SessionContext, self).__init__() + + def get_current(self): + return self.registry() + def set_current(self, session): + self.registry.set(session) + def del_current(self): + self.registry.clear() + current = property(get_current, set_current, del_current, + """Property used to get/set/del the session in the current scope""") + + def _get_mapper_extension(self): + try: + return self._extension + except AttributeError: + self._extension = ext = SessionContextExt(self) + return ext + mapper_extension = property(_get_mapper_extension, + doc="""get a mapper extension that implements get_session using this context""") + + +class SessionContextExt(MapperExtension): + """a mapper extionsion that provides sessions to a mapper using SessionContext""" + + def __init__(self, context): + MapperExtension.__init__(self) + self.context = context + + def get_session(self): + return self.context.current diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index b1fb0b889..043abc38b 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -1,182 +1,72 @@ -from sqlalchemy import * - -""" -SqlSoup provides a convenient way to access database tables without having -to declare table or mapper classes ahead of time. - -Suppose we have a database with users, books, and loans tables -(corresponding to the PyWebOff dataset, if you're curious). -For testing purposes, we can create this db as follows: - ->>> from sqlalchemy import create_engine ->>> e = create_engine('sqlite://filename=:memory:') ->>> for sql in _testsql: e.execute(sql) -... - -Creating a SqlSoup gateway is just like creating an SqlAlchemy engine: ->>> from sqlalchemy.ext.sqlsoup import SqlSoup ->>> soup = SqlSoup('sqlite://filename=:memory:') - -or, you can re-use an existing engine: ->>> soup = SqlSoup(e) - -Loading objects is as easy as this: ->>> users = soup.users.select() ->>> users.sort() ->>> users -[Class_Users(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1), Class_Users(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)] - -Of course, letting the database do the sort is better (".c" is short for ".columns"): ->>> soup.users.select(order_by=[soup.users.c.name]) -[Class_Users(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1), - Class_Users(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)] - -Field access is intuitive: ->>> users[0].email -u'basepair@example.edu' - -Of course, you don't want to load all users very often. The common case is to -select by a key or other field: ->>> soup.users.selectone_by(name='Bhargan Basepair') -Class_Users(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1) - -All the SqlAlchemy mapper select variants (select, select_by, selectone, selectone_by, selectfirst, selectfirst_by) -are available. See the SqlAlchemy documentation for details: -http://www.sqlalchemy.org/docs/sqlconstruction.myt - -Modifying objects is intuitive: ->>> user = _ ->>> user.email = 'basepair+nospam@example.edu' ->>> soup.commit() - -(SqlSoup leverages the sophisticated SqlAlchemy unit-of-work code, so -multiple updates to a single object will be turned into a single UPDATE -statement when you commit.) - -Finally, insert and delete. Let's insert a new loan, then delete it: ->>> soup.loans.insert(book_id=soup.books.selectfirst().id, user_name=user.name) -Class_Loans(book_id=1,user_name='Bhargan Basepair',loan_date=None) ->>> soup.commit() - ->>> loan = soup.loans.selectone_by(book_id=1, user_name='Bhargan Basepair') ->>> soup.delete(loan) ->>> soup.commit() -""" - -_testsql = """ -CREATE TABLE books ( - id integer PRIMARY KEY, -- auto-SERIAL in sqlite - title text NOT NULL, - published_year char(4) NOT NULL, - authors text NOT NULL -); - -CREATE TABLE users ( - name varchar(32) PRIMARY KEY, - email varchar(128) NOT NULL, - password varchar(128) NOT NULL, - classname text, - admin int NOT NULL -- 0 = false -); - -CREATE TABLE loans ( - book_id int PRIMARY KEY REFERENCES books(id), - user_name varchar(32) references users(name) - ON DELETE SET NULL ON UPDATE CASCADE, - loan_date date NOT NULL DEFAULT current_timestamp -); - -insert into users(name, email, password, admin) -values('Bhargan Basepair', 'basepair@example.edu', 'basepair', 1); -insert into users(name, email, password, admin) -values('Joe Student', 'student@example.edu', 'student', 0); - -insert into books(title, published_year, authors) -values('Mustards I Have Known', '1989', 'Jones'); -insert into books(title, published_year, authors) -values('Regional Variation in Moss', '1971', 'Flim and Flam'); - -insert into loans(book_id, user_name) -values ( - (select min(id) from books), - (select name from users where name like 'Joe%')) -; -""".split(';') - -__all__ = ['NoSuchTableError', 'SqlSoup'] - -class NoSuchTableError(SQLAlchemyError): pass - -# metaclass is necessary to expose class methods with getattr, e.g. -# we want to pass db.users.select through to users._mapper.select -class TableClassType(type): - def insert(cls, **kwargs): - o = cls() - o.__dict__.update(kwargs) - return o - def __getattr__(cls, attr): - if attr == '_mapper': - # called during mapper init - raise AttributeError() - return getattr(cls._mapper, attr) - -def class_for_table(table): - klass = TableClassType('Class_' + table.name.capitalize(), (object,), {}) - def __repr__(self): - import locale - encoding = locale.getdefaultlocale()[1] - L = [] - for k in self.__class__.c.keys(): - value = getattr(self, k, '') - if isinstance(value, unicode): - value = value.encode(encoding) - L.append("%s=%r" % (k, value)) - return '%s(%s)' % (self.__class__.__name__, ','.join(L)) - klass.__repr__ = __repr__ - klass._mapper = mapper(klass, table) - return klass - -class SqlSoup: - def __init__(self, *args, **kwargs): - """ - args may either be an SQLEngine or a set of arguments suitable - for passing to create_engine - """ - from sqlalchemy.engine import SQLEngine - # meh, sometimes having method overloading instead of kwargs would be easier - if isinstance(args[0], SQLEngine): - args = list(args) - engine = args.pop(0) - if args or kwargs: - raise ArgumentError('Extra arguments not allowed when engine is given') - else: - engine = create_engine(*args, **kwargs) - self._engine = engine - self._cache = {} - def delete(self, *args, **kwargs): - objectstore.delete(*args, **kwargs) - def commit(self): - objectstore.get_session().commit() - def rollback(self): - objectstore.clear() - def _reset(self): - # for debugging - self._cache = {} - self.rollback() - def __getattr__(self, attr): - try: - t = self._cache[attr] - except KeyError: - table = Table(attr, self._engine, autoload=True) - if table.columns: - t = class_for_table(table) - else: - t = None - self._cache[attr] = t - if not t: - raise NoSuchTableError('%s does not exist' % attr) - return t - -if __name__ == '__main__': - import doctest - doctest.testmod() +from sqlalchemy import *
+
+class NoSuchTableError(SQLAlchemyError): pass
+
+# metaclass is necessary to expose class methods with getattr, e.g.
+# we want to pass db.users.select through to users._mapper.select
+class TableClassType(type):
+ def insert(cls, **kwargs):
+ o = cls()
+ o.__dict__.update(kwargs)
+ return o
+ def __getattr__(cls, attr):
+ if attr == '_mapper':
+ # called during mapper init
+ raise AttributeError()
+ return getattr(cls._mapper, attr)
+
+def class_for_table(table):
+ klass = TableClassType('Class_' + table.name.capitalize(), (object,), {})
+ def __repr__(self):
+ import locale
+ encoding = locale.getdefaultlocale()[1]
+ L = []
+ for k in self.__class__.c.keys():
+ value = getattr(self, k, '')
+ if isinstance(value, unicode):
+ value = value.encode(encoding)
+ L.append("%s=%r" % (k, value))
+ return '%s(%s)' % (self.__class__.__name__, ','.join(L))
+ klass.__repr__ = __repr__
+ klass._mapper = mapper(klass, table)
+ return klass
+
+class SqlSoup:
+ def __init__(self, *args, **kwargs):
+ """
+ args may either be an SQLEngine or a set of arguments suitable
+ for passing to create_engine
+ """
+ from sqlalchemy.sql import Engine
+ # meh, sometimes having method overloading instead of kwargs would be easier
+ if isinstance(args[0], Engine):
+ engine = args.pop(0)
+ if args or kwargs:
+ raise ArgumentError('Extra arguments not allowed when engine is given')
+ else:
+ engine = create_engine(*args, **kwargs)
+ self._engine = engine
+ self._cache = {}
+ def delete(self, *args, **kwargs):
+ objectstore.delete(*args, **kwargs)
+ def commit(self):
+ objectstore.get_session().commit()
+ def rollback(self):
+ objectstore.clear()
+ def _reset(self):
+ # for debugging
+ self._cache = {}
+ self.rollback()
+ def __getattr__(self, attr):
+ try:
+ t = self._cache[attr]
+ except KeyError:
+ table = Table(attr, self._engine, autoload=True)
+ if table.columns:
+ t = class_for_table(table)
+ else:
+ t = None
+ self._cache[attr] = t
+ if not t:
+ raise NoSuchTableError('%s does not exist' % attr)
+ return t
diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py deleted file mode 100644 index 91f947085..000000000 --- a/lib/sqlalchemy/mapping/objectstore.py +++ /dev/null @@ -1,358 +0,0 @@ -# objectstore.py -# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""provides the Session object and a function-oriented convenience interface. This is the -"front-end" to the Unit of Work system in unitofwork.py. Issues of "scope" are dealt with here, -primarily through an important function "get_session()", which is where mappers and units of work go to get a handle on the current threa-local context. """ - -from sqlalchemy import util -from sqlalchemy.exceptions import * -import unitofwork -import weakref -import sqlalchemy - -class Session(object): - """Maintains a UnitOfWork instance, including transaction state.""" - - def __init__(self, hash_key=None, new_imap=True, import_session=None): - """Initialize the objectstore with a UnitOfWork registry. If called - with no arguments, creates a single UnitOfWork for all operations. - - nest_transactions - indicates begin/commit statements can be executed in a - "nested", defaults to False which indicates "only commit on the outermost begin/commit" - hash_key - the hash_key used to identify objects against this session, which - defaults to the id of the Session instance. - """ - if import_session is not None: - self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map) - elif new_imap is False: - self.uow = unitofwork.UnitOfWork(identity_map=objectstore.get_session().uow.identity_map) - else: - self.uow = unitofwork.UnitOfWork() - - self.binds = {} - if hash_key is None: - self.hash_key = id(self) - else: - self.hash_key = hash_key - _sessions[self.hash_key] = self - - def bind_table(self, table, bindto): - self.binds[table] = bindto - - def get_id_key(ident, class_, entity_name=None): - """returns an identity-map key for use in storing/retrieving an item from the identity - map, given a tuple of the object's primary key values. - - ident - a tuple of primary key values corresponding to the object to be stored. these - values should be in the same order as the primary keys of the table - - class_ - a reference to the object's class - - entity_name - optional string name to further qualify the class - """ - return (class_, tuple(ident), entity_name) - get_id_key = staticmethod(get_id_key) - - def get_row_key(row, class_, primary_key, entity_name=None): - """returns an identity-map key for use in storing/retrieving an item from the identity - map, given a result set row. - - row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set - column names to their values within a row. - - class_ - a reference to the object's class - - primary_key - a list of column objects that will target the primary key values - in the given row. - - entity_name - optional string name to further qualify the class - """ - return (class_, tuple([row[column] for column in primary_key]), entity_name) - get_row_key = staticmethod(get_row_key) - - def engines(self, mapper): - return [t.engine for t in mapper.tables] - - def flush(self, *obj): - self.uow.flush(self, *obj) - - def refresh(self, *obj): - """reloads the attributes for the given objects from the database, clears - any changes made.""" - for o in obj: - self.uow.refresh(o) - - def expire(self, *obj): - """invalidates the data in the given objects and sets them to refresh themselves - the next time they are requested.""" - for o in obj: - self.uow.expire(o) - - def expunge(self, *obj): - for o in obj: - self.uow.expunge(o) - - def register_clean(self, obj): - self._bind_to(obj) - self.uow.register_clean(obj) - - def register_new(self, obj): - self._bind_to(obj) - self.uow.register_new(obj) - - def _bind_to(self, obj): - """given an object, binds it to this session. changes on the object will affect - the currently scoped UnitOfWork maintained by this session.""" - obj._sa_session_id = self.hash_key - - def __getattr__(self, key): - """proxy other methods to our underlying UnitOfWork""" - return getattr(self.uow, key) - - def clear(self): - self.uow = unitofwork.UnitOfWork() - - def delete(self, *obj): - """registers the given objects as to be deleted upon the next commit""" - for o in obj: - self.uow.register_deleted(o) - - def import_instance(self, instance): - """places the given instance in the current thread's unit of work context, - either in the current IdentityMap or marked as "new". Returns either the object - or the current corresponding version in the Identity Map. - - this method should be used for any object instance that is coming from a serialized - storage, from another thread (assuming the regular threaded unit of work model), or any - case where the instance was loaded/created corresponding to a different base unitofwork - than the current one.""" - if instance is None: - return None - key = getattr(instance, '_instance_key', None) - mapper = object_mapper(instance) - u = self.uow - if key is not None: - if u.identity_map.has_key(key): - return u.identity_map[key] - else: - instance._instance_key = key - u.identity_map[key] = instance - self._bind_to(instance) - else: - u.register_new(instance) - return instance - -class LegacySession(Session): - def __init__(self, nest_on=None, hash_key=None, **kwargs): - super(LegacySession, self).__init__(**kwargs) - self.parent_uow = None - self.begin_count = 0 - self.nest_on = util.to_list(nest_on) - self.__pushed_count = 0 - def was_pushed(self): - if self.nest_on is None: - return - self.__pushed_count += 1 - if self.__pushed_count == 1: - for n in self.nest_on: - n.push_session() - def was_popped(self): - if self.nest_on is None or self.__pushed_count == 0: - return - self.__pushed_count -= 1 - if self.__pushed_count == 0: - for n in self.nest_on: - n.pop_session() - class SessionTrans(object): - """returned by Session.begin(), denotes a transactionalized UnitOfWork instance. - call commit() on this to commit the transaction.""" - def __init__(self, parent, uow, isactive): - self.__parent = parent - self.__isactive = isactive - self.__uow = uow - isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.") - parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.") - uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.") - def begin(self): - """calls begin() on the underlying Session object, returning a new no-op SessionTrans object.""" - if self.parent.uow is not self.uow: - raise InvalidRequestError("This SessionTrans is no longer valid") - return self.parent.begin() - def commit(self): - """commits the transaction noted by this SessionTrans object.""" - self.__parent._trans_commit(self) - self.__isactive = False - def rollback(self): - """rolls back the current UnitOfWork transaction, in the case that begin() - has been called. The changes logged since the begin() call are discarded.""" - self.__parent._trans_rollback(self) - self.__isactive = False - def begin(self): - """begins a new UnitOfWork transaction and returns a tranasaction-holding - object. commit() or rollback() should be called on the returned object. - commit() on the Session will do nothing while a transaction is pending, and further - calls to begin() will return no-op transactional objects.""" - if self.parent_uow is not None: - return Session.SessionTrans(self, self.uow, False) - self.parent_uow = self.uow - self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map) - return Session.SessionTrans(self, self.uow, True) - def commit(self, *objects): - """commits the current UnitOfWork transaction. called with - no arguments, this is only used - for "implicit" transactions when there was no begin(). - if individual objects are submitted, then only those objects are committed, and the - begin/commit cycle is not affected.""" - # if an object list is given, commit just those but dont - # change begin/commit status - if len(objects): - self._commit_uow(*objects) - self.uow.flush(self, *objects) - return - if self.parent_uow is None: - self._commit_uow() - def _trans_commit(self, trans): - if trans.uow is self.uow and trans.isactive: - try: - self._commit_uow() - finally: - self.uow = self.parent_uow - self.parent_uow = None - def _trans_rollback(self, trans): - if trans.uow is self.uow: - self.uow = self.parent_uow - self.parent_uow = None - def _commit_uow(self, *obj): - self.was_pushed() - try: - self.uow.flush(self, *obj) - finally: - self.was_popped() - -Session = LegacySession - -def get_id_key(ident, class_, entity_name=None): - return Session.get_id_key(ident, class_, entity_name) - -def get_row_key(row, class_, primary_key, entity_name=None): - return Session.get_row_key(row, class_, primary_key, entity_name) - -def begin(): - """deprecated. use s = Session(new_imap=False).""" - return get_session().begin() - -def commit(*obj): - """deprecated; use flush(*obj)""" - get_session().flush(*obj) - -def flush(*obj): - """flushes the current UnitOfWork transaction. if a transaction was begun - via begin(), flushes only those objects that were created, modified, or deleted - since that begin statement. otherwise flushes all objects that have been - changed. - - if individual objects are submitted, then only those objects are committed, and the - begin/commit cycle is not affected.""" - get_session().flush(*obj) - -def clear(): - """removes all current UnitOfWorks and IdentityMaps for this thread and - establishes a new one. It is probably a good idea to discard all - current mapped object instances, as they are no longer in the Identity Map.""" - get_session().clear() - -def refresh(*obj): - """reloads the state of this object from the database, and cancels any in-memory - changes.""" - get_session().refresh(*obj) - -def expire(*obj): - """invalidates the data in the given objects and sets them to refresh themselves - the next time they are requested.""" - get_session().expire(*obj) - -def expunge(*obj): - get_session().expunge(*obj) - -def delete(*obj): - """registers the given objects as to be deleted upon the next commit""" - s = get_session().delete(*obj) - -def has_key(key): - """returns True if the current thread-local IdentityMap contains the given instance key""" - return get_session().has_key(key) - -def has_instance(instance): - """returns True if the current thread-local IdentityMap contains the given instance""" - return get_session().has_instance(instance) - -def is_dirty(obj): - """returns True if the given object is in the current UnitOfWork's new or dirty list, - or if its a modified list attribute on an object.""" - return get_session().is_dirty(obj) - -def instance_key(instance): - """returns the IdentityMap key for the given instance""" - return get_session().instance_key(instance) - -def import_instance(instance): - return get_session().import_instance(instance) - -def mapper(*args, **params): - return sqlalchemy.mapping.mapper(*args, **params) - -def object_mapper(obj): - return sqlalchemy.mapping.object_mapper(obj) - -def class_mapper(class_): - return sqlalchemy.mapping.class_mapper(class_) - -global_attributes = unitofwork.global_attributes - -session_registry = util.ScopedRegistry(Session) # Default session registry -_sessions = weakref.WeakValueDictionary() # all referenced sessions (including user-created) - -def get_session(obj=None): - # object-specific session ? - if obj is not None: - # does it have a hash key ? - hashkey = getattr(obj, '_sa_session_id', None) - if hashkey is not None: - # ok, return that - try: - return _sessions[hashkey] - except KeyError: - raise InvalidRequestError("Session '%s' referenced by object '%s' no longer exists" % (hashkey, repr(obj))) - - return session_registry() - -unitofwork.get_session = get_session -uow = get_session # deprecated - -def push_session(sess): - old = get_session() - if getattr(sess, '_previous', None) is not None: - raise InvalidRequestError("Given Session is already pushed onto some thread's stack") - sess._previous = old - session_registry.set(sess) - sess.was_pushed() - -def pop_session(): - sess = get_session() - old = sess._previous - sess._previous = None - session_registry.set(old) - sess.was_popped() - return old - -def using_session(sess, func): - push_session(sess) - try: - return func() - finally: - pop_session() - diff --git a/lib/sqlalchemy/mapping/util.py b/lib/sqlalchemy/mapping/util.py deleted file mode 100644 index 4d957241b..000000000 --- a/lib/sqlalchemy/mapping/util.py +++ /dev/null @@ -1,31 +0,0 @@ -# mapper/util.py -# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - - -import sqlalchemy.sql as sql - -class TableFinder(sql.ClauseVisitor): - """given a Clause, locates all the Tables within it into a list.""" - def __init__(self, table, check_columns=False): - self.tables = [] - self.check_columns = check_columns - if table is not None: - table.accept_visitor(self) - def visit_table(self, table): - self.tables.append(table) - def __len__(self): - return len(self.tables) - def __getitem__(self, i): - return self.tables[i] - def __iter__(self): - return iter(self.tables) - def __contains__(self, obj): - return obj in self.tables - def __add__(self, obj): - return self.tables + list(obj) - def visit_column(self, column): - if self.check_columns: - column.table.accept_visitor(self) diff --git a/lib/sqlalchemy/mods/__init__.py b/lib/sqlalchemy/mods/__init__.py index 328df3c56..e69de29bb 100644 --- a/lib/sqlalchemy/mods/__init__.py +++ b/lib/sqlalchemy/mods/__init__.py @@ -1,7 +0,0 @@ -def install_mods(*mods): - for mod in mods: - if isinstance(mod, str): - _mod = getattr(__import__('sqlalchemy.mods.%s' % mod).mods, mod) - _mod.install_plugin() - else: - mod.install_plugin()
\ No newline at end of file diff --git a/lib/sqlalchemy/mods/legacy_session.py b/lib/sqlalchemy/mods/legacy_session.py new file mode 100644 index 000000000..7dbeda924 --- /dev/null +++ b/lib/sqlalchemy/mods/legacy_session.py @@ -0,0 +1,139 @@ + +import sqlalchemy.orm.objectstore as objectstore +import sqlalchemy.orm.unitofwork as unitofwork +import sqlalchemy.util as util +import sqlalchemy + +import sqlalchemy.mods.threadlocal + +class LegacySession(objectstore.Session): + def __init__(self, nest_on=None, hash_key=None, **kwargs): + super(LegacySession, self).__init__(**kwargs) + self.parent_uow = None + self.begin_count = 0 + self.nest_on = util.to_list(nest_on) + self.__pushed_count = 0 + def was_pushed(self): + if self.nest_on is None: + return + self.__pushed_count += 1 + if self.__pushed_count == 1: + for n in self.nest_on: + n.push_session() + def was_popped(self): + if self.nest_on is None or self.__pushed_count == 0: + return + self.__pushed_count -= 1 + if self.__pushed_count == 0: + for n in self.nest_on: + n.pop_session() + class SessionTrans(object): + """returned by Session.begin(), denotes a transactionalized UnitOfWork instance. + call commit() on this to commit the transaction.""" + def __init__(self, parent, uow, isactive): + self.__parent = parent + self.__isactive = isactive + self.__uow = uow + isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.") + parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.") + uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.") + def begin(self): + """calls begin() on the underlying Session object, returning a new no-op SessionTrans object.""" + if self.parent.uow is not self.uow: + raise InvalidRequestError("This SessionTrans is no longer valid") + return self.parent.begin() + def commit(self): + """commits the transaction noted by this SessionTrans object.""" + self.__parent._trans_commit(self) + self.__isactive = False + def rollback(self): + """rolls back the current UnitOfWork transaction, in the case that begin() + has been called. The changes logged since the begin() call are discarded.""" + self.__parent._trans_rollback(self) + self.__isactive = False + def begin(self): + """begins a new UnitOfWork transaction and returns a tranasaction-holding + object. commit() or rollback() should be called on the returned object. + commit() on the Session will do nothing while a transaction is pending, and further + calls to begin() will return no-op transactional objects.""" + if self.parent_uow is not None: + return LegacySession.SessionTrans(self, self.uow, False) + self.parent_uow = self.uow + self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map) + return LegacySession.SessionTrans(self, self.uow, True) + def commit(self, *objects): + """commits the current UnitOfWork transaction. called with + no arguments, this is only used + for "implicit" transactions when there was no begin(). + if individual objects are submitted, then only those objects are committed, and the + begin/commit cycle is not affected.""" + # if an object list is given, commit just those but dont + # change begin/commit status + if len(objects): + self._commit_uow(*objects) + self.uow.flush(self, *objects) + return + if self.parent_uow is None: + self._commit_uow() + def _trans_commit(self, trans): + if trans.uow is self.uow and trans.isactive: + try: + self._commit_uow() + finally: + self.uow = self.parent_uow + self.parent_uow = None + def _trans_rollback(self, trans): + if trans.uow is self.uow: + self.uow = self.parent_uow + self.parent_uow = None + def _commit_uow(self, *obj): + self.was_pushed() + try: + self.uow.flush(self, *obj) + finally: + self.was_popped() + +def begin(): + """deprecated. use s = Session(new_imap=False).""" + return objectstore.get_session().begin() + +def commit(*obj): + """deprecated; use flush(*obj)""" + objectstore.get_session().flush(*obj) + +def uow(): + return objectstore.get_session() + +def push_session(sess): + old = get_session() + if getattr(sess, '_previous', None) is not None: + raise InvalidRequestError("Given Session is already pushed onto some thread's stack") + sess._previous = old + session_registry.set(sess) + sess.was_pushed() + +def pop_session(): + sess = get_session() + old = sess._previous + sess._previous = None + session_registry.set(old) + sess.was_popped() + return old + +def using_session(sess, func): + push_session(sess) + try: + return func() + finally: + pop_session() + +def install_plugin(): + objectstore.Session = LegacySession + objectstore.session_registry = util.ScopedRegistry(objectstore.Session) + objectstore.begin = begin + objectstore.commit = commit + objectstore.uow = uow + objectstore.push_session = push_session + objectstore.pop_session = pop_session + objectstore.using_session = using_session +install_plugin() diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py index bff436ace..51ed6e4a5 100644 --- a/lib/sqlalchemy/mods/selectresults.py +++ b/lib/sqlalchemy/mods/selectresults.py @@ -1,86 +1,7 @@ -import sqlalchemy.sql as sql +from sqlalchemy.ext.selectresults import * +from sqlalchemy.orm.mapper import global_extensions -import sqlalchemy.mapping as mapping def install_plugin(): - mapping.global_extensions.append(SelectResultsExt) - -class SelectResultsExt(mapping.MapperExtension): - def select_by(self, query, *args, **params): - return SelectResults(query, query._by_clause(*args, **params)) - def select(self, query, arg=None, **kwargs): - if arg is not None and isinstance(arg, sql.Selectable): - return mapping.EXT_PASS - else: - return SelectResults(query, arg, ops=kwargs) - -MapperExtension = SelectResultsExt - -class SelectResults(object): - def __init__(self, query, clause=None, ops={}): - self._query = query - self._clause = clause - self._ops = {} - self._ops.update(ops) - - def count(self): - return self._query.count(self._clause) - - def min(self, col): - return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar() - - def max(self, col): - return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar() - - def sum(self, col): - return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar() - - def avg(self, col): - return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar() - - def clone(self): - return SelectResults(self._query, self._clause, self._ops.copy()) - - def filter(self, clause): - new = self.clone() - new._clause = sql.and_(self._clause, clause) - return new - - def order_by(self, order_by): - new = self.clone() - new._ops['order_by'] = order_by - return new - - def limit(self, limit): - return self[:limit] - - def offset(self, offset): - return self[offset:] - - def list(self): - return list(self) - - def __getitem__(self, item): - if isinstance(item, slice): - start = item.start - stop = item.stop - if (isinstance(start, int) and start < 0) or \ - (isinstance(stop, int) and stop < 0): - return list(self)[item] - else: - res = self.clone() - if start is not None and stop is not None: - res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start)) - elif start is None and stop is not None: - res._ops.update(dict(limit=stop)) - elif start is not None and stop is None: - res._ops.update(dict(offset=self._ops.get('offset', 0)+start)) - if item.step is not None: - return list(res)[None:None:item.step] - else: - return res - else: - return list(self[item:item+1])[0] - - def __iter__(self): - return iter(self._query.select_whereclause(self._clause, **self._ops)) + global_extensions.append(SelectResultsExt) +install_plugin() diff --git a/lib/sqlalchemy/mods/threadlocal.py b/lib/sqlalchemy/mods/threadlocal.py new file mode 100644 index 000000000..b67329612 --- /dev/null +++ b/lib/sqlalchemy/mods/threadlocal.py @@ -0,0 +1,46 @@ +from sqlalchemy import util, engine, mapper +from sqlalchemy.ext.sessioncontext import SessionContext +import sqlalchemy.ext.assignmapper as assignmapper +from sqlalchemy.orm.mapper import global_extensions +from sqlalchemy.orm.session import Session +import sqlalchemy +import sys, types + +"""this plugin installs thread-local behavior at the Engine and Session level. + +The default Engine strategy will be "threadlocal", producing TLocalEngine instances for create_engine by default. +With this engine, connect() method will return the same connection on the same thread, if it is already checked out +from the pool. this greatly helps functions that call multiple statements to be able to easily use just one connection +without explicit "close" statements on result handles. + +on the Session side, module-level methods will be installed within the objectstore module, such as flush(), delete(), etc. +which call this method on the thread-local session. + +Note: this mod creates a global, thread-local session context named sqlalchemy.objectstore. All mappers created +while this mod is installed will reference this global context when creating new mapped object instances. +""" + +class Objectstore(SessionContext): + def __getattr__(self, key): + return getattr(self.current, key) + def get_session(self): + return self.current + +def assign_mapper(class_, *args, **kwargs): + assignmapper.assign_mapper(objectstore, class_, *args, **kwargs) + +def _mapper_extension(): + return SessionContext._get_mapper_extension(objectstore) + +objectstore = Objectstore(Session) +def install_plugin(): + sqlalchemy.objectstore = objectstore + global_extensions.append(_mapper_extension) + engine.default_strategy = 'threadlocal' + sqlalchemy.assign_mapper = assign_mapper + +def uninstall_plugin(): + engine.default_strategy = 'plain' + global_extensions.remove(_mapper_extension) + +install_plugin() diff --git a/lib/sqlalchemy/mapping/__init__.py b/lib/sqlalchemy/orm/__init__.py index d21b02aa5..662736f22 100644 --- a/lib/sqlalchemy/mapping/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -8,50 +8,44 @@ the mapper package provides object-relational functionality, building upon the schema and sql packages and tying operations to class properties and constructors. """ -import sqlalchemy.sql as sql -import sqlalchemy.schema as schema -import sqlalchemy.engine as engine -import sqlalchemy.util as util -import objectstore -from exceptions import * -import types as types +from sqlalchemy import sql, schema, engine, util, exceptions from mapper import * -from properties import * -import mapper as mapperlib +from mapper import mapper_registry +from query import Query +from util import polymorphic_union +import properties +from session import Session as create_session __all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', - 'mapper', 'clear_mappers', 'objectstore', 'sql', 'extension', 'class_mapper', 'object_mapper', 'MapperExtension', - 'assign_mapper', 'cascade_mappers' + 'mapper', 'clear_mappers', 'sql', 'extension', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query', + 'cascade_mappers', 'polymorphic_union', 'create_session', ] def relation(*args, **kwargs): """provides a relationship of a primary Mapper to a secondary Mapper, which corresponds to a parent-child or associative table relationship.""" if len(args) > 1 and isinstance(args[0], type): - raise ArgumentError("relation(class, table, **kwargs) is deprecated. Please use relation(class, **kwargs) or relation(mapper, **kwargs).") + raise exceptions.ArgumentError("relation(class, table, **kwargs) is deprecated. Please use relation(class, **kwargs) or relation(mapper, **kwargs).") return _relation_loader(*args, **kwargs) def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs): if lazy: - return LazyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs) + return properties.LazyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs) elif lazy is None: - return PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs) + return properties.PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs) else: - return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs) + return properties.EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs) def backref(name, **kwargs): - return BackRef(name, **kwargs) + return properties.BackRef(name, **kwargs) def deferred(*columns, **kwargs): """returns a DeferredColumnProperty, which indicates this object attributes should only be loaded from its corresponding table column when first accessed.""" - return DeferredColumnProperty(*columns, **kwargs) + return properties.DeferredColumnProperty(*columns, **kwargs) def mapper(class_, table=None, *args, **params): - """returns a new or already cached Mapper object.""" - if table is None: - return class_mapper(class_) - + """returns a newMapper object.""" return Mapper(class_, table, *args, **params) def clear_mappers(): @@ -72,58 +66,29 @@ def extension(ext): def eagerload(name, **kwargs): """returns a MapperOption that will convert the property of the given name into an eager load. Used with mapper.options()""" - return EagerLazyOption(name, toeager=True, **kwargs) + return properties.EagerLazyOption(name, toeager=True, **kwargs) def lazyload(name, **kwargs): """returns a MapperOption that will convert the property of the given name into a lazy load. Used with mapper.options()""" - return EagerLazyOption(name, toeager=False, **kwargs) + return properties.EagerLazyOption(name, toeager=False, **kwargs) def noload(name, **kwargs): """returns a MapperOption that will convert the property of the given name into a non-load. Used with mapper.options()""" - return EagerLazyOption(name, toeager=None, **kwargs) + return properties.EagerLazyOption(name, toeager=None, **kwargs) def defer(name, **kwargs): """returns a MapperOption that will convert the column property of the given name into a deferred load. Used with mapper.options()""" - return DeferredOption(name, defer=True) + return properties.DeferredOption(name, defer=True) def undefer(name, **kwargs): """returns a MapperOption that will convert the column property of the given name into a non-deferred (regular column) load. Used with mapper.options.""" - return DeferredOption(name, defer=False) + return properties.DeferredOption(name, defer=False) -def assign_mapper(class_, *args, **params): - params.setdefault("is_primary", True) - if not isinstance(getattr(class_, '__init__'), types.MethodType): - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - class_.__init__ = __init__ - m = mapper(class_, *args, **params) - class_.mapper = m - class_.get = m.get - class_.select = m.select - class_.select_by = m.select_by - class_.selectone = m.selectone - class_.get_by = m.get_by - def commit(self): - objectstore.commit(self) - def delete(self): - objectstore.delete(self) - def expire(self): - objectstore.expire(self) - def refresh(self): - objectstore.refresh(self) - def expunge(self): - objectstore.expunge(self) - class_.commit = commit - class_.delete = delete - class_.expire = expire - class_.refresh = refresh - class_.expunge = expunge def cascade_mappers(*classes_or_mappers): """given a list of classes and/or mappers, identifies the foreign key relationships diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py new file mode 100644 index 000000000..f805207fa --- /dev/null +++ b/lib/sqlalchemy/orm/dependency.py @@ -0,0 +1,369 @@ +# orm/dependency.py +# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +"""bridges the PropertyLoader (i.e. a relation()) and the UOWTransaction +together to allow processing of scalar- and list-based dependencies at flush time.""" + +from sync import ONETOMANY,MANYTOONE,MANYTOMANY +from sqlalchemy import sql + +def create_dependency_processor(key, syncrules, cascade, secondary=None, association=None, is_backref=False, post_update=False): + types = { + ONETOMANY : OneToManyDP, + MANYTOONE: ManyToOneDP, + MANYTOMANY : ManyToManyDP, + } + if association is not None: + return AssociationDP(key, syncrules, cascade, secondary, association, is_backref, post_update) + else: + return types[syncrules.direction](key, syncrules, cascade, secondary, association, is_backref, post_update) + +class DependencyProcessor(object): + def __init__(self, key, syncrules, cascade, secondary=None, association=None, is_backref=False, post_update=False): + # TODO: update instance variable names to be more meaningful + self.syncrules = syncrules + self.cascade = cascade + self.mapper = syncrules.child_mapper + self.parent = syncrules.parent_mapper + self.association = association + self.secondary = secondary + self.direction = syncrules.direction + self.is_backref = is_backref + self.post_update = post_update + self.key = key + + def register_dependencies(self, uowcommit): + """tells a UOWTransaction what mappers are dependent on which, with regards + to the two or three mappers handled by this PropertyLoader. + + Also registers itself as a "processor" for one of its mappers, which + will be executed after that mapper's objects have been saved or before + they've been deleted. The process operation manages attributes and dependent + operations upon the objects of one of the involved mappers.""" + raise NotImplementedError() + + def whose_dependent_on_who(self, obj1, obj2): + """given an object pair assuming obj2 is a child of obj1, returns a tuple + with the dependent object second, or None if they are equal. + used by objectstore's object-level topological sort (i.e. cyclical + table dependency).""" + if obj1 is obj2: + return None + elif self.direction == ONETOMANY: + return (obj1, obj2) + else: + return (obj2, obj1) + + def process_dependencies(self, task, deplist, uowcommit, delete = False): + """this method is called during a flush operation to synchronize data between a parent and child object. + it is called within the context of the various mappers and sometimes individual objects sorted according to their + insert/update/delete order (topological sort).""" + raise NotImplementedError() + + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + """used before the flushes' topological sort to traverse through related objects and insure every + instance which will require save/update/delete is properly added to the UOWTransaction.""" + raise NotImplementedError() + + def _synchronize(self, obj, child, associationrow, clearkeys): + """called during a flush to synchronize primary key identifier values between a parent/child object, as well as + to an associationrow in the case of many-to-many.""" + raise NotImplementedError() + + def get_object_dependencies(self, obj, uowcommit, passive = True): + """returns the list of objects that are dependent on the given object, as according to the relationship + this dependency processor represents""" + return uowcommit.uow.attributes.get_history(obj, self.key, passive = passive) + + +class OneToManyDP(DependencyProcessor): + def register_dependencies(self, uowcommit): + if self.post_update: + stub = MapperStub(self.mapper) + uowcommit.register_dependency(self.mapper, stub) + uowcommit.register_dependency(self.parent, stub) + uowcommit.register_processor(stub, self, self.parent) + else: + uowcommit.register_dependency(self.parent, self.mapper) + uowcommit.register_processor(self.parent, self, self.parent) + def process_dependencies(self, task, deplist, uowcommit, delete = False): + #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) + if delete: + # head object is being deleted, and we manage its list of child objects + # the child objects have to have their foreign key to the parent set to NULL + if not self.cascade.delete_orphan or self.post_update: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=False) + for child in childlist.deleted_items(): + if child is not None and childlist.hasparent(child) is False: + self._synchronize(obj, child, None, True) + if self.post_update: + uowcommit.register_object(child, postupdate=True) + for child in childlist.unchanged_items(): + if child is not None: + self._synchronize(obj, child, None, True) + if self.post_update: + uowcommit.register_object(child, postupdate=True) + else: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=True) + if childlist is not None: + for child in childlist.added_items(): + self._synchronize(obj, child, None, False) + if child is not None and self.post_update: + uowcommit.register_object(child, postupdate=True) + for child in childlist.deleted_items(): + if not self.cascade.delete_orphan: + self._synchronize(obj, child, None, True) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) + + if delete: + # head object is being deleted, and we manage its list of child objects + # the child objects have to have their foreign key to the parent set to NULL + if self.post_update: + # TODO: post_update instructions should be established in this step as well + # (and executed in the regular traversal) + pass + elif self.cascade.delete_orphan: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=False) + for child in childlist.deleted_items(): + if child is not None and childlist.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True) + for c in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c, isdelete=True) + for child in childlist.unchanged_items(): + if child is not None: + uowcommit.register_object(child, isdelete=True) + for c in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c, isdelete=True) + else: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=False) + for child in childlist.deleted_items(): + if child is not None and childlist.hasparent(child) is False: + uowcommit.register_object(child) + for child in childlist.unchanged_items(): + if child is not None: + uowcommit.register_object(child) + else: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=True) + if childlist is not None: + for child in childlist.added_items(): + if child is not None: + uowcommit.register_object(child) + for child in childlist.deleted_items(): + if not self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=False) + elif childlist.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True) + for c in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c, isdelete=True) + + def _synchronize(self, obj, child, associationrow, clearkeys): + source = obj + dest = child + if dest is None: + return + self.syncrules.execute(source, dest, obj, child, clearkeys) + +class ManyToOneDP(DependencyProcessor): + def register_dependencies(self, uowcommit): + if self.post_update: + stub = MapperStub(self.mapper) + uowcommit.register_dependency(self.mapper, stub) + uowcommit.register_dependency(self.parent, stub) + uowcommit.register_processor(stub, self, self.parent) + else: + uowcommit.register_dependency(self.mapper, self.parent) + uowcommit.register_processor(self.mapper, self, self.parent) + def process_dependencies(self, task, deplist, uowcommit, delete = False): + #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) + if delete: + if self.post_update and not self.cascade.delete_orphan: + # post_update means we have to update our row to not reference the child object + # before we can DELETE the row + for obj in deplist: + self._synchronize(obj, None, None, True) + uowcommit.register_object(obj, postupdate=True) + else: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=True) + if childlist is not None: + for child in childlist.added_items(): + self._synchronize(obj, child, None, False) + if self.post_update: + uowcommit.register_object(obj, postupdate=True) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) + # TODO: post_update instructions should be established in this step as well + # (and executed in the regular traversal) + if self.post_update: + return + if delete: + if self.cascade.delete: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=False) + for child in childlist.deleted_items() + childlist.unchanged_items(): + if child is not None and childlist.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True) + for c in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c, isdelete=True) + else: + for obj in deplist: + uowcommit.register_object(obj) + if self.cascade.delete_orphan: + childlist = self.get_object_dependencies(obj, uowcommit, passive=False) + for child in childlist.deleted_items(): + if childlist.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True) + for c in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c, isdelete=True) + + def _synchronize(self, obj, child, associationrow, clearkeys): + source = child + dest = obj + if dest is None: + return + self.syncrules.execute(source, dest, obj, child, clearkeys) + +class ManyToManyDP(DependencyProcessor): + def register_dependencies(self, uowcommit): + # many-to-many. create a "Stub" mapper to represent the + # "middle table" in the relationship. This stub mapper doesnt save + # or delete any objects, but just marks a dependency on the two + # related mappers. its dependency processor then populates the + # association table. + + if self.is_backref: + # if we are the "backref" half of a two-way backref + # relationship, let the other mapper handle inserting the rows + return + stub = MapperStub(self.mapper) + uowcommit.register_dependency(self.parent, stub) + uowcommit.register_dependency(self.mapper, stub) + uowcommit.register_processor(stub, self, self.parent) + + def process_dependencies(self, task, deplist, uowcommit, delete = False): + #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) + connection = uowcommit.transaction.connection(self.mapper) + secondary_delete = [] + secondary_insert = [] + if delete: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=False) + for child in childlist.deleted_items() + childlist.unchanged_items(): + associationrow = {} + self._synchronize(obj, child, associationrow, False) + secondary_delete.append(associationrow) + else: + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit) + if childlist is None: continue + for child in childlist.added_items(): + associationrow = {} + self._synchronize(obj, child, associationrow, False) + secondary_insert.append(associationrow) + for child in childlist.deleted_items(): + associationrow = {} + self._synchronize(obj, child, associationrow, False) + secondary_delete.append(associationrow) + if len(secondary_delete): + # TODO: precompile the delete/insert queries and store them as instance variables + # on the PropertyLoader + statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key) for c in self.secondary.c])) + connection.execute(statement, secondary_delete) + if len(secondary_insert): + statement = self.secondary.insert() + connection.execute(statement, secondary_insert) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + pass + def _synchronize(self, obj, child, associationrow, clearkeys): + dest = associationrow + source = None + if dest is None: + return + self.syncrules.execute(source, dest, obj, child, clearkeys) + +class AssociationDP(OneToManyDP): + def register_dependencies(self, uowcommit): + # association object. our mapper should be dependent on both + # the parent mapper and the association object mapper. + # this is where we put the "stub" as a marker, so we get + # association/parent->stub->self, then we process the child + # elments after the 'stub' save, which is before our own + # mapper's save. + stub = MapperStub(self.association) + uowcommit.register_dependency(self.parent, stub) + uowcommit.register_dependency(self.association, stub) + uowcommit.register_dependency(stub, self.mapper) + uowcommit.register_processor(stub, self, self.parent) + def process_dependencies(self, task, deplist, uowcommit, delete = False): + #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) + # manage association objects. + for obj in deplist: + childlist = self.get_object_dependencies(obj, uowcommit, passive=True) + if childlist is None: continue + + #print "DIRECTION", self.direction + d = {} + for child in childlist: + self._synchronize(obj, child, None, False) + key = self.mapper.instance_key(child) + #print "SYNCHRONIZED", child, "INSTANCE KEY", key + d[key] = child + uowcommit.unregister_object(child) + + for child in childlist.added_items(): + uowcommit.register_object(child) + key = self.mapper.instance_key(child) + #print "ADDED, INSTANCE KEY", key + d[key] = child + + for child in childlist.unchanged_items(): + key = self.mapper.instance_key(child) + o = d[key] + o._instance_key= key + + for child in childlist.deleted_items(): + key = self.mapper.instance_key(child) + #print "DELETED, INSTANCE KEY", key + if d.has_key(key): + o = d[key] + o._instance_key = key + uowcommit.unregister_object(child) + else: + #print "DELETE ASSOC OBJ", repr(child) + uowcommit.register_object(child, isdelete=True) + def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): + # TODO: clean up the association step in process_dependencies and move the + # appropriate sections of it to here + pass + + +class MapperStub(object): + """poses as a Mapper representing the association table in a many-to-many + join, when performing a flush(). + + The Task objects in the objectstore module treat it just like + any other Mapper, but in fact it only serves as a "dependency" placeholder + for the many-to-many update task.""" + def __init__(self, mapper): + self.mapper = mapper + def register_dependencies(self, uowcommit): + pass + def save_obj(self, *args, **kwargs): + pass + def delete_obj(self, *args, **kwargs): + pass + def _primary_mapper(self): + return self diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7977cae6a..21daafad8 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1,20 +1,18 @@ -# mapper/mapper.py +# orm/mapper.py # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - -import sqlalchemy.sql as sql -import sqlalchemy.schema as schema -import sqlalchemy.util as util +from sqlalchemy import sql, schema, util, exceptions +from sqlalchemy import sql_util as sqlutil import util as mapperutil import sync -from sqlalchemy.exceptions import * -import query -import objectstore -import sys -import weakref +import query as querylib +import session as sessionlib +import sys, weakref, sets + +__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper'] # a dictionary mapping classes to their primary mappers mapper_registry = weakref.WeakKeyDictionary() @@ -36,11 +34,11 @@ class Mapper(object): relation() function.""" def __init__(self, class_, - table, - primarytable = None, + local_table, properties = None, primary_key = None, is_primary = False, + non_primary = False, inherits = None, inherit_condition = None, extension = None, @@ -49,99 +47,155 @@ class Mapper(object): entity_name = None, always_refresh = False, version_id_col = None, - construct_new = False, - **kwargs): + polymorphic_on=None, + polymorphic_map=None, + polymorphic_identity=None, + concrete=False, + select_table=None): - if primarytable is not None: - sys.stderr.write("'primarytable' argument to mapper is deprecated\n") - - ext = MapperExtension() - + # uber-pendantic style of making mapper chain, as various testbase/ + # threadlocal/assignmapper combinations keep putting dupes etc. in the list + # TODO: do something that isnt 21 lines.... + extlist = util.HashSet() for ext_class in global_extensions: - ext = ext_class().chain(ext) + if isinstance(ext_class, MapperExtension): + extlist.append(ext_class) + else: + extlist.append(ext_class()) if extension is not None: for ext_obj in util.to_list(extension): - ext = ext_obj.chain(ext) + extlist.append(ext_obj) + + self.extension = None + previous = None + for ext in extlist: + if self.extension is None: + self.extension = ext + if previous is not None: + previous.chain(ext) + previous = ext + if self.extension is None: + self.extension = MapperExtension() - self.extension = ext - self.class_ = class_ self.entity_name = entity_name self.class_key = ClassKey(class_, entity_name) self.is_primary = is_primary + self.non_primary = non_primary self.order_by = order_by self._options = {} self.always_refresh = always_refresh self.version_id_col = version_id_col - self.construct_new = construct_new + self._inheriting_mappers = sets.Set() + self.polymorphic_on = polymorphic_on + if polymorphic_map is None: + self.polymorphic_map = {} + else: + self.polymorphic_map = polymorphic_map + self.__surrogate_mapper = None + self._surrogate_parent = None if not issubclass(class_, object): - raise ArgumentError("Class '%s' is not a new-style class" % class_.__name__) - - if isinstance(table, sql.Select): - # some db's, noteably postgres, dont want to select from a select - # without an alias. also if we make our own alias internally, then - # the configured properties on the mapper are not matched against the alias - # we make, theres workarounds but it starts to get really crazy (its crazy enough - # the SQL that gets generated) so just require an alias - raise ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')") - else: - self.table = table + raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) + + # set up various Selectable units: + + # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table) + # local_table - the Selectable that was passed to this Mapper's constructor, if any + # select_table - the Selectable that will be used during queries. if this is specified + # as a constructor keyword argument, it takes precendence over mapped_table, otherwise its mapped_table + # unjoined_table - our Selectable, minus any joins constructed against the inherits table. + # this is either select_table if it was given explicitly, or in the case of a mapper that inherits + # its local_table + # tables - a collection of underlying Table objects pulled from mapped_table + + for table in (local_table, select_table): + if table is not None and isinstance(local_table, sql.SelectBaseMixin): + # some db's, noteably postgres, dont want to select from a select + # without an alias. also if we make our own alias internally, then + # the configured properties on the mapper are not matched against the alias + # we make, theres workarounds but it starts to get really crazy (its crazy enough + # the SQL that gets generated) so just require an alias + raise exceptions.ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')") + + self.local_table = local_table if inherits is not None: + if isinstance(inherits, type): + inherits = class_mapper(inherits) if self.class_.__mro__[1] != inherits.class_: - raise ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, inherits.class_.__name__)) - self.primarytable = inherits.primarytable + raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, inherits.class_.__name__)) # inherit_condition is optional. - if not table is inherits.noninherited_table: - if inherit_condition is None: - # figure out inherit condition from our table to the immediate table - # of the inherited mapper, not its full table which could pull in other - # stuff we dont want (allows test/inheritance.InheritTest4 to pass) - inherit_condition = sql.join(inherits.noninherited_table, table).onclause - self.table = sql.join(inherits.table, table, inherit_condition) - #print "inherit condition", str(self.table.onclause) - - # generate sync rules. similarly to creating the on clause, specify a - # stricter set of tables to create "sync rules" by,based on the immediate - # inherited table, rather than all inherited tables - self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), mapperutil.TableFinder(table)) - # the old rule - #self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table)) + if local_table is None: + self.local_table = local_table = inherits.local_table + if not local_table is inherits.local_table: + if concrete: + self._synchronizer= None + self.mapped_table = self.local_table + else: + if inherit_condition is None: + # figure out inherit condition from our table to the immediate table + # of the inherited mapper, not its full table which could pull in other + # stuff we dont want (allows test/inheritance.InheritTest4 to pass) + inherit_condition = sql.join(inherits.local_table, self.local_table).onclause + self.mapped_table = sql.join(inherits.mapped_table, self.local_table, inherit_condition) + #print "inherit condition", str(self.table.onclause) + + # generate sync rules. similarly to creating the on clause, specify a + # stricter set of tables to create "sync rules" by,based on the immediate + # inherited table, rather than all inherited tables + self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) + self._synchronizer.compile(self.mapped_table.onclause, util.HashSet([inherits.local_table]), sqlutil.TableFinder(self.local_table)) else: self._synchronizer = None + self.mapped_table = self.local_table self.inherits = inherits - self.noninherited_table = table + if polymorphic_identity is not None: + inherits.add_polymorphic_mapping(polymorphic_identity, self) + self.polymorphic_identity = polymorphic_identity + if self.polymorphic_on is None: + self.effective_polymorphic_on = inherits.effective_polymorphic_on + else: + self.effective_polymorphic_on = self.polymorphic_on if self.order_by is False: self.order_by = inherits.order_by else: - self.primarytable = self.table - self.noninherited_table = self.table self._synchronizer = None self.inherits = None + self.mapped_table = self.local_table + if polymorphic_identity is not None: + self.add_polymorphic_mapping(polymorphic_identity, self) + self.polymorphic_identity = polymorphic_identity + self.effective_polymorphic_on = self.polymorphic_on + + if select_table is not None: + self.select_table = select_table + else: + self.select_table = self.mapped_table + self.unjoined_table = self.local_table # locate all tables contained within the "table" passed in, which # may be a join or other construct - self.tables = mapperutil.TableFinder(self.table) + self.tables = sqlutil.TableFinder(self.mapped_table) # determine primary key columns, either passed in, or get them from our set of tables self.pks_by_table = {} if primary_key is not None: for k in primary_key: self.pks_by_table.setdefault(k.table, util.HashSet(ordered=True)).append(k) - if k.table != self.table: + if k.table != self.mapped_table: # associate pk cols from subtables to the "main" table - self.pks_by_table.setdefault(self.table, util.HashSet(ordered=True)).append(k) + self.pks_by_table.setdefault(self.mapped_table, util.HashSet(ordered=True)).append(k) + # TODO: need local_table properly accounted for when custom primary key is sent else: - for t in self.tables + [self.table]: + for t in self.tables + [self.mapped_table]: try: l = self.pks_by_table[t] except KeyError: l = self.pks_by_table.setdefault(t, util.HashSet(ordered=True)) if not len(t.primary_key): - raise ArgumentError("Table " + t.name + " has no primary key columns. Specify primary_key argument to mapper.") + raise exceptions.ArgumentError("Table " + t.name + " has no primary key columns. Specify primary_key argument to mapper.") for k in t.primary_key: l.append(k) @@ -155,40 +209,29 @@ class Mapper(object): # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes - self.columntoproperty = {} + self.columntoproperty = TranslatingDict(self.mapped_table) # load custom properties if properties is not None: for key, prop in properties.iteritems(): - if sql.is_column(prop): - try: - prop = self.table._get_col_by_original(prop) - except KeyError: - raise ArgumentError("Column '%s' is not represented in mapper's table" % prop._label) - self.columns[key] = prop - prop = ColumnProperty(prop) - elif isinstance(prop, list) and sql.is_column(prop[0]): - try: - prop = [self.table._get_col_by_original(p) for p in prop] - except KeyError, e: - raise ArgumentError("Column '%s' is not represented in mapper's table" % e.args[0]) - self.columns[key] = prop[0] - prop = ColumnProperty(*prop) - self.props[key] = prop - if isinstance(prop, ColumnProperty): - for col in prop.columns: - proplist = self.columntoproperty.setdefault(col.original, []) - proplist.append(prop) + self.add_property(key, prop, False) + + if inherits is not None: + inherits._inheriting_mappers.add(self) + for key, prop in inherits.props.iteritems(): + if not self.props.has_key(key): + p = prop.copy() + if p.adapt(self): + self.add_property(key, p, init=False) # load properties from the main table object, # not overriding those set up in the 'properties' argument - for column in self.table.columns: + for column in self.mapped_table.columns: + if self.columntoproperty.has_key(column): + continue if not self.columns.has_key(column.key): self.columns[column.key] = column - if self.columntoproperty.has_key(column.original): - continue - prop = self.props.get(column.key, None) if prop is None: prop = ColumnProperty(column) @@ -198,126 +241,109 @@ class Mapper(object): # column at index 0 determines which result column is used to populate the object # attribute, in the case of mapping against a join with column names repeated # (and particularly in an inheritance relationship) + # TODO: clarify this comment prop.columns.insert(0, column) #prop.columns.append(column) else: if not allow_column_override: - raise ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) + raise exceptions.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) else: continue # its a ColumnProperty - match the ultimate table columns # back to the property - proplist = self.columntoproperty.setdefault(column.original, []) + proplist = self.columntoproperty.setdefault(column, []) proplist.append(prop) - - if not mapper_registry.has_key(self.class_key) or self.is_primary or (inherits is not None and inherits._is_primary_mapper()): - objectstore.global_attributes.reset_class_managed(self.class_) - self._init_class() - if inherits is not None: - for key, prop in inherits.props.iteritems(): - if not self.props.has_key(key): - self.props[key] = prop.copy() - self.props[key].parent = self - # self.props[key].key = None # force re-init + if not non_primary and (not mapper_registry.has_key(self.class_key) or self.is_primary or (inherits is not None and inherits._is_primary_mapper())): + sessionlib.global_attributes.reset_class_managed(self.class_) + self._init_class() + elif not non_primary: + raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined. Use is_primary=True to assign a new primary mapper to the class, or use non_primary=True to create a non primary Mapper" % self.class_) + + for key in self.polymorphic_map.keys(): + if isinstance(self.polymorphic_map[key], type): + self.polymorphic_map[key] = class_mapper(self.polymorphic_map[key]) + l = [(key, prop) for key, prop in self.props.iteritems()] for key, prop in l: if getattr(prop, 'key', None) is None: prop.init(key, self) - # this prints a summary of the object attributes and how they - # will be mapped to table columns - #print "mapper %s, columntoproperty:" % (self.class_.__name__) - #for key, value in self.columntoproperty.iteritems(): - # print key.table.name, key.key, [(v.key, v) for v in value] - - def _get_query(self): - try: - if self._query.mapper is not self: - self._query = query.Query(self) - return self._query - except AttributeError: - self._query = query.Query(self) - return self._query - query = property(_get_query, doc=\ - """returns an instance of sqlalchemy.mapping.query.Query, which implements all the query-constructing - methods such as get(), select(), select_by(), etc. The default Query object uses the global thread-local - Session from the objectstore package. To get a Query object for a specific Session, call the - using(session) method.""") + # select_table specified...set up a surrogate mapper that will be used for selects + # select_table has to encompass all the columns of the mapped_table either directly + # or through proxying relationships + if self.select_table is not self.mapped_table: + props = {} + if properties is not None: + for key, prop in properties.iteritems(): + if sql.is_column(prop): + props[key] = self.select_table.corresponding_column(prop) + elif (isinstance(column, list) and sql.is_column(column[0])): + props[key] = [self.select_table.corresponding_column(c) for c in prop] + self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, polymorphic_map=self.polymorphic_map, polymorphic_on=self.polymorphic_on) + + def add_polymorphic_mapping(self, key, class_or_mapper, entity_name=None): + if isinstance(class_or_mapper, type): + class_or_mapper = class_mapper(class_or_mapper, entity_name=entity_name) + self.polymorphic_map[key] = class_or_mapper - def get(self, *ident, **kwargs): - """calls get() on this mapper's default Query object.""" - return self.query.get(*ident, **kwargs) - - def _get(self, key, ident=None, reload=False): - return self.query._get(key, ident=ident, reload=reload) - - def get_by(self, *args, **params): - """calls get_by() on this mapper's default Query object.""" - return self.query.get_by(*args, **params) - - def select_by(self, *args, **params): - """calls select_by() on this mapper's default Query object.""" - return self.query.select_by(*args, **params) - - def selectfirst_by(self, *args, **params): - """calls selectfirst_by() on this mapper's default Query object.""" - return self.query.selectfirst_by(*args, **params) - - def selectone_by(self, *args, **params): - """calls selectone_by() on this mapper's default Query object.""" - return self.query.selectone_by(*args, **params) - - def count_by(self, *args, **params): - """calls count_by() on this mapper's default Query object.""" - return self.query.count_by(*args, **params) - - def selectfirst(self, *args, **params): - """calls selectfirst() on this mapper's default Query object.""" - return self.query.selectfirst(*args, **params) - - def selectone(self, *args, **params): - """calls selectone() on this mapper's default Query object.""" - return self.query.selectone(*args, **params) - - def select(self, arg=None, **kwargs): - """calls select() on this mapper's default Query object.""" - return self.query.select(arg=arg, **kwargs) - - def select_whereclause(self, whereclause=None, params=None, **kwargs): - """calls select_whereclause() on this mapper's default Query object.""" - return self.query.select_whereclause(whereclause=whereclause, params=params, **kwargs) - - def count(self, whereclause=None, params=None, **kwargs): - """calls count() on this mapper's default Query object.""" - return self.query.count(whereclause=whereclause, params=params, **kwargs) - - def select_statement(self, statement, **params): - """calls select_statement() on this mapper's default Query object.""" - return self.query.select_statement(statement, **params) - - def select_text(self, text, **params): - return self.query.select_text(text, **params) + def add_properties(self, dict_of_properties): + """adds the given dictionary of properties to this mapper, using add_property.""" + for key, value in dict_of_properties.iteritems(): + self.add_property(key, value, True) + + def _create_prop_from_column(self, column, skipmissing=False): + if sql.is_column(column): + try: + column = self.mapped_table.corresponding_column(column) + except KeyError: + if skipmissing: + return + raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table" % prop._label) + return ColumnProperty(column) + elif isinstance(column, list) and sql.is_column(column[0]): + try: + column = [self.mapped_table.corresponding_column(c) for c in column] + except KeyError, e: + # TODO: want to take the columns we have from this + if skipmissing: + return + raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table" % e.args[0]) + return ColumnProperty(*column) + else: + return None - def add_property(self, key, prop): + def add_property(self, key, prop, init=True, skipmissing=False): """adds an additional property to this mapper. this is the same as if it were specified within the 'properties' argument to the constructor. if the named property already exists, this will replace it. Useful for circular relationships, or overriding the parameters of auto-generated properties such as backreferences.""" - if sql.is_column(prop): - self.columns[key] = prop - prop = ColumnProperty(prop) + + if not isinstance(prop, MapperProperty): + prop = self._create_prop_from_column(prop, skipmissing=skipmissing) + if prop is None: + raise exceptions.ArgumentError("'%s' is not an instance of MapperProperty or Column" % repr(prop)) + self.props[key] = prop + if isinstance(prop, ColumnProperty): + self.columns[key] = prop.columns[0] for col in prop.columns: - proplist = self.columntoproperty.setdefault(col.original, []) + proplist = self.columntoproperty.setdefault(col, []) proplist.append(prop) - prop.init(key, self) + + if init: + prop.init(key, self) + + for mapper in self._inheriting_mappers: + p = prop.copy() + if p.adapt(mapper): + mapper.add_property(key, p, init=False) def __str__(self): - return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + self.primarytable.name + return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + self.mapped_table.name def _is_primary_mapper(self): """returns True if this mapper is the primary mapper for its class key (class + entity_name)""" @@ -328,39 +354,59 @@ class Mapper(object): return mapper_registry[self.class_key] def is_assigned(self, instance): - """returns True if this mapper is the primary mapper for the given instance. this is dependent + """returns True if this mapper handles the given instance. this is dependent not only on class assignment but the optional "entity_name" parameter as well.""" return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name + def _assign_entity_name(self, instance): + """assigns this Mapper's entity name to the given instance. subsequent Mapper lookups for this + instance will return the primary mapper corresponding to this Mapper's class and entity name.""" + instance._entity_name = self.entity_name + def _init_class(self): - """sets up our classes' overridden __init__ method, this mappers hash key as its - '_mapper' property, and our columns as its 'c' property. if the class already had a - mapper, the old __init__ method is kept the same.""" + """decorates the __init__ method on the mapped class to include auto-session attachment logic, + and assocites this Mapper with its class via the mapper_registry.""" oldinit = self.class_.__init__ def init(self, *args, **kwargs): self._entity_name = kwargs.pop('_sa_entity_name', None) # this gets the AttributeManager to do some pre-initialization, # in order to save on KeyErrors later on - objectstore.global_attributes.init_attr(self) + sessionlib.global_attributes.init_attr(self) - nohist = kwargs.pop('_mapper_nohistory', False) - session = kwargs.pop('_sa_session', objectstore.get_session()) - if not nohist: - # register new with the correct session, before the object's - # constructor is called, since further assignments within the - # constructor would otherwise bind it to whatever get_session() is. - session.register_new(self) + if kwargs.has_key('_sa_session'): + session = kwargs.pop('_sa_session') + else: + # works for whatever mapper the class is associated with + mapper = mapper_registry.get(ClassKey(self.__class__, self._entity_name)) + if mapper is not None: + session = mapper.extension.get_session() + if session is EXT_PASS: + session = None + else: + session = None + if session is not None: + session._register_new(self) if oldinit is not None: oldinit(self, *args, **kwargs) - # override oldinit, insuring that its not already one of our - # own modified inits + # override oldinit, insuring that its not already a Mapper-decorated init method if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'): init._sa_mapper_init = True self.class_.__init__ = init mapper_registry[self.class_key] = self if self.entity_name is None: self.class_.c = self.c + + def get_session(self): + """returns the contextual session provided by the mapper extension chain + + raises InvalidRequestError if a session cannot be retrieved from the + extension chain + """ + s = self.extension.get_session() + if s is EXT_PASS: + raise exceptions.InvalidRequestError("No contextual Session is established. Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.") + return s def has_eager(self): """returns True if one of the properties attached to this Mapper is eager loading""" @@ -370,14 +416,11 @@ class Mapper(object): self.props[key] = prop prop.init(key, self) - def instances(self, cursor, *mappers, **kwargs): + def instances(self, cursor, session, *mappers, **kwargs): """given a cursor (ResultProxy) from an SQLEngine, returns a list of object instances corresponding to the rows in the cursor.""" limit = kwargs.get('limit', None) offset = kwargs.get('offset', None) - session = kwargs.get('session', None) - if session is None: - session = objectstore.get_session() populate_existing = kwargs.get('populate_existing', False) result = util.HistoryArraySet() @@ -399,28 +442,24 @@ class Mapper(object): # store new stuff in the identity map for value in imap.values(): - session.register_clean(value) + session._register_clean(value) if mappers: result = [result] + otherresults return result - def identity_key(self, *primary_key): - """returns the instance key for the given identity value. this is a global tracking object used by the objectstore, and is usually available off a mapped object as instance._instance_key.""" - return objectstore.get_id_key(tuple(primary_key), self.class_, self.entity_name) - + def identity_key(self, primary_key): + """returns the instance key for the given identity value. this is a global tracking object used by the Session, and is usually available off a mapped object as instance._instance_key.""" + return sessionlib.get_id_key(util.to_list(primary_key), self.class_, self.entity_name) + def instance_key(self, instance): - """returns the instance key for the given instance. this is a global tracking object used by the objectstore, and is usually available off a mapped object as instance._instance_key.""" - return self.identity_key(*self.identity(instance)) + """returns the instance key for the given instance. this is a global tracking object used by the Session, and is usually available off a mapped object as instance._instance_key.""" + return self.identity_key(self.identity(instance)) def identity(self, instance): """returns the identity (list of primary key values) for the given instance. The list of values can be fed directly into the get() method as mapper.get(*key).""" - return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.table]] + return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.mapped_table]] - def compile(self, whereclause = None, **options): - """works like select, except returns the SQL statement object without - compiling or executing it""" - return self.query._compile(whereclause, **options) def copy(self, **kwargs): mapper = Mapper.__new__(Mapper) @@ -428,13 +467,6 @@ class Mapper(object): mapper.__dict__.update(kwargs) mapper.props = self.props.copy() return mapper - - def using(self, session): - """returns a new Query object with the given Session.""" - if objectstore.get_session() is session: - return self.query - else: - return query.Query(self, session=session) def options(self, *options, **kwargs): """uses this mapper as a prototype for a new mapper with different behavior. @@ -450,41 +482,20 @@ class Mapper(object): self._options[optkey] = mapper return mapper - def _get_criterion(self, key, value): - """used by select_by to match a key/value pair against - local properties, column names, or a matching property in this mapper's - list of relations.""" - if self.props.has_key(key): - return self.props[key].columns[0] == value - elif self.table.c.has_key(key): - return self.table.c[key] == value - else: - for prop in self.props.values(): - c = prop.get_criterion(key, value) - if c is not None: - return c - else: - return None - - def __getattr__(self, key): - if (key.startswith('select_by_') or key.startswith('get_by_')): - return getattr(self.query, key) - else: - raise AttributeError(key) def _getpropbycolumn(self, column, raiseerror=True): try: - prop = self.columntoproperty[column.original] + prop = self.columntoproperty[column] except KeyError: try: prop = self.props[column.key] if not raiseerror: return None - raise InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) + raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) except KeyError: if not raiseerror: return None - raise InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) + raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) return prop[0] def _getattrbycolumn(self, obj, column, raiseerror=True): @@ -494,15 +505,19 @@ class Mapper(object): return prop.getattr(obj) def _setattrbycolumn(self, obj, column, value): - self.columntoproperty[column.original][0].setattr(obj, value) - + self.columntoproperty[column][0].setattr(obj, value) + + def primary_mapper(self): + return mapper_registry[self.class_key] + def save_obj(self, objects, uow, postupdate=False): """called by a UnitOfWork object to save objects, which involves either an INSERT or an UPDATE statement for each table used by this mapper, for each element of the list.""" - + #print "SAVE_OBJ MAPPER", self.class_.__name__, objects + connection = uow.transaction.connection(self) for table in self.tables: - #print "SAVE_OBJ table ", table.name + #print "SAVE_OBJ table ", self.class_.__name__, table.name # looping through our set of tables, which are all "real" tables, as opposed # to our main table which might be a select statement or something non-writeable @@ -511,6 +526,7 @@ class Mapper(object): # they are separate execs via execute(), not executemany() if not self._has_pks(table): + #print "NO PKS ?", str(table) # if we dont have a full set of primary keys for this table, we cant really # do any CRUD with it, so skip. this occurs if we are mapping against a query # that joins on other tables so its not really an error condition. @@ -532,9 +548,9 @@ class Mapper(object): # time" isinsert = not postupdate and not hasattr(obj, "_instance_key") if isinsert: - self.extension.before_insert(self, obj) + self.extension.before_insert(self, connection, obj) else: - self.extension.before_update(self, obj) + self.extension.before_update(self, connection, obj) hasdata = False for col in table.columns: if col is self.version_id_col: @@ -558,6 +574,11 @@ class Mapper(object): value = self._getattrbycolumn(obj, col) if value is not None: params[col.key] = value + elif self.effective_polymorphic_on is not None and self.effective_polymorphic_on.shares_lineage(col): + if isinsert: + value = self.polymorphic_identity + if col.default is None or value is not None: + params[col.key] = value else: # column is not a primary key ? if not isinsert: @@ -601,19 +622,20 @@ class Mapper(object): clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label, type=col.type)) statement = table.update(clause) rows = 0 + supports_sane_rowcount = True for rec in update: (obj, params) = rec - c = statement.execute(params) - self._postfetch(table, obj, c, c.last_updated_params()) - self.extension.after_update(self, obj) + c = connection.execute(statement, params) + self._postfetch(connection, table, obj, c, c.last_updated_params()) + self.extension.after_update(self, connection, obj) rows += c.cursor.rowcount if c.supports_sane_rowcount() and rows != len(update): - raise CommitError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) + raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) if len(insert): statement = table.insert() for rec in insert: (obj, params) = rec - c = statement.execute(**params) + c = connection.execute(statement, params) primary_key = c.last_inserted_ids() if primary_key is not None: i = 0 @@ -622,12 +644,12 @@ class Mapper(object): if self._getattrbycolumn(obj, col) is None: self._setattrbycolumn(obj, col, primary_key[i]) i+=1 - self._postfetch(table, obj, c, c.last_inserted_params()) + self._postfetch(connection, table, obj, c, c.last_inserted_params()) if self._synchronizer is not None: self._synchronizer.execute(obj, obj) - self.extension.after_insert(self, obj) + self.extension.after_insert(self, connection, obj) - def _postfetch(self, table, obj, resultproxy, params): + def _postfetch(self, connection, table, obj, resultproxy, params): """after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side which need to be post-fetched, *or* if pre-exec defaults like ColumnDefaults were fired off and should be populated into the instance. this is only for non-primary key columns.""" @@ -635,7 +657,7 @@ class Mapper(object): clause = sql.and_() for p in self.pks_by_table[table]: clause.clauses.append(p == self._getattrbycolumn(obj, p)) - row = table.select(clause).execute().fetchone() + row = connection.execute(table.select(clause), None).fetchone() for c in table.c: if self._getattrbycolumn(obj, c, False) is None: self._setattrbycolumn(obj, c, row[c]) @@ -652,10 +674,13 @@ class Mapper(object): def delete_obj(self, objects, uow): """called by a UnitOfWork object to delete objects, which involves a DELETE statement for each table used by this mapper, for each object in the list.""" + connection = uow.transaction.connection(self) + for table in util.reversed(self.tables): if not self._has_pks(table): continue delete = [] + deleted_objects = [] for obj in objects: params = {} if not hasattr(obj, "_instance_key"): @@ -666,7 +691,8 @@ class Mapper(object): params[col.key] = self._getattrbycolumn(obj, col) if self.version_id_col is not None: params[self.version_id_col.key] = self._getattrbycolumn(obj, self.version_id_col) - self.extension.before_delete(self, obj) + self.extension.before_delete(self, connection, obj) + deleted_objects.append(obj) if len(delete): clause = sql.and_() for col in self.pks_by_table[table]: @@ -674,14 +700,16 @@ class Mapper(object): if self.version_id_col is not None: clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key, type=self.version_id_col.type)) statement = table.delete(clause) - c = statement.execute(*delete) + c = connection.execute(statement, delete) if c.supports_sane_rowcount() and c.rowcount != len(delete): - raise CommitError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete))) + raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete))) + for obj in deleted_objects: + self.extension.after_delete(self, connection, obj) def _has_pks(self, table): try: for k in self.pks_by_table[table]: - if not self.columntoproperty.has_key(k.original): + if not self.columntoproperty.has_key(k): return False else: return True @@ -689,46 +717,58 @@ class Mapper(object): return False def register_dependencies(self, uowcommit, *args, **kwargs): - """called by an instance of objectstore.UOWTransaction to register + """called by an instance of unitofwork.UOWTransaction to register which mappers are dependent on which, as well as DependencyProcessor objects which will process lists of objects in between saves and deletes.""" for prop in self.props.values(): prop.register_dependencies(uowcommit, *args, **kwargs) if self.inherits is not None: uowcommit.register_dependency(self.inherits, self) - - def register_deleted(self, obj, uow): - for prop in self.props.values(): - prop.register_deleted(obj, uow) - - def _identity_key(self, row): - return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table], self.entity_name) + def cascade_iterator(self, type, object, recursive=None): + if recursive is None: + recursive=sets.Set() + if object not in recursive: + recursive.add(object) + yield object + for prop in self.props.values(): + for c in prop.cascade_iterator(type, object, recursive): + yield c + + def _row_identity_key(self, row): + return sessionlib.get_row_key(row, self.class_, self.pks_by_table[self.mapped_table], self.entity_name) + def get_select_mapper(self): + return self.__surrogate_mapper or self + def _instance(self, session, row, imap, result = None, populate_existing = False): """pulls an object instance from the given row and appends it to the given result list. if the instance already exists in the given identity map, its not added. in either case, executes all the property loaders on the instance to also process extra information in the row.""" + + if self.polymorphic_on is not None: + discriminator = row[self.polymorphic_on] + mapper = self.polymorphic_map[discriminator] + if mapper is not self: + row = self.translate_row(mapper, row) + return mapper._instance(session, row, imap, result=result, populate_existing=populate_existing) + # look in main identity map. if its there, we dont do anything to it, # including modifying any of its related items lists, as its already # been exposed to being modified by the application. - if session is None: - session = objectstore.get_session() - populate_existing = populate_existing or self.always_refresh - identitykey = self._identity_key(row) + identitykey = self._row_identity_key(row) if session.has_key(identitykey): instance = session._get(identitykey) - isnew = False if populate_existing or session.is_expired(instance, unexpire=True): if not imap.has_key(identitykey): imap[identitykey] = instance for prop in self.props.values(): prop.execute(session, instance, row, identitykey, imap, True) - if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS: + if self.extension.append_result(self, session, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS: if result is not None: result.append_nohistory(instance) return instance @@ -738,11 +778,11 @@ class Mapper(object): if not exists: # check if primary key cols in the result are None - this indicates # an instance of the object is not present in the row - for col in self.pks_by_table[self.table]: + for col in self.pks_by_table[self.mapped_table]: if row[col] is None: return None # plugin point - instance = self.extension.create_instance(self, row, imap, self.class_) + instance = self.extension.create_instance(self, session, row, imap, self.class_) if instance is EXT_PASS: instance = self._create_instance(session) imap[identitykey] = instance @@ -757,44 +797,104 @@ class Mapper(object): # instances from the row and possibly populate this item. if self.extension.populate_instance(self, session, instance, row, identitykey, imap, isnew) is EXT_PASS: self.populate_instance(session, instance, row, identitykey, imap, isnew) - if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS: + if self.extension.append_result(self, session, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS: if result is not None: result.append_nohistory(instance) return instance def _create_instance(self, session): - if not self.construct_new: - return self.class_(_mapper_nohistory=True, _sa_entity_name=self.entity_name, _sa_session=session) - obj = self.class_.__new__(self.class_) obj._entity_name = self.entity_name - + # this gets the AttributeManager to do some pre-initialization, # in order to save on KeyErrors later on - objectstore.global_attributes.init_attr(obj) - - session._bind_to(obj) + sessionlib.global_attributes.init_attr(obj) return obj def translate_row(self, tomapper, row): """attempts to take a row and translate its values to a row that can - be understood by another mapper. breaks the column references down to their - bare keynames to accomplish this. So far this works for the various polymorphic - examples.""" + be understood by another mapper.""" newrow = util.DictDecorator(row) - for c in self.table.c: - newrow[c.name] = row[c] - for c in tomapper.table.c: - newrow[c] = newrow[c.name] + for c in tomapper.mapped_table.c: + c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True) + newrow[c] = row[c2] return newrow def populate_instance(self, session, instance, row, identitykey, imap, isnew, frommapper=None): if frommapper is not None: row = frommapper.translate_row(self, row) - for prop in self.props.values(): prop.execute(session, instance, row, identitykey, imap, isnew) + + # deprecated query methods. Query is constructed from Session, and the rest + # of these methods are called off of Query now. + def query(self, session=None): + """deprecated. use Query instead.""" + if session is not None: + return querylib.Query(self, session=session) + + try: + if self._query.mapper is not self: + self._query = querylib.Query(self) + return self._query + except AttributeError: + self._query = querylib.Query(self) + return self._query + def using(self, session): + """deprecated. use Query instead.""" + return querylib.Query(self, session=session) + def __getattr__(self, key): + """deprecated. use Query instead.""" + if (key.startswith('select_by_') or key.startswith('get_by_')): + return getattr(self.query(), key) + else: + raise AttributeError(key) + def compile(self, whereclause = None, **options): + """deprecated. use Query instead.""" + return self.query()._compile(whereclause, **options) + def get(self, ident, **kwargs): + """deprecated. use Query instead.""" + return self.query().get(ident, **kwargs) + def _get(self, key, ident=None, reload=False): + """deprecated. use Query instead.""" + return self.query()._get(key, ident=ident, reload=reload) + def get_by(self, *args, **params): + """deprecated. use Query instead.""" + return self.query().get_by(*args, **params) + def select_by(self, *args, **params): + """deprecated. use Query instead.""" + return self.query().select_by(*args, **params) + def selectfirst_by(self, *args, **params): + """deprecated. use Query instead.""" + return self.query().selectfirst_by(*args, **params) + def selectone_by(self, *args, **params): + """deprecated. use Query instead.""" + return self.query().selectone_by(*args, **params) + def count_by(self, *args, **params): + """deprecated. use Query instead.""" + return self.query().count_by(*args, **params) + def selectfirst(self, *args, **params): + """deprecated. use Query instead.""" + return self.query().selectfirst(*args, **params) + def selectone(self, *args, **params): + """deprecated. use Query instead.""" + return self.query().selectone(*args, **params) + def select(self, arg=None, **kwargs): + """deprecated. use Query instead.""" + return self.query().select(arg=arg, **kwargs) + def select_whereclause(self, whereclause=None, params=None, **kwargs): + """deprecated. use Query instead.""" + return self.query().select_whereclause(whereclause=whereclause, params=params, **kwargs) + def count(self, whereclause=None, params=None, **kwargs): + """deprecated. use Query instead.""" + return self.query().count(whereclause=whereclause, params=params, **kwargs) + def select_statement(self, statement, **params): + """deprecated. use Query instead.""" + return self.query().select_statement(statement, **params) + def select_text(self, text, **params): + """deprecated. use Query instead.""" + return self.query().select_text(text, **params) class MapperProperty(object): """an element attached to a Mapper that describes and assists in the loading and saving @@ -803,9 +903,11 @@ class MapperProperty(object): """called when the mapper receives a row. instance is the parent instance corresponding to the row. """ raise NotImplementedError() + def cascade_iterator(self, type, object, recursive=None): + return [] def copy(self): raise NotImplementedError() - def get_criterion(self, key, value): + def get_criterion(self, query, key, value): """Returns a WHERE clause suitable for this MapperProperty corresponding to the given key/value pair, where the key is a column or object property name, and value is a value to be matched. This is only picked up by PropertyLoaders. @@ -826,6 +928,14 @@ class MapperProperty(object): self.key = key self.parent = parent self.do_init(key, parent) + def adapt(self, newparent): + """adapts this MapperProperty to a new parent, assuming the new parent is an inheriting + descendant of the old parent. Should return True if the adaptation was successful, or + False if this MapperProperty cannot be adapted to the new parent (the case for this is, + the parent mapper has a polymorphic select, and this property represents a column that is not + represented in the new mapper's mapped table)""" + self.parent = newparent + return True def do_init(self, key, parent): """template method for subclasses""" pass @@ -860,8 +970,18 @@ class MapperExtension(object): def __init__(self): self.next = None def chain(self, ext): + if ext is self: + raise "nu uh " + repr(self) + " " + repr(ext) self.next = ext - return self + return self + def get_session(self): + """called to retrieve a contextual Session instance with which to + register a new object. Note: this is not called if a session is + provided with the __init__ params (i.e. _sa_session)""" + if self.next is None: + return EXT_PASS + else: + return self.next.get_session() def select_by(self, query, *args, **kwargs): """overrides the select_by method of the Query object""" if self.next is None: @@ -874,7 +994,7 @@ class MapperExtension(object): return EXT_PASS else: return self.next.select(query, *args, **kwargs) - def create_instance(self, mapper, row, imap, class_): + def create_instance(self, mapper, session, row, imap, class_): """called when a new object instance is about to be created from a row. the method can choose to create the instance itself, or it can return None to indicate normal object creation should take place. @@ -891,8 +1011,8 @@ class MapperExtension(object): if self.next is None: return EXT_PASS else: - return self.next.create_instance(mapper, row, imap, class_) - def append_result(self, mapper, row, imap, result, instance, isnew, populate_existing=False): + return self.next.create_instance(mapper, session, row, imap, class_) + def append_result(self, mapper, session, row, imap, result, instance, isnew, populate_existing=False): """called when an object instance is being appended to a result list. If this method returns True, it is assumed that the mapper should do the appending, else @@ -921,7 +1041,7 @@ class MapperExtension(object): if self.next is None: return EXT_PASS else: - return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing) + return self.next.append_result(mapper, session, row, imap, result, instance, isnew, populate_existing) def populate_instance(self, mapper, session, instance, row, identitykey, imap, isnew): """called right before the mapper, after creating an instance from a row, passes the row to its MapperProperty objects which are responsible for populating the object's attributes. @@ -938,29 +1058,58 @@ class MapperExtension(object): return EXT_PASS else: return self.next.populate_instance(mapper, session, instance, row, identitykey, imap, isnew) - def before_insert(self, mapper, instance): + def before_insert(self, mapper, connection, instance): """called before an object instance is INSERTed into its table. this is a good place to set up primary key values and such that arent handled otherwise.""" if self.next is not None: - self.next.before_insert(mapper, instance) - def before_update(self, mapper, instance): + self.next.before_insert(mapper, connection, instance) + def before_update(self, mapper, connection, instance): """called before an object instnace is UPDATED""" if self.next is not None: - self.next.before_update(mapper, instance) - def after_update(self, mapper, instance): + self.next.before_update(mapper, connection, instance) + def after_update(self, mapper, connection, instance): """called after an object instnace is UPDATED""" if self.next is not None: - self.next.after_update(mapper, instance) - def after_insert(self, mapper, instance): + self.next.after_update(mapper, connection, instance) + def after_insert(self, mapper, connection, instance): """called after an object instance has been INSERTed""" if self.next is not None: - self.next.after_insert(mapper, instance) - def before_delete(self, mapper, instance): + self.next.after_insert(mapper, connection, instance) + def before_delete(self, mapper, connection, instance): """called before an object instance is DELETEed""" if self.next is not None: - self.next.before_delete(mapper, instance) + self.next.before_delete(mapper, connection, instance) + def after_delete(self, mapper, connection, instance): + """called after an object instance is DELETEed""" + if self.next is not None: + self.next.after_delete(mapper, connection, instance) +class TranslatingDict(dict): + """a dictionary that stores ColumnElement objects as keys. incoming ColumnElement + keys are translated against those of an underling FromClause for all operations. + This way the columns from any Selectable that is derived from or underlying this + TranslatingDict's selectable can be used as keys.""" + def __init__(self, selectable): + super(TranslatingDict, self).__init__() + self.selectable = selectable + def __translate_col(self, col): + ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False) + if ourcol is None: + return col + else: + return ourcol + def __getitem__(self, col): + return super(TranslatingDict, self).__getitem__(self.__translate_col(col)) + def has_key(self, col): + return super(TranslatingDict, self).has_key(self.__translate_col(col)) + def __setitem__(self, col, value): + return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value) + def __contains__(self, col): + return self.has_key(col) + def setdefault(self, col, value): + return super(TranslatingDict, self).setdefault(self.__translate_col(col), value) + class ClassKey(object): """keys a class and an entity name to a mapper, via the mapper_registry""" def __init__(self, class_, entity_name): @@ -981,17 +1130,20 @@ def hash_key(obj): else: return repr(obj) -def object_mapper(object): +def object_mapper(object, raiseerror=True, entity_name=None): """given an object, returns the primary Mapper associated with the object or the object's class.""" try: - return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', None))] + return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))] except KeyError: - raise InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None))) + if raiseerror: + raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None))) + else: + return None def class_mapper(class_, entity_name=None): """given a ClassKey, returns the primary Mapper associated with the key.""" try: return mapper_registry[ClassKey(class_, entity_name)] except (KeyError, AttributeError): - raise InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) + raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/orm/properties.py index b7ff9fbb2..7b15aa773 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -4,23 +4,19 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""defines a set of MapperProperty objects, including basic column properties as +"""defines a set of mapper.MapperProperty objects, including basic column properties as well as relationships. also defines some MapperOptions that can be used with the properties.""" -from mapper import * -import sqlalchemy.sql as sql -import sqlalchemy.schema as schema -import sqlalchemy.engine as engine -import sqlalchemy.util as util -import sqlalchemy.attributes as attributes +from sqlalchemy import sql, schema, util, attributes, exceptions import sync import mapper -import objectstore -from sqlalchemy.exceptions import * -import random +import session as sessionlib +import dependency +import util as mapperutil +import sets, random -class ColumnProperty(MapperProperty): +class ColumnProperty(mapper.MapperProperty): """describes an object attribute that corresponds to a table column.""" def __init__(self, *columns): """the list of columns describes a single object property. if there @@ -32,21 +28,22 @@ class ColumnProperty(MapperProperty): def setattr(self, object, value): setattr(object, self.key, value) def get_history(self, obj, passive=False): - return objectstore.global_attributes.get_history(obj, self.key, passive=passive) + return sessionlib.global_attributes.get_history(obj, self.key, passive=passive) def copy(self): return ColumnProperty(*self.columns) def setup(self, key, statement, eagertable=None, **options): for c in self.columns: if eagertable is not None: - statement.append_column(eagertable._get_col_by_original(c)) + statement.append_column(eagertable.corresponding_column(c)) else: statement.append_column(c) def do_init(self, key, parent): self.key = key + self.parent = parent # establish a SmartProperty property manager on the object for this key if parent._is_primary_mapper(): #print "regiser col on class %s key %s" % (parent.class_.__name__, key) - objectstore.uow().register_attribute(parent.class_, key, uselist = False) + sessionlib.global_attributes.register_attribute(parent.class_, key, uselist = False) def execute(self, session, instance, row, identitykey, imap, isnew): if isnew: #print "POPULATING OBJ", instance.__class__.__name__, "COL", self.columns[0]._label, "WITH DATA", row[self.columns[0]], "ROW IS A", row.__class__.__name__, "COL ID", id(self.columns[0]) @@ -68,9 +65,13 @@ class DeferredColumnProperty(ColumnProperty): # establish a SmartProperty property manager on the object for this key, # containing a callable to load in the attribute if self.is_primary(): - objectstore.uow().register_attribute(parent.class_, key, uselist=False, callable_=lambda i:self.setup_loader(i)) + sessionlib.global_attributes.register_attribute(parent.class_, key, uselist=False, callable_=lambda i:self.setup_loader(i)) def setup_loader(self, instance): + if not self.parent.is_assigned(instance): + return mapper.object_mapper(instance).props[self.key].setup_loader(instance) def lazyload(): + session = sessionlib.object_session(instance) + connection = session.connection(self.parent) clause = sql.and_() try: pk = self.parent.pks_by_table[self.columns[0].table] @@ -82,37 +83,40 @@ class DeferredColumnProperty(ColumnProperty): return None clause.clauses.append(primary_key == attr) - if self.group is not None: - groupcols = [p for p in self.parent.props.values() if isinstance(p, DeferredColumnProperty) and p.group==self.group] - row = sql.select([g.columns[0] for g in groupcols], clause, use_labels=True).execute().fetchone() - for prop in groupcols: - if prop is self: - continue - instance.__dict__[prop.key] = row[prop.columns[0]] - objectstore.global_attributes.create_history(instance, prop.key, uselist=False) - return row[self.columns[0]] - else: - return sql.select([self.columns[0]], clause, use_labels=True).scalar() + try: + if self.group is not None: + groupcols = [p for p in self.parent.props.values() if isinstance(p, DeferredColumnProperty) and p.group==self.group] + row = connection.execute(sql.select([g.columns[0] for g in groupcols], clause, use_labels=True), None).fetchone() + for prop in groupcols: + if prop is self: + continue + instance.__dict__[prop.key] = row[prop.columns[0]] + sessionlib.global_attributes.create_history(instance, prop.key, uselist=False) + return row[self.columns[0]] + else: + return connection.scalar(sql.select([self.columns[0]], clause, use_labels=True),None) + finally: + connection.close() return lazyload def setup(self, key, statement, **options): pass def execute(self, session, instance, row, identitykey, imap, isnew): if isnew: if not self.is_primary(): - objectstore.global_attributes.create_history(instance, self.key, False, callable_=self.setup_loader(instance)) + sessionlib.global_attributes.create_history(instance, self.key, False, callable_=self.setup_loader(instance)) else: - objectstore.global_attributes.reset_history(instance, self.key) + sessionlib.global_attributes.reset_history(instance, self.key) mapper.ColumnProperty = ColumnProperty -class PropertyLoader(MapperProperty): +class PropertyLoader(mapper.MapperProperty): ONETOMANY = 0 MANYTOONE = 1 MANYTOMANY = 2 """describes an object property that holds a single item or list of items that correspond to a related database table.""" - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, association=None, use_alias=None, selectalias=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None): self.uselist = uselist self.argument = argument self.secondary = secondary @@ -123,20 +127,23 @@ class PropertyLoader(MapperProperty): # would like to have foreignkey be a list. # however, have to figure out how to do - # <column> in <list>, since column overrides the == operator or somethign + # <column> in <list>, since column overrides the == operator # and it doesnt work self.foreignkey = foreignkey #util.to_set(foreignkey) if foreignkey: self.foreigntable = foreignkey.table else: self.foreigntable = None - - self.private = private + + if cascade is not None: + self.cascade = mapperutil.CascadeOptions(cascade) + else: + if private: + self.cascade = mapperutil.CascadeOptions("all, delete-orphan") + else: + self.cascade = mapperutil.CascadeOptions("save-update") + self.association = association - if selectalias is not None: - print "'selectalias' argument to relation() is deprecated. eager loads automatically alias-ize tables now." - if use_alias is not None: - print "'use_alias' argument to relation() is deprecated. eager loads automatically alias-ize tables now." self.order_by = order_by self.attributeext=attributeext if isinstance(backref, str): @@ -145,6 +152,24 @@ class PropertyLoader(MapperProperty): self.backref = backref self.is_backref = is_backref + private = property(lambda s:s.cascade.delete_orphan) + + def cascade_iterator(self, type, object, recursive=None): + if not type in self.cascade: + return + if recursive is None: + recursive = sets.Set() + + childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True) + + for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items(): + if c is not None: + if c not in recursive: + recursive.add(c) + yield c + for c2 in self.mapper.cascade_iterator(type, c, recursive): + yield c2 + def copy(self): x = self.__class__.__new__(self.__class__) x.__dict__.update(self.__dict__) @@ -155,31 +180,34 @@ class PropertyLoader(MapperProperty): pass def do_init(self, key, parent): - import sqlalchemy.mapping + import sqlalchemy.orm if isinstance(self.argument, type): - self.mapper = sqlalchemy.mapping.class_mapper(self.argument) + self.mapper = mapper.class_mapper(self.argument) else: self.mapper = self.argument + self.mapper = self.mapper.get_select_mapper() + if self.association is not None: if isinstance(self.association, type): - self.association = sqlalchemy.mapping.class_mapper(self.association) + self.association = mapper.class_mapper(self.association) - self.target = self.mapper.table + self.target = self.mapper.mapped_table self.key = key self.parent = parent if self.secondaryjoin is not None and self.secondary is None: - raise ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") + raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") # if join conditions were not specified, figure them out based on foreign keys if self.secondary is not None: if self.secondaryjoin is None: - self.secondaryjoin = sql.join(self.mapper.noninherited_table, self.secondary).onclause + self.secondaryjoin = sql.join(self.mapper.unjoined_table, self.secondary).onclause if self.primaryjoin is None: - self.primaryjoin = sql.join(parent.noninherited_table, self.secondary).onclause + self.primaryjoin = sql.join(parent.unjoined_table, self.secondary).onclause else: if self.primaryjoin is None: - self.primaryjoin = sql.join(parent.noninherited_table, self.target).onclause + self.primaryjoin = sql.join(parent.unjoined_table, self.target).onclause + # if the foreign key wasnt specified and theres no assocaition table, try to figure # out who is dependent on who. we dont need all the foreign keys represented in the join, # just one of them. @@ -193,13 +221,14 @@ class PropertyLoader(MapperProperty): if self.direction is None: self.direction = self._get_direction() - if self.uselist is None and self.direction == PropertyLoader.MANYTOONE: + if self.uselist is None and self.direction == sync.MANYTOONE: self.uselist = False if self.uselist is None: self.uselist = True self._compile_synchronizers() + self._dependency_processor = dependency.create_dependency_processor(self.key, self.syncrules, self.cascade, secondary=self.secondary, association=self.association, is_backref=self.is_backref, post_update=self.post_update) # primary property handler, set up class attributes if self.is_primary(): @@ -213,32 +242,32 @@ class PropertyLoader(MapperProperty): if self.backref is not None: self.backref.compile(self) - elif not objectstore.global_attributes.is_class_managed(parent.class_, key): - raise ArgumentError("Non-primary property created for attribute '%s' on class '%s', but that attribute is not managed! Insure that the primary mapper for this class defines this property" % (key, parent.class_.__name__)) + elif not sessionlib.global_attributes.is_class_managed(parent.class_, key): + raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (key, parent.class_.__name__, parent.class_.__name__)) self.do_init_subclass(key, parent) def _set_class_attribute(self, class_, key): """sets attribute behavior on our target class.""" - objectstore.uow().register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, extension=self.attributeext) + sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True) def _get_direction(self): """determines our 'direction', i.e. do we represent one to many, many to many, etc.""" - #print self.key, repr(self.parent.table.name), repr(self.parent.primarytable.name), repr(self.foreignkey.table.name), repr(self.target), repr(self.foreigntable.name) + #print self.key, repr(self.parent.mapped_table.name), repr(self.parent.primarytable.name), repr(self.foreignkey.table.name), repr(self.target), repr(self.foreigntable.name) if self.secondaryjoin is not None: - return PropertyLoader.MANYTOMANY - elif self.parent.table is self.target: + return sync.MANYTOMANY + elif self.parent.mapped_table is self.target: if self.foreignkey.primary_key: - return PropertyLoader.MANYTOONE + return sync.MANYTOONE else: - return PropertyLoader.ONETOMANY - elif self.foreigntable == self.mapper.noninherited_table: - return PropertyLoader.ONETOMANY - elif self.foreigntable == self.parent.noninherited_table: - return PropertyLoader.MANYTOONE + return sync.ONETOMANY + elif self.foreigntable == self.mapper.unjoined_table: + return sync.ONETOMANY + elif self.foreigntable == self.parent.unjoined_table: + return sync.MANYTOONE else: - raise ArgumentError("Cant determine relation direction") + raise exceptions.ArgumentError("Cant determine relation direction") def _find_dependent(self): """searches through the primary join condition to determine which side @@ -248,297 +277,39 @@ class PropertyLoader(MapperProperty): # set as a reference to allow assignment from inside a first-class function dependent = [None] def foo(binary): - if binary.operator != '=': + if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return - if isinstance(binary.left, schema.Column) and binary.left.primary_key: + if binary.left.primary_key: if dependent[0] is binary.left.table: - raise ArgumentError("bidirectional dependency not supported...specify foreignkey") + raise exceptions.ArgumentError("Could not determine the parent/child relationship for property '%s', based on join condition '%s' (table '%s' appears on both sides of the relationship, or in an otherwise ambiguous manner). please specify the 'foreignkey' keyword parameter to the relation() function indicating a column on the remote side of the relationship" % (self.key, str(self.primaryjoin), str(binary.left.table))) dependent[0] = binary.right.table self.foreignkey= binary.right - elif isinstance(binary.right, schema.Column) and binary.right.primary_key: + elif binary.right.primary_key: if dependent[0] is binary.right.table: - raise ArgumentError("bidirectional dependency not supported...specify foreignkey") + raise exceptions.ArgumentError("Could not determine the parent/child relationship for property '%s', based on join condition '%s' (table '%s' appears on both sides of the relationship, or in an otherwise ambiguous manner). please specify the 'foreignkey' keyword parameter to the relation() function indicating a column on the remote side of the relationship" % (self.key, str(self.primaryjoin), str(binary.right.table))) dependent[0] = binary.left.table self.foreignkey = binary.left visitor = BinaryVisitor(foo) self.primaryjoin.accept_visitor(visitor) if dependent[0] is None: - raise ArgumentError("cant determine primary foreign key in the join relationship....specify foreignkey=<column> or foreignkey=[<columns>]") + raise exceptions.ArgumentError("Could not determine the parent/child relationship for property '%s', based on join condition '%s' (no relationships joining tables '%s' and '%s' could be located). please specify the 'foreignkey' keyword parameter to the relation() function indicating a column on the remote side of the relationship" % (self.key, str(self.primaryjoin), str(binary.left.table), str(binary.right.table))) else: self.foreigntable = dependent[0] - - def get_criterion(self, key, value): - """given a key/value pair, determines if this PropertyLoader's mapper contains a key of the - given name in its property list, or if this PropertyLoader's association mapper, if any, - contains a key of the given name in its property list, and returns a WHERE clause against - the given value if found. - - this is called by a mappers select_by method to formulate a set of key/value pairs into - a WHERE criterion that spans multiple tables if needed.""" - # TODO: optimization: change mapper to accept a WHERE clause with separate bind parameters - # then cache the generated WHERE clauses here, since the creation + the copy_container - # is an extra expense - if self.mapper.props.has_key(key): - if self.secondaryjoin is not None: - c = (self.mapper.props[key].columns[0]==value) & self.primaryjoin & self.secondaryjoin - else: - c = (self.mapper.props[key].columns[0]==value) & self.primaryjoin - return c.copy_container() - elif self.mapper.table.c.has_key(key): - if self.secondaryjoin is not None: - c = (self.mapper.table.c[key].columns[0]==value) & self.primaryjoin & self.secondaryjoin - else: - c = (self.mapper.table.c[key].columns[0]==value) & self.primaryjoin - return c.copy_container() - elif self.association is not None: - c = self.mapper._get_criterion(key, value) & self.primaryjoin - return c.copy_container() - return None - - def register_deleted(self, obj, uow): - if not self.private: - return - - if self.uselist: - childlist = uow.attributes.get_history(obj, self.key, passive = False) - else: - childlist = uow.attributes.get_history(obj, self.key) - for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is not None: - uow.register_deleted(child) - - class MapperStub(object): - """poses as a Mapper representing the association table in a many-to-many - join, when performing a commit(). - - The Task objects in the objectstore module treat it just like - any other Mapper, but in fact it only serves as a "dependency" placeholder - for the many-to-many update task.""" - def __init__(self, mapper): - self.mapper = mapper - def save_obj(self, *args, **kwargs): - pass - def delete_obj(self, *args, **kwargs): - pass - def _primary_mapper(self): - return self - - def register_dependencies(self, uowcommit): - """tells a UOWTransaction what mappers are dependent on which, with regards - to the two or three mappers handled by this PropertyLoader. - - Also registers itself as a "processor" for one of its mappers, which - will be executed after that mapper's objects have been saved or before - they've been deleted. The process operation manages attributes and dependent - operations upon the objects of one of the involved mappers.""" - if self.association is not None: - # association object. our mapper should be dependent on both - # the parent mapper and the association object mapper. - # this is where we put the "stub" as a marker, so we get - # association/parent->stub->self, then we process the child - # elments after the 'stub' save, which is before our own - # mapper's save. - stub = PropertyLoader.MapperStub(self.association) - uowcommit.register_dependency(self.parent, stub) - uowcommit.register_dependency(self.association, stub) - uowcommit.register_dependency(stub, self.mapper) - uowcommit.register_processor(stub, self, self.parent, False) - uowcommit.register_processor(stub, self, self.parent, True) - - elif self.direction == PropertyLoader.MANYTOMANY: - # many-to-many. create a "Stub" mapper to represent the - # "middle table" in the relationship. This stub mapper doesnt save - # or delete any objects, but just marks a dependency on the two - # related mappers. its dependency processor then populates the - # association table. - - if self.is_backref: - # if we are the "backref" half of a two-way backref - # relationship, let the other mapper handle inserting the rows - return - stub = PropertyLoader.MapperStub(self.mapper) - uowcommit.register_dependency(self.parent, stub) - uowcommit.register_dependency(self.mapper, stub) - uowcommit.register_processor(stub, self, self.parent, False) - uowcommit.register_processor(stub, self, self.parent, True) - elif self.direction == PropertyLoader.ONETOMANY: - if self.post_update: - stub = PropertyLoader.MapperStub(self.mapper) - uowcommit.register_dependency(self.mapper, stub) - uowcommit.register_dependency(self.parent, stub) - uowcommit.register_processor(stub, self, self.parent, False) - uowcommit.register_processor(stub, self, self.parent, True) - else: - uowcommit.register_dependency(self.parent, self.mapper) - uowcommit.register_processor(self.parent, self, self.parent, False) - uowcommit.register_processor(self.parent, self, self.parent, True) - elif self.direction == PropertyLoader.MANYTOONE: - if self.post_update: - stub = PropertyLoader.MapperStub(self.mapper) - uowcommit.register_dependency(self.mapper, stub) - uowcommit.register_dependency(self.parent, stub) - uowcommit.register_processor(stub, self, self.parent, False) - uowcommit.register_processor(stub, self, self.parent, True) - else: - uowcommit.register_dependency(self.mapper, self.parent) - uowcommit.register_processor(self.mapper, self, self.parent, False) - uowcommit.register_processor(self.mapper, self, self.parent, True) - else: - raise AssertionError(" no foreign key ?") - - def get_object_dependencies(self, obj, uowcommit, passive = True): - return uowcommit.uow.attributes.get_history(obj, self.key, passive = passive) - - def whose_dependent_on_who(self, obj1, obj2): - """given an object pair assuming obj2 is a child of obj1, returns a tuple - with the dependent object second, or None if they are equal. - used by objectstore's object-level topological sort (i.e. cyclical - table dependency).""" - if obj1 is obj2: - return None - elif self.direction == PropertyLoader.ONETOMANY: - return (obj1, obj2) - else: - return (obj2, obj1) - - def process_dependencies(self, task, deplist, uowcommit, delete = False): - """this method is called during a commit operation to synchronize data between a parent and child object. - it also can establish child or parent objects within the unit of work as "to be saved" or "deleted" - in some cases.""" - #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) - - def getlist(obj, passive=True): - l = self.get_object_dependencies(obj, uowcommit, passive) - uowcommit.register_saved_history(l) - return l - - # plugin point - - if self.direction == PropertyLoader.MANYTOMANY: - secondary_delete = [] - secondary_insert = [] - if delete: - for obj in deplist: - childlist = getlist(obj, False) - for child in childlist.deleted_items() + childlist.unchanged_items(): - associationrow = {} - self._synchronize(obj, child, associationrow, False) - secondary_delete.append(associationrow) - else: - for obj in deplist: - childlist = getlist(obj) - if childlist is None: continue - for child in childlist.added_items(): - associationrow = {} - self._synchronize(obj, child, associationrow, False) - secondary_insert.append(associationrow) - for child in childlist.deleted_items(): - associationrow = {} - self._synchronize(obj, child, associationrow, False) - secondary_delete.append(associationrow) - if len(secondary_delete): - # TODO: precompile the delete/insert queries and store them as instance variables - # on the PropertyLoader - statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key) for c in self.secondary.c])) - statement.execute(*secondary_delete) - if len(secondary_insert): - statement = self.secondary.insert() - statement.execute(*secondary_insert) - elif self.direction == PropertyLoader.MANYTOONE and delete: - if self.private: - for obj in deplist: - childlist = getlist(obj, False) - for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is None: - continue - # if private child object, and is in the uow's "deleted" list, - # insure its in the list of items to be deleted - if child in uowcommit.uow.deleted: - uowcommit.register_object(child, isdelete=True) - elif self.post_update: - # post_update means we have to update our row to not reference the child object - # before we can DELETE the row - for obj in deplist: - self._synchronize(obj, None, None, True) - uowcommit.register_object(obj, postupdate=True) - elif self.direction == PropertyLoader.ONETOMANY and delete: - # head object is being deleted, and we manage its list of child objects - # the child objects have to have their foreign key to the parent set to NULL - if self.private and not self.post_update: - for obj in deplist: - childlist = getlist(obj, False) - for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is None: - continue - # if private child object, and is in the uow's "deleted" list, - # insure its in the list of items to be deleted - if child in uowcommit.uow.deleted: - uowcommit.register_object(child, isdelete=True) - else: - for obj in deplist: - childlist = getlist(obj, False) - for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is not None: - self._synchronize(obj, child, None, True) - uowcommit.register_object(child, postupdate=self.post_update) - elif self.association is not None: - # manage association objects. - for obj in deplist: - childlist = getlist(obj, passive=True) - if childlist is None: continue - - #print "DIRECTION", self.direction - d = {} - for child in childlist: - self._synchronize(obj, child, None, False) - key = self.mapper.instance_key(child) - #print "SYNCHRONIZED", child, "INSTANCE KEY", key - d[key] = child - uowcommit.unregister_object(child) - - for child in childlist.added_items(): - uowcommit.register_object(child) - key = self.mapper.instance_key(child) - #print "ADDED, INSTANCE KEY", key - d[key] = child - - for child in childlist.unchanged_items(): - key = self.mapper.instance_key(child) - o = d[key] - o._instance_key= key - - for child in childlist.deleted_items(): - key = self.mapper.instance_key(child) - #print "DELETED, INSTANCE KEY", key - if d.has_key(key): - o = d[key] - o._instance_key = key - uowcommit.unregister_object(child) - else: - #print "DELETE ASSOC OBJ", repr(child) - uowcommit.register_object(child, isdelete=True) + def get_join(self): + if self.secondaryjoin is not None: + return self.primaryjoin & self.secondaryjoin else: - for obj in deplist: - childlist = getlist(obj, passive=True) - if childlist is not None: - for child in childlist.added_items(): - self._synchronize(obj, child, None, False) - if self.direction == PropertyLoader.ONETOMANY and child is not None: - uowcommit.register_object(child, postupdate=self.post_update) - if self.direction == PropertyLoader.MANYTOONE: - uowcommit.register_object(obj, postupdate=self.post_update) - if self.direction != PropertyLoader.MANYTOONE: - for child in childlist.deleted_items(): - if not self.private: - self._synchronize(obj, child, None, True) - uowcommit.register_object(child, isdelete=self.private) + return self.primaryjoin def execute(self, session, instance, row, identitykey, imap, isnew): if self.is_primary(): return #print "PLAIN PROPLOADER EXEC NON-PRIAMRY", repr(id(self)), repr(self.mapper.class_), self.key - objectstore.global_attributes.create_history(instance, self.key, self.uselist) + sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True) + + def register_dependencies(self, uowcommit): + self._dependency_processor.register_dependencies(uowcommit) def _compile_synchronizers(self): """assembles a list of 'synchronization rules', which are instructions on how to populate @@ -547,71 +318,54 @@ class PropertyLoader(MapperProperty): The list of rules is used within commits by the _synchronize() method when dependent objects are processed.""" - - - parent_tables = util.HashSet(self.parent.tables + [self.parent.primarytable]) - target_tables = util.HashSet(self.mapper.tables + [self.mapper.primarytable]) + parent_tables = util.HashSet(self.parent.tables + [self.parent.mapped_table]) + target_tables = util.HashSet(self.mapper.tables + [self.mapper.mapped_table]) self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction) - if self.direction == PropertyLoader.MANYTOMANY: + if self.direction == sync.MANYTOMANY: #print "COMPILING p/c", self.parent, self.mapper self.syncrules.compile(self.primaryjoin, parent_tables, [self.secondary], False) self.syncrules.compile(self.secondaryjoin, target_tables, [self.secondary], True) else: self.syncrules.compile(self.primaryjoin, parent_tables, target_tables) - def _synchronize(self, obj, child, associationrow, clearkeys): - """called during a commit to execute the full list of syncrules on the - given object/child/optional association row""" - if self.direction == PropertyLoader.ONETOMANY: - source = obj - dest = child - elif self.direction == PropertyLoader.MANYTOONE: - source = child - dest = obj - elif self.direction == PropertyLoader.MANYTOMANY: - dest = associationrow - source = None - - if dest is None: - return - - self.syncrules.execute(source, dest, obj, child, clearkeys) - class LazyLoader(PropertyLoader): def do_init_subclass(self, key, parent): - (self.lazywhere, self.lazybinds, self.lazyreverse) = create_lazy_clause(self.parent.noninherited_table, self.primaryjoin, self.secondaryjoin, self.foreignkey) + (self.lazywhere, self.lazybinds, self.lazyreverse) = create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() - self.use_get = not self.uselist and self.mapper.query._get_clause.compare(self.lazywhere) + self.use_get = not self.uselist and self.mapper.query()._get_clause.compare(self.lazywhere) def _set_class_attribute(self, class_, key): # establish a class-level lazy loader on our class #print "SETCLASSATTR LAZY", repr(class_), key - objectstore.global_attributes.register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, callable_=lambda i: self.setup_loader(i), extension=self.attributeext) + sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, callable_=lambda i: self.setup_loader(i), extension=self.attributeext, cascade=self.cascade, trackparent=True) def setup_loader(self, instance): if not self.parent.is_assigned(instance): - return object_mapper(instance).props[self.key].setup_loader(instance) + return mapper.object_mapper(instance).props[self.key].setup_loader(instance) def lazyload(): params = {} allparams = True - session = objectstore.get_session(instance) - #print "setting up loader, lazywhere", str(self.lazywhere) - for col, bind in self.lazybinds.iteritems(): - params[bind.key] = self.parent._getattrbycolumn(instance, col) - if params[bind.key] is None: - allparams = False - break + session = sessionlib.object_session(instance) + #print "setting up loader, lazywhere", str(self.lazywhere), "binds", self.lazybinds + if session is not None: + for col, bind in self.lazybinds.iteritems(): + params[bind.key] = self.parent._getattrbycolumn(instance, col) + if params[bind.key] is None: + allparams = False + break + else: + allparams = False if allparams: # if we have a simple straight-primary key load, use mapper.get() # to possibly save a DB round trip if self.use_get: ident = [] - for primary_key in self.mapper.pks_by_table[self.mapper.table]: + for primary_key in self.mapper.pks_by_table[self.mapper.mapped_table]: bind = self.lazyreverse[primary_key] ident.append(params[bind.key]) - return self.mapper.using(session).get(*ident) + return self.mapper.using(session).get(ident) elif self.order_by is not False: order_by = self.order_by elif self.secondary is not None and self.secondary.default_order_by() is not None: @@ -637,49 +391,48 @@ class LazyLoader(PropertyLoader): #print "EXEC NON-PRIAMRY", repr(self.mapper.class_), self.key # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, # which will override the class-level behavior - objectstore.global_attributes.create_history(instance, self.key, self.uselist, callable_=self.setup_loader(instance)) + sessionlib.global_attributes.create_history(instance, self.key, self.uselist, callable_=self.setup_loader(instance), cascade=self.cascade, trackparent=True) else: #print "EXEC PRIMARY", repr(self.mapper.class_), self.key # we are the primary manager for this attribute on this class - reset its per-instance attribute state, # so that the class-level lazy loader is executed when next referenced on this instance. # this usually is not needed unless the constructor of the object referenced the attribute before we got # to load data into it. - objectstore.global_attributes.reset_history(instance, self.key) + sessionlib.global_attributes.reset_history(instance, self.key) def create_lazy_clause(table, primaryjoin, secondaryjoin, foreignkey): binds = {} - reverselookup = {} - + reverse = {} def bind_label(): return "lazy_" + hex(random.randint(0, 65535))[2:] - + def visit_binary(binary): circular = isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column) and binary.left.table is binary.right.table if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column) and ((not circular and binary.left.table is table) or (circular and binary.right is foreignkey)): col = binary.left binary.left = binds.setdefault(binary.left, sql.BindParamClause(bind_label(), None, shortname = binary.left.name)) - reverselookup[binary.right] = binds[col] - #binary.swap() + reverse[binary.right] = binds[col] if isinstance(binary.right, schema.Column) and isinstance(binary.left, schema.Column) and ((not circular and binary.right.table is table) or (circular and binary.left is foreignkey)): col = binary.right binary.right = binds.setdefault(binary.right, sql.BindParamClause(bind_label(), None, shortname = binary.right.name)) - reverselookup[binary.left] = binds[col] - + reverse[binary.left] = binds[col] + lazywhere = primaryjoin.copy_container() li = BinaryVisitor(visit_binary) lazywhere.accept_visitor(li) - #print "PRIMARYJOIN", str(lazywhere), [b.key for b in binds.values()] if secondaryjoin is not None: lazywhere = sql.and_(lazywhere, secondaryjoin) - return (lazywhere, binds, reverselookup) + return (lazywhere, binds, reverse) -class EagerLoader(PropertyLoader): +class EagerLoader(LazyLoader): """loads related objects inline with a parent query.""" def do_init_subclass(self, key, parent, recursion_stack=None): + if recursion_stack is None: + LazyLoader.do_init_subclass(self, key, parent) parent._has_eager = True self.eagertarget = self.target.alias() @@ -723,7 +476,7 @@ class EagerLoader(PropertyLoader): if isinstance(prop, EagerLoader): eagerprops.append(prop) if len(eagerprops): - recursion_stack[self.parent.table] = True + recursion_stack[self.parent.mapped_table] = True self.mapper = self.mapper.copy() try: for prop in eagerprops: @@ -733,16 +486,16 @@ class EagerLoader(PropertyLoader): continue p = prop.copy() self.mapper.props[prop.key] = p -# print "we are:", id(self), self.target.name, (self.secondary and self.secondary.name or "None"), self.parent.table.name -# print "prop is",id(prop), prop.target.name, (prop.secondary and prop.secondary.name or "None"), prop.parent.table.name +# print "we are:", id(self), self.target.name, (self.secondary and self.secondary.name or "None"), self.parent.mapped_table.name +# print "prop is",id(prop), prop.target.name, (prop.secondary and prop.secondary.name or "None"), prop.parent.mapped_table.name p.do_init_subclass(prop.key, prop.parent, recursion_stack) p._create_eager_chain(in_chain=True, recursion_stack=recursion_stack) p.eagerprimary = p.eagerprimary.copy_container() -# aliasizer = Aliasizer(p.parent.table, aliases={p.parent.table:self.eagertarget}) +# aliasizer = Aliasizer(p.parent.mapped_table, aliases={p.parent.mapped_table:self.eagertarget}) p.eagerprimary.accept_visitor(self.aliasizer) - #print "new eagertqarget", p.eagertarget.name, (p.secondary and p.secondary.name or "none"), p.parent.table.name + #print "new eagertqarget", p.eagertarget.name, (p.secondary and p.secondary.name or "none"), p.parent.mapped_table.name finally: - del recursion_stack[self.parent.table] + del recursion_stack[self.parent.mapped_table] self._row_decorator = self._create_decorator_row() @@ -755,7 +508,7 @@ class EagerLoader(PropertyLoader): orderby = util.to_list(orderby) for i in range(0, len(orderby)): if isinstance(orderby[i], schema.Column): - orderby[i] = self.eagertarget._get_col_by_original(orderby[i]) + orderby[i] = self.eagertarget.corresponding_column(orderby[i]) else: orderby[i].accept_visitor(self.aliasizer) return orderby @@ -769,7 +522,7 @@ class EagerLoader(PropertyLoader): if hasattr(statement, '_outerjoin'): towrap = statement._outerjoin else: - towrap = self.parent.table + towrap = self.parent.mapped_table # print "hello, towrap", str(towrap) if self.secondaryjoin is not None: @@ -795,26 +548,34 @@ class EagerLoader(PropertyLoader): """receive a row. tell our mapper to look for a new object instance in the row, and attach it to a list on the parent instance.""" + decorated_row = self._decorate_row(row) + try: + # check for identity key + identity_key = self.mapper._row_identity_key(decorated_row) + except KeyError: + # else degrade to a lazy loader + LazyLoader.execute(self, session, instance, row, identitykey, imap, isnew) + return + if isnew: # new row loaded from the database. initialize a blank container on the instance. # this will override any per-class lazyloading type of stuff. - h = objectstore.global_attributes.create_history(instance, self.key, self.uselist) + h = sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True) if not self.uselist: if isnew: - h.setattr_clean(self._instance(session, row, imap)) + h.setattr_clean(self.mapper._instance(session, decorated_row, imap, None)) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties - self._instance(session, row, imap) + self.mapper._instance(session, decorated_row, imap, None) return elif isnew: result_list = h else: result_list = getattr(instance, self.key) - - self._instance(session, row, imap, result_list) + self.mapper._instance(session, decorated_row, imap, result_list) def _create_decorator_row(self): class DecoratorDict(object): @@ -828,14 +589,13 @@ class EagerLoader(PropertyLoader): return map.keys() map = {} for c in self.eagertarget.c: - parent = self.target._get_col_by_original(c.original) + parent = self.target.corresponding_column(c) map[parent] = c map[parent._label] = c map[parent.name] = c return DecoratorDict - - def _instance(self, session, row, imap, result_list=None): - """gets an instance from a row, via this EagerLoader's mapper.""" + + def _decorate_row(self, row): # since the EagerLoader makes an Alias of its mapper's table, # we translate the actual result columns back to what they # would normally be into a "virtual row" which is passed to the child mapper. @@ -843,10 +603,13 @@ class EagerLoader(PropertyLoader): # (neither do any MapperExtensions). The row is keyed off the Column object # (which is what mappers use) as well as its "label" (which might be what # user-defined code is using) - row = self._row_decorator(row) - return self.mapper._instance(session, row, imap, result_list) + try: + return self._row_decorator(row) + except AttributeError: + self._create_eager_chain() + return self._row_decorator(row) -class GenericOption(MapperOption): +class GenericOption(mapper.MapperOption): """a mapper option that can handle dotted property names, descending down through the relations of a mapper until it reaches the target.""" @@ -879,26 +642,37 @@ class BackRef(object): """called by the owning PropertyLoader to set up a backreference on the PropertyLoader's mapper.""" # try to set a LazyLoader on our mapper referencing the parent mapper + mapper = prop.mapper.primary_mapper() if not prop.mapper.props.has_key(self.key): - if prop.secondaryjoin is not None: - # if setting up a backref to a many-to-many, reverse the order - # of the "primary" and "secondary" joins - pj = prop.secondaryjoin - sj = prop.primaryjoin - else: - pj = prop.primaryjoin - sj = None + pj = self.kwargs.pop('primaryjoin', None) + sj = self.kwargs.pop('secondaryjoin', None) + # TODO: we are going to have the newly backref'd property create its + # primary/secondary join through normal means, and only override if they are + # specified to the constructor. think about if this is really going to work + # all the way. + #if pj is None: + # if prop.secondaryjoin is not None: + # # if setting up a backref to a many-to-many, reverse the order + # # of the "primary" and "secondary" joins + # pj = prop.secondaryjoin + # sj = prop.primaryjoin + # else: + # pj = prop.primaryjoin + # sj = None lazy = self.kwargs.pop('lazy', True) if lazy: cls = LazyLoader else: cls = EagerLoader - relation = cls(prop.parent, prop.secondary, pj, sj, backref=prop.key, is_backref=True, **self.kwargs) - prop.mapper.add_property(self.key, relation); + # the backref property is set on the primary mapper + parent = prop.parent.primary_mapper() + relation = cls(parent, prop.secondary, pj, sj, backref=prop.key, is_backref=True, **self.kwargs) + mapper.add_property(self.key, relation); else: # else set one of us as the "backreference" - if not prop.mapper.props[self.key].is_backref: + if not mapper.props[self.key].is_backref: prop.is_backref=True + prop._dependency_processor.is_backref=True def get_extension(self): """returns an attribute extension to use with this backreference.""" return attributes.GenericBackrefExtension(self.key) @@ -964,14 +738,14 @@ class Aliasizer(sql.ClauseVisitor): for i in range(0, len(clist.clauses)): if isinstance(clist.clauses[i], schema.Column) and self.tables.has_key(clist.clauses[i].table): orig = clist.clauses[i] - clist.clauses[i] = self.get_alias(clist.clauses[i].table)._get_col_by_original(clist.clauses[i]) + clist.clauses[i] = self.get_alias(clist.clauses[i].table).corresponding_column(clist.clauses[i]) if clist.clauses[i] is None: raise "cant get orig for " + str(orig) + " against table " + orig.table.name + " " + self.get_alias(orig.table).name def visit_binary(self, binary): if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table): - binary.left = self.get_alias(binary.left.table)._get_col_by_original(binary.left) + binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left) if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table): - binary.right = self.get_alias(binary.right.table)._get_col_by_original(binary.right) + binary.right = self.get_alias(binary.right.table).corresponding_column(binary.right) class BinaryVisitor(sql.ClauseVisitor): def __init__(self, func): diff --git a/lib/sqlalchemy/mapping/query.py b/lib/sqlalchemy/orm/query.py index 283e8c189..cb51da02a 100644 --- a/lib/sqlalchemy/mapping/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1,24 +1,26 @@ -# mapper/query.py +# orm/query.py # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - -import objectstore -import sqlalchemy.sql as sql -import sqlalchemy.util as util +import session as sessionlib +from sqlalchemy import sql, util, exceptions import mapper -from sqlalchemy.exceptions import * class Query(object): """encapsulates the object-fetching operations provided by Mappers.""" - def __init__(self, mapper, **kwargs): - self.mapper = mapper + def __init__(self, class_or_mapper, session=None, entity_name=None, **kwargs): + if isinstance(class_or_mapper, type): + self.mapper = class_mapper(class_or_mapper, entity_name=entity_name) + else: + self.mapper = class_or_mapper + self.mapper = self.mapper.get_select_mapper() + self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) self.order_by = kwargs.pop('order_by', self.mapper.order_by) self.extension = kwargs.pop('extension', self.mapper.extension) - self._session = kwargs.pop('session', None) + self._session = session if not hasattr(mapper, '_get_clause'): _get_clause = sql.and_() for primary_key in self.mapper.pks_by_table[self.table]: @@ -27,21 +29,30 @@ class Query(object): self._get_clause = self.mapper._get_clause def _get_session(self): if self._session is None: - return objectstore.get_session() + return self.mapper.get_session() else: return self._session - table = property(lambda s:s.mapper.table) - props = property(lambda s:s.mapper.props) + table = property(lambda s:s.mapper.select_table) session = property(_get_session) - def get(self, *ident, **kwargs): + def get(self, ident, **kwargs): """returns an instance of the object based on the given identifier, or None - if not found. The *ident argument is a - list of primary key columns in the order of the table def's primary key columns.""" - key = self.mapper.identity_key(*ident) - #print "key: " + repr(key) + " ident: " + repr(ident) + if not found. The ident argument is a scalar or tuple of primary key column values + in the order of the table def's primary key columns.""" + key = self.mapper.identity_key(ident) return self._get(key, ident, **kwargs) + def load(self, ident, **kwargs): + """returns an instance of the object based on the given identifier. If not found, + raises an exception. The method will *remove all pending changes* to the object + already existing in the Session. The ident argument is a scalar or tuple of primary + key column values in the order of the table def's primary key columns.""" + key = self.mapper.identity_key(ident) + instance = self._get(key, ident, reload=True, **kwargs) + if instance is None: + raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) + return instance + def get_by(self, *args, **params): """returns a single object instance based on the given key/value criterion. this is either the first value in the result list, or None if the list is @@ -55,7 +66,7 @@ class Query(object): e.g. u = usermapper.get_by(user_name = 'fred') """ - x = self.select_whereclause(self._by_clause(*args, **params), limit=1) + x = self.select_whereclause(self.join_by(*args, **params), limit=1) if x: return x[0] else: @@ -65,6 +76,7 @@ class Query(object): """returns an array of object instances based on the given clauses and key/value criterion. *args is a list of zero or more ClauseElements which will be connected by AND operators. + **params is a set of zero or more key/value parameters which are converted into ClauseElements. the keys are mapped to property or column names mapped by this mapper's Table, and the values are coerced into a WHERE clause separated by AND operators. If the local property/column @@ -77,8 +89,76 @@ class Query(object): ret = self.extension.select_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret - return self.select_whereclause(self._by_clause(*args, **params)) + return self.select_whereclause(self.join_by(*args, **params)) + + def join_by(self, *args, **params): + """like select_by, but returns a ClauseElement representing the WHERE clause that would normally + be sent to select_whereclause by select_by.""" + clause = None + for arg in args: + if clause is None: + clause = arg + else: + clause &= arg + for key, value in params.iteritems(): + (keys, prop) = self._locate_prop(key) + c = (prop.columns[0]==value) & self.join_via(keys) + if clause is None: + clause = c + else: + clause &= c + return clause + + def _locate_prop(self, key): + import properties + keys = [] + def search_for_prop(mapper): + if mapper.props.has_key(key): + prop = mapper.props[key] + if isinstance(prop, properties.PropertyLoader): + keys.insert(0, prop.key) + return prop + else: + for prop in mapper.props.values(): + if not isinstance(prop, properties.PropertyLoader): + continue + x = search_for_prop(prop.mapper) + if x: + keys.insert(0, prop.key) + return x + else: + return None + p = search_for_prop(self.mapper) + if p is None: + raise exceptions.InvalidRequestError("Cant locate property named '%s'" % key) + return [keys, p] + + def join_to(self, key): + """given the key name of a property, will recursively descend through all child properties + from this Query's mapper to locate the property, and will return a ClauseElement + representing a join from this Query's mapper to the endmost mapper.""" + [keys, p] = self._locate_prop(key) + return self.join_via(keys) + + def join_via(self, keys): + """given a list of keys that represents a path from this Query's mapper to a related mapper + based on names of relations from one mapper to the next, returns a + ClauseElement representing a join from this Query's mapper to the endmost mapper. + """ + mapper = self.mapper + clause = None + for key in keys: + prop = mapper.props[key] + if clause is None: + clause = prop.get_join() + else: + clause &= prop.get_join() + mapper = prop.mapper + + return clause + + def selectfirst_by(self, *args, **params): """works like select_by(), but only returns the first result by itself, or None if no objects returned. Synonymous with get_by()""" @@ -86,15 +166,15 @@ class Query(object): def selectone_by(self, *args, **params): """works like selectfirst_by(), but throws an error if not exactly one result was returned.""" - ret = self.select_whereclause(self._by_clause(*args, **params), limit=2) + ret = self.select_whereclause(self.join_by(*args, **params), limit=2) if len(ret) == 1: return ret[0] - raise InvalidRequestError('Multiple rows returned for selectone_by') + raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by') def count_by(self, *args, **params): """returns the count of instances based on the given clauses and key/value criterion. The criterion is constructed in the same way as the select_by() method.""" - return self.count(self._by_clause(*args, **params)) + return self.count(self.join_by(*args, **params)) def selectfirst(self, *args, **params): """works like select(), but only returns the first result by itself, or None if no @@ -111,7 +191,7 @@ class Query(object): ret = list(self.select(*args, **params)[0:2]) if len(ret) == 1: return ret[0] - raise InvalidRequestError('Multiple rows returned for selectone') + raise exceptions.InvalidRequestError('Multiple rows returned for selectone') def select(self, arg=None, **kwargs): """selects instances of the object from the database. @@ -138,17 +218,18 @@ class Query(object): def count(self, whereclause=None, params=None, **kwargs): s = self.table.count(whereclause) - if params is not None: - return s.scalar(**params) - else: - return s.scalar() + return self.session.scalar(self.mapper, s, params=params) def select_statement(self, statement, **params): return self._select_statement(statement, params=params) def select_text(self, text, **params): - t = sql.text(text, engine=self.mapper.primarytable.engine) - return self.instances(t.execute(**params)) + t = sql.text(text) + return self.instances(t, params=params) + + def options(self, *args, **kwargs): + """returns a new Query object using the given MapperOptions.""" + return self.mapper.options(*args, **kwargs).using(session=self._session) def __getattr__(self, key): if (key.startswith('select_by_')): @@ -164,28 +245,13 @@ class Query(object): else: raise AttributeError(key) - def instances(self, *args, **kwargs): - return self.mapper.instances(session=self.session, *args, **kwargs) + def instances(self, clauseelement, params=None, *args, **kwargs): + result = self.session.execute(self.mapper, clauseelement, params=params) + try: + return self.mapper.instances(result, self.session, **kwargs) + finally: + result.close() - def _by_clause(self, *args, **params): - clause = None - for arg in args: - if clause is None: - clause = arg - else: - clause &= arg - for key, value in params.iteritems(): - if value is False: - continue - c = self.mapper._get_criterion(key, value) - if c is None: - raise InvalidRequestError("Cant find criterion for property '"+ key + "'") - if clause is None: - clause = c - else: - clause &= c - return clause - def _get(self, key, ident=None, reload=False): if not reload and not self.always_refresh: try: @@ -195,6 +261,8 @@ class Query(object): if ident is None: ident = key[1] + else: + ident = util.to_list(ident) i = 0 params = {} for primary_key in self.mapper.pks_by_table[self.table]: @@ -210,7 +278,7 @@ class Query(object): statement.use_labels = True if params is None: params = {} - return self.instances(statement.execute(**params), **kwargs) + return self.instances(statement, params=params, **kwargs) def _should_nest(self, **kwargs): """returns True if the given statement options indicate that we should "nest" the @@ -229,7 +297,7 @@ class Query(object): if order_by is False: if self.table.default_order_by() is not None: order_by = self.table.default_order_by() - + if self._should_nest(**kwargs): from_obj.append(self.table) s2 = sql.select(self.table.primary_key, whereclause, use_labels=True, from_obj=from_obj, **kwargs) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py new file mode 100644 index 000000000..42f442b09 --- /dev/null +++ b/lib/sqlalchemy/orm/session.py @@ -0,0 +1,453 @@ +# objectstore.py +# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy import util, exceptions, sql +import unitofwork, query +import weakref +import sqlalchemy + +class SessionTransaction(object): + def __init__(self, session, parent=None, autoflush=True): + self.session = session + self.connections = {} + self.parent = parent + self.autoflush = autoflush + def connection(self, mapper_or_class, entity_name=None): + if isinstance(mapper_or_class, type): + mapper_or_class = class_mapper(mapper_or_class, entity_name=entity_name) + if self.parent is not None: + return self.parent.connection(mapper_or_class) + engine = self.session.get_bind(mapper_or_class) + return self.get_or_add(engine) + def _begin(self): + return SessionTransaction(self.session, self) + def add(self, connectable): + if self.connections.has_key(connectable.engine): + raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") + return self.get_or_add(connectable) + def get_or_add(self, connectable): + # we reference the 'engine' attribute on the given object, which in the case of + # Connection, ProxyEngine, Engine, ComposedSQLEngine, whatever, should return the original + # "Engine" object that is handling the connection. + if self.connections.has_key(connectable.engine): + return self.connections[connectable.engine][0] + e = connectable.engine + c = connectable.contextual_connect() + if not self.connections.has_key(e): + self.connections[e] = (c, c.begin()) + return self.connections[e][0] + def commit(self): + if self.parent is not None: + return + if self.autoflush: + self.session.flush() + for t in self.connections.values(): + t[1].commit() + self.close() + def rollback(self): + if self.parent is not None: + self.parent.rollback() + return + for k, t in self.connections.iteritems(): + t[1].rollback() + self.close() + def close(self): + if self.parent is not None: + return + for t in self.connections.values(): + t[0].close() + self.session.transaction = None + +class Session(object): + """encapsulates a set of objects being operated upon within an object-relational operation.""" + def __init__(self, bind_to=None, hash_key=None, import_session=None, echo_uow=False): + if import_session is not None: + self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map) + else: + self.uow = unitofwork.UnitOfWork() + + self.bind_to = bind_to + self.binds = {} + self.echo_uow = echo_uow + self.transaction = None + if hash_key is None: + self.hash_key = id(self) + else: + self.hash_key = hash_key + _sessions[self.hash_key] = self + + def create_transaction(self, **kwargs): + """returns a new SessionTransaction corresponding to an existing or new transaction. + if the transaction is new, the returned SessionTransaction will have commit control + over the underlying transaction, else will have rollback control only.""" + if self.transaction is not None: + return self.transaction._begin() + else: + self.transaction = SessionTransaction(self, **kwargs) + return self.transaction + def connect(self, mapper=None, **kwargs): + """returns a unique connection corresponding to the given mapper. this connection + will not be part of any pre-existing transactional context.""" + return self.get_bind(mapper).connect(**kwargs) + def connection(self, mapper, **kwargs): + """returns a Connection corresponding to the given mapper. used by the execute() + method which performs select operations for Mapper and Query. + if this Session is transactional, + the connection will be in the context of this session's transaction. otherwise, the connection + is returned by the contextual_connect method, which some Engines override to return a thread-local + connection, and will have close_with_result set to True. + + the given **kwargs will be sent to the engine's contextual_connect() method, if no transaction is in progress.""" + if self.transaction is not None: + return self.transaction.connection(mapper) + else: + return self.get_bind(mapper).contextual_connect(**kwargs) + def execute(self, mapper, clause, params, **kwargs): + """using the given mapper to identify the appropriate Engine or Connection to be used for statement execution, + executes the given ClauseElement using the provided parameter dictionary. Returns a ResultProxy corresponding + to the execution's results. If this method allocates a new Connection for the operation, then the ResultProxy's close() + method will release the resources of the underlying Connection, otherwise its a no-op. + """ + return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs) + def scalar(self, mapper, clause, params, **kwargs): + """works like execute() but returns a scalar result.""" + return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs) + + def close(self): + """closes this Session. + """ + self.clear() + if self.transaction is not None: + self.transaction.close() + + def clear(self): + """removes all object instances from this Session. this is equivalent to calling expunge() for all + objects in this Session.""" + for instance in self: + self._unattach(instance) + self.uow = unitofwork.UnitOfWork() + + def mapper(self, class_, entity_name=None): + """given an Class, returns the primary Mapper responsible for persisting it""" + return class_mapper(class_, entity_name = entity_name) + def bind_mapper(self, mapper, bindto): + """binds the given Mapper to the given Engine or Connection. All subsequent operations involving this + Mapper will use the given bindto.""" + self.binds[mapper] = bindto + def bind_table(self, table, bindto): + """binds the given Table to the given Engine or Connection. All subsequent operations involving this + Table will use the given bindto.""" + self.binds[table] = bindto + def get_bind(self, mapper): + """given a Mapper, returns the Engine or Connection which is used to execute statements on behalf of this + Mapper. Calling connect() on the return result will always result in a Connection object. This method + disregards any SessionTransaction that may be in progress. + + The order of searching is as follows: + + if an Engine or Connection was bound to this Mapper specifically within this Session, returns that + Engine or Connection. + + if an Engine or Connection was bound to this Mapper's underlying Table within this Session + (i.e. not to the Table directly), returns that Engine or Conneciton. + + if an Engine or Connection was bound to this Session, returns that Engine or Connection. + + finally, returns the Engine which was bound directly to the Table's MetaData object. + + If no Engine is bound to the Table, an exception is raised. + """ + if mapper is None: + return self.bind_to + elif self.binds.has_key(mapper): + return self.binds[mapper] + elif self.binds.has_key(mapper.mapped_table): + return self.binds[mapper.mapped_table] + elif self.bind_to is not None: + return self.bind_to + else: + e = mapper.mapped_table.engine + if e is None: + raise exceptions.InvalidRequestError("Could not locate any Engine bound to mapper '%s'" % str(mapper)) + return e + def query(self, mapper_or_class, entity_name=None): + """given a mapper or Class, returns a new Query object corresponding to this Session and the mapper, or the classes' primary mapper.""" + if isinstance(mapper_or_class, type): + return query.Query(class_mapper(mapper_or_class, entity_name=entity_name), self) + else: + return query.Query(mapper_or_class, self) + def _sql(self): + class SQLProxy(object): + def __getattr__(self, key): + def call(*args, **kwargs): + kwargs[engine] = self.engine + return getattr(sql, key)(*args, **kwargs) + + sql = property(_sql) + + + def get_id_key(ident, class_, entity_name=None): + """returns an identity-map key for use in storing/retrieving an item from the identity + map, given a tuple of the object's primary key values. + + ident - a tuple of primary key values corresponding to the object to be stored. these + values should be in the same order as the primary keys of the table + + class_ - a reference to the object's class + + entity_name - optional string name to further qualify the class + """ + return (class_, tuple(ident), entity_name) + get_id_key = staticmethod(get_id_key) + + def get_row_key(row, class_, primary_key, entity_name=None): + """returns an identity-map key for use in storing/retrieving an item from the identity + map, given a result set row. + + row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set + column names to their values within a row. + + class_ - a reference to the object's class + + primary_key - a list of column objects that will target the primary key values + in the given row. + + entity_name - optional string name to further qualify the class + """ + return (class_, tuple([row[column] for column in primary_key]), entity_name) + get_row_key = staticmethod(get_row_key) + + def begin(self, *obj): + """deprecated""" + raise exceptions.InvalidRequestError("Session.begin() is deprecated. use install_mod('legacy_session') to enable the old behavior") + def commit(self, *obj): + """deprecated""" + raise exceptions.InvalidRequestError("Session.commit() is deprecated. use install_mod('legacy_session') to enable the old behavior") + + def flush(self, objects=None): + """flushes all the object modifications present in this session to the database. 'objects' + is a list or tuple of objects specifically to be flushed.""" + self.uow.flush(self, objects, echo=self.echo_uow) + + def get(self, class_, ident, **kwargs): + """returns an instance of the object based on the given identifier, or None + if not found. The ident argument is a scalar or tuple of primary key column values in the order of the + table def's primary key columns. + + the entity_name keyword argument may also be specified which further qualifies the underlying + Mapper used to perform the query.""" + entity_name = kwargs.get('entity_name', None) + return self.query(class_, entity_name=entity_name).get(ident) + + def load(self, class_, ident, **kwargs): + """returns an instance of the object based on the given identifier. If not found, + raises an exception. The method will *remove all pending changes* to the object + already existing in the Session. The ident argument is a scalar or tuple of + primary key columns in the order of the table def's primary key columns. + + the entity_name keyword argument may also be specified which further qualifies the underlying + Mapper used to perform the query.""" + entity_name = kwargs.get('entity_name', None) + return self.query(class_, entity_name=entity_name).load(ident) + + def refresh(self, object): + """reloads the attributes for the given object from the database, clears + any changes made.""" + self.uow.refresh(self, object) + + def expire(self, object): + """invalidates the data in the given object and sets them to refresh themselves + the next time they are requested.""" + self.uow.expire(self, object) + + def expunge(self, object): + """removes the given object from this Session. this will free all internal references to the object.""" + self.uow.expunge(object) + + def save(self, object, entity_name=None): + """ + Adds a transient (unsaved) instance to this Session. This operation cascades the "save_or_update" + method to associated instances if the relation is mapped with cascade="save-update". + + The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this + instance. + """ + for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object): + if c is object: + self._save_impl(c, entity_name=entity_name) + else: + self.save_or_update(c, entity_name=entity_name) + + def update(self, object, entity_name=None): + """Brings the given detached (saved) instance into this Session. + If there is a persistent instance with the same identifier (i.e. a saved instance already associated with this + Session), an exception is thrown. + This operation cascades the "save_or_update" method to associated instances if the relation is mapped + with cascade="save-update".""" + for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object): + if c is object: + self._update_impl(c, entity_name=entity_name) + else: + self.save_or_update(c, entity_name=entity_name) + + def save_or_update(self, object, entity_name=None): + for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object): + key = getattr(object, '_instance_key', None) + if key is None: + self._save_impl(c, entity_name=entity_name) + else: + self._update_impl(c, entity_name=entity_name) + + def delete(self, object, entity_name=None): + for c in object_mapper(object, entity_name=entity_name).cascade_iterator('delete', object): + self.uow.register_deleted(c) + + def merge(self, object, entity_name=None): + instance = None + for obj in object_mapper(object, entity_name=entity_name).cascade_iterator('merge', object): + key = getattr(obj, '_instance_key', None) + if key is None: + mapper = object_mapper(object, entity_name=entity_name) + ident = mapper.identity(object) + for k in ident: + if k is None: + raise exceptions.InvalidRequestError("Instance '%s' does not have a full set of identity values, and does not represent a saved entity in the database. Use the add() method to add unsaved instances to this Session." % repr(obj)) + key = mapper.identity_key(ident) + u = self.uow + if u.identity_map.has_key(key): + # TODO: copy the state of the given object into this one. tricky ! + inst = u.identity_map[key] + else: + inst = self.get(object.__class__, *key[1]) + if obj is object: + instance = inst + + return instance + + def _save_impl(self, object, **kwargs): + if hasattr(object, '_instance_key'): + if not self.uow.has_key(object._instance_key): + raise exceptions.InvalidRequestError("Instance '%s' is already persistent in a different Session" % repr(object)) + else: + entity_name = kwargs.get('entity_name', None) + if entity_name is not None: + m = class_mapper(object.__class__, entity_name=entity_name) + m._assign_entity_name(object) + self._register_new(object) + + def _update_impl(self, object, **kwargs): + if self._is_attached(object) and object not in self.deleted: + return + if not hasattr(object, '_instance_key'): + raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % repr(object)) + if global_attributes.is_modified(object): + self._register_dirty(object) + else: + self._register_clean(object) + + def _register_changed(self, obj): + if hasattr(obj, '_instance_key'): + self._register_dirty(obj) + else: + self._register_new(obj) + def _register_new(self, obj): + self._attach(obj) + self.uow.register_new(obj) + def _register_dirty(self, obj): + self._attach(obj) + self.uow.register_dirty(obj) + def _register_clean(self, obj): + self._attach(obj) + self.uow.register_clean(obj) + def _register_deleted(self, obj): + self._attach(obj) + self.uow.register_deleted(obj) + + def _attach(self, obj): + """given an object, attaches it to this session. """ + if getattr(obj, '_sa_session_id', None) != self.hash_key: + old = getattr(obj, '_sa_session_id', None) + if old is not None: + raise exceptions.InvalidRequestError("Object '%s' is already attached to session '%s' (this is '%s')" % (repr(obj), old, id(self))) + + # auto-removal from the old session is disabled. but if we decide to + # turn it back on, do it as below: gingerly since _sessions is a WeakValueDict + # and it might be affected by other threads + try: + sess = _sessions[old] + except KeyError: + sess = None + if sess is not None: + sess.expunge(old) + key = getattr(obj, '_instance_key', None) + if key is not None: + self.identity_map[key] = obj + obj._sa_session_id = self.hash_key + + def _unattach(self, obj): + if not self._is_attached(obj): #getattr(obj, '_sa_session_id', None) != self.hash_key: + raise exceptions.InvalidRequestError("Object '%s' is not attached to this Session" % repr(obj)) + del obj._sa_session_id + + def _is_attached(self, obj): + return getattr(obj, '_sa_session_id', None) == self.hash_key + def __contains__(self, obj): + return self._is_attached(obj) and (obj in self.uow.new or self.uow.has_key(obj._instance_key)) + def __iter__(self): + return iter(self.uow.identity_map.values()) + def _get(self, key): + return self.uow._get(key) + def has_key(self, key): + return self.uow.has_key(key) + def is_expired(self, instance, **kwargs): + return self.uow.is_expired(instance, **kwargs) + + dirty = property(lambda s:s.uow.dirty, doc="a Set of all objects marked as 'dirty' within this Session") + deleted = property(lambda s:s.uow.deleted, doc="a Set of all objects marked as 'deleted' within this Session") + new = property(lambda s:s.uow.new, doc="a Set of all objects marked as 'new' within this Session.") + identity_map = property(lambda s:s.uow.identity_map, doc="a WeakValueDictionary consisting of all objects within this Session keyed to their _instance_key value.") + + + def import_instance(self, *args, **kwargs): + """deprecated; a synynom for merge()""" + return self.merge(*args, **kwargs) + +def get_id_key(ident, class_, entity_name=None): + return Session.get_id_key(ident, class_, entity_name) + +def get_row_key(row, class_, primary_key, entity_name=None): + return Session.get_row_key(row, class_, primary_key, entity_name) + +def object_mapper(obj, **kwargs): + return sqlalchemy.orm.object_mapper(obj, **kwargs) + +def class_mapper(class_, **kwargs): + return sqlalchemy.orm.class_mapper(class_, **kwargs) + +# this is the AttributeManager instance used to provide attribute behavior on objects. +# to all the "global variable police" out there: its a stateless object. +global_attributes = unitofwork.global_attributes + +# this dictionary maps the hash key of a Session to the Session itself, and +# acts as a Registry with which to locate Sessions. this is to enable +# object instances to be associated with Sessions without having to attach the +# actual Session object directly to the object instance. +_sessions = weakref.WeakValueDictionary() + +def object_session(obj): + hashkey = getattr(obj, '_sa_session_id', None) + if hashkey is not None: + return _sessions.get(hashkey) + return None + +unitofwork.object_session = object_session + + +def get_session(obj=None): + """deprecated""" + if obj is not None: + return object_session(obj) + raise exceptions.InvalidRequestError("get_session() is deprecated, and does not return the thread-local session anymore. Use the SessionContext.mapper_extension or import sqlalchemy.mod.threadlocal to establish a default thread-local context.") diff --git a/lib/sqlalchemy/mapping/sync.py b/lib/sqlalchemy/orm/sync.py index cfce9b6b6..8bb7d5aff 100644 --- a/lib/sqlalchemy/mapping/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -10,7 +10,7 @@ import sqlalchemy.sql as sql import sqlalchemy.schema as schema from sqlalchemy.exceptions import * -"""contains the ClauseSynchronizer class which is used to map attributes between two objects +"""contains the ClauseSynchronizer class, which is used to map attributes between two objects in a manner corresponding to a SQL clause that compares column values.""" ONETOMANY = 0 diff --git a/lib/sqlalchemy/mapping/topological.py b/lib/sqlalchemy/orm/topological.py index 495eec8ce..89e760039 100644 --- a/lib/sqlalchemy/mapping/topological.py +++ b/lib/sqlalchemy/orm/topological.py @@ -102,7 +102,7 @@ class QueueDependencySorter(object): n.cycles = Set([n]) continue else: - raise CommitError("Self-referential dependency detected " + repr(t)) + raise FlushError("Self-referential dependency detected " + repr(t)) childnode = nodes[t[1]] parentnode = nodes[t[0]] self._add_edge(edges, (parentnode, childnode)) @@ -136,7 +136,7 @@ class QueueDependencySorter(object): continue else: # long cycles not allowed - raise CommitError("Circular dependency detected " + repr(edges) + repr(queue)) + raise FlushError("Circular dependency detected " + repr(edges) + repr(queue)) node = queue.pop() if not hasattr(node, '_cyclical'): output.append(node) @@ -328,7 +328,7 @@ class TreeDependencySorter(object): elif parentnode.is_descendant_of(childnode): # check for a line thats backwards with nodes in between, this is a # circular dependency (although confirmation on this would be helpful) - raise CommitError("Circular dependency detected") + raise FlushError("Circular dependency detected") elif not childnode.is_descendant_of(parentnode): # if relationship doesnt exist, connect nodes together root = childnode.get_sibling_ancestor(parentnode) diff --git a/lib/sqlalchemy/mapping/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 873bed548..3a2750edb 100644 --- a/lib/sqlalchemy/mapping/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -1,4 +1,4 @@ -# unitofwork.py +# orm/unitofwork.py # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under @@ -34,46 +34,57 @@ class UOWProperty(attributes.SmartProperty): super(UOWProperty, self).__init__(*args, **kwargs) self.class_ = class_ property = property(lambda s:class_mapper(s.class_).props[s.key], doc="returns the MapperProperty object associated with this property") - -class UOWListElement(attributes.ListElement): + + +class UOWListElement(attributes.ListAttribute): """overrides ListElement to provide unit-of-work "dirty" hooks when list attributes are modified, plus specialzed append() method.""" - def __init__(self, obj, key, data=None, deleteremoved=False, **kwargs): - attributes.ListElement.__init__(self, obj, key, data=data, **kwargs) - self.deleteremoved = deleteremoved - def list_value_changed(self, obj, key, item, listval, isdelete): - sess = get_session(obj) - if not isdelete and sess.deleted.contains(item): - #raise InvalidRequestError("re-inserting a deleted value into a list") - del sess.deleted[item] - sess.modified_lists.append(self) - if self.deleteremoved and isdelete: - sess.register_deleted(item) + def __init__(self, obj, key, data=None, cascade=None, **kwargs): + attributes.ListAttribute.__init__(self, obj, key, data=data, **kwargs) + self.cascade = cascade + def do_value_changed(self, obj, key, item, listval, isdelete): + sess = object_session(obj) + if sess is not None: + sess._register_changed(obj) + if self.cascade is not None: + if not isdelete: + if self.cascade.save_update: + sess.save_or_update(item) def append(self, item, _mapper_nohistory = False): if _mapper_nohistory: self.append_nohistory(item) else: - attributes.ListElement.append(self, item) + attributes.ListAttribute.append(self, item) + +class UOWScalarElement(attributes.ScalarAttribute): + def __init__(self, obj, key, cascade=None, **kwargs): + attributes.ScalarAttribute.__init__(self, obj, key, **kwargs) + self.cascade=cascade + def do_value_changed(self, oldvalue, newvalue): + obj = self.obj + sess = object_session(obj) + if sess is not None: + sess._register_changed(obj) + if newvalue is not None and self.cascade is not None: + if self.cascade.save_update: + sess.save_or_update(newvalue) class UOWAttributeManager(attributes.AttributeManager): """overrides AttributeManager to provide unit-of-work "dirty" hooks when scalar attribues are modified, plus factory methods for UOWProperrty/UOWListElement.""" def __init__(self): attributes.AttributeManager.__init__(self) - def value_changed(self, obj, key, value): - if hasattr(obj, '_instance_key'): - get_session(obj).register_dirty(obj) - else: - get_session(obj).register_new(obj) - def create_prop(self, class_, key, uselist, callable_, **kwargs): return UOWProperty(class_, self, key, uselist, callable_, **kwargs) + def create_scalar(self, obj, key, **kwargs): + return UOWScalarElement(obj, key, **kwargs) + def create_list(self, obj, key, list_, **kwargs): return UOWListElement(obj, key, list_, **kwargs) class UnitOfWork(object): - """main UOW object which stores lists of dirty/new/deleted objects, as well as 'modified_lists' for list attributes. provides top-level "flush" functionality as well as the transaction boundaries with the SQLEngine(s) involved in a write operation.""" + """main UOW object which stores lists of dirty/new/deleted objects. provides top-level "flush" functionality as well as the transaction boundaries with the SQLEngine(s) involved in a write operation.""" def __init__(self, identity_map=None): if identity_map is not None: self.identity_map = identity_map @@ -83,7 +94,7 @@ class UnitOfWork(object): self.attributes = global_attributes self.new = util.HashSet(ordered = True) self.dirty = util.HashSet() - self.modified_lists = util.HashSet() + self.deleted = util.HashSet() def get(self, class_, *id): @@ -97,14 +108,14 @@ class UnitOfWork(object): def _put(self, key, obj): self.identity_map[key] = obj - def refresh(self, obj): + def refresh(self, sess, obj): self.rollback_object(obj) - object_mapper(obj)._get(obj._instance_key, reload=True) + sess.query(obj.__class__)._get(obj._instance_key, reload=True) - def expire(self, obj): + def expire(self, sess, obj): self.rollback_object(obj) def exp(): - object_mapper(obj)._get(obj._instance_key, reload=True) + sess.query(obj.__class__)._get(obj._instance_key, reload=True) global_attributes.trigger_history(obj, exp) def is_expired(self, obj, unexpire=False): @@ -174,14 +185,18 @@ class UnitOfWork(object): self.attributes.commit(obj) def register_new(self, obj): + if hasattr(obj, '_instance_key'): + raise InvalidRequestError("Object '%s' already has an identity - it cant be registered as new" % repr(obj)) if not self.new.contains(obj): self.new.append(obj) + self.unregister_deleted(obj) def register_dirty(self, obj): if not self.dirty.contains(obj): self._validate_obj(obj) self.dirty.append(obj) - + self.unregister_deleted(obj) + def is_dirty(self, obj): if not self.dirty.contains(obj): return False @@ -192,10 +207,6 @@ class UnitOfWork(object): if not self.deleted.contains(obj): self._validate_obj(obj) self.deleted.append(obj) - mapper = object_mapper(obj) - # TODO: should the cascading delete dependency thing - # happen wtihin PropertyLoader.process_dependencies ? - mapper.register_deleted(obj, self) def unregister_deleted(self, obj): try: @@ -203,10 +214,10 @@ class UnitOfWork(object): except KeyError: pass - def flush(self, session, *objects): + def flush(self, session, objects=None, echo=False): flush_context = UOWTransaction(self, session) - if len(objects): + if objects is not None: objset = util.HashSet(iter=objects) else: objset = None @@ -217,42 +228,20 @@ class UnitOfWork(object): if self.deleted.contains(obj): continue flush_context.register_object(obj) - for item in self.modified_lists: - obj = item.obj - if objset is not None and not objset.contains(obj): - continue - if self.deleted.contains(obj): - continue - flush_context.register_object(obj, listonly = True) - flush_context.register_saved_history(item) - -# for o in item.added_items() + item.deleted_items(): -# if self.deleted.contains(o): -# continue -# flush_context.register_object(o, listonly=True) - + for obj in self.deleted: if objset is not None and not objset.contains(obj): continue flush_context.register_object(obj, isdelete=True) - - engines = util.HashSet() - for mapper in flush_context.mappers: - for e in session.engines(mapper): - engines.append(e) - - echo_commit = False - for e in engines: - echo_commit = echo_commit or e.echo_uow - e.begin() + + trans = session.create_transaction(autoflush=False) + flush_context.transaction = trans try: - flush_context.execute(echo=echo_commit) + flush_context.execute(echo=echo) + trans.commit() except: - for e in engines: - e.rollback() + trans.rollback() raise - for e in engines: - e.commit() flush_context.post_exec() @@ -279,8 +268,8 @@ class UOWTransaction(object): self.mappers = util.HashSet() self.dependencies = {} self.tasks = {} - self.saved_histories = util.HashSet() self.__modified = False + self.__is_executing = False def register_object(self, obj, isdelete = False, listonly = False, postupdate=False, **kwargs): """adds an object to this UOWTransaction to be updated in the database. @@ -292,6 +281,7 @@ class UOWTransaction(object): save/delete operation on the object itself, unless an additional save/delete registration is entered for the object.""" #print "REGISTER", repr(obj), repr(getattr(obj, '_instance_key', None)), str(isdelete), str(listonly) + # things can get really confusing if theres duplicate instances floating around, # so make sure everything is OK self.uow._validate_obj(obj) @@ -299,25 +289,31 @@ class UOWTransaction(object): mapper = object_mapper(obj) self.mappers.append(mapper) task = self.get_task_by_mapper(mapper) - + if postupdate: mod = task.append_postupdate(obj) - self.__modified = self.__modified or mod + if mod: self._mark_modified() return - + # for a cyclical task, things need to be sorted out already, # so this object should have already been added to the appropriate sub-task # can put an assertion here to make sure.... if task.circular: return - + mod = task.append(obj, listonly, isdelete=isdelete, **kwargs) - self.__modified = self.__modified or mod + if mod: self._mark_modified() def unregister_object(self, obj): mapper = object_mapper(obj) task = self.get_task_by_mapper(mapper) - task.delete(obj) + if obj in task.objects: + task.delete(obj) + self._mark_modified() + + def _mark_modified(self): + #if self.__is_executing: + # raise "test assertion failed" self.__modified = True def get_task_by_mapper(self, mapper): @@ -329,16 +325,18 @@ class UOWTransaction(object): try: return self.tasks[mapper] except KeyError: - return UOWTask(self, mapper) + task = UOWTask(self, mapper) + task.mapper.register_dependencies(self) + return task def register_dependency(self, mapper, dependency): """called by mapper.PropertyLoader to register the objects handled by one mapper being dependent on the objects handled by another.""" # correct for primary mapper (the mapper offcially associated with the class) self.dependencies[(mapper._primary_mapper(), dependency._primary_mapper())] = True - self.__modified = True + self._mark_modified() - def register_processor(self, mapper, processor, mapperfrom, isdeletefrom): + def register_processor(self, mapper, processor, mapperfrom): """called by mapper.PropertyLoader to register itself as a "processor", which will be associated with a particular UOWTask, and be given a list of "dependent" objects corresponding to another UOWTask to be processed, either after that secondary @@ -346,23 +344,38 @@ class UOWTransaction(object): # when the task from "mapper" executes, take the objects from the task corresponding # to "mapperfrom"'s list of save/delete objects, and send them to "processor" # for dependency processing - #print "registerprocessor", str(mapper), repr(processor.key), str(mapperfrom), repr(isdeletefrom) + #print "registerprocessor", str(mapper), repr(processor.key), str(mapperfrom) # correct for primary mapper (the mapper offcially associated with the class) mapper = mapper._primary_mapper() mapperfrom = mapperfrom._primary_mapper() task = self.get_task_by_mapper(mapper) targettask = self.get_task_by_mapper(mapperfrom) - task.dependencies.append(UOWDependencyProcessor(processor, targettask, isdeletefrom)) - self.__modified = True - - def register_saved_history(self, listobj): - self.saved_histories.append(listobj) + up = UOWDependencyProcessor(processor, targettask) + task.dependencies.append(up) + self._mark_modified() def execute(self, echo=False): - for task in self.tasks.values(): - task.mapper.register_dependencies(self) - + # pre-execute dependency processors. this process may + # result in new tasks, objects and/or dependency processors being added, + # particularly with 'delete-orphan' cascade rules. + # keep running through the full list of tasks until all + # objects have been processed. + while True: + ret = False + for task in self.tasks.values(): + for up in task.dependencies: + if up.preexecute(self): + ret = True + if not ret: + break + + # flip the execution flag on. in some test cases + # we like to check this flag against any new objects being added, since everything + # should be registered by now. there is a slight exception in the case of + # post_update requests; this should be fixed. + self.__is_executing = True + head = self._sort_dependencies() self.__modified = False if LOG or echo: @@ -372,11 +385,10 @@ class UOWTransaction(object): print "Task dump:\n" + head.dump() if head is not None: head.execute(self) + #if self.__modified and head is not None: + # raise "Assertion failed ! new pre-execute dependency step should eliminate post-execute changes (except post_update stuff)." if LOG or echo: - if self.__modified and head is not None: - print "\nAfter Execute:\n" + head.dump() - else: - print "\nExecute complete (no post-exec changes)\n" + print "\nExecute complete\n" def post_exec(self): """after an execute/flush is completed, all of the objects and lists that have @@ -388,18 +400,6 @@ class UOWTransaction(object): self.uow._remove_deleted(elem.obj) else: self.uow.register_clean(elem.obj) - - for obj in self.saved_histories: - try: - obj.commit() - del self.uow.modified_lists[obj] - except KeyError: - pass - - # this assertion only applies to a full flush(), not a - # partial one - #if len(self.uow.new) > 0 or len(self.uow.dirty) >0 or len(self.uow.modified_lists) > 0: - # raise "assertion failed" def _sort_dependencies(self): """creates a hierarchical tree of dependent tasks. the root node is returned. @@ -422,6 +422,7 @@ class UOWTransaction(object): mappers = util.HashSet() for task in self.tasks.values(): mappers.append(task.mapper) + head = DependencySorter(self.dependencies, mappers).sort(allow_all_cycles=True) #print str(head) task = sort_hier(head) @@ -432,31 +433,77 @@ class UOWTaskElement(object): """an element within a UOWTask. corresponds to a single object instance to be saved, deleted, or just part of the transaction as a placeholder for further dependencies (i.e. 'listonly'). - in the case of self-referential mappers, may also store a "childtask", which is a - UOWTask containing objects dependent on this element's object instance.""" + in the case of self-referential mappers, may also store a list of childtasks, + further UOWTasks containing objects dependent on this element's object instance.""" def __init__(self, obj): self.obj = obj - self.listonly = True + self.__listonly = True self.childtasks = [] - self.isdelete = False - self.mapper = None + self.__isdelete = False + self.__preprocessed = {} + def _get_listonly(self): + return self.__listonly + def _set_listonly(self, value): + """set_listonly is a one-way setter, will only go from True to False.""" + if not value and self.__listonly: + self.__listonly = False + self.clear_preprocessed() + def _get_isdelete(self): + return self.__isdelete + def _set_isdelete(self, value): + if self.__isdelete is not value: + self.__isdelete = value + self.clear_preprocessed() + listonly = property(_get_listonly, _set_listonly) + isdelete = property(_get_isdelete, _set_isdelete) + + def mark_preprocessed(self, processor): + """marks this element as "preprocessed" by a particular UOWDependencyProcessor. preprocessing is the step + which sweeps through all the relationships on all the objects in the flush transaction and adds other objects + which are also affected, In some cases it can switch an object from "tosave" to "todelete". changes to the state + of this UOWTaskElement will reset all "preprocessed" flags, causing it to be preprocessed again. When all UOWTaskElements + have been fully preprocessed by all UOWDependencyProcessors, then the topological sort can be done.""" + self.__preprocessed[processor] = True + def is_preprocessed(self, processor): + return self.__preprocessed.get(processor, False) + def clear_preprocessed(self): + self.__preprocessed.clear() def __repr__(self): return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) ) class UOWDependencyProcessor(object): """in between the saving and deleting of objects, process "dependent" data, such as filling in a foreign key on a child item from a new primary key, or deleting association rows before a - delete.""" - def __init__(self, processor, targettask, isdeletefrom): + delete. This object acts as a proxy to a DependencyProcessor.""" + def __init__(self, processor, targettask): self.processor = processor self.targettask = targettask - self.isdeletefrom = isdeletefrom - + + def preexecute(self, trans): + """traverses all objects handled by this dependency processor and locates additional objects which should be + part of the transaction, such as those affected deletes, orphans to be deleted, etc. Returns True if any + objects were preprocessed, or False if no objects were preprocessed.""" + def getobj(elem): + elem.mark_preprocessed(self) + return elem.obj + + ret = False + elements = [getobj(elem) for elem in self.targettask.tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)] + if len(elements): + ret = True + self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) + + elements = [getobj(elem) for elem in self.targettask.todelete_elements if elem.obj is not None and not elem.is_preprocessed(self)] + if len(elements): + ret = True + self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) + return ret + def execute(self, trans, delete): if not delete: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.tosave_elements() if elem.obj is not None], trans, delete = delete) + self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.tosave_elements if elem.obj is not None], trans, delete=False) else: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.todelete_elements() if elem.obj is not None], trans, delete = delete) + self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.todelete_elements if elem.obj is not None], trans, delete=True) def get_object_dependencies(self, obj, trans, passive): return self.processor.get_object_dependencies(obj, trans, passive=passive) @@ -465,7 +512,7 @@ class UOWDependencyProcessor(object): return self.processor.whose_dependent_on_who(obj, o) def branch(self, task): - return UOWDependencyProcessor(self.processor, task, self.isdeletefrom) + return UOWDependencyProcessor(self.processor, task) class UOWTask(object): def __init__(self, uowtransaction, mapper): @@ -477,13 +524,11 @@ class UOWTask(object): self.dependencies = [] self.cyclical_dependencies = [] self.circular = None - self.postcircular = None self.childtasks = [] -# print "NEW TASK", repr(self) def is_empty(self): return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0 - + def append(self, obj, listonly = False, childtask = None, isdelete = False): """appends an object to this task, to be either saved or deleted depending on the 'isdelete' attribute of this UOWTask. 'listonly' indicates that the object should @@ -504,6 +549,8 @@ class UOWTask(object): rec.childtasks.append(childtask) if isdelete: rec.isdelete = True + #if not childtask: + # rec.preprocessed = False return retval def append_postupdate(self, obj): @@ -524,41 +571,29 @@ class UOWTask(object): self.circular.execute(trans) return - self.mapper.save_obj(self.tosave_objects(), trans) - for dep in self.cyclical_save_dependencies(): - dep.execute(trans, delete=False) - for element in self.tosave_elements(): + self.mapper.save_obj(self.tosave_objects, trans) + for dep in self.cyclical_dependencies: + dep.execute(trans, False) + for element in self.tosave_elements: for task in element.childtasks: task.execute(trans) - for dep in self.save_dependencies(): - dep.execute(trans, delete=False) - for dep in self.delete_dependencies(): - dep.execute(trans, delete=True) - for dep in self.cyclical_delete_dependencies(): - dep.execute(trans, delete=True) + for dep in self.dependencies: + dep.execute(trans, False) + for dep in self.dependencies: + dep.execute(trans, True) + for dep in self.cyclical_dependencies: + dep.execute(trans, True) for child in self.childtasks: child.execute(trans) - for element in self.todelete_elements(): + for element in self.todelete_elements: for task in element.childtasks: task.execute(trans) - self.mapper.delete_obj(self.todelete_objects(), trans) - - def tosave_elements(self): - return [rec for rec in self.objects.values() if not rec.isdelete] - def todelete_elements(self): - return [rec for rec in self.objects.values() if rec.isdelete] - def tosave_objects(self): - return [rec.obj for rec in self.objects.values() if rec.obj is not None and not rec.listonly and rec.isdelete is False] - def todelete_objects(self): - return [rec.obj for rec in self.objects.values() if rec.obj is not None and not rec.listonly and rec.isdelete is True] - def save_dependencies(self): - return [dep for dep in self.dependencies if not dep.isdeletefrom] - def cyclical_save_dependencies(self): - return [dep for dep in self.cyclical_dependencies if not dep.isdeletefrom] - def delete_dependencies(self): - return [dep for dep in self.dependencies if dep.isdeletefrom] - def cyclical_delete_dependencies(self): - return [dep for dep in self.cyclical_dependencies if dep.isdeletefrom] + self.mapper.delete_obj(self.todelete_objects, trans) + + tosave_elements = property(lambda self: [rec for rec in self.objects.values() if not rec.isdelete]) + todelete_elements = property(lambda self:[rec for rec in self.objects.values() if rec.isdelete]) + tosave_objects = property(lambda self:[rec.obj for rec in self.objects.values() if rec.obj is not None and not rec.listonly and rec.isdelete is False]) + todelete_objects = property(lambda self:[rec.obj for rec in self.objects.values() if rec.obj is not None and not rec.listonly and rec.isdelete is True]) def _sort_circular_dependencies(self, trans, cycles): """for a single task, creates a hierarchical tree of "subtasks" which associate @@ -567,30 +602,30 @@ class UOWTask(object): of its object list contain dependencies on each other. this is not the normal case; this logic only kicks in when something like - a hierarchical tree is being represented.""" + a hierarchical tree is being represented. + + """ allobjects = [] for task in cycles: allobjects += task.objects.keys() tuples = [] - objecttotask = {} - cycles = Set(cycles) + #print "BEGIN CIRC SORT-------" + #print "PRE-CIRC:" + #print list(cycles)[0].dump() + # dependency processors that arent part of the cyclical thing # get put here extradeplist = [] - def get_object_task(parent, obj): - try: - return objecttotask[obj] - except KeyError: - t = UOWTask(None, parent.mapper) - t.parent = parent - objecttotask[obj] = t - return t - + object_to_original_task = {} + + # organizes a set of new UOWTasks that will be assembled into + # the final tree, for the purposes of holding new UOWDependencyProcessors + # which process small sub-sections of dependent parent/child operations dependencies = {} def get_dependency_task(obj, depprocessor): try: @@ -605,10 +640,7 @@ class UOWTask(object): dp[depprocessor] = l return l - # work out a list of all the "dependency processors" that - # represent objects that have to be dependency sorted at the - # per-object level. all other dependency processors go in - # "extradep." + # organize all original UOWDependencyProcessors by their target task deps_by_targettask = {} for task in cycles: for dep in task.dependencies: @@ -620,53 +652,38 @@ class UOWTask(object): for task in cycles: for taskelement in task.objects.values(): obj = taskelement.obj + object_to_original_task[obj] = task #print "OBJ", repr(obj), "TASK", repr(task) - # create a placeholder UOWTask that may be built into the final - # task tree - get_object_task(task, obj) for dep in deps_by_targettask.get(task, []): - (processor, targettask, isdelete) = (dep.processor, dep.targettask, dep.isdeletefrom) - if taskelement.isdelete is not dep.isdeletefrom: + # is this dependency involved in one of the cycles ? + cyclicaldep = dep.targettask in cycles and trans.get_task_by_mapper(dep.processor.mapper) in cycles + if not cyclicaldep: continue - #print "GETING LIST OFF PROC", processor.key, "OBJ", repr(obj) - - # traverse through the modified child items of each object. normally this - # is done via PropertyLoader in properties.py, but we need all the info - # up front here to do the object-level topological sort. + + (processor, targettask) = (dep.processor, dep.targettask) + isdelete = taskelement.isdelete # list of dependent objects from this object childlist = dep.get_object_dependencies(obj, trans, passive = True) # the task corresponding to the processor's objects childtask = trans.get_task_by_mapper(processor.mapper) - # is this dependency involved in one of the cycles ? - cyclicaldep = dep.targettask in cycles and trans.get_task_by_mapper(dep.processor.mapper) in cycles - if isdelete: - childlist = childlist.unchanged_items() + childlist.deleted_items() - else: - childlist = childlist.added_items() + +# if isdelete: +# childlist = childlist.unchanged_items() + childlist.deleted_items() +# else: +# childlist = childlist.added_items() + + childlist = childlist.added_items() + childlist.unchanged_items() + childlist.deleted_items() for o in childlist: - if o is None: - # this can be None due to the many-to-one dependency processor added - # for deleted items, line 385 properties.py + if o is None or o not in childtask.objects: continue - if not o in childtask.objects: - # item needs to be saved since its added, or attached to a deleted object - childtask.append(o, isdelete=isdelete and dep.processor.private) - if cyclicaldep: - # cyclical, so create a placeholder UOWTask that may be built into the - # final task tree - t = get_object_task(childtask, o) - if not cyclicaldep: - # not cyclical, so we are done with this - continue - # cyclical, so create an ordered pair for the dependency sort whosdep = dep.whose_dependent_on_who(obj, o) if whosdep is not None: tuples.append(whosdep) - # then locate a UOWDependencyProcessor to add the object onto, which - # will handle the modifications between saves/deletes + # create a UOWDependencyProcessor representing this pair of objects. + # append it to a UOWTask if whosdep[0] is obj: get_dependency_task(whosdep[0], dep).append(whosdep[0], isdelete=isdelete) else: @@ -680,23 +697,38 @@ class UOWTask(object): #print str(head) + hierarchical_tasks = {} + def get_object_task(obj): + try: + return hierarchical_tasks[obj] + except KeyError: + originating_task = object_to_original_task[obj] + return hierarchical_tasks.setdefault(obj, UOWTask(None, originating_task.mapper)) + def make_task_tree(node, parenttask): """takes a dependency-sorted tree of objects and creates a tree of UOWTasks""" - t = objecttotask[node.item] + #print "MAKETASKTREE", node.item + + t = get_object_task(node.item) + for n in node.children: + t2 = make_task_tree(n, t) + can_add_to_parent = t.mapper is parenttask.mapper - if can_add_to_parent: - parenttask.append(node.item, t.parent.objects[node.item].listonly, isdelete=t.parent.objects[node.item].isdelete, childtask=t) - else: - t.append(node.item, t.parent.objects[node.item].listonly, isdelete=t.parent.objects[node.item].isdelete) - parenttask.append(None, listonly=False, isdelete=t.parent.objects[node.item].isdelete, childtask=t) + original_task = object_to_original_task[node.item] + if original_task.objects.has_key(node.item): + if can_add_to_parent: + parenttask.append(node.item, original_task.objects[node.item].listonly, isdelete=original_task.objects[node.item].isdelete, childtask=t) + else: + t.append(node.item, original_task.objects[node.item].listonly, isdelete=original_task.objects[node.item].isdelete) + parenttask.append(None, listonly=False, isdelete=original_task.objects[node.item].isdelete, childtask=t) + #else: + # parenttask.append(None, listonly=False, isdelete=original_task.objects[node.item].isdelete, childtask=t) if dependencies.has_key(node.item): for depprocessor, deptask in dependencies[node.item].iteritems(): if can_add_to_parent: parenttask.cyclical_dependencies.append(depprocessor.branch(deptask)) else: t.cyclical_dependencies.append(depprocessor.branch(deptask)) - for n in node.children: - t2 = make_task_tree(n, t) return t # this is the new "circular" UOWTask which will execute in place of "self" @@ -707,6 +739,7 @@ class UOWTask(object): t.dependencies += [d for d in extradeplist] t.childtasks = self.childtasks make_task_tree(head, t) + #print t.dump() return t def dump(self): @@ -729,44 +762,43 @@ class UOWTask(object): buf.write(text) headers[text] = True - def _dump_processor(proc): - if proc.isdeletefrom: + def _dump_processor(proc, deletes): + if deletes: val = [t for t in proc.targettask.objects.values() if t.isdelete] else: val = [t for t in proc.targettask.objects.values() if not t.isdelete] - buf.write(_indent() + " |- UOWDependencyProcessor(%d) %s attribute on %s (%s)\n" % ( - id(proc), + buf.write(_indent() + " |- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % ( repr(proc.processor.key), - (proc.isdeletefrom and - "%s's to be deleted" % _repr_task_class(proc.targettask) - or "saved %s's" % _repr_task_class(proc.targettask)), + ("%s's to be %s" % (_repr_task_class(proc.targettask), deletes and "deleted" or "saved")), + id(proc), _repr_task(proc.targettask)) ) if len(val) == 0: buf.write(_indent() + " | |-" + "(no objects)\n") for v in val: - buf.write(_indent() + " | |-" + _repr_task_element(v) + "\n") + buf.write(_indent() + " | |-" + _repr_task_element(v, proc.processor.key) + "\n") - def _repr_task_element(te): + def _repr_task_element(te, attribute=None): if te.obj is None: objid = "(placeholder)" else: - objid = "%s(%d)" % (te.obj.__class__.__name__, id(te.obj)) - return "UOWTaskElement(%d): %s %s%s" % (id(te), objid, (te.listonly and '(listonly)' or (te.isdelete and '(delete' or '(save')), - (te.mapper is not None and " w/ " + str(te.mapper) + ")" or ")") - ) + if attribute is not None: + objid = "%s(%d).%s" % (te.obj.__class__.__name__, id(te.obj), attribute) + else: + objid = "%s(%d)" % (te.obj.__class__.__name__, id(te.obj)) + return "%s (UOWTaskElement(%d, %s))" % (objid, id(te), (te.listonly and 'listonly' or (te.isdelete and 'delete' or 'save'))) def _repr_task(task): if task.mapper is not None: if task.mapper.__class__.__name__ == 'Mapper': - name = task.mapper.class_.__name__ + "/" + str(task.mapper.primarytable) + "/" + str(id(task.mapper)) + name = task.mapper.class_.__name__ + "/" + task.mapper.local_table.name + "/" + str(task.mapper.entity_name) else: name = repr(task.mapper) else: name = '(none)' - return ("UOWTask(%d) '%s'" % (id(task), name)) + return ("UOWTask(%d, %s)" % (id(task), name)) def _repr_task_class(task): if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper': return task.mapper.class_.__name__ @@ -790,51 +822,55 @@ class UOWTask(object): buf.write(i + " " + _repr_task(self)) buf.write("\n") - for rec in self.tosave_elements(): + for rec in self.tosave_elements: if rec.listonly: continue header(buf, _indent() + " |- Save elements\n") - buf.write(_indent() + " |- Save: " + _repr_task_element(rec) + "\n") - for dep in self.cyclical_save_dependencies(): + buf.write(_indent() + " |- " + _repr_task_element(rec) + "\n") + for dep in self.cyclical_dependencies: header(buf, _indent() + " |- Cyclical Save dependencies\n") - _dump_processor(dep) - for element in self.tosave_elements(): + _dump_processor(dep, False) + for element in self.tosave_elements: for task in element.childtasks: header(buf, _indent() + " |- Save subelements of UOWTaskElement(%s)\n" % id(element)) task._dump(buf, indent + 1) - for dep in self.save_dependencies(): + for dep in self.dependencies: header(buf, _indent() + " |- Save dependencies\n") - _dump_processor(dep) - for dep in self.delete_dependencies(): + _dump_processor(dep, False) + for dep in self.dependencies: header(buf, _indent() + " |- Delete dependencies\n") - _dump_processor(dep) - for dep in self.cyclical_delete_dependencies(): + _dump_processor(dep, True) + for dep in self.cyclical_dependencies: header(buf, _indent() + " |- Cyclical Delete dependencies\n") - _dump_processor(dep) + _dump_processor(dep, True) for child in self.childtasks: header(buf, _indent() + " |- Child tasks\n") child._dump(buf, indent + 1) # for obj in self.postupdate: # header(buf, _indent() + " |- Post Update objects\n") # buf.write(_repr(obj) + "\n") - for element in self.todelete_elements(): + for element in self.todelete_elements: for task in element.childtasks: header(buf, _indent() + " |- Delete subelements of UOWTaskElement(%s)\n" % id(element)) task._dump(buf, indent + 1) - for rec in self.todelete_elements(): + for rec in self.todelete_elements: if rec.listonly: continue header(buf, _indent() + " |- Delete elements\n") - buf.write(_indent() + " |- Delete: " + _repr_task_element(rec) + "\n") + buf.write(_indent() + " |- " + _repr_task_element(rec) + "\n") - buf.write(_indent() + " |----\n") + if self.is_empty(): + buf.write(_indent() + " |- (empty task)\n") + else: + buf.write(_indent() + " |----\n") + buf.write(_indent() + "\n") def __repr__(self): if self.mapper is not None: if self.mapper.__class__.__name__ == 'Mapper': - name = self.mapper.class_.__name__ + "/" + self.mapper.primarytable.name + name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.name else: name = repr(self.mapper) else: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py new file mode 100644 index 000000000..19cb21367 --- /dev/null +++ b/lib/sqlalchemy/orm/util.py @@ -0,0 +1,55 @@ +# mapper/util.py +# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import sets +import sqlalchemy.sql as sql + +class CascadeOptions(object): + """keeps track of the options sent to relation().cascade""" + def __init__(self, arg=""): + values = sets.Set([c.strip() for c in arg.split(',')]) + self.delete_orphan = "delete-orphan" in values + self.delete = "delete" in values or self.delete_orphan or "all" in values + self.save_update = "save-update" in values or "all" in values + self.merge = "merge" in values or "all" in values + self.expunge = "expunge" in values or "all" in values + self.refresh_expire = "refresh-expire" in values or "all" in values + def __contains__(self, item): + return getattr(self, item.replace("-", "_"), False) + + +def polymorphic_union(table_map, typecolname, aliasname='p_union'): + colnames = sets.Set() + colnamemaps = {} + + for key in table_map.keys(): + table = table_map[key] + + # mysql doesnt like selecting from a select; make it an alias of the select + if isinstance(table, sql.Select): + table = table.alias() + table_map[key] = table + + m = {} + for c in table.c: + colnames.add(c.name) + m[c.name] = c + colnamemaps[table] = m + + def col(name, table): + try: + return colnamemaps[table][name] + except KeyError: + return sql.null().label(name) + + result = [] + for type, table in table_map.iteritems(): + if typecolname is not None: + result.append(sql.select([col(name, table) for name in colnames] + [sql.column("'%s'" % type).label(typecolname)], from_obj=[table])) + else: + result.append(sql.select([col(name, table) for name in colnames], from_obj=[table])) + return sql.union_all(*result).alias(aliasname) +
\ No newline at end of file diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 1603c9873..4452f6419 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -70,41 +70,38 @@ def clear_managers(): class Pool(object): - def __init__(self, echo = False, use_threadlocal = True, logger=None): - self._threadconns = weakref.WeakValueDictionary() + def __init__(self, echo = False, use_threadlocal = True, logger=None, **kwargs): + self._threadconns = {} #weakref.WeakValueDictionary() self._use_threadlocal = use_threadlocal - self._echo = echo + self.echo = echo self._logger = logger or util.Logger(origin='pool') def unique_connection(self): - return ConnectionFairy(self) + return ConnectionFairy(self).checkout() def connect(self): if not self._use_threadlocal: - return ConnectionFairy(self) + return ConnectionFairy(self).checkout() try: - return self._threadconns[thread.get_ident()] + return self._threadconns[thread.get_ident()].checkout() except KeyError: - agent = ConnectionFairy(self) + agent = ConnectionFairy(self).checkout() self._threadconns[thread.get_ident()] = agent return agent - def return_conn(self, conn): - if self._echo: - self.log("return connection to pool") - self.do_return_conn(conn) + def return_conn(self, agent): + if self._use_threadlocal: + try: + del self._threadconns[thread.get_ident()] + except KeyError: + pass + self.do_return_conn(agent.connection) def get(self): - if self._echo: - self.log("get connection from pool") - self.log(self.status()) return self.do_get() def return_invalid(self): - if self._echo: - self.log("return invalid connection to pool") - self.log(self.status()) self.do_return_invalid() def do_get(self): @@ -125,6 +122,7 @@ class Pool(object): class ConnectionFairy(object): def __init__(self, pool, connection=None): self.pool = pool + self.__counter = 0 if connection is not None: self.connection = connection else: @@ -134,16 +132,35 @@ class ConnectionFairy(object): self.connection = None self.pool.return_invalid() raise + if self.pool.echo: + self.pool.log("Connection %s checked out from pool" % repr(self.connection)) def invalidate(self): + if self.pool.echo: + self.pool.log("Invalidate connection %s" % repr(self.connection)) + self.connection.rollback() self.connection = None self.pool.return_invalid() def cursor(self): return CursorFairy(self, self.connection.cursor()) def __getattr__(self, key): return getattr(self.connection, key) + def checkout(self): + if self.connection is None: + raise "this connection is closed" + self.__counter +=1 + return self + def close(self): + self.__counter -=1 + if self.__counter == 0: + self._close() def __del__(self): + self._close() + def _close(self): if self.connection is not None: - self.pool.return_conn(self.connection) + if self.pool.echo: + self.pool.log("Connection %s being returned to pool" % repr(self.connection)) + self.connection.rollback() + self.pool.return_conn(self) self.pool = None self.connection = None @@ -156,19 +173,18 @@ class CursorFairy(object): class SingletonThreadPool(Pool): """Maintains one connection per each thread, never moving to another thread. this is - used for SQLite and other databases with a similar restriction.""" + used for SQLite.""" def __init__(self, creator, **params): Pool.__init__(self, **params) self._conns = {} self._creator = creator def status(self): - return "SingletonThreadPool thread:%d size: %d" % (thread.get_ident(), len(self._conns)) + return "SingletonThreadPool id:%d thread:%d size: %d" % (id(self), thread.get_ident(), len(self._conns)) def do_return_conn(self, conn): - if self._conns.get(thread.get_ident(), None) is None: - self._conns[thread.get_ident()] = conn - + pass + def do_return_invalid(self): try: del self._conns[thread.get_ident()] @@ -177,54 +193,48 @@ class SingletonThreadPool(Pool): def do_get(self): try: - c = self._conns[thread.get_ident()] - if c is None: - return self._creator() + return self._conns[thread.get_ident()] except KeyError: c = self._creator() - self._conns[thread.get_ident()] = None - return c + self._conns[thread.get_ident()] = c + return c class QueuePool(Pool): """uses Queue.Queue to maintain a fixed-size list of connections.""" - def __init__(self, creator, pool_size = 5, max_overflow = 10, **params): + def __init__(self, creator, pool_size = 5, max_overflow = 10, timeout=30, **params): Pool.__init__(self, **params) self._creator = creator self._pool = Queue.Queue(pool_size) self._overflow = 0 - pool_size self._max_overflow = max_overflow + self._timeout = timeout def do_return_conn(self, conn): - if self._echo: - self.log("return QP connection to pool") try: self._pool.put(conn, False) except Queue.Full: self._overflow -= 1 def do_return_invalid(self): - if self._echo: - self.log("return invalid connection") if self._pool.full(): self._overflow -= 1 def do_get(self): - if self._echo: - self.log("get QP connection from pool") - self.log(self.status()) try: - return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow) + return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow, self._timeout) except Queue.Empty: self._overflow += 1 return self._creator() - def __del__(self): + def dispose(self): while True: try: conn = self._pool.get(False) conn.close() except Queue.Empty: break + def __del__(self): + self.dispose() def status(self): tup = (self.size(), self.checkedin(), self.overflow(), self.checkedout()) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index c88249011..8b0c9f0b3 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -14,26 +14,15 @@ structure with its own clause-specific objects as well as the visitor interface, the schema package "plugs in" to the SQL package. """ -import sql -from util import * -from types import * -from exceptions import * +from sqlalchemy import sql, types, exceptions,util +import sqlalchemy import copy, re, string __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', - 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] - -class SchemaMeta(type): - """provides universal constructor arguments for all SchemaItems""" - def __call__(self, *args, **kwargs): - engine = kwargs.pop('engine', None) - obj = type.__call__(self, *args, **kwargs) - obj._engine = engine - return obj - + 'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] + class SchemaItem(object): """base class for items that define a database schema.""" - __metaclass__ = SchemaMeta def _init_items(self, *args): for item in args: if item is not None: @@ -43,70 +32,66 @@ class SchemaItem(object): raise NotImplementedError() def __repr__(self): return "%s()" % self.__class__.__name__ - -class EngineMixin(object): - """a mixin for SchemaItems that provides an "engine" accessor.""" - def _derived_engine(self): - """subclasses override this method to return an AbstractEngine - bound to a parent item""" + def _derived_metadata(self): + """subclasses override this method to return a the MetaData + to which this item is bound""" return None def _get_engine(self): - if self._engine is not None: - return self._engine - else: - return self._derived_engine() - engine = property(_get_engine) + return self._derived_metadata().engine + engine = property(lambda s:s._get_engine()) + metadata = property(lambda s:s._derived_metadata()) -def _get_table_key(engine, name, schema): - if schema is not None:# and schema == engine.get_default_schema_name(): - schema = None +def _get_table_key(name, schema): if schema is None: return name else: return schema + "." + name -class TableSingleton(SchemaMeta): +class TableSingleton(type): """a metaclass used by the Table object to provide singleton behavior.""" - def __call__(self, name, engine=None, *args, **kwargs): + def __call__(self, name, metadata, *args, **kwargs): try: - if engine is not None and not isinstance(engine, SchemaEngine): - args = [engine] + list(args) - engine = default_engine + if isinstance(metadata, sql.Engine): + # backwards compatibility - get a BoundSchema associated with the engine + engine = metadata + if not hasattr(engine, '_legacy_metadata'): + engine._legacy_metadata = BoundMetaData(engine) + metadata = engine._legacy_metadata name = str(name) # in case of incoming unicode schema = kwargs.get('schema', None) autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', False) redefine = kwargs.pop('redefine', False) mustexist = kwargs.pop('mustexist', False) useexisting = kwargs.pop('useexisting', False) - if not engine: - table = type.__call__(self, name, engine, **kwargs) - table._init_items(*args) - return table - key = _get_table_key(engine, name, schema) - table = engine.tables[key] + key = _get_table_key(name, schema) + table = metadata.tables[key] if len(args): if redefine: table.reload_values(*args) elif not useexisting: - raise ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name)) + raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name)) return table except KeyError: if mustexist: - raise ArgumentError("Table '%s.%s' not defined" % (schema, name)) - table = type.__call__(self, name, engine, **kwargs) - engine.tables[key] = table + raise exceptions.ArgumentError("Table '%s.%s' not defined" % (schema, name)) + table = type.__call__(self, name, metadata, **kwargs) + table._set_parent(metadata) # load column definitions from the database if 'autoload' is defined # we do it after the table is in the singleton dictionary to support # circular foreign keys if autoload: - engine.reflecttable(table) + if autoload_with: + autoload_with.reflecttable(table) + else: + metadata.engine.reflecttable(table) # initialize all the column, etc. objects. done after # reflection to allow user-overrides table._init_items(*args) return table -class Table(sql.TableClause, SchemaItem): +class Table(SchemaItem, sql.TableClause): """represents a relational database table. This subclasses sql.TableClause to provide a table that is "wired" to an engine. Whereas TableClause represents a table as its used in a SQL expression, Table represents a table as its created in the database. @@ -114,7 +99,7 @@ class Table(sql.TableClause, SchemaItem): Be sure to look at sqlalchemy.sql.TableImpl for additional methods defined on a Table.""" __metaclass__ = TableSingleton - def __init__(self, name, engine, **kwargs): + def __init__(self, name, metadata, **kwargs): """Table objects can be constructed directly. The init method is actually called via the TableSingleton metaclass. Arguments are: @@ -123,9 +108,6 @@ class Table(sql.TableClause, SchemaItem): Further tables constructed with the same name/schema combination will return the same Table instance. - engine : a SchemaEngine instance to provide services to this table. Usually a subclass of - sql.SQLEngine. - *args : should contain a listing of the Column objects for this table. **kwargs : options include: @@ -148,17 +130,18 @@ class Table(sql.TableClause, SchemaItem): """ super(Table, self).__init__(name) - self._engine = engine + self._metadata = metadata self.schema = kwargs.pop('schema', None) if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: self.fullname = self.name self.kwargs = kwargs - + def _derived_metadata(self): + return self._metadata def __repr__(self): return "Table(%s)" % string.join( - [repr(self.name)] + [repr(self.engine)] + + [repr(self.name)] + [repr(self.metadata)] + [repr(x) for x in self.columns] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']] , ',\n') @@ -191,9 +174,9 @@ class Table(sql.TableClause, SchemaItem): def append_index(self, index): self.indexes[index.name] = index - def _set_parent(self, schema): - schema.tables[self.name] = self - self.schema = schema + def _set_parent(self, metadata): + metadata.tables[self.name] = self + self._metadata = metadata def accept_schema_visitor(self, visitor): """traverses the given visitor across the Column objects inside this Table, then calls the visit_table method on the visitor.""" @@ -227,29 +210,35 @@ class Table(sql.TableClause, SchemaItem): return index def deregister(self): - """removes this table from it's engines table registry. this does not + """removes this table from it's metadata. this does not issue a SQL DROP statement.""" - key = _get_table_key(self.engine, self.name, self.schema) - del self.engine.tables[key] - def create(self, **params): - self.engine.create(self) + key = _get_table_key(self.name, self.schema) + del self.metadata.tables[key] + def create(self, connectable=None): + if connectable is not None: + connectable.create(self) + else: + self.engine.create(self) return self - def drop(self, **params): - self.engine.drop(self) - def toengine(self, engine, schema=None): - """returns a singleton instance of this Table with a different engine""" + def drop(self, connectable=None): + if connectable is not None: + connectable.drop(self) + else: + self.engine.drop(self) + def tometadata(self, metadata, schema=None): + """returns a singleton instance of this Table with a different Schema""" try: if schema is None: schema = self.schema - key = _get_table_key(engine, self.name, schema) - return engine.tables[key] - except: + key = _get_table_key(self.name, schema) + return metadata.tables[key] + except KeyError: args = [] for c in self.columns: args.append(c.copy()) - return Table(self.name, engine, schema=schema, *args) + return Table(self.name, metadata, schema=schema, *args) -class Column(sql.ColumnClause, SchemaItem): +class Column(SchemaItem, sql.ColumnClause): """represents a column in a database table. this is a subclass of sql.ColumnClause and represents an actual existing table in the database, in a similar fashion as TableClause/Table.""" def __init__(self, name, type, *args, **kwargs): @@ -297,7 +286,6 @@ class Column(sql.ColumnClause, SchemaItem): order of their creation. """ - name = str(name) # in case of incoming unicode super(Column, self).__init__(name, None, type) self.args = args @@ -310,20 +298,30 @@ class Column(sql.ColumnClause, SchemaItem): self.unique = kwargs.pop('unique', None) self.onupdate = kwargs.pop('onupdate', None) if self.index is not None and self.unique is not None: - raise ArgumentError("Column may not define both index and unique") + raise exceptions.ArgumentError("Column may not define both index and unique") self._foreign_key = None - self._orig = None - self._parent = None if len(kwargs): - raise ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) + raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) - primary_key = SimpleProperty('_primary_key') - foreign_key = SimpleProperty('_foreign_key') - original = property(lambda s: s._orig or s) - parent = property(lambda s:s._parent or s) - engine = property(lambda s: s.table.engine) + primary_key = util.SimpleProperty('_primary_key') + foreign_key = util.SimpleProperty('_foreign_key') columns = property(lambda self:[self]) + def __str__(self): + if self.table is not None: + tname = self.table.displayname + if tname is not None: + return tname + "." + self.name + else: + return self.name + else: + return self.name + + def _derived_metadata(self): + return self.table.metadata + def _get_engine(self): + return self.table.engine + def __repr__(self): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + @@ -343,7 +341,7 @@ class Column(sql.ColumnClause, SchemaItem): def _set_parent(self, table): if getattr(self, 'table', None) is not None: - raise ArgumentError("this Column already has a table!") + raise exceptions.ArgumentError("this Column already has a table!") table.append_column(self) if self.index or self.unique: table.append_index_column(self, index=self.index, @@ -374,7 +372,7 @@ class Column(sql.ColumnClause, SchemaItem): fk = self.foreign_key.copy() c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden) c.table = selectable - c._orig = self.original + c.orig_set = self.orig_set c._parent = self if not c.hidden: selectable.columns[c.key] = c @@ -403,7 +401,7 @@ class ForeignKey(SchemaItem): """Constructs a new ForeignKey object. "column" can be a schema.Column object representing the relationship, or just its string name given as "tablename.columnname". schema can be specified as - "schemaname.tablename.columnname" """ + "schema.tablename.columnname" """ if isinstance(column, unicode): column = str(column) self._colspec = column @@ -426,7 +424,7 @@ class ForeignKey(SchemaItem): def references(self, table): """returns True if the given table is referenced by this ForeignKey.""" - return table._get_col_by_original(self.column, False) is not None + return table.corresponding_column(self.column, False) is not None def _init_column(self): # ForeignKey inits its remote column as late as possible, so tables can @@ -435,13 +433,13 @@ class ForeignKey(SchemaItem): if isinstance(self._colspec, str): m = re.match(r"^([\w_-]+)(?:\.([\w_-]+))?(?:\.([\w_-]+))?$", self._colspec) if m is None: - raise ArgumentError("Invalid foreign key column specification: " + self._colspec) + raise exceptions.ArgumentError("Invalid foreign key column specification: " + self._colspec) if m.group(3) is None: (tname, colname) = m.group(1, 2) - schema = self.parent.original.table.schema + schema = list(self.parent.orig_set)[0].table.schema else: (schema,tname,colname) = m.group(1,2,3) - table = Table(tname, self.parent.engine, mustexist=True, schema=schema) + table = Table(tname, list(self.parent.orig_set)[0].metadata, mustexist=True, schema=schema) if colname is None: key = self.parent self._column = table.c[self.parent.key] @@ -466,20 +464,25 @@ class ForeignKey(SchemaItem): self.parent.foreign_key = self self.parent.table.foreign_keys.append(self) -class DefaultGenerator(SchemaItem, EngineMixin): +class DefaultGenerator(SchemaItem): """Base class for column "default" values.""" - def __init__(self, for_update=False): + def __init__(self, for_update=False, metadata=None): self.for_update = for_update - def _derived_engine(self): - return self.column.table.engine + self._metadata = metadata + def _derived_metadata(self): + try: + return self.column.table.metadata + except AttributeError: + return self._metadata def _set_parent(self, column): self.column = column + self._metadata = self.column.table.metadata if self.for_update: self.column.onupdate = self else: self.column.default = self - def execute(self): - return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.proxy)) + def execute(self, **kwargs): + return self.engine.execute_default(self, **kwargs) def __repr__(self): return "DefaultGenerator()" @@ -534,7 +537,7 @@ class Sequence(DefaultGenerator): return visitor.visit_sequence(self) -class Index(SchemaItem, EngineMixin): +class Index(SchemaItem): """Represents an index of columns from a database table """ def __init__(self, name, *columns, **kw): @@ -555,8 +558,8 @@ class Index(SchemaItem, EngineMixin): self.unique = kw.pop('unique', False) self._init_items(*columns) - def _derived_engine(self): - return self.table.engine + def _derived_metadata(self): + return self.table.metadata def _init_items(self, *args): for column in args: self.append_column(column) @@ -569,23 +572,27 @@ class Index(SchemaItem, EngineMixin): self.table.append_index(self) elif column.table != self.table: # all columns muse be from same table - raise ArgumentError("All index columns must be from same table. " + raise exceptions.ArgumentError("All index columns must be from same table. " "%s is from %s not %s" % (column, column.table, self.table)) elif column.name in [ c.name for c in self.columns ]: - raise ArgumentError("A column may not appear twice in the " + raise exceptions.ArgumentError("A column may not appear twice in the " "same index (%s already has column %s)" % (self.name, column)) self.columns.append(column) - def create(self): - self.engine.create(self) - return self - def drop(self): - self.engine.drop(self) - def execute(self): - self.create() + def create(self, engine=None): + if engine is not None: + engine.create(self) + else: + self.engine.create(self) + return self + def drop(self, engine=None): + if engine is not None: + engine.drop(self) + else: + self.engine.drop(self) def accept_schema_visitor(self, visitor): visitor.visit_index(self) def __str__(self): @@ -596,22 +603,105 @@ class Index(SchemaItem, EngineMixin): for c in self.columns]), (self.unique and ', unique=True') or '') -class SchemaEngine(sql.AbstractEngine): - """a factory object used to create implementations for schema objects. This object - is the ultimate base class for the engine.SQLEngine class.""" - - def __init__(self): +class MetaData(SchemaItem): + """represents a collection of Tables and their associated schema constructs.""" + def __init__(self, name=None): # a dictionary that stores Table objects keyed off their name (and possibly schema name) self.tables = {} - def reflecttable(self, table): - """given a table, will query the database and populate its Column and ForeignKey - objects.""" - raise NotImplementedError() - def schemagenerator(self, **params): - raise NotImplementedError() - def schemadropper(self, **params): - raise NotImplementedError() - + self.name = name + def is_bound(self): + return False + def clear(self): + self.tables.clear() + def table_iterator(self, reverse=True): + return self._sort_tables(self.tables.values(), reverse=reverse) + + def create_all(self, engine=None, tables=None): + if not tables: + tables = self.tables.values() + + if engine is None and self.is_bound(): + engine = self.engine + + def do(conn): + e = conn.engine + ts = self._sort_tables( tables ) + for table in ts: + if e.dialect.has_table(conn, table.name): + continue + conn.create(table) + engine.run_callable(do) + + def drop_all(self, engine=None, tables=None): + if not tables: + tables = self.tables.values() + + if engine is None and self.is_bound(): + engine = self.engine + + def do(conn): + e = conn.engine + ts = self._sort_tables( tables, reverse=True ) + for table in ts: + if e.dialect.has_table(conn, table.name): + conn.drop(table) + engine.run_callable(do) + + def _sort_tables(self, tables, reverse=False): + import sqlalchemy.sql_util + sorter = sqlalchemy.sql_util.TableCollection() + for t in self.tables.values(): + sorter.add(t) + return sorter.sort(reverse=reverse) + + def _derived_metadata(self): + return self + def _get_engine(self): + if not self.is_bound(): + return None + return self._engine + +class BoundMetaData(MetaData): + """builds upon MetaData to provide the capability to bind to an Engine implementation.""" + def __init__(self, engine_or_url, name=None, **kwargs): + super(BoundMetaData, self).__init__(name) + if isinstance(engine_or_url, str): + self._engine = sqlalchemy.create_engine(engine_or_url, **kwargs) + else: + self._engine = engine_or_url + def is_bound(self): + return True + +class DynamicMetaData(MetaData): + """builds upon MetaData to provide the capability to bind to multiple Engine implementations + on a dynamically alterable, thread-local basis.""" + def __init__(self, name=None, threadlocal=True): + super(DynamicMetaData, self).__init__(name) + if threadlocal: + self.context = util.ThreadLocal() + else: + self.context = self + self.__engines = {} + def connect(self, engine_or_url, **kwargs): + if isinstance(engine_or_url, str): + try: + self.context._engine = self.__engines[engine_or_url] + except KeyError: + e = sqlalchemy.create_engine(engine_or_url, **kwargs) + self.__engines[engine_or_url] = e + self.context._engine = e + else: + if not self.__engines.has_key(engine_or_url): + self.__engines[engine_or_url] = engine_or_url + self.context._engine = engine_or_url + def is_bound(self): + return self.context._engine is not None + def dispose(self): + """disposes all Engines to which this DynamicMetaData has been connected.""" + for e in self.__engines.values(): + e.dispose() + engine=property(lambda s:s.context._engine) + class SchemaVisitor(sql.ClauseVisitor): """defines the visiting for SchemaItem objects""" def visit_schema(self, schema): @@ -642,5 +732,6 @@ class SchemaVisitor(sql.ClauseVisitor): """visit a Sequence.""" pass - +default_metadata = DynamicMetaData('default') + diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 38866184f..d1d1d837e 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -5,11 +5,9 @@ """defines the base components of SQL expression trees.""" -import schema -import util -import types as sqltypes -from exceptions import * -import string, re, random +from sqlalchemy import util, exceptions +from sqlalchemy import types as sqltypes +import string, re, random, sets types = __import__('types') __all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] @@ -220,8 +218,7 @@ def text(text, engine=None, *args, **kwargs): text - the text of the SQL statement to be created. use :<param> to specify bind parameters; they will be compiled to their engine-specific format. - engine - an optional engine to be used for this text query. Alternatively, call the - text() method off the engine directly. + engine - an optional engine to be used for this text query. bindparams - a list of bindparam() instances which can be used to define the types and/or initial values for the bind parameters within the textual statement; @@ -257,28 +254,33 @@ def _is_literal(element): def is_column(col): return isinstance(col, ColumnElement) -class AbstractEngine(object): - """represents a 'thing that can produce Compiler objects an execute them'.""" +class Engine(object): + """represents a 'thing that can produce Compiled objects and execute them'.""" def execute_compiled(self, compiled, parameters, echo=None, **kwargs): raise NotImplementedError() def compiler(self, statement, parameters, **kwargs): raise NotImplementedError() +class AbstractDialect(object): + """represents the behavior of a particular database. Used by Compiled objects.""" + pass + class ClauseParameters(util.OrderedDict): """represents a dictionary/iterator of bind parameter key names/values. Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method. Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects. The non-converted value can be retrieved via the get_original method. For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order.""" - def __init__(self, engine=None): + def __init__(self, dialect): super(ClauseParameters, self).__init__(self) - self.engine = engine + self.dialect=dialect self.binds = {} def set_parameter(self, key, value, bindparam): self[key] = value self.binds[key] = bindparam def get_original(self, key): + """returns the given parameter as it was originally placed in this ClauseParameters object, without any Type conversion""" return super(ClauseParameters, self).__getitem__(key) def __getitem__(self, key): v = super(ClauseParameters, self).__getitem__(key) - if self.engine is not None and self.binds.has_key(key): - v = self.binds[key].typeprocess(v, self.engine) + if self.binds.has_key(key): + v = self.binds[key].typeprocess(v, self.dialect) return v def values(self): return [self[key] for key in self] @@ -318,7 +320,7 @@ class Compiled(ClauseVisitor): object be dependent on the actual values of those bind parameters, even though it may reference those values as defaults.""" - def __init__(self, statement, parameters, engine=None): + def __init__(self, dialect, statement, parameters, engine=None): """constructs a new Compiled object. statement - ClauseElement to be compiled @@ -332,11 +334,12 @@ class Compiled(ClauseVisitor): clauses of an UPDATE statement. The keys of the parameter dictionary can either be the string names of columns or ColumnClause objects. - engine - optional SQLEngine to compile this statement against""" - self.parameters = parameters + engine - optional Engine to compile this statement against""" + self.dialect = dialect self.statement = statement + self.parameters = parameters self.engine = engine - + def __str__(self): """returns the string text of the generated SQL statement.""" raise NotImplementedError() @@ -357,13 +360,10 @@ class Compiled(ClauseVisitor): def execute(self, *multiparams, **params): """executes this compiled object using the AbstractEngine it is bound to.""" - if len(multiparams): - params = multiparams - e = self.engine if e is None: - raise InvalidRequestError("This Compiled object is not bound to any engine.") - return e.execute_compiled(self, params) + raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.") + return e.execute_compiled(self, *multiparams, **params) def scalar(self, *multiparams, **params): """executes this compiled object via the execute() method, then @@ -373,30 +373,25 @@ class Compiled(ClauseVisitor): # in a result set is not performance-wise any different than specifying limit=1 # else we'd have to construct a copy of the select() object with the limit # installed (else if we change the existing select, not threadsafe) - row = self.execute(*multiparams, **params).fetchone() - if row is not None: - return row[0] - else: - return None + r = self.execute(*multiparams, **params) + row = r.fetchone() + try: + if row is not None: + return row[0] + else: + return None + finally: + r.close() class Executor(object): - """handles the compilation/execution of a ClauseElement within the context of a particular AbtractEngine. This - AbstractEngine will usually be a SQLEngine or ConnectionProxy.""" + """context-sensitive executor for the using() function.""" def __init__(self, clauseelement, abstractengine=None): self.engine=abstractengine self.clauseelement = clauseelement def execute(self, *multiparams, **params): - return self.compile(*multiparams, **params).execute(*multiparams, **params) + return self.clauseelement.execute_using(self.engine) def scalar(self, *multiparams, **params): - return self.compile(*multiparams, **params).scalar(*multiparams, **params) - def compile(self, *multiparams, **params): - if len(multiparams): - bindparams = multiparams[0] - else: - bindparams = params - compiler = self.engine.compiler(self.clauseelement, bindparams) - compiler.compile() - return compiler + return self.clauseelement.scalar_using(self.engine) class ClauseElement(object): """base class for elements of a programmatically constructed SQL expression.""" @@ -454,26 +449,52 @@ class ClauseElement(object): else: return None - engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.") + engine = property(lambda s: s._find_engine(), doc="attempts to locate a Engine within this ClauseElement structure, or returns None if none found.") def using(self, abstractengine): return Executor(self, abstractengine) + + def execute_using(self, engine, *multiparams, **params): + compile_params = self._conv_params(*multiparams, **params) + return self.compile(engine=engine, parameters=compile_params).execute(*multiparams, **params) + def scalar_using(self, engine, *multiparams, **params): + compile_params = self._conv_params(*multiparams, **params) + return self.compile(engine=engine, parameters=compile_params).scalar(*multiparams, **params) + def _conv_params(self, *multiparams, **params): + if len(multiparams): + return multiparams[0] + else: + return params + def compile(self, engine=None, parameters=None, compiler=None, dialect=None): + """compiles this SQL expression. + + Uses the given Compiler, or the given AbstractDialect or Engine to create a Compiler. If no compiler + arguments are given, tries to use the underlying Engine this ClauseElement is bound + to to create a Compiler, if any. Finally, if there is no bound Engine, uses an ANSIDialect + to create a default Compiler. - def compile(self, engine = None, parameters = None, typemap=None, compiler=None): - """compiles this SQL expression using its underlying SQLEngine to produce - a Compiled object. If no engine can be found, an ANSICompiler is used with no engine. bindparams is a dictionary representing the default bind parameters to be used with - the statement. """ + the statement. if the bindparams is a list, it is assumed to be a list of dictionaries + and the first dictionary in the list is used with which to compile against. + The bind parameters can in some cases determine the output of the compilation, such as for UPDATE + and INSERT statements the bind parameters that are present determine the SET and VALUES clause of + those statements. + """ + + if (isinstance(parameters, list) or isinstance(parameters, tuple)): + parameters = parameters[0] if compiler is None: - if engine is not None: + if dialect is not None: + compiler = dialect.compiler(self, parameters) + elif engine is not None: compiler = engine.compiler(self, parameters) elif self.engine is not None: compiler = self.engine.compiler(self, parameters) if compiler is None: import sqlalchemy.ansisql as ansisql - compiler = ansisql.ANSICompiler(self, parameters=parameters) + compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters) compiler.compile() return compiler @@ -481,10 +502,10 @@ class ClauseElement(object): return str(self.compile()) def execute(self, *multiparams, **params): - return self.using(self.engine).execute(*multiparams, **params) + return self.execute_using(self.engine, *multiparams, **params) def scalar(self, *multiparams, **params): - return self.using(self.engine).scalar(*multiparams, **params) + return self.scalar_using(self.engine, *multiparams, **params) def __and__(self, other): return and_(self, other) @@ -543,7 +564,7 @@ class CompareMixin(object): def __div__(self, other): return self._operate('/', other) def __mod__(self, other): - return self._operate('%', other) + return self._operate('%', other) def __truediv__(self, other): return self._operate('/', other) def _bind_param(self, obj): @@ -554,11 +575,11 @@ class CompareMixin(object): return BooleanExpression(self._compare_self(), null(), 'IS') elif operator == '!=': return BooleanExpression(self._compare_self(), null(), 'IS NOT') - return BooleanExpression(self._compare_self(), null(), 'IS') else: raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") elif _is_literal(obj): obj = self._bind_param(obj) + return BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj)) def _operate(self, operator, obj): if _is_literal(obj): @@ -588,24 +609,43 @@ class Selectable(ClauseElement): return True class ColumnElement(Selectable, CompareMixin): - """represents a column element within the list of a Selectable's columns. Provides - default implementations for the things a "column" needs, including a "primary_key" flag, - a "foreign_key" accessor, an "original" accessor which represents the ultimate column - underlying a string of labeled/select-wrapped columns, and "columns" which returns a list - of the single column, providing the same list-based interface as a FromClause.""" - primary_key = property(lambda self:getattr(self, '_primary_key', False)) - foreign_key = property(lambda self:getattr(self, '_foreign_key', False)) - original = property(lambda self:getattr(self, '_original', self)) - parent = property(lambda self:getattr(self, '_parent', self)) - columns = property(lambda self:[self]) + """represents a column element within the list of a Selectable's columns. + A ColumnElement can either be directly associated with a TableClause, or + a free-standing textual column with no table, or is a "proxy" column, indicating + it is placed on a Selectable such as an Alias or Select statement and ultimately corresponds + to a TableClause-attached column (or in the case of a CompositeSelect, a proxy ColumnElement + may correspond to several TableClause-attached columns).""" + + primary_key = property(lambda self:getattr(self, '_primary_key', False), doc="primary key flag. indicates if this Column represents part or whole of a primary key.") + foreign_key = property(lambda self:getattr(self, '_foreign_key', False), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.") + columns = property(lambda self:[self], doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.") + + def _get_orig_set(self): + try: + return self.__orig_set + except AttributeError: + self.__orig_set = sets.Set([self]) + return self.__orig_set + def _set_orig_set(self, s): + if len(s) == 0: + s.add(self) + self.__orig_set = s + orig_set = property(_get_orig_set, _set_orig_set,doc="""a Set containing TableClause-bound, non-proxied ColumnElements for which this ColumnElement is a proxy. In all cases except for a column proxied from a Union (i.e. CompoundSelect), this set will be just one element.""") + + def shares_lineage(self, othercolumn): + """returns True if the given ColumnElement has a common ancestor to this ColumnElement.""" + for c in self.orig_set: + if c in othercolumn.orig_set: + return True + else: + return False def _make_proxy(self, selectable, name=None): """creates a new ColumnElement representing this ColumnElement as it appears in the select list - of an enclosing selectable. The default implementation returns a ColumnClause if a name is given, - else just returns self. This has various mechanics with schema.Column and sql.Label so that - Column objects as well as non-column objects like Function and BinaryClause can both appear in the - select list of an enclosing selectable.""" + of a descending selectable. The default implementation returns a ColumnClause if a name is given, + else just returns self.""" if name is not None: co = ColumnClause(name, selectable) + co.orig_set = self.orig_set selectable.columns[name]= co return co else: @@ -615,16 +655,17 @@ class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" def __init__(self, from_name = None): self.from_name = self.name = from_name + def _display_name(self): + if self.named_with_column(): + return self.name + else: + return None + displayname = property(_display_name) def _get_from_objects(self): # this could also be [self], at the moment it doesnt matter to the Select object return [] def default_order_by(self): - if not self.engine.default_ordering: - return None - elif self.oid_column is not None: - return [self.oid_column] - else: - return self.primary_key + return [self.oid_column] def accept_visitor(self, visitor): visitor.visit_fromclause(self) def count(self, whereclause=None, **params): @@ -635,6 +676,9 @@ class FromClause(Selectable): return Join(self, right, isouter = True, *args, **kwargs) def alias(self, name=None): return Alias(self, name) + def named_with_column(self): + """True if the name of this FromClause may be prepended to a column in a generated SQL statement""" + return False def _locate_oid_column(self): """subclasses override this to return an appropriate OID column""" return None @@ -642,18 +686,24 @@ class FromClause(Selectable): if not hasattr(self, '_oid_column'): self._oid_column = self._locate_oid_column() return self._oid_column - def _get_col_by_original(self, column, raiseerr=True): - """given a column which is a schema.Column object attached to a schema.Table object - (i.e. an "original" column), return the Column object from this - Selectable which corresponds to that original Column, or None if this Selectable - does not contain the column.""" - try: - return self.original_columns[column.original] - except KeyError: + def corresponding_column(self, column, raiseerr=True, keys_ok=False): + """given a ColumnElement, return the ColumnElement object from this + Selectable which corresponds to that original Column via a proxy relationship.""" + for c in column.orig_set: + try: + return self.original_columns[c] + except KeyError: + pass + else: + if keys_ok: + try: + return self.c[column.key] + except KeyError: + pass if not raiseerr: return None else: - raise InvalidRequestError("cant get orig for " + str(column) + " with table " + column.table.name + " from table " + self.name) + raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(column.table), self.name)) def _get_exported_attribute(self, name): try: @@ -665,10 +715,12 @@ class FromClause(Selectable): c = property(lambda s:s._get_exported_attribute('_columns')) primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys')) - original_columns = property(lambda s:s._get_exported_attribute('_orig_cols')) + original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'), doc="a dictionary mapping an original Table-bound column to a proxied column in this FromClause.") oid_column = property(_get_oid_column) def _export_columns(self): + """this method is called the first time any of the "exported attrbutes" are called. it receives from the Selectable + a list of all columns to be exported and creates "proxy" columns for each one.""" if hasattr(self, '_columns'): # TODO: put a mutex here ? this is a key place for threading probs return @@ -681,9 +733,11 @@ class FromClause(Selectable): if column.is_selectable(): for co in column.columns: cp = self._proxy_column(co) - self._orig_cols[co.original] = cp - if getattr(self, 'oid_column', None): - self._orig_cols[self.oid_column.original] = self.oid_column + for ci in cp.orig_set: + self._orig_cols[ci] = cp + if self.oid_column is not None: + for ci in self.oid_column.orig_set: + self._orig_cols[ci] = self.oid_column def _exportable_columns(self): return [] def _proxy_column(self, column): @@ -702,8 +756,8 @@ class BindParamClause(ClauseElement, CompareMixin): return [] def copy_container(self): return BindParamClause(self.key, self.value, self.shortname, self.type) - def typeprocess(self, value, engine): - return self.type.engine_impl(engine).convert_bind_param(value, engine) + def typeprocess(self, value, dialect): + return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) def compare(self, other): """compares this BindParamClause to the given clause. @@ -720,8 +774,9 @@ class TypeClause(ClauseElement): self.type = type def accept_visitor(self, visitor): visitor.visit_typeclause(self) - def _get_from_objects(self): - return [] + def _get_from_objects(self): + return [] + class TextClause(ClauseElement): """represents literal a SQL text fragment. public constructor is the text() function. @@ -909,7 +964,8 @@ class FunctionGenerator(object): self.__names.append(name) return self def __call__(self, *c, **kwargs): - return Function(self.__names[-1], packagenames=self.__names[0:-1], engine=self.__engine, *c, **kwargs) + kwargs.setdefault('engine', self.__engine) + return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **kwargs) class BinaryClause(ClauseElement): """represents two clauses with an operator in between""" @@ -956,15 +1012,13 @@ class Join(FromClause): def __init__(self, left, right, onclause=None, isouter = False): self.left = left self.right = right - # TODO: if no onclause, do NATURAL JOIN if onclause is None: self.onclause = self._match_primaries(left, right) else: self.onclause = onclause self.isouter = isouter - name = property(lambda self: "Join on %s, %s" % (self.left.name, self.right.name)) - + name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name) def _locate_oid_column(self): return self.left.oid_column @@ -981,15 +1035,15 @@ class Join(FromClause): crit = [] for fk in secondary.foreign_keys: if fk.references(primary): - crit.append(primary._get_col_by_original(fk.column) == fk.parent) + crit.append(primary.corresponding_column(fk.column) == fk.parent) self.foreignkey = fk.parent if primary is not secondary: for fk in primary.foreign_keys: if fk.references(secondary): - crit.append(secondary._get_col_by_original(fk.column) == fk.parent) + crit.append(secondary.corresponding_column(fk.column) == fk.parent) self.foreignkey = fk.parent if len(crit) == 0: - raise ArgumentError("Cant find any foreign key relationships between '%s' and '%s'" % (primary.name, secondary.name)) + raise exceptions.ArgumentError("Cant find any foreign key relationships between '%s' and '%s'" % (primary.name, secondary.name)) elif len(crit) == 1: return (crit[0]) else: @@ -1037,12 +1091,13 @@ class Alias(FromClause): self.original = baseselectable self.selectable = selectable if alias is None: - n = getattr(self.original, 'name', None) - if n is None: - n = 'anon' - elif len(n) > 15: - n = n[0:15] - alias = n + "_" + hex(random.randint(0, 65535))[2:] + if self.original.named_with_column(): + alias = getattr(self.original, 'name', None) + if alias is None: + alias = 'anon' + elif len(alias) > 15: + alias = alias[0:15] + alias = alias + "_" + hex(random.randint(0, 65535))[2:] self.name = alias def _locate_oid_column(self): @@ -1050,8 +1105,11 @@ class Alias(FromClause): return self.selectable.oid_column._make_proxy(self) else: return None - + + def named_with_column(self): + return True def _exportable_columns(self): + #return self.selectable._exportable_columns() return self.selectable.columns def accept_visitor(self, visitor): @@ -1076,10 +1134,8 @@ class Label(ColumnElement): self.type = sqltypes.to_instance(type) obj.parens=True key = property(lambda s: s.name) - _label = property(lambda s: s.name) - original = property(lambda s:s.obj.original) - parent = property(lambda s:s.obj.parent) + orig_set = property(lambda s:s.obj.orig_set) def accept_visitor(self, visitor): self.obj.accept_visitor(visitor) visitor.visit_label(self) @@ -1091,19 +1147,20 @@ class Label(ColumnElement): class ColumnClause(ColumnElement): """represents a textual column clause in a SQL statement. May or may not be bound to an underlying Selectable.""" - def __init__(self, text, selectable=None, type=None): - self.key = self.name = self.text = text + def __init__(self, text, selectable=None, type=None, hidden=False): + self.key = self.name = text self.table = selectable self.type = sqltypes.to_instance(type) + self.hidden = hidden self.__label = None def _get_label(self): if self.__label is None: - if self.table is not None and self.table.name is not None: - self.__label = self.table.name + "_" + self.text + if self.table is not None and self.table.named_with_column(): + self.__label = self.table.name + "_" + self.name + if self.table.c.has_key(self.__label) or len(self.__label) >= 30: + self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:] else: - self.__label = self.text - if (self.table is not None and self.table.c.has_key(self.__label)) or len(self.__label) >= 30: - self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:] + self.__label = self.name return self.__label _label = property(_get_label) def accept_visitor(self, visitor): @@ -1113,21 +1170,19 @@ class ColumnClause(ColumnElement): for example, this could translate the column "name" from a Table object to an Alias of a Select off of that Table object.""" - return selectable._get_col_by_original(self.original, False) + return selectable.corresponding_column(self.original, False) def _get_from_objects(self): if self.table is not None: return [self.table] else: return [] def _bind_param(self, obj): - if self.table.name is None: - return BindParamClause(self.text, obj, shortname=self.text, type=self.type) - else: - return BindParamClause(self._label, obj, shortname = self.text, type=self.type) + return BindParamClause(self._label, obj, shortname = self.name, type=self.type) def _make_proxy(self, selectable, name = None): - c = ColumnClause(name or self.text, selectable) - c._original = self.original - selectable.columns[c.name] = c + c = ColumnClause(name or self.name, selectable, hidden=self.hidden) + c.orig_set = self.orig_set + if not self.hidden: + selectable.columns[c.name] = c return c def _compare_type(self, obj): return self.type @@ -1144,29 +1199,25 @@ class TableClause(FromClause): self._primary_key = [] for c in columns: self.append_column(c) + self._oid_column = ColumnClause('oid', self, hidden=True) indexes = property(lambda s:s._indexes) - + + def named_with_column(self): + return True def append_column(self, c): - self._columns[c.text] = c + self._columns[c.name] = c c.table = self def _locate_oid_column(self): - if self.engine is None: - return None - if self.engine.oid_column_name() is not None: - _oid_column = schema.Column(self.engine.oid_column_name(), sqltypes.Integer, hidden=True) - _oid_column._set_parent(self) - self._orig_columns()[_oid_column.original] = _oid_column - return _oid_column - else: - return None + return self._oid_column def _orig_columns(self): try: return self._orig_cols except AttributeError: self._orig_cols= {} for c in self.columns: - self._orig_cols[c.original] = c + for ci in c.orig_set: + self._orig_cols[ci] = c return self._orig_cols columns = property(lambda s:s._columns) c = property(lambda s:s._columns) @@ -1177,6 +1228,7 @@ class TableClause(FromClause): def _clear(self): """clears all attributes on this TableClause so that new items can be added again""" self.columns.clear() + self.indexes.clear() self.foreign_keys[:] = [] self.primary_key[:] = [] try: @@ -1240,6 +1292,7 @@ class SelectBaseMixin(object): class CompoundSelect(SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): + SelectBaseMixin.__init__(self) self.keyword = keyword self.selects = selects self.use_labels = kwargs.pop('use_labels', False) @@ -1251,21 +1304,34 @@ class CompoundSelect(SelectBaseMixin, FromClause): s.order_by(None) self.group_by(*kwargs.get('group_by', [None])) self.order_by(*kwargs.get('order_by', [None])) + self._col_map = {} +# name = property(lambda s:s.keyword + " statement") + def _foo(self): + raise "this is a temporary assertion while we refactor SQL to not call 'name' on non-table Selectables" + name = property(lambda s:s._foo()) #"SELECT statement") + def _locate_oid_column(self): return self.selects[0].oid_column - def _exportable_columns(self): for s in self.selects: for c in s.c: yield c - def _proxy_column(self, column): if self.use_labels: - return column._make_proxy(self, name=column._label) + col = column._make_proxy(self, name=column._label) else: - return column._make_proxy(self, name=column.name) + col = column._make_proxy(self, name=column.name) + try: + colset = self._col_map[col.name] + except KeyError: + colset = sets.Set() + self._col_map[col.name] = colset + [colset.add(c) for c in col.orig_set] + col.orig_set = colset + return col + def accept_visitor(self, visitor): self.order_by_clause.accept_visitor(visitor) self.group_by_clause.accept_visitor(visitor) @@ -1284,9 +1350,9 @@ class Select(SelectBaseMixin, FromClause): """represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + SelectBaseMixin.__init__(self) self._froms = util.OrderedDict() self.use_labels = use_labels - self.name = None self.whereclause = None self.having = None self._engine = engine @@ -1331,8 +1397,11 @@ class Select(SelectBaseMixin, FromClause): for f in from_obj: self.append_from(f) - - + + def _foo(self): + raise "this is a temporary assertion while we refactor SQL to not call 'name' on non-table Selectables" + name = property(lambda s:s._foo()) #"SELECT statement") + class CorrelatedVisitor(ClauseVisitor): """visits a clause, locates any Select clauses, and tells them that they should correlate their FROM list to that of their parent.""" @@ -1401,6 +1470,9 @@ class Select(SelectBaseMixin, FromClause): fromclause._process_from_dict(self._froms, True) def _locate_oid_column(self): for f in self._froms.values(): + if f is self: + # TODO: why would we be in our own _froms list ? + raise exceptions.AssertionError("Select statement should not be in its own _froms list") oid = f.oid_column if oid is not None: return oid @@ -1429,16 +1501,8 @@ class Select(SelectBaseMixin, FromClause): return union(self, other, **kwargs) def union_all(self, other, **kwargs): return union_all(self, other, **kwargs) - -# def scalar(self, *multiparams, **params): - # need to set limit=1, but only in this thread. - # we probably need to make a copy of the select(). this - # is expensive. I think cursor.fetchone(), then discard remaining results - # should be fine with most DBs - # for now use base scalar() method - def _find_engine(self): - """tries to return a SQLEngine, either explicitly set in this object, or searched + """tries to return a Engine, either explicitly set in this object, or searched within the from clauses for one""" if self._engine is not None: @@ -1454,7 +1518,6 @@ class Select(SelectBaseMixin, FromClause): class UpdateBase(ClauseElement): """forms the base for INSERT, UPDATE, and DELETE statements.""" - def _process_colparams(self, parameters): """receives the "values" of an INSERT or UPDATE statement and constructs appropriate ind parameters.""" @@ -1483,17 +1546,14 @@ class UpdateBase(ClauseElement): except KeyError: del parameters[key] return parameters - def _find_engine(self): - return self._engine - + return self.table.engine class Insert(UpdateBase): def __init__(self, table, values=None, **params): self.table = table self.select = None self.parameters = self._process_colparams(values) - self._engine = self.table.engine def accept_visitor(self, visitor): if self.select is not None: @@ -1506,7 +1566,6 @@ class Update(UpdateBase): self.table = table self.whereclause = whereclause self.parameters = self._process_colparams(values) - self._engine = self.table.engine def accept_visitor(self, visitor): if self.whereclause is not None: @@ -1517,7 +1576,6 @@ class Delete(UpdateBase): def __init__(self, table, whereclause, **params): self.table = table self.whereclause = whereclause - self._engine = self.table.engine def accept_visitor(self, visitor): if self.whereclause is not None: diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py new file mode 100644 index 000000000..e3170ebce --- /dev/null +++ b/lib/sqlalchemy/sql_util.py @@ -0,0 +1,59 @@ +import sqlalchemy.sql as sql +import sqlalchemy.schema as schema + +"""utility functions that build upon SQL and Schema constructs""" + + +class TableCollection(object): + def __init__(self): + self.tables = [] + def add(self, table): + self.tables.append(table) + def sort(self, reverse=False ): + import sqlalchemy.orm.topological + tuples = [] + class TVisitor(schema.SchemaVisitor): + def visit_foreign_key(self, fkey): + parent_table = fkey.column.table + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + vis = TVisitor() + for table in self.tables: + table.accept_schema_visitor(vis) + sorter = sqlalchemy.orm.topological.QueueDependencySorter( tuples, self.tables ) + head = sorter.sort() + sequence = [] + def to_sequence( node, seq=sequence): + seq.append( node.item ) + for child in node.children: + to_sequence( child ) + to_sequence( head ) + if reverse: + sequence.reverse() + return sequence + + +class TableFinder(TableCollection, sql.ClauseVisitor): + """given a Clause, locates all the Tables within it into a list.""" + def __init__(self, table, check_columns=False): + TableCollection.__init__(self) + self.check_columns = check_columns + if table is not None: + table.accept_visitor(self) + def visit_table(self, table): + self.tables.append(table) + def __len__(self): + return len(self.tables) + def __getitem__(self, i): + return self.tables[i] + def __iter__(self): + return iter(self.tables) + def __contains__(self, obj): + return obj in self.tables + def __add__(self, obj): + return self.tables + list(obj) + def visit_column(self, column): + if self.check_columns: + column.table.accept_visitor(self) + +
\ No newline at end of file diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 74961dbf8..65b5d14fa 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -26,46 +26,61 @@ class AbstractType(object): self._impl_dict = {} return self._impl_dict impl_dict = property(_get_impl_dict) + class TypeEngine(AbstractType): def __init__(self, *args, **params): pass def engine_impl(self, engine): + """deprecated; call dialect_impl with a dialect directly.""" + return self.dialect_impl(engine.dialect) + def dialect_impl(self, dialect): try: - return self.impl_dict[engine] - except KeyError: - return self.impl_dict.setdefault(engine, engine.type_descriptor(self)) + return self.impl_dict[dialect] + except: + return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self)) + def _get_impl(self): + if hasattr(self, '_impl'): + return self._impl + else: + return NULLTYPE + def _set_impl(self, impl): + self._impl = impl + impl = property(_get_impl, _set_impl) def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): return value - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): return value def adapt(self, cls): return cls() -AbstractType.impl = TypeEngine class TypeDecorator(AbstractType): def __init__(self, *args, **kwargs): + if not hasattr(self.__class__, 'impl'): + raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") self.impl = self.__class__.impl(*args, **kwargs) def engine_impl(self, engine): + return self.dialect_impl(engine.dialect) + def dialect_impl(self, dialect): try: - return self.impl_dict[engine] + return self.impl_dict[dialect] except: - typedesc = engine.type_descriptor(self.impl) + typedesc = dialect.type_descriptor(self.impl) tt = self.copy() if not isinstance(tt, self.__class__): raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) tt.impl = typedesc - self.impl_dict[engine] = tt + self.impl_dict[dialect] = tt return tt def get_col_spec(self): return self.impl.get_col_spec() - def convert_bind_param(self, value, engine): - return self.impl.convert_bind_param(value, engine) - def convert_result_value(self, value, engine): - return self.impl.convert_result_value(value, engine) + def convert_bind_param(self, value, dialect): + return self.impl.convert_bind_param(value, dialect) + def convert_result_value(self, value, dialect): + return self.impl.convert_result_value(value, dialect) def copy(self): instance = self.__class__.__new__(self.__class__) instance.__dict__.update(self.__dict__) @@ -95,9 +110,9 @@ def adapt_type(typeobj, colspecs): class NullTypeEngine(TypeEngine): def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): return value - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): return value @@ -111,27 +126,27 @@ class String(TypeEngine): self.length = length def adapt(self, impltype): return impltype(length=self.length) - def convert_bind_param(self, value, engine): - if not engine.convert_unicode or value is None or not isinstance(value, unicode): + def convert_bind_param(self, value, dialect): + if not dialect.convert_unicode or value is None or not isinstance(value, unicode): return value else: - return value.encode(engine.encoding) - def convert_result_value(self, value, engine): - if not engine.convert_unicode or value is None or isinstance(value, unicode): + return value.encode(dialect.encoding) + def convert_result_value(self, value, dialect): + if not dialect.convert_unicode or value is None or isinstance(value, unicode): return value else: - return value.decode(engine.encoding) + return value.decode(dialect.encoding) class Unicode(TypeDecorator): impl = String - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): if value is not None and isinstance(value, unicode): - return value.encode(engine.encoding) + return value.encode(dialect.encoding) else: return value - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): if value is not None and not isinstance(value, unicode): - return value.decode(engine.encoding) + return value.decode(dialect.encoding) else: return value @@ -172,11 +187,14 @@ class Time(TypeEngine): class Binary(TypeEngine): def __init__(self, length=None): self.length = length - def convert_bind_param(self, value, engine): - return engine.dbapi().Binary(value) - def convert_result_value(self, value, engine): + def convert_bind_param(self, value, dialect): + if value is not None: + return dialect.dbapi().Binary(value) + else: + return None + def convert_result_value(self, value, dialect): return value - def adap(self, impltype): + def adapt(self, impltype): return impltype(length=self.length) class PickleType(TypeDecorator): @@ -185,15 +203,15 @@ class PickleType(TypeDecorator): """allows the pickle protocol to be specified""" self.protocol = protocol super(PickleType, self).__init__() - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): if value is None: return None - buf = self.impl.convert_result_value(value, engine) + buf = self.impl.convert_result_value(value, dialect) return pickle.loads(str(buf)) - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): if value is None: return None - return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), engine) + return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), dialect) class Boolean(TypeEngine): pass diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 23e828ef3..fec58e0bf 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -255,46 +255,57 @@ class HistoryArraySet(UserList.UserList): """sets the data for this HistoryArraySet to be that of the given data. duplicates in the incoming list will be removed.""" # first mark everything current as "deleted" - for i in self.data: - self.records[i] = False + for item in self.data: + self.records[item] = False + self.do_value_deleted(item) # switch array self.data = data # TODO: fix this up, remove items from array while iterating for i in range(0, len(self.data)): - if not self._setrecord(self.data[i]): - del self.data[i] - i -= 1 + if not self.__setrecord(self.data[i], False): + del self.data[i] + i -= 1 + for item in self.data: + self.do_value_appended(item) def history_contains(self, obj): """returns true if the given object exists within the history for this HistoryArrayList.""" return self.records.has_key(obj) def __hash__(self): return id(self) - def _setrecord(self, item): - if self.readonly: - raise InvalidRequestError("This list is read only") + def do_value_appended(self, value): + pass + def do_value_deleted(self, value): + pass + def __setrecord(self, item, dochanged=True): try: val = self.records[item] if val is True or val is None: return False else: self.records[item] = None + if dochanged: + self.do_value_appended(item) return True except KeyError: self.records[item] = True + if dochanged: + self.do_value_appended(item) return True - def _delrecord(self, item): - if self.readonly: - raise InvalidRequestError("This list is read only") + def __delrecord(self, item, dochanged=True): try: val = self.records[item] if val is None: self.records[item] = False + if dochanged: + self.do_value_deleted(item) return True elif val is True: del self.records[item] + if dochanged: + self.do_value_deleted(item) return True return False except KeyError: @@ -350,12 +361,13 @@ class HistoryArraySet(UserList.UserList): def has_item(self, item): return self.records.has_key(item) and self.records[item] is not False def __setitem__(self, i, item): - if self._setrecord(item): + if self.__setrecord(item): self.data[i] = item def __delitem__(self, i): - self._delrecord(self.data[i]) + self.__delrecord(self.data[i]) del self.data[i] def __setslice__(self, i, j, other): + print "HAS SETSLICE" i = max(i, 0); j = max(j, 0) if isinstance(other, UserList.UserList): l = other.data @@ -363,25 +375,26 @@ class HistoryArraySet(UserList.UserList): l = other else: l = list(other) - g = [a for a in l if self._setrecord(a)] + [self.__delrecord(x) for x in self.data[i:]] + g = [a for a in l if self.__setrecord(a)] self.data[i:] = g def __delslice__(self, i, j): i = max(i, 0); j = max(j, 0) for a in self.data[i:j]: - self._delrecord(a) + self.__delrecord(a) del self.data[i:j] def append(self, item): - if self._setrecord(item): + if self.__setrecord(item): self.data.append(item) def insert(self, i, item): - if self._setrecord(item): + if self.__setrecord(item): self.data.insert(i, item) def pop(self, i=-1): item = self.data[i] - if self._delrecord(item): + if self.__delrecord(item): return self.data.pop(i) def remove(self, item): - if self._delrecord(item): + if self.__delrecord(item): self.data.remove(item) def extend(self, item_list): for item in item_list: |
