summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-11-24 00:55:39 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-11-24 00:55:39 +0000
commit3f93103a5ef9128b7b300c51d41dea43dd843834 (patch)
tree7c21cb4a15c91c9d7ae38425da69c96d0ed26caf /lib/sqlalchemy/sql
parent238dc916fa9fca6c79046dea004d108df685e29e (diff)
downloadsqlalchemy-3f93103a5ef9128b7b300c51d41dea43dd843834.tar.gz
- all kinds of cleanup, tiny-to-slightly-significant speed improvements
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py78
-rw-r--r--lib/sqlalchemy/sql/expression.py106
-rw-r--r--lib/sqlalchemy/sql/util.py2
-rw-r--r--lib/sqlalchemy/sql/visitors.py12
4 files changed, 111 insertions, 87 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index dd6f0dddd..c1f3bc2a0 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -245,19 +245,25 @@ class DefaultCompiler(engine.Compiled):
n = self.dialect.oid_column_name(column)
if n is not None:
if column.table is None or not column.table.named_with_column():
- return self.preparer.format_column(column, name=n)
+ return n
else:
- return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n)
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n
elif len(column.table.primary_key) != 0:
pk = list(column.table.primary_key)[0]
pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
- return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname)
else:
return None
elif column.table is None or not column.table.named_with_column():
- return self.preparer.format_column(column, name=name)
+ if getattr(column, "is_literal", False):
+ return name
+ else:
+ return self.preparer.quote(column, name)
else:
- return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
+ if getattr(column, "is_literal", False):
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name
+ else:
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(column, name)
def visit_fromclause(self, fromclause, **kwargs):
@@ -588,7 +594,10 @@ class DefaultCompiler(engine.Compiled):
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
- return self.preparer.format_table(table)
+ if getattr(table, "schema", None):
+ return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name)
+ else:
+ return self.preparer.quote(table, table.name)
else:
return ""
@@ -606,7 +615,7 @@ class DefaultCompiler(engine.Compiled):
return ("INSERT INTO %s (%s) VALUES (%s)" %
(preparer.format_table(insert_stmt.table),
- ', '.join([preparer.format_column(c[0])
+ ', '.join([preparer.quote(c[0], c[0].name)
for c in colparams]),
', '.join([c[1] for c in colparams])))
@@ -616,7 +625,7 @@ class DefaultCompiler(engine.Compiled):
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ')
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
@@ -831,7 +840,7 @@ class SchemaGenerator(DDLBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
+ self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
def visit_foreign_key_constraint(self, constraint):
if constraint.use_alter and self.dialect.supports_alter:
@@ -849,10 +858,11 @@ class SchemaGenerator(DDLBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " %
preparer.format_constraint(constraint))
+ table = list(constraint.elements)[0].column.table
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join([preparer.format_column(f.parent) for f in constraint.elements]),
- preparer.format_table(list(constraint.elements)[0].column.table),
- ', '.join([preparer.format_column(f.column) for f in constraint.elements])
+ ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]),
+ preparer.format_table(table),
+ ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -864,7 +874,7 @@ class SchemaGenerator(DDLBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint])))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
def visit_column(self, column):
pass
@@ -877,7 +887,7 @@ class SchemaGenerator(DDLBase):
self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
- string.join([preparer.format_column(c) for c in index.columns], ', ')))
+ string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
self.execute()
class SchemaDropper(DDLBase):
@@ -978,12 +988,12 @@ class IdentifierPreparer(object):
or not self.legal_characters.match(unicode(value))
or (lc_value != value))
- def __generic_obj_format(self, obj, ident):
+ def quote(self, obj, ident):
if getattr(obj, 'quote', False):
return self.quote_identifier(ident)
- try:
+ if ident in self.__strings:
return self.__strings[ident]
- except KeyError:
+ else:
if self._requires_quotes(ident):
self.__strings[ident] = self.quote_identifier(ident)
else:
@@ -994,45 +1004,49 @@ class IdentifierPreparer(object):
return object.quote or self._requires_quotes(object.name)
def format_sequence(self, sequence, use_schema=True):
- name = self.__generic_obj_format(sequence, sequence.name)
+ name = self.quote(sequence, sequence.name)
if use_schema and sequence.schema is not None:
- name = self.__generic_obj_format(sequence, sequence.schema) + "." + name
+ name = self.quote(sequence, sequence.schema) + "." + name
return name
def format_label(self, label, name=None):
- return self.__generic_obj_format(label, name or label.name)
+ return self.quote(label, name or label.name)
def format_alias(self, alias, name=None):
- return self.__generic_obj_format(alias, name or alias.name)
+ return self.quote(alias, name or alias.name)
def format_savepoint(self, savepoint, name=None):
- return self.__generic_obj_format(savepoint, name or savepoint.ident)
+ return self.quote(savepoint, name or savepoint.ident)
def format_constraint(self, constraint):
- return self.__generic_obj_format(constraint, constraint.name)
+ return self.quote(constraint, constraint.name)
def format_index(self, index):
- return self.__generic_obj_format(index, index.name)
+ return self.quote(index, index.name)
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
- result = self.__generic_obj_format(table, name)
+ result = self.quote(table, name)
if use_schema and getattr(table, "schema", None):
- result = self.__generic_obj_format(table, table.schema) + "." + result
+ result = self.quote(table, table.schema) + "." + result
return result
def format_column(self, column, use_table=False, name=None, table_name=None):
- """Prepare a quoted column name."""
+ """Prepare a quoted column name.
+
+ deprecated. use preparer.quote(col, column.name) or combine with format_table()
+ """
+
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name)
else:
- return self.__generic_obj_format(column, name)
+ return self.quote(column, name)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
@@ -1040,12 +1054,6 @@ class IdentifierPreparer(object):
else:
return name
- def format_column_with_table(self, column, column_name=None, table_name=None):
- """Prepare a quoted column name with table name."""
-
- return self.format_column(column, use_table=True, name=column_name, table_name=table_name)
-
-
def format_table_seq(self, table, use_schema=True):
"""Format table name and schema as a tuple."""
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index c7ab34272..b3200a7eb 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -863,6 +863,16 @@ class ClauseElement(object):
raise NotImplementedError(repr(self))
+ def _aggregate_hide_froms(self, **modifiers):
+ """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces, taking into account
+ previous ClauseElements which this ClauseElement is a clone of."""
+
+ s = self
+ while s is not None:
+ for h in s._hide_froms(**modifiers):
+ yield h
+ s = getattr(s, '_is_clone_of', None)
+
def _hide_froms(self, **modifiers):
"""Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces."""
@@ -2203,11 +2213,10 @@ class Join(FromClause):
else:
equivs[x] = util.Set([y])
- class BinaryVisitor(visitors.ClauseVisitor):
- def visit_binary(self, binary):
- if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
- add_equiv(binary.left, binary.right)
- BinaryVisitor().traverse(self.onclause)
+ def visit_binary(binary):
+ if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
+ add_equiv(binary.left, binary.right)
+ visitors.traverse(self.onclause, visit_binary=visit_binary)
for col in pkcol:
for fk in col.foreign_keys:
@@ -2719,8 +2728,8 @@ class _SelectBaseMixin(object):
self._offset = offset
self._bind = bind
- self.append_order_by(*util.to_list(order_by, []))
- self.append_group_by(*util.to_list(group_by, []))
+ self._order_by_clause = ClauseList(*util.to_list(order_by, []))
+ self._group_by_clause = ClauseList(*util.to_list(group_by, []))
def as_scalar(self):
"""return a 'scalar' representation of this selectable, which can be used
@@ -2967,30 +2976,41 @@ class Select(_SelectBaseMixin, FromClause):
# usually called via a generative method, create a copy of each collection
# by default
- self._raw_columns = []
self.__correlate = util.Set()
- self._froms = util.OrderedSet()
- self._whereclause = None
self._having = None
self._prefixes = []
- if columns is not None:
- for c in columns:
- self.append_column(c, _copy_collection=False)
-
- if from_obj is not None:
- for f in from_obj:
- self.append_from(f, _copy_collection=False)
+ if columns:
+ self._raw_columns = [
+ isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
+ for c in
+ [_literal_as_column(c) for c in columns]
+ ]
+ else:
+ self._raw_columns = []
+
+ if from_obj:
+ self._froms = util.Set([
+ _is_literal(f) and _TextFromClause(f) or f
+ for f in from_obj
+ ])
+ else:
+ self._froms = util.Set()
- if whereclause is not None:
- self.append_whereclause(whereclause)
+ if whereclause:
+ self._whereclause = _literal_as_text(whereclause)
+ else:
+ self._whereclause = None
- if having is not None:
- self.append_having(having)
+ if having:
+ self._having = _literal_as_text(having)
+ else:
+ self._having = None
- if prefixes is not None:
- for p in prefixes:
- self.append_prefix(p, _copy_collection=False)
+ if prefixes:
+ self._prefixes = [_literal_as_text(p) for p in prefixes]
+ else:
+ self._prefixes = []
_SelectBaseMixin.__init__(self, **kwargs)
@@ -3003,48 +3023,30 @@ class Select(_SelectBaseMixin, FromClause):
correlating.
"""
- froms = util.OrderedSet()
+ froms = util.Set()
hide_froms = util.Set()
for col in self._raw_columns:
- for f in col._hide_froms():
- hide_froms.add(f)
- while hasattr(f, '_is_clone_of'):
- hide_froms.add(f._is_clone_of)
- f = f._is_clone_of
- for f in col._get_from_objects():
- froms.add(f)
+ hide_froms.update(col._aggregate_hide_froms())
+ froms.update(col._get_from_objects())
if self._whereclause is not None:
- for f in self._whereclause._get_from_objects(is_where=True):
- froms.add(f)
+ froms.update(self._whereclause._get_from_objects(is_where=True))
- for elem in self._froms:
- froms.add(elem)
- for f in elem._get_from_objects():
- froms.add(f)
-
- for elem in froms:
- for f in elem._hide_froms():
- hide_froms.add(f)
- while hasattr(f, '_is_clone_of'):
- hide_froms.add(f._is_clone_of)
- f = f._is_clone_of
+ if self._froms:
+ froms.update(self._froms)
+ for elem in self._froms:
+ hide_froms.update(elem._aggregate_hide_froms())
froms = froms.difference(hide_froms)
if len(froms) > 1:
corr = self.__correlate
if self._should_correlate and existing_froms is not None:
- corr = existing_froms.union(corr)
-
- for f in list(corr):
- while hasattr(f, '_is_clone_of'):
- corr.add(f._is_clone_of)
- f = f._is_clone_of
+ corr.update(existing_froms)
f = froms.difference(corr)
- if len(f) == 0:
+ if not f:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
return f
else:
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 70d1940e6..3e2d4ec31 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -3,6 +3,7 @@ from sqlalchemy.sql import expression, visitors
"""Utility functions that build upon SQL and Schema constructs."""
+# TODO: replace with plain list. break out sorting funcs into module-level funcs
class TableCollection(object):
def __init__(self, tables=None):
self.tables = tables or []
@@ -65,6 +66,7 @@ class TableCollection(object):
return sequence
+# TODO: replace with plain module-level func
class TableFinder(TableCollection, visitors.NoColumnVisitor):
"""locate all Tables within a clause."""
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 1a0629a17..150ee9cc7 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -98,3 +98,15 @@ class NoColumnVisitor(ClauseVisitor):
"""
__traverse_options__ = {'column_collections':False}
+
+def traverse(clause, **kwargs):
+ clone = kwargs.pop('clone', False)
+ class Vis(ClauseVisitor):
+ __traverse_options__ = kwargs.pop('traverse_options', {})
+ def __getattr__(self, key):
+ if key in kwargs:
+ return kwargs[key]
+ else:
+ return None
+ return Vis().traverse(clause, clone=clone)
+