diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/ansisql.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/mapping/mapper.py | 65 | ||||
| -rw-r--r-- | lib/sqlalchemy/mapping/properties.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql.py | 28 |
4 files changed, 71 insertions, 33 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 8df5e5352..abbc06751 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -262,6 +262,13 @@ class ANSICompiler(sql.Compiled): l = co.label(co._label) l.accept_visitor(self) inner_columns[co._label] = l + elif select.issubquery and isinstance(co, Column): + # SQLite doesnt like selecting from a subquery where the column + # names look like table.colname, so add a label synonomous with + # the column name + l = co.label(co.key) + l.accept_visitor(self) + inner_columns[self.get_str(l.obj)] = l else: co.accept_visitor(self) inner_columns[self.get_str(co)] = co diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 41616dceb..4c63f5c0c 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -49,7 +49,7 @@ class Mapper(object): 'primarytable':primarytable, 'properties':properties or {}, 'primary_key':primary_key, - 'is_primary':False, + 'is_primary':None, 'inherits':inherits, 'inherit_condition':inherit_condition, 'extension':extension, @@ -72,8 +72,13 @@ class Mapper(object): primarytable = inherits.primarytable # inherit_condition is optional since the join can figure it out table = sql.join(table, inherits.table, inherit_condition) - - self.table = table + + if isinstance(table, sql.Select): + # some db's, noteably postgres, dont want to select from a select + # without an alias + self.table = table.alias(None) + else: + self.table = table # locate all tables contained within the "table" passed in, which # may be a join or other construct @@ -93,9 +98,10 @@ class Mapper(object): self.pks_by_table = {} if primary_key is not None: for k in primary_key: - self.pks_by_table.setdefault(k.table, []).append(k) + self.pks_by_table.setdefault(k.table, util.HashSet()).append(k) if k.table != self.table: - self.pks_by_table.setdefault(self.table, []).append(k) + # associate pk cols from subtables to the "main" table + self.pks_by_table.setdefault(self.table, util.HashSet()).append(k) else: for t in self.tables + [self.table]: try: @@ -122,10 +128,10 @@ class Mapper(object): # load custom properties if properties is not None: for key, prop in properties.iteritems(): - if isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement): + if is_column(prop): self.columns[key] = prop prop = ColumnProperty(prop) - elif isinstance(prop, list) and (isinstance(prop[0], schema.Column) or isinstance(prop[0], sql.ColumnElement)) : + elif isinstance(prop, list) and is_column(prop[0]): self.columns[key] = prop[0] prop = ColumnProperty(*prop) self.props[key] = prop @@ -158,7 +164,11 @@ class Mapper(object): proplist = self.columntoproperty.setdefault(column.original, []) proplist.append(prop) - if not hasattr(self.class_, '_mapper') or self.is_primary or not mapper_registry.has_key(self.class_._mapper) or (inherits is not None and inherits._is_primary_mapper()): + if ( + (not hasattr(self.class_, '_mapper') or not mapper_registry.has_key(self.class_._mapper)) + or self.is_primary + or (inherits is not None and inherits._is_primary_mapper()) + ): objectstore.global_attributes.reset_class_managed(self.class_) self._init_class() @@ -166,13 +176,12 @@ class Mapper(object): for key, prop in inherits.props.iteritems(): if not self.props.has_key(key): self.props[key] = prop._copy() - engines = property(lambda s: [t.engine for t in s.tables]) def add_property(self, key, prop): self.copyargs['properties'][key] = prop - if (isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement)): + if is_column(prop): self.columns[key] = prop prop = ColumnProperty(prop) self.props[key] = prop @@ -194,7 +203,7 @@ class Mapper(object): return self.hashkey def _is_primary_mapper(self): - return getattr(self.class_, '_mapper') == self.hashkey + return getattr(self.class_, '_mapper', None) == self.hashkey def _init_class(self): """sets up our classes' overridden __init__ method, this mappers hash key as its @@ -447,6 +456,9 @@ class Mapper(object): list.""" for table in self.tables: + if not self._has_pks(table): + continue + # loop thru tables in the outer loop, objects on the inner loop. # this is important for an object represented across two tables # so that it gets its primary key columns populated for the benefit of the @@ -457,9 +469,8 @@ class Mapper(object): # we have our own idea of the primary key columns # for this table, in the case that the user # specified custom primary key cols. - pk = {} - for k in self.pks_by_table[table]: - pk[k] = k + # also, if we are missing a primary key for this table, then + # just skip inserting/updating the table for obj in objects: # print "SAVE_OBJ we are " + hash_key(self) + " obj: " + obj.__class__.__name__ + repr(id(obj)) @@ -471,8 +482,7 @@ class Mapper(object): hasdata = False for col in table.columns: - #if col.primary_key: - if pk.has_key(col): + if self.pks_by_table[table].contains(col): if hasattr(obj, "_instance_key"): params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col) else: @@ -536,6 +546,8 @@ class Mapper(object): """called by a UnitOfWork object to delete objects, which involves a DELETE statement for each table used by this mapper, for each object in the list.""" for table in self.tables: + if not self._has_pks(table): + continue delete = [] for obj in objects: params = {} @@ -556,6 +568,16 @@ class Mapper(object): if table.engine.supports_sane_rowcount() and c.rowcount != len(delete): raise "ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)) + def _has_pks(self, table): + try: + for k in self.pks_by_table[table]: + if not self.columntoproperty.has_key(k.original): + return False + else: + return True + except KeyError: + return False + def register_dependencies(self, *args, **kwargs): """called by an instance of objectstore.UOWTransaction to register which mappers are dependent on which, as well as DependencyProcessor @@ -581,12 +603,10 @@ class Mapper(object): if not no_sort: if self.order_by: order_by = self.order_by -# elif self.table.rowid_column is not None: - # order_by = self.table.rowid_column - # else: - # order_by = None - else: + elif self.table.rowid_column is not None: order_by = self.table.rowid_column + else: + order_by = None else: order_by = None @@ -779,6 +799,9 @@ def hash_key(obj): else: return repr(obj) +def is_column(col): + return isinstance(col, schema.Column) or isinstance(col, sql.ColumnElement) + def mapper_hash_key(class_, table, primarytable = None, properties = None, **kwargs): if properties is None: properties = {} diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index ba7312c12..e53ee644c 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -24,7 +24,6 @@ import sqlalchemy.util as util import sqlalchemy.attributes as attributes import mapper import objectstore -import random class ColumnProperty(MapperProperty): """describes an object attribute that corresponds to a table column.""" @@ -856,8 +855,7 @@ class Aliasizer(sql.ClauseVisitor): try: return self.aliases[table] except: - aliasname = table.name + "_" + hex(random.randint(0, 65535))[2:] - return self.aliases.setdefault(table, sql.alias(table, aliasname)) + return self.aliases.setdefault(table, sql.alias(table)) def visit_compound(self, compound): for i in range(0, len(compound.clauses)): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index b0e86259a..7db60ffb9 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -20,7 +20,7 @@ import sqlalchemy.schema as schema import sqlalchemy.util as util import sqlalchemy.types as types -import string, re +import string, re, random __all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] @@ -497,7 +497,7 @@ class FromClause(Selectable): return Join(self, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): return Join(self, right, isouter = True, *args, **kwargs) - def alias(self, name): + def alias(self, name=None): return Alias(self, name) @@ -751,11 +751,17 @@ class Alias(FromClause): self._columns = util.OrderedProperties() self.foreign_keys = [] if alias is None: - alias = id(self) + n = getattr(selectable, 'name') + if n is None: + n = 'anon' + alias = n + "_" + hex(random.randint(0, 65535))[2:] self.name = alias self.id = self.name self.count = 0 - self.rowid_column = self.selectable.rowid_column._make_proxy(self) + if self.selectable.rowid_column is not None: + self.rowid_column = self.selectable.rowid_column._make_proxy(self) + else: + self.rowid_column = None for co in selectable.columns: co._make_proxy(self) @@ -930,7 +936,7 @@ class TableImpl(FromClause): return Join(self.table, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): return Join(self.table, right, isouter = True, *args, **kwargs) - def alias(self, name): + def alias(self, name=None): return Alias(self.table, name) def select(self, whereclause = None, **params): return select([self.table], whereclause, **params) @@ -1082,16 +1088,20 @@ class Select(SelectBaseMixin, FromClause): for f in column._get_from_objects(): f.accept_visitor(self._correlator) - if self.rowid_column is None and hasattr(f, 'rowid_column') and f.rowid_column is not None: - self.rowid_column = f.rowid_column._make_proxy(self) column._process_from_dict(self._froms, False) if column.is_selectable(): + # if its a column unit, add it to our exported + # list of columns. this is where "columns" + # attribute of the select object gets populated. + # notice we are overriding the names of the column + # with either its label or its key, since one or the other + # is used when selecting from a select statement (i.e. a subquery) for co in column.columns: if self.use_labels: - co._make_proxy(self, name = co._label) + co._make_proxy(self, name=co._label) else: - co._make_proxy(self) + co._make_proxy(self, name=co.key) def _get_col_by_original(self, column): if self.use_labels: |
