summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r--lib/sqlalchemy/sql.py64
1 files changed, 47 insertions, 17 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 6924a60ce..74f085cb1 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -1075,15 +1075,24 @@ class ColumnCollection(util.OrderedProperties):
super(ColumnCollection, self).__init__()
[self.add(c) for c in cols]
+ def __str__(self):
+ return repr([str(c) for c in self])
+
def add(self, column):
"""Add a column to this collection.
The key attribute of the column will be used as the hash key
for this dictionary.
"""
-
self[column.key] = column
-
+
+ def remove(self, column):
+ del self[column.key]
+
+ def extend(self, iter):
+ for c in iter:
+ self.add(c)
+
def __eq__(self, other):
l = []
for c in other:
@@ -1243,6 +1252,16 @@ class FromClause(Selectable):
self._primary_key = ColumnCollection()
self._foreign_keys = util.Set()
self._orig_cols = {}
+ for co in self._adjusted_exportable_columns():
+ cp = self._proxy_column(co)
+ for ci in cp.orig_set:
+ self._orig_cols[ci] = cp
+ if self.oid_column is not None:
+ for ci in self.oid_column.orig_set:
+ self._orig_cols[ci] = self.oid_column
+
+ def _adjusted_exportable_columns(self):
+ """return the list of ColumnElements represented within this FromClause's _exportable_columns"""
export = self._exportable_columns()
for column in export:
try:
@@ -1250,13 +1269,8 @@ class FromClause(Selectable):
except AttributeError:
continue
for co in s.columns:
- cp = self._proxy_column(co)
- for ci in cp.orig_set:
- self._orig_cols[ci] = cp
- if self.oid_column is not None:
- for ci in self.oid_column.orig_set:
- self._orig_cols[ci] = self.oid_column
-
+ yield co
+
def _exportable_columns(self):
return []
@@ -1661,10 +1675,23 @@ class Join(FromClause):
else:
self.onclause = onclause
self.isouter = isouter
-
+ self.__folded_equivalents = None
+ self._init_primary_key()
+
name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name)
encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace'))
-
+
+ def _init_primary_key(self):
+ pkcol = util.Set()
+ for col in self._adjusted_exportable_columns():
+ if col.primary_key:
+ pkcol.add(col)
+ for col in list(pkcol):
+ for f in col.foreign_keys:
+ if f.column in pkcol:
+ pkcol.remove(col)
+ self.primary_key.extend(pkcol)
+
def _locate_oid_column(self):
return self.left.oid_column
@@ -1673,8 +1700,6 @@ class Join(FromClause):
def _proxy_column(self, column):
self._columns[column._label] = column
- if column.primary_key:
- self._primary_key.add(column)
for f in column.foreign_keys:
self._foreign_keys.add(f)
return column
@@ -1706,6 +1731,8 @@ class Join(FromClause):
return True
def _get_folded_equivalents(self, equivs=None):
+ if self.__folded_equivalents is not None:
+ return self.__folded_equivalents
if equivs is None:
equivs = util.Set()
class LocateEquivs(NoColumnVisitor):
@@ -1731,7 +1758,8 @@ class Join(FromClause):
used.add(c.name)
else:
collist.append(c)
- return collist
+ self.__folded_equivalents = collist
+ return self.__folded_equivalents
def select(self, whereclause = None, fold_equivalents=False, **kwargs):
"""Create a ``Select`` from this ``Join``.
@@ -1740,9 +1768,11 @@ class Join(FromClause):
the WHERE criterion that will be sent to the ``select()`` function
fold_equivalents
- based on the join criterion of this ``Join``, do not include equivalent
- columns in the column list of the resulting select. this will recursively
- apply to any joins directly nested by this one as well.
+ based on the join criterion of this ``Join``, do not include repeat
+ column names in the column list of the resulting select, for columns that
+ are calculated to be "equivalent" based on the join criterion of this
+ ``Join``. this will recursively apply to any joins directly nested by
+ this one as well.
\**kwargs
all other kwargs are sent to the underlying ``select()`` function