diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/connectors/mxodbc.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/access/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/maxdb/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 44 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/mxodbc.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/pyodbc.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sybase/base.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sybase/pyodbc.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 141 |
11 files changed, 187 insertions, 79 deletions
diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py index 29b047d23..68b88019c 100644 --- a/lib/sqlalchemy/connectors/mxodbc.py +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -96,9 +96,4 @@ class MxODBCConnector(Connector): version.append(n) return tuple(version) - def do_execute(self, cursor, statement, parameters, context=None): - # TODO: dont need tuple() here - # TODO: use cursor.execute() - cursor.executedirect(statement, tuple(parameters)) - diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index c10e77011..7dfb3153e 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -371,9 +371,9 @@ class AccessCompiler(compiler.SQLCompiler): return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) class AccessDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index 504c31209..758cfaf05 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -558,8 +558,8 @@ class MaxDBCompiler(compiler.SQLCompiler): return labels - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) + def order_by_clause(self, select, **kw): + order_by = self.process(select._order_by_clause, **kw) # ORDER BY clauses in DISTINCT queries must reference aliased # inner columns by alias name, not true column name. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 254aa54fd..4d697854f 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -843,8 +843,9 @@ class MSExecutionContext(default.DefaultExecutionContext): class MSSQLCompiler(compiler.SQLCompiler): returning_precedes_values = True - extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update ({ + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { 'doy': 'dayofyear', 'dow': 'weekday', 'milliseconds': 'millisecond', @@ -937,9 +938,9 @@ class MSSQLCompiler(compiler.SQLCompiler): kwargs['mssql_aliased'] = True return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) @@ -1011,8 +1012,8 @@ class MSSQLCompiler(compiler.SQLCompiler): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use return '' - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) + def order_by_clause(self, select, **kw): + order_by = self.process(select._order_by_clause, **kw) # MSSQL only allows ORDER BY in subqueries if there is a LIMIT if order_by and (not self.is_subquery() or select._limit): @@ -1020,6 +1021,37 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" +class MSSQLStrictCompiler(MSSQLCompiler): + """A subclass of MSSQLCompiler which disables the usage of bind + parameters where not allowed natively by MS-SQL. + + A dialect may use this compiler on a platform where native + binds are used. + + """ + ansi_bind_rules = True + + def visit_in_op(self, binary, **kw): + kw['literal_binds'] = True + return "%s IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_notin_op(self, binary, **kw): + kw['literal_binds'] = True + return "%s NOT IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_function(self, func, **kw): + kw['literal_binds'] = True + return super(MSSQLStrictCompiler, self).visit_function(func, **kw) + + #def render_literal_value(self, value): + # TODO! use mxODBC's literal quoting services here + class MSDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index bf14601b8..59cf65d63 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -4,9 +4,10 @@ import sys from sqlalchemy import types as sqltypes from sqlalchemy.connectors.mxodbc import MxODBCConnector from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc -from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect - +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, \ + MSSQLCompiler, MSSQLStrictCompiler + class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): """ The pyodbc execution context is useful for enabling @@ -20,7 +21,11 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): class MSDialect_mxodbc(MxODBCConnector, MSDialect): execution_ctx_cls = MSExecutionContext_mxodbc - + + # TODO: may want to use this only if FreeTDS is not in use, + # since FreeTDS doesn't seem to use native binds. + statement_compiler = MSSQLStrictCompiler + def __init__(self, description_encoding='latin-1', **params): super(MSDialect_mxodbc, self).__init__(**params) self.description_encoding = description_encoding diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 54b43320a..34050271f 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -1,3 +1,16 @@ +""" +Support for MS-SQL via pyodbc. + +http://pypi.python.org/pypi/pyodbc/ + +Connect strings are of the form:: + + mssql+pyodbc://<username>:<password>@<dsn>/ + mssql+pyodbc://<username>:<password>@<host>/<database> + + +""" + from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy import types as sqltypes diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 37e63fbc1..98df8d0cb 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -218,10 +218,10 @@ class SQLiteCompiler(compiler.SQLCompiler): else: return self.process(cast.clause) - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): try: return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( - self.extract_map[extract.field], self.process(extract.expr)) + self.extract_map[extract.field], self.process(extract.expr, **kw)) except KeyError: raise exc.ArgumentError( "%s is not a valid extract argument." % extract.field) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 5d20faaf9..c440015d0 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -236,9 +236,11 @@ class SybaseExecutionContext(default.DefaultExecutionContext): return lastrowid class SybaseSQLCompiler(compiler.SQLCompiler): + ansi_bind_rules = True - extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update ({ + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { 'doy': 'dayofyear', 'dow': 'weekday', 'milliseconds': 'millisecond' @@ -267,33 +269,17 @@ class SybaseSQLCompiler(compiler.SQLCompiler): # Limit in sybase is after the select keyword return "" - def dont_visit_binary(self, binary): - """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) - else: - return super(SybaseSQLCompiler, self).visit_binary(binary) - - def dont_label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) - -# def visit_getdate_func(self, fn, **kw): - # TODO: need to cast? something ? -# pass - - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use return '' - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) + def order_by_clause(self, select, **kw): + kw['literal_binds'] = True + order_by = self.process(select._order_by_clause, **kw) # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT if order_by and (not self.is_subquery() or select._limit): diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index 4f89fe334..1bfdb6151 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -1,7 +1,12 @@ """ Support for Sybase via pyodbc. -This dialect is a stub only and is likely non functional at this time. +http://pypi.python.org/pypi/pyodbc/ + +Connect strings are of the form:: + + sybase+pyodbc://<username>:<password>@<dsn>/ + sybase+pyodbc://<username>:<password>@<host>/<database> """ diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ce24a9ae4..2ef8fd104 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -62,6 +62,7 @@ class DefaultDialect(base.Dialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True dbapi_type_map = {} + colspecs = {} default_paramstyle = 'named' supports_default_values = False supports_empty_insert = True diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index be3375def..4e9175ae8 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -26,6 +26,7 @@ import re from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions, util as sql_util, visitors from sqlalchemy.sql import expression as sql +import decimal RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -183,6 +184,12 @@ class SQLCompiler(engine.Compiled): # clauses before the VALUES or WHERE clause (i.e. MSSQL) returning_precedes_values = False + # SQL 92 doesn't allow bind parameters to be used + # in the columns clause of a SELECT, nor does it allow + # ambiguous expressions like "? = ?". A compiler + # subclass can set this flag to False if the target + # driver/DB enforces this + ansi_bind_rules = False def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -260,9 +267,14 @@ class SQLCompiler(engine.Compiled): else: if bindparam.required: if _group_number: - raise exc.InvalidRequestError("A value is required for bind parameter %r, in parameter group %d" % (bindparam.key, _group_number)) + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" % + (bindparam.key, _group_number)) else: - raise exc.InvalidRequestError("A value is required for bind parameter %r" % bindparam.key) + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key) elif util.callable(bindparam.value): pd[name] = bindparam.value() else: @@ -290,10 +302,10 @@ class SQLCompiler(engine.Compiled): """ return "" - def visit_grouping(self, grouping, **kwargs): - return "(" + self.process(grouping.element) + ")" + def visit_grouping(self, grouping, asfrom=False, **kwargs): + return "(" + self.process(grouping.element, **kwargs) + ")" - def visit_label(self, label, result_map=None, within_columns_clause=False): + def visit_label(self, label, result_map=None, within_columns_clause=False, **kw): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. @@ -305,11 +317,15 @@ class SQLCompiler(engine.Compiled): result_map[labelname.lower()] = \ (label.name, (label, label.element, labelname), label.element.type) - return self.process(label.element) + \ + return self.process(label.element, + within_columns_clause=within_columns_clause, + **kw) + \ OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: - return self.process(label.element) + return self.process(label.element, + within_columns_clause=within_columns_clause, + **kw) def visit_column(self, column, result_map=None, **kwargs): name = column.name @@ -384,27 +400,28 @@ class SQLCompiler(engine.Compiled): sep = " " else: sep = OPERATORS[clauselist.operator] - return sep.join(s for s in (self.process(c) for c in clauselist.clauses) + return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses) if s is not None) def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: - x += self.process(clause.value) + " " + x += self.process(clause.value, **kwargs) + " " for cond, result in clause.whens: - x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " " + x += "WHEN " + self.process(cond, **kwargs) + \ + " THEN " + self.process(result, **kwargs) + " " if clause.else_ is not None: - x += "ELSE " + self.process(clause.else_) + " " + x += "ELSE " + self.process(clause.else_, **kwargs) + " " x += "END" return x def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % \ - (self.process(cast.clause), self.process(cast.typeclause)) + (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs)) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr)) + return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs)) def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: @@ -421,22 +438,23 @@ class SQLCompiler(engine.Compiled): def function_argspec(self, func, **kwargs): return self.process(func.clause_expr, **kwargs) - def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): + def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (self.process(c, asfrom=asfrom, parens=False, compound_index=i) + (self.process(c, asfrom=asfrom, parens=False, + compound_index=i, **kwargs) for i, c in enumerate(cs.selects)) ) - group_by = self.process(cs._group_by_clause, asfrom=asfrom) + group_by = self.process(cs._group_by_clause, asfrom=asfrom, **kwargs) if group_by: text += " GROUP BY " + group_by - text += self.order_by_clause(cs) + text += self.order_by_clause(cs, **kwargs) text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or "" self.stack.pop(-1) @@ -453,32 +471,47 @@ class SQLCompiler(engine.Compiled): s = s + OPERATORS[unary.modifier] return s - def visit_binary(self, binary, **kwargs): - + def visit_binary(self, binary, **kw): + # don't allow "? = ?" to render + if self.ansi_bind_rules and \ + isinstance(binary.left, sql._BindParamClause) and \ + isinstance(binary.right, sql._BindParamClause): + kw['literal_binds'] = True + return self._operator_dispatch(binary.operator, binary, - lambda opstr: self.process(binary.left) + opstr + self.process(binary.right), - **kwargs + lambda opstr: self.process(binary.left, **kw) + + opstr + + self.process(binary.right, **kw), + **kw ) def visit_like_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) - return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + return '%s LIKE %s' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + (escape and ' ESCAPE \'%s\'' % escape or '') def visit_notlike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + return '%s NOT LIKE %s' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + (escape and ' ESCAPE \'%s\'' % escape or '') def visit_ilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + return 'lower(%s) LIKE lower(%s)' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + (escape and ' ESCAPE \'%s\'' % escape or '') def visit_notilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + return 'lower(%s) NOT LIKE lower(%s)' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + (escape and ' ESCAPE \'%s\'' % escape or '') def _operator_dispatch(self, operator, element, fn, **kw): @@ -491,7 +524,16 @@ class SQLCompiler(engine.Compiled): else: return fn(" " + operator + " ") - def visit_bindparam(self, bindparam, **kwargs): + def visit_bindparam(self, bindparam, within_columns_clause=False, + literal_binds=False, **kwargs): + if literal_binds or \ + (within_columns_clause and \ + self.ansi_bind_rules): + if bindparam.value is None: + raise exc.CompileError("Bind parameter without a " + "renderable value not allowed here.") + return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs) + name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] @@ -510,7 +552,36 @@ class SQLCompiler(engine.Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) - + + def render_literal_bindparam(self, bindparam, **kw): + value = bindparam.value + processor = bindparam.bind_processor(self.dialect) + if processor: + value = processor(value) + return self.render_literal_value(value, bindparam.type) + + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + This is used for statement sections that do not accept bind paramters + on the target driver/database. + + This should be implemented by subclasses using the quoting services + of the DBAPI. + + """ + if isinstance(value, basestring): + value = value.replace("'", "''") + return "'%s'" % value + elif value is None: + return "NULL" + elif isinstance(value, (float, int, long)): + return repr(value) + elif isinstance(value, decimal.Decimal): + return str(value) + else: + raise NotImplementedError("Don't know how to literal-quote value %r" % value) + def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] @@ -624,33 +695,33 @@ class SQLCompiler(engine.Compiled): text = "SELECT " # we're off to a good start ! if select._prefixes: - text += " ".join(self.process(x) for x in select._prefixes) + " " + text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " text += self.get_select_precolumns(select) text += ', '.join(inner_columns) if froms: text += " \nFROM " - text += ', '.join(self.process(f, asfrom=True) for f in froms) + text += ', '.join(self.process(f, asfrom=True, **kwargs) for f in froms) else: text += self.default_from() if select._whereclause is not None: - t = self.process(select._whereclause) + t = self.process(select._whereclause, **kwargs) if t: text += " \nWHERE " + t if select._group_by_clause.clauses: - group_by = self.process(select._group_by_clause) + group_by = self.process(select._group_by_clause, **kwargs) if group_by: text += " GROUP BY " + group_by if select._having is not None: - t = self.process(select._having) + t = self.process(select._having, **kwargs) if t: text += " \nHAVING " + t if select._order_by_clause.clauses: - text += self.order_by_clause(select) + text += self.order_by_clause(select, **kwargs) if select._limit is not None or select._offset is not None: text += self.limit_clause(select) if select.for_update: @@ -670,8 +741,8 @@ class SQLCompiler(engine.Compiled): """ return select._distinct and "DISTINCT " or "" - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) + def order_by_clause(self, select, **kw): + order_by = self.process(select._order_by_clause, **kw) if order_by: return " ORDER BY " + order_by else: |
