diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
| -rw-r--r-- | lib/sqlalchemy/sql.py | 64 |
1 files changed, 47 insertions, 17 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 6924a60ce..74f085cb1 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1075,15 +1075,24 @@ class ColumnCollection(util.OrderedProperties): super(ColumnCollection, self).__init__() [self.add(c) for c in cols] + def __str__(self): + return repr([str(c) for c in self]) + def add(self, column): """Add a column to this collection. The key attribute of the column will be used as the hash key for this dictionary. """ - self[column.key] = column - + + def remove(self, column): + del self[column.key] + + def extend(self, iter): + for c in iter: + self.add(c) + def __eq__(self, other): l = [] for c in other: @@ -1243,6 +1252,16 @@ class FromClause(Selectable): self._primary_key = ColumnCollection() self._foreign_keys = util.Set() self._orig_cols = {} + for co in self._adjusted_exportable_columns(): + cp = self._proxy_column(co) + for ci in cp.orig_set: + self._orig_cols[ci] = cp + if self.oid_column is not None: + for ci in self.oid_column.orig_set: + self._orig_cols[ci] = self.oid_column + + def _adjusted_exportable_columns(self): + """return the list of ColumnElements represented within this FromClause's _exportable_columns""" export = self._exportable_columns() for column in export: try: @@ -1250,13 +1269,8 @@ class FromClause(Selectable): except AttributeError: continue for co in s.columns: - cp = self._proxy_column(co) - for ci in cp.orig_set: - self._orig_cols[ci] = cp - if self.oid_column is not None: - for ci in self.oid_column.orig_set: - self._orig_cols[ci] = self.oid_column - + yield co + def _exportable_columns(self): return [] @@ -1661,10 +1675,23 @@ class Join(FromClause): else: self.onclause = onclause self.isouter = isouter - + self.__folded_equivalents = None + self._init_primary_key() + name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name) encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace')) - + + def _init_primary_key(self): + pkcol = util.Set() + for col in self._adjusted_exportable_columns(): + if col.primary_key: + pkcol.add(col) + for col in list(pkcol): + for f in col.foreign_keys: + if f.column in pkcol: + pkcol.remove(col) + self.primary_key.extend(pkcol) + def _locate_oid_column(self): return self.left.oid_column @@ -1673,8 +1700,6 @@ class Join(FromClause): def _proxy_column(self, column): self._columns[column._label] = column - if column.primary_key: - self._primary_key.add(column) for f in column.foreign_keys: self._foreign_keys.add(f) return column @@ -1706,6 +1731,8 @@ class Join(FromClause): return True def _get_folded_equivalents(self, equivs=None): + if self.__folded_equivalents is not None: + return self.__folded_equivalents if equivs is None: equivs = util.Set() class LocateEquivs(NoColumnVisitor): @@ -1731,7 +1758,8 @@ class Join(FromClause): used.add(c.name) else: collist.append(c) - return collist + self.__folded_equivalents = collist + return self.__folded_equivalents def select(self, whereclause = None, fold_equivalents=False, **kwargs): """Create a ``Select`` from this ``Join``. @@ -1740,9 +1768,11 @@ class Join(FromClause): the WHERE criterion that will be sent to the ``select()`` function fold_equivalents - based on the join criterion of this ``Join``, do not include equivalent - columns in the column list of the resulting select. this will recursively - apply to any joins directly nested by this one as well. + based on the join criterion of this ``Join``, do not include repeat + column names in the column list of the resulting select, for columns that + are calculated to be "equivalent" based on the join criterion of this + ``Join``. this will recursively apply to any joins directly nested by + this one as well. \**kwargs all other kwargs are sent to the underlying ``select()`` function |
