summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-03-03 13:00:44 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2012-03-03 13:00:44 -0500
commit1607b74f8527905ecdc6133b4b4166a9ed675e09 (patch)
treecd752b16ab90c4864a071689c57f3ff946f8b241 /lib/sqlalchemy
parent4d43079e34a66c3718127266bc5eaa3041c69447 (diff)
downloadsqlalchemy-1607b74f8527905ecdc6133b4b4166a9ed675e09.tar.gz
- [feature] Added cte() method to Query,
invokes common table expression support from the Core (see below). [ticket:1859] - [feature] Added support for SQL standard common table expressions (CTE), allowing SELECT objects as the CTE source (DML not yet supported). This is invoked via the cte() method on any select() construct. [ticket:1859]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py7
-rw-r--r--lib/sqlalchemy/orm/query.py56
-rw-r--r--lib/sqlalchemy/sql/compiler.py58
-rw-r--r--lib/sqlalchemy/sql/expression.py160
4 files changed, 281 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index f7c94aabc..b73235875 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -949,6 +949,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
]
return 'OUTPUT ' + ', '.join(columns)
+ def get_cte_preamble(self, recursive):
+ # SQL Server finds it too inconvenient to accept
+ # an entirely optional, SQL standard specified,
+ # "RECURSIVE" word with their "WITH",
+ # so here we go
+ return "WITH"
+
def label_select_column(self, select, column, asfrom):
if isinstance(column, expression.Function):
return column.label(None)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index cafce5e3c..5b7f7c9af 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -450,6 +450,62 @@ class Query(object):
"""
return self.enable_eagerloads(False).statement.alias(name=name)
+ def cte(self, name=None, recursive=False):
+ """Return the full SELECT statement represented by this :class:`.Query`
+ represented as a common table expression (CTE).
+
+ The :meth:`.Query.cte` method is new in 0.7.6.
+
+ Parameters and usage are the same as those of the
+ :meth:`._SelectBase.cte` method; see that method for
+ further details.
+
+ Here is the `Postgresql WITH
+ RECURSIVE example <http://www.postgresql.org/docs/8.4/static/queries-with.html>`_.
+ Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias
+ of it are Core selectables, which
+ means the columns are accessed via the ``.c.`` attribute. The ``parts_alias``
+ object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped
+ attributes are available directly::
+
+ from sqlalchemy.orm import aliased
+
+ class Part(Base):
+ __tablename__ = 'part'
+ part = Column(String)
+ sub_part = Column(String)
+ quantity = Column(Integer)
+
+ included_parts = session.query(
+ Part.sub_part,
+ Part.part,
+ Part.quantity).\\
+ filter(Part.part=="our part").\\
+ cte(name="included_parts", recursive=True)
+
+ incl_alias = aliased(included_parts, name="pr")
+ parts_alias = aliased(Part, name="p")
+ included_parts = included_parts.union(
+ session.query(
+ parts_alias.part,
+ parts_alias.sub_part,
+ parts_alias.quantity).\\
+ filter(parts_alias.part==incl_alias.c.sub_part)
+ )
+
+ q = session.query(
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).label('total_quantity')
+ ).\
+ group_by(included_parts.c.sub_part)
+
+ See also:
+
+ :meth:`._SelectBase.cte`
+
+ """
+ return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive)
+
def label(self, name):
"""Return the full SELECT statement represented by this :class:`.Query`, converted
to a scalar subquery with a label of the given name.
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b955c5608..e8f86634d 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -252,6 +252,10 @@ class SQLCompiler(engine.Compiled):
# column targeting
self.result_map = {}
+ # collect CTEs to tack on top of a SELECT
+ self.ctes = util.OrderedDict()
+ self.ctes_recursive = False
+
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
@@ -749,6 +753,45 @@ class SQLCompiler(engine.Compiled):
else:
return self.bindtemplate % {'name':name}
+ def visit_cte(self, cte, asfrom=False, ashint=False,
+ fromhints=None, **kwargs):
+ if isinstance(cte.name, sql._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte.name)
+ else:
+ cte_name = cte.name
+ if cte.cte_alias:
+ if isinstance(cte.cte_alias, sql._truncated_label):
+ cte_alias = self._truncated_identifier("alias", cte.cte_alias)
+ else:
+ cte_alias = cte.cte_alias
+ if not cte.cte_alias and cte not in self.ctes:
+ if cte.recursive:
+ self.ctes_recursive = True
+ text = self.preparer.format_alias(cte, cte_name)
+ if cte.recursive:
+ if isinstance(cte.original, sql.Select):
+ col_source = cte.original
+ elif isinstance(cte.original, sql.CompoundSelect):
+ col_source = cte.original.selects[0]
+ else:
+ assert False
+ recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
+ if c is not None]
+
+ text += "(%s)" % (", ".join(recur_cols))
+ text += " AS \n" + \
+ cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
+ self.ctes[cte] = text
+ if asfrom:
+ if cte.cte_alias:
+ text = self.preparer.format_alias(cte, cte_alias)
+ text += " AS " + cte_name
+ else:
+ return self.preparer.format_alias(cte, cte_name)
+ return text
+
def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
@@ -909,6 +952,15 @@ class SQLCompiler(engine.Compiled):
if select.for_update:
text += self.for_update_clause(select)
+ if self.ctes and \
+ compound_index==1 and not entry:
+ cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+ cte_text += ", \n".join(
+ [txt for txt in self.ctes.values()]
+ )
+ cte_text += "\n "
+ text = cte_text + text
+
self.stack.pop(-1)
if asfrom and parens:
@@ -916,6 +968,12 @@ class SQLCompiler(engine.Compiled):
else:
return text
+ def get_cte_preamble(self, recursive):
+ if recursive:
+ return "WITH RECURSIVE"
+ else:
+ return "WITH"
+
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list.
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 4b61e6dc3..22fe6c420 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -3719,6 +3719,47 @@ class Alias(FromClause):
def bind(self):
return self.element.bind
+class CTE(Alias):
+ """Represent a Common Table Expression.
+
+ The :class:`.CTE` object is obtained using the
+ :meth:`._SelectBase.cte` method from any selectable.
+ See that method for complete examples.
+
+ New in 0.7.6.
+
+ """
+ __visit_name__ = 'cte'
+ def __init__(self, selectable,
+ name=None,
+ recursive=False,
+ cte_alias=False):
+ self.recursive = recursive
+ self.cte_alias = cte_alias
+ super(CTE, self).__init__(selectable, name=name)
+
+ def alias(self, name=None):
+ return CTE(
+ self.original,
+ name=name,
+ recursive=self.recursive,
+ cte_alias = self.name
+ )
+
+ def union(self, other):
+ return CTE(
+ self.original.union(other),
+ name=self.name,
+ recursive=self.recursive
+ )
+
+ def union_all(self, other):
+ return CTE(
+ self.original.union_all(other),
+ name=self.name,
+ recursive=self.recursive
+ )
+
class _Grouping(ColumnElement):
"""Represent a grouping within a column expression"""
@@ -4289,6 +4330,125 @@ class _SelectBase(Executable, FromClause):
"""
return self.as_scalar().label(name)
+ def cte(self, name=None, recursive=False):
+ """Return a new :class:`.CTE`, or Common Table Expression instance.
+
+ Common table expressions are a SQL standard whereby SELECT
+ statements can draw upon secondary statements specified along
+ with the primary statement, using a clause called "WITH".
+ Special semantics regarding UNION can also be employed to
+ allow "recursive" queries, where a SELECT statement can draw
+ upon the set of rows that have previously been selected.
+
+ SQLAlchemy detects :class:`.CTE` objects, which are treated
+ similarly to :class:`.Alias` objects, as special elements
+ to be delivered to the FROM clause of the statement as well
+ as to a WITH clause at the top of the statement.
+
+ The :meth:`._SelectBase.cte` method is new in 0.7.6.
+
+ :param name: name given to the common table expression. Like
+ :meth:`._FromClause.alias`, the name can be left as ``None``
+ in which case an anonymous symbol will be used at query
+ compile time.
+ :param recursive: if ``True``, will render ``WITH RECURSIVE``.
+ A recursive common table expression is intended to be used in
+ conjunction with UNION or UNION ALL in order to derive rows
+ from those already selected.
+
+ The following examples illustrate two examples from
+ Postgresql's documentation at
+ http://www.postgresql.org/docs/8.4/static/queries-with.html.
+
+ Example 1, non recursive::
+
+ from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+ select, func
+
+ metadata = MetaData()
+
+ orders = Table('orders', metadata,
+ Column('region', String),
+ Column('amount', Integer),
+ Column('product', String),
+ Column('quantity', Integer)
+ )
+
+ regional_sales = select([
+ orders.c.region,
+ func.sum(orders.c.amount).label('total_sales')
+ ]).group_by(orders.c.region).cte("regional_sales")
+
+
+ top_regions = select([regional_sales.c.region]).\\
+ where(
+ regional_sales.c.total_sales >
+ select([
+ func.sum(regional_sales.c.total_sales)/10
+ ])
+ ).cte("top_regions")
+
+ statement = select([
+ orders.c.region,
+ orders.c.product,
+ func.sum(orders.c.quantity).label("product_units"),
+ func.sum(orders.c.amount).label("product_sales")
+ ]).where(orders.c.region.in_(
+ select([top_regions.c.region])
+ )).group_by(orders.c.region, orders.c.product)
+
+ result = conn.execute(statement).fetchall()
+
+ Example 2, WITH RECURSIVE::
+
+ from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+ select, func
+
+ metadata = MetaData()
+
+ parts = Table('parts', metadata,
+ Column('part', String),
+ Column('sub_part', String),
+ Column('quantity', Integer),
+ )
+
+ included_parts = select([
+ parts.c.sub_part,
+ parts.c.part,
+ parts.c.quantity]).\\
+ where(parts.c.part=='our part').\\
+ cte(recursive=True)
+
+
+ incl_alias = included_parts.alias()
+ parts_alias = parts.alias()
+ included_parts = included_parts.union(
+ select([
+ parts_alias.c.part,
+ parts_alias.c.sub_part,
+ parts_alias.c.quantity
+ ]).
+ where(parts_alias.c.part==incl_alias.c.sub_part)
+ )
+
+ statement = select([
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).label('total_quantity')
+ ]).\
+ select_from(included_parts.join(parts,
+ included_parts.c.part==parts.c.part)).\\
+ group_by(included_parts.c.sub_part)
+
+ result = conn.execute(statement).fetchall()
+
+
+ See also:
+
+ :meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`.
+
+ """
+ return CTE(self, name=name, recursive=recursive)
+
@_generative
@util.deprecated('0.6',
message=":func:`.autocommit` is deprecated. Use "