diff options
| author | Jason Kirtland <jek@discorporate.us> | 2008-02-14 20:02:10 +0000 |
|---|---|---|
| committer | Jason Kirtland <jek@discorporate.us> | 2008-02-14 20:02:10 +0000 |
| commit | 71e745e96b8c5be990b3dc949cb99310dd055609 (patch) | |
| tree | 00c748e65e7e85e0231a1c7c504dec6cfcab8e87 /lib/sqlalchemy/sql | |
| parent | 8dd5eb402ef65194af4c54a6fd33a181b7d5eaf0 (diff) | |
| download | sqlalchemy-71e745e96b8c5be990b3dc949cb99310dd055609.tar.gz | |
- Fixed a couple pyflakes, cleaned up imports & whitespace
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 172 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 78 |
3 files changed, 123 insertions, 130 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3f32778d6..8d8cfa38f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -7,21 +7,20 @@ """Base SQL and DDL compiler implementations. Provides the [sqlalchemy.sql.compiler#DefaultCompiler] class, which is -responsible for generating all SQL query strings, as well as +responsible for generating all SQL query strings, as well as [sqlalchemy.sql.compiler#SchemaGenerator] and [sqlalchemy.sql.compiler#SchemaDropper] which issue CREATE and DROP DDL for tables, sequences, and indexes. The elements in this module are used by public-facing constructs like [sqlalchemy.sql.expression#ClauseElement] and [sqlalchemy.engine#Engine]. While dialect authors will want to be familiar with this module for the purpose of -creating database-specific compilers and schema generators, the module +creating database-specific compilers and schema generators, the module is otherwise internal to SQLAlchemy. """ import string, re from sqlalchemy import schema, engine, util, exceptions -from sqlalchemy.sql import operators, visitors, functions -from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import operators, functions from sqlalchemy.sql import expression as sql RESERVED_WORDS = util.Set([ @@ -57,7 +56,7 @@ BIND_TEMPLATES = { 'numeric':"%(position)s", 'named':":%(name)s" } - + OPERATORS = { operators.and_ : 'AND', @@ -96,14 +95,14 @@ OPERATORS = { FUNCTIONS = { functions.coalesce : 'coalesce%(expr)s', - functions.current_date: 'CURRENT_DATE', - functions.current_time: 'CURRENT_TIME', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', functions.current_timestamp: 'CURRENT_TIMESTAMP', - functions.current_user: 'CURRENT_USER', - functions.localtime: 'LOCALTIME', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', functions.localtimestamp: 'LOCALTIMESTAMP', functions.sysdate: 'sysdate', - functions.session_user :'SESSION_USER', + functions.session_user :'SESSION_USER', functions.user: 'USER' } @@ -118,7 +117,7 @@ class DefaultCompiler(engine.Compiled): operators = OPERATORS functions = FUNCTIONS - + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -132,35 +131,35 @@ class DefaultCompiler(engine.Compiled): a list of column names to be compiled into an INSERT or UPDATE statement. """ - + super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs) # if we are insert/update/delete. set to true when we visit an INSERT, UPDATE or DELETE self.isdelete = self.isinsert = self.isupdate = False - + # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) - + # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} - + # a dictionary of _BindParamClause instances to "compiled" names that are # actually present in the generated SQL self.bind_names = {} # a stack. what recursive compiler doesn't have a stack ? :) self.stack = [] - + # relates label names in the final SQL to # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. # ResultProxy uses this for type processing and column targeting self.result_map = {} - + # a dictionary of ClauseElement subclasses to counters, which are used to # generate truncated identifier names or "anonymous" identifiers such as # for aliases self.generated_ids = {} - + # paramstyle from the dialect (comes from DB-API) self.paramstyle = self.dialect.paramstyle @@ -168,17 +167,17 @@ class DefaultCompiler(engine.Compiled): self.positional = self.dialect.positional self.bindtemplate = BIND_TEMPLATES[self.paramstyle] - + # a list of the compiled's bind parameter names, used to help # formulate a positional argument list self.positiontup = [] # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer - + def compile(self): self.string = self.process(self.statement) - + def process(self, obj, stack=None, **kwargs): if stack: self.stack.append(stack) @@ -189,23 +188,23 @@ class DefaultCompiler(engine.Compiled): finally: if stack: self.stack.pop(-1) - + def is_subquery(self, select): return self.stack and self.stack[-1].get('is_subquery') - + def get_whereclause(self, obj): - """given a FROM clause, return an additional WHERE condition that should be - applied to a SELECT. - + """given a FROM clause, return an additional WHERE condition that should be + applied to a SELECT. + Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN constructs in non-ansi mode. """ - + return None def construct_params(self, params=None): """return a dictionary of bind parameter keys and values""" - + if params: pd = {} for bindparam, name in self.bind_names.iteritems(): @@ -218,9 +217,9 @@ class DefaultCompiler(engine.Compiled): return pd else: return dict([(self.bind_names[bindparam], bindparam.value) for bindparam in self.bind_names]) - + params = property(construct_params) - + def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -228,22 +227,22 @@ class DefaultCompiler(engine.Compiled): """ return "" - + def visit_grouping(self, grouping, **kwargs): return "(" + self.process(grouping.elem) + ")" - + def visit_label(self, label, result_map=None): labelname = self._truncated_identifier("colident", label.name) - + if result_map is not None: result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type) - + return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) - + def visit_column(self, column, result_map=None, use_schema=False, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily - # want to truncate a column identifier, if its mapped to the name of a - # physical column. but thats very hard to identify at this point, and + # want to truncate a column identifier, if its mapped to the name of a + # physical column. but thats very hard to identify at this point, and # the identifier length should be greater than the id lengths of any physical # columns so should not matter. @@ -259,7 +258,7 @@ class DefaultCompiler(engine.Compiled): if result_map is not None: result_map[name.lower()] = (name, (column, ), column.type) - + if column._is_oid: n = self.dialect.oid_column_name(column) if n is not None: @@ -288,7 +287,7 @@ class DefaultCompiler(engine.Compiled): # TODO: some dialects might need different behavior here return text.replace('%', '%%') - + def visit_fromclause(self, fromclause, **kwargs): return fromclause.name @@ -302,7 +301,7 @@ class DefaultCompiler(engine.Compiled): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): self.result_map[colname.lower()] = (colname, None, type_) - + def do_bindparam(m): name = m.group(1) if name in textclause.bindparams: @@ -311,7 +310,7 @@ class DefaultCompiler(engine.Compiled): return self.bindparam_string(name) # un-escape any \:params - return BIND_PARAMS_ESC.sub(lambda m: m.group(1), + return BIND_PARAMS_ESC.sub(lambda m: m.group(1), BIND_PARAMS.sub(do_bindparam, textclause.text) ) @@ -339,37 +338,37 @@ class DefaultCompiler(engine.Compiled): result_map[func.name.lower()] = (func.name, None, func.type) name = self.function_string(func) - + if callable(name): return name(*[self.process(x) for x in func.clause_expr]) else: return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} - + def function_argspec(self, func): return self.process(func.clause_expr) - + def function_string(self, func): return self.functions.get(func.__class__, func.name + "%(expr)s") def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): stack_entry = {'select':cs} - + if asfrom: stack_entry['is_subquery'] = True elif self.stack and self.stack[-1].get('select'): stack_entry['is_subquery'] = True self.stack.append(stack_entry) - + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by - text += self.order_by_clause(cs) + text += self.order_by_clause(cs) text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" - + self.stack.pop(-1) - + if asfrom and parens: return "(" + text + ")" else: @@ -382,19 +381,17 @@ class DefaultCompiler(engine.Compiled): if unary.modifier: s = s + " " + self.operator_string(unary.modifier) return s - + def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) if callable(op): return op(self.process(binary.left), self.process(binary.right)) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) - - return ret - + def operator_string(self, operator): return self.operators.get(operator, str(operator)) - + def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: @@ -403,22 +400,22 @@ class DefaultCompiler(engine.Compiled): raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) - + def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] - + bind_name = bindparam.key bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation self.bind_names[bindparam] = bind_name - + return bind_name - + def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] - + anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) if len(anonname) > self.dialect.max_identifier_length: @@ -441,14 +438,14 @@ class DefaultCompiler(engine.Compiled): self.generated_ids[('anon_counter', derived)] = anonymous_counter + 1 self.generated_ids[key] = newname return newname - + def _anonymize(self, name): return ANONYMOUS_LABEL.sub(self._process_anon, name) - + def bindparam_string(self, name): if self.positional: self.positiontup.append(name) - + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} def visit_alias(self, alias, asfrom=False, **kwargs): @@ -459,13 +456,13 @@ class DefaultCompiler(engine.Compiled): def label_select_column(self, select, column, asfrom): """label columns present in a select().""" - + if isinstance(column, sql._Label): return column - + if select.use_labels and getattr(column, '_label', None): return column.label(column._label) - + if \ asfrom and \ isinstance(column, sql._ColumnClause) and \ @@ -494,12 +491,12 @@ class DefaultCompiler(engine.Compiled): stack_entry['iswrapper'] = True else: column_clause_args = {'result_map':self.result_map} - + if self.stack and 'from' in self.stack[-1]: existingfroms = self.stack[-1]['from'] else: existingfroms = None - + froms = select._get_display_froms(existingfroms) correlate_froms = util.Set() @@ -510,17 +507,17 @@ class DefaultCompiler(engine.Compiled): # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost # if existingfroms: -# correlate_froms = correlate_froms.union(existingfroms) +# correlate_froms = correlate_froms.union(existingfroms) stack_entry['from'] = correlate_froms self.stack.append(stack_entry) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet() - + for co in select.inner_columns: l = self.label_select_column(select, co, asfrom=asfrom) inner_columns.add(self.process(l, **column_clause_args)) - + collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " @@ -539,7 +536,7 @@ class DefaultCompiler(engine.Compiled): whereclause = sql.and_(w, whereclause) else: whereclause = w - + if froms: text += " \nFROM " text += string.join(from_strings, ', ') @@ -559,7 +556,7 @@ class DefaultCompiler(engine.Compiled): t = self.process(select._having) if t: text += " \nHAVING " + t - + text += self.order_by_clause(select) text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) @@ -625,10 +622,10 @@ class DefaultCompiler(engine.Compiled): ', '.join([preparer.quote(c[0], c[0].name) for c in colparams]), ', '.join([c[1] for c in colparams]))) - + def visit_update(self, update_stmt): self.stack.append({'from':util.Set([update_stmt.table])}) - + self.isupdate = True colparams = self._get_colparams(update_stmt) @@ -636,15 +633,15 @@ class DefaultCompiler(engine.Compiled): if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) - + self.stack.pop(-1) - + return text def _get_colparams(self, stmt): - """create a set of tuples representing column/string pairs for use + """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. - + """ def create_bind_param(col, value): @@ -654,7 +651,7 @@ class DefaultCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] - + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: @@ -688,7 +685,7 @@ class DefaultCompiler(engine.Compiled): if (((isinstance(c.default, schema.Sequence) and not c.default.optional) or not self.dialect.supports_pk_autoincrement) or - (c.default is not None and + (c.default is not None and not isinstance(c.default, schema.Sequence))): values.append((c, create_bind_param(c, None))) self.prefetch.append(c) @@ -732,18 +729,18 @@ class DefaultCompiler(engine.Compiled): text += " WHERE " + self.process(delete_stmt._whereclause) self.stack.pop(-1) - + return text - + def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - + def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - + def __str__(self): return self.string or '' @@ -1072,10 +1069,10 @@ class IdentifierPreparer(object): def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name. - + deprecated. use preparer.quote(col, column.name) or combine with format_table() """ - + if name is None: name = column.name if not getattr(column, 'is_literal', False): @@ -1121,7 +1118,6 @@ class IdentifierPreparer(object): 'final': final, 'escaped': escaped_final }) self._r_identifiers = r - + return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] - diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0b7684803..b39e406da 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -25,8 +25,7 @@ classes usually have few or no public methods and are less guaranteed to stay the same in future releases. """ -import datetime, re -import itertools +import itertools, re from sqlalchemy import util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 2cd0a26fd..70a1dcc96 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,5 +1,5 @@ -from sqlalchemy import util, schema, topological -from sqlalchemy.sql import expression, visitors, operators +from sqlalchemy import exceptions, schema, topological, util +from sqlalchemy.sql import expression, operators, visitors from itertools import chain """Utility functions that build upon SQL and Schema constructs.""" @@ -30,16 +30,16 @@ def find_tables(clause, check_columns=False, include_aliases=False): def visit_alias(alias): tables.append(alias) kwargs['visit_alias'] = visit_alias - + if check_columns: def visit_column(column): tables.append(column.table) kwargs['visit_column'] = visit_column - + def visit_table(table): tables.append(table) kwargs['visit_table'] = visit_table - + visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) return tables @@ -49,26 +49,26 @@ def find_columns(clause): cols.add(col) visitors.traverse(clause, visit_column=visit_column) return cols - - + + def reduce_columns(columns, *clauses): """given a list of columns, return a 'reduced' set based on natural equivalents. the set is reduced to the smallest list of columns which have no natural equivalent present in the list. A "natural equivalent" means that two columns will ultimately represent the same value because they are related by a foreign key. - + \*clauses is an optional list of join clauses which will be traversed to further identify columns that are "equivalent". - + This function is primarily used to determine the most minimal "primary key" from a selectable, by reducing the set of primary key columns present in the the selectable to just those that are not repeated. - + """ - + columns = util.OrderedSet(columns) - + omit = util.Set() for col in columns: for fk in col.foreign_keys: @@ -78,7 +78,7 @@ def reduce_columns(columns, *clauses): if fk.column.shares_lineage(c): omit.add(col) break - + if clauses: def visit_binary(binary): if binary.operator == operators.eq: @@ -90,7 +90,7 @@ def reduce_columns(columns, *clauses): break for clause in clauses: visitors.traverse(clause, visit_binary=visit_binary) - + return expression.ColumnSet(columns.difference(omit)) def row_adapter(from_, to, equivalent_columns=None): @@ -133,7 +133,7 @@ def row_adapter(from_, to, equivalent_columns=None): return map.keys() AliasedRow.map = map return AliasedRow - + class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns from the clause are in the selectable. @@ -149,16 +149,16 @@ class ColumnsInClause(visitors.ClauseVisitor): class AbstractClauseProcessor(object): """Traverse and copy a ClauseElement, replacing selected elements based on rules. - + This class implements its own visit-and-copy strategy but maintains the same public interface as visitors.ClauseVisitor. """ - + __traverse_options__ = {'column_collections':False} - + def __init__(self, stop_on=None): self.stop_on = stop_on - + def convert_element(self, elem): """Define the *conversion* method for this ``AbstractClauseProcessor``.""" @@ -166,14 +166,14 @@ class AbstractClauseProcessor(object): def chain(self, visitor): # chaining AbstractClauseProcessor and other ClauseVisitor - # objects separately. All the ACP objects are chained on + # objects separately. All the ACP objects are chained on # their convert_element() method whereas regular visitors # chain on their visit_XXX methods. if isinstance(visitor, AbstractClauseProcessor): attr = '_next_acp' else: attr = '_next' - + tail = self while getattr(tail, attr, None) is not None: tail = getattr(tail, attr) @@ -182,7 +182,7 @@ class AbstractClauseProcessor(object): def copy_and_process(self, list_): """Copy the given list to a new list, with each element traversed individually.""" - + list_ = list(list_) stop_on = util.Set(self.stop_on or []) cloned = {} @@ -198,44 +198,44 @@ class AbstractClauseProcessor(object): stop_on.add(newelem) return newelem v = getattr(v, '_next_acp', None) - + if elem not in cloned: # the full traversal will only make a clone of a particular element # once. cloned[elem] = elem._clone() return cloned[elem] - + def traverse(self, elem, clone=True): if not clone: raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True") - + return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True) - + def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False): if elem in stop_on: return elem - + if _clone_toplevel: elem = self._convert_element(elem, stop_on, cloned) if elem in stop_on: return elem - + def clone(element): return self._convert_element(element, stop_on, cloned) elem._copy_internals(clone=clone) - + v = getattr(self, '_next', None) while v is not None: meth = getattr(v, "visit_%s" % elem.__visit_name__, None) if meth: meth(elem) v = getattr(v, '_next', None) - + for e in elem.get_children(**self.__traverse_options__): if e not in stop_on: self._traverse(e, stop_on, cloned) return elem - + class ClauseAdapter(AbstractClauseProcessor): """Given a clause (like as in a WHERE criterion), locate columns which are embedded within a given selectable, and changes those @@ -273,23 +273,23 @@ class ClauseAdapter(AbstractClauseProcessor): def copy_and_chain(self, adapter): """create a copy of this adapter and chain to the given adapter. - + currently this adapter must be unchained to start, raises - an exception if it's already chained. - + an exception if it's already chained. + Does not modify the given adapter. """ - + if adapter is None: return self - + if hasattr(self, '_next_acp') or hasattr(self, '_next'): raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") - + ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) ca._next_acp = adapter return ca - + def convert_element(self, col): if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): @@ -309,5 +309,3 @@ class ClauseAdapter(AbstractClauseProcessor): if newcol: return newcol return newcol - - |
