diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/connectors/pyodbc.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/access/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 40 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/mxodbc.py | 40 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/pyodbc.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/oursql.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sybase/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sybase/base.py | 57 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sybase/pyodbc.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sybase/pysybase.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 53 |
13 files changed, 174 insertions, 114 deletions
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index ce8e84c33..e503135f7 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -19,6 +19,10 @@ class PyODBCConnector(Connector): # hold the desired driver name pyodbc_driver_name = None + # will be set to True after initialize() + # if the freetds.so is detected + freetds = False + @classmethod def dbapi(cls): return __import__('pyodbc') @@ -76,6 +80,26 @@ class PyODBCConnector(Connector): else: return False + def initialize(self, connection): + # determine FreeTDS first. can't issue SQL easily + # without getting unicode_statements/binds set up. + + pyodbc = self.dbapi + + dbapi_con = connection.connection + + self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME))) + + # the "Py2K only" part here is theoretical. + # have not tried pyodbc + python3.1 yet. + # Py2K + self.supports_unicode_statements = not self.freetds + self.supports_unicode_binds = not self.freetds + # end Py2K + + # run other initialization which asks for user name, etc. + super(PyODBCConnector, self).initialize(connection) + def _get_server_version_info(self, connection): dbapi_con = connection.connection version = [] diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index c10e77011..7dfb3153e 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -371,9 +371,9 @@ class AccessCompiler(compiler.SQLCompiler): return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) class AccessDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index eb4073b94..4d697854f 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -843,8 +843,9 @@ class MSExecutionContext(default.DefaultExecutionContext): class MSSQLCompiler(compiler.SQLCompiler): returning_precedes_values = True - extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update ({ + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { 'doy': 'dayofyear', 'dow': 'weekday', 'milliseconds': 'millisecond', @@ -937,9 +938,9 @@ class MSSQLCompiler(compiler.SQLCompiler): kwargs['mssql_aliased'] = True return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) @@ -1020,6 +1021,37 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" +class MSSQLStrictCompiler(MSSQLCompiler): + """A subclass of MSSQLCompiler which disables the usage of bind + parameters where not allowed natively by MS-SQL. + + A dialect may use this compiler on a platform where native + binds are used. + + """ + ansi_bind_rules = True + + def visit_in_op(self, binary, **kw): + kw['literal_binds'] = True + return "%s IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_notin_op(self, binary, **kw): + kw['literal_binds'] = True + return "%s NOT IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_function(self, func, **kw): + kw['literal_binds'] = True + return super(MSSQLStrictCompiler, self).visit_function(func, **kw) + + #def render_literal_value(self, value): + # TODO! use mxODBC's literal quoting services here + class MSDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index 3dcc78b8c..59cf65d63 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -4,43 +4,9 @@ import sys from sqlalchemy import types as sqltypes from sqlalchemy.connectors.mxodbc import MxODBCConnector from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc -from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSSQLCompiler +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, \ + MSSQLCompiler, MSSQLStrictCompiler -# TODO: does Pyodbc on windows have the same limitations ? -# if so this compiler can be moved to a common "odbc.py" module -# here -# *or* - should we implement this for MS-SQL across the board -# since its technically MS-SQL's behavior ? -# perhaps yes, with a dialect flag "strict_binds" to turn it off -class MSSQLCompiler_mxodbc(MSSQLCompiler): - binds_in_columns_clause = False - - def visit_in_op(self, binary, **kw): - kw['literal_binds'] = True - return "%s IN %s" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw) - ) - - def visit_notin_op(self, binary, **kw): - kw['literal_binds'] = True - return "%s NOT IN %s" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw) - ) - - def visit_function(self, func, **kw): - kw['literal_binds'] = True - return super(MSSQLCompiler_mxodbc, self).visit_function(func, **kw) - - def render_literal_value(self, value): - # TODO! use mxODBC's literal quoting services here - if isinstance(value, basestring): - value = value.replace("'", "''") - return "'%s'" % value - else: - return repr(value) - class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): """ @@ -58,7 +24,7 @@ class MSDialect_mxodbc(MxODBCConnector, MSDialect): # TODO: may want to use this only if FreeTDS is not in use, # since FreeTDS doesn't seem to use native binds. - statement_compiler = MSSQLCompiler_mxodbc + statement_compiler = MSSQLStrictCompiler def __init__(self, description_encoding='latin-1', **params): super(MSDialect_mxodbc, self).__init__(**params) diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 9ef065b1a..34050271f 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -1,3 +1,16 @@ +""" +Support for MS-SQL via pyodbc. + +http://pypi.python.org/pypi/pyodbc/ + +Connect strings are of the form:: + + mssql+pyodbc://<username>:<password>@<dsn>/ + mssql+pyodbc://<username>:<password>@<host>/<database> + + +""" + from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy import types as sqltypes @@ -60,19 +73,4 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): self.description_encoding = description_encoding self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') - def initialize(self, connection): - super(MSDialect_pyodbc, self).initialize(connection) - pyodbc = self.dbapi - - dbapi_con = connection.connection - - self.freetds = re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)) - - # the "Py2K only" part here is theoretical. - # have not tried pyodbc + python3.1 yet. - # Py2K - self.supports_unicode_statements = not self.freetds - self.supports_unicode_binds = not self.freetds - # end Py2K - dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index 6c8bbcac4..605b39760 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -56,8 +56,6 @@ class MySQLDialect_oursql(MySQLDialect): driver = 'oursql' # Py3K # description_encoding = None -# supports_unicode_binds = False -# supports_unicode_statements = False # Py2K supports_unicode_binds = True supports_unicode_statements = True diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 37e63fbc1..98df8d0cb 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -218,10 +218,10 @@ class SQLiteCompiler(compiler.SQLCompiler): else: return self.process(cast.clause) - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): try: return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( - self.extract_map[extract.field], self.process(extract.expr)) + self.extract_map[extract.field], self.process(extract.expr, **kw)) except KeyError: raise exc.ArgumentError( "%s is not a valid extract argument." % extract.field) diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py index 573aedde3..4d9b07007 100644 --- a/lib/sqlalchemy/dialects/sybase/__init__.py +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -1,4 +1,4 @@ -from sqlalchemy.dialects.sybase import base, pysybase +from sqlalchemy.dialects.sybase import base, pysybase, pyodbc from base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 2addba2f8..c440015d0 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -176,7 +176,19 @@ ischema_names = { class SybaseExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False - + + def set_ddl_autocommit(self, connection, value): + """Must be implemented by subclasses to accommodate DDL executions. + + "connection" is the raw unwrapped DBAPI connection. "value" + is True or False. when True, the connection should be configured + such that a DDL can take place subsequently. when False, + a DDL has taken place and the connection should be resumed + into non-autocommit mode. + + """ + raise NotImplementedError() + def pre_exec(self): if self.isinsert: tbl = self.compiled.statement.table @@ -192,7 +204,22 @@ class SybaseExecutionContext(default.DefaultExecutionContext): self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.dialect.identifier_preparer.format_table(tbl)) + if self.isddl: + # TODO: to enhance this, we can detect "ddl in tran" on the + # database settings. this error message should be improved to + # include a note about that. + if not self.should_autocommit: + raise exc.InvalidRequestError("The Sybase dialect only supports " + "DDL in 'autocommit' mode at this time.") + + self.root_connection.engine.logger.info("AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") + + self.set_ddl_autocommit(self.root_connection.connection.connection, True) + + def post_exec(self): + if self.isddl: + self.set_ddl_autocommit(self.root_connection, False) if self._enable_identity_insert: self.cursor.execute( @@ -209,9 +236,11 @@ class SybaseExecutionContext(default.DefaultExecutionContext): return lastrowid class SybaseSQLCompiler(compiler.SQLCompiler): + ansi_bind_rules = True - extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update ({ + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { 'doy': 'dayofyear', 'dow': 'weekday', 'milliseconds': 'millisecond' @@ -240,32 +269,16 @@ class SybaseSQLCompiler(compiler.SQLCompiler): # Limit in sybase is after the select keyword return "" - def dont_visit_binary(self, binary): - """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) - else: - return super(SybaseSQLCompiler, self).visit_binary(binary) - - def dont_label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) - -# def visit_getdate_func(self, fn, **kw): - # TODO: need to cast? something ? -# pass - - def visit_extract(self, extract): + def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use return '' def order_by_clause(self, select, **kw): + kw['literal_binds'] = True order_by = self.process(select._order_by_clause, **kw) # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index 642ae3219..1bfdb6151 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -1,7 +1,12 @@ """ Support for Sybase via pyodbc. -This dialect is a stub only and is likely non functional at this time. +http://pypi.python.org/pypi/pyodbc/ + +Connect strings are of the form:: + + sybase+pyodbc://<username>:<password>@<dsn>/ + sybase+pyodbc://<username>:<password>@<host>/<database> """ @@ -10,7 +15,11 @@ from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContex from sqlalchemy.connectors.pyodbc import PyODBCConnector class SybaseExecutionContext_pyodbc(SybaseExecutionContext): - pass + def set_ddl_autocommit(self, connection, value): + if value: + connection.autocommit = True + else: + connection.autocommit = False class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py index 195407384..200ce11a2 100644 --- a/lib/sqlalchemy/dialects/sybase/pysybase.py +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -20,6 +20,14 @@ from sqlalchemy.dialects.sybase.base import SybaseDialect, \ class SybaseExecutionContext_pysybase(SybaseExecutionContext): + + def set_ddl_autocommit(self, dbapi_connection, value): + if value: + # call commit() on the Sybase connection directly, + # to avoid any side effects of calling a Connection + # transactional method inside of pre_exec() + dbapi_connection.commit() + def pre_exec(self): SybaseExecutionContext.pre_exec(self) @@ -28,18 +36,6 @@ class SybaseExecutionContext_pysybase(SybaseExecutionContext): param["@" + key] = param[key] del param[key] - if self.isddl: - # TODO: to enhance this, we can detect "ddl in tran" on the - # database settings. this error message should be improved to - # include a note about that. - if not self.should_autocommit: - raise exc.InvalidRequestError("The Sybase dialect only supports " - "DDL in 'autocommit' mode at this time.") - # call commit() on the Sybase connection directly, - # to avoid any side effects of calling a Connection - # transactional method inside of pre_exec() - self.root_connection.engine.logger.info("COMMIT (Assuming no Sybase 'ddl in tran')") - self.root_connection.connection.commit() class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): def bindparam_string(self, name): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ce24a9ae4..2ef8fd104 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -62,6 +62,7 @@ class DefaultDialect(base.Dialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True dbapi_type_map = {} + colspecs = {} default_paramstyle = 'named' supports_default_values = False supports_empty_insert = True diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a3008d085..4e9175ae8 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -26,6 +26,7 @@ import re from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions, util as sql_util, visitors from sqlalchemy.sql import expression as sql +import decimal RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -184,10 +185,11 @@ class SQLCompiler(engine.Compiled): returning_precedes_values = False # SQL 92 doesn't allow bind parameters to be used - # in the columns clause of a SELECT. A compiler + # in the columns clause of a SELECT, nor does it allow + # ambiguous expressions like "? = ?". A compiler # subclass can set this flag to False if the target # driver/DB enforces this - binds_in_columns_clause = True + ansi_bind_rules = False def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -303,7 +305,7 @@ class SQLCompiler(engine.Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + self.process(grouping.element, **kwargs) + ")" - def visit_label(self, label, result_map=None, within_columns_clause=False): + def visit_label(self, label, result_map=None, within_columns_clause=False, **kw): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. @@ -315,11 +317,15 @@ class SQLCompiler(engine.Compiled): result_map[labelname.lower()] = \ (label.name, (label, label.element, labelname), label.element.type) - return self.process(label.element) + \ + return self.process(label.element, + within_columns_clause=within_columns_clause, + **kw) + \ OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: - return self.process(label.element) + return self.process(label.element, + within_columns_clause=within_columns_clause, + **kw) def visit_column(self, column, result_map=None, **kwargs): name = column.name @@ -465,14 +471,19 @@ class SQLCompiler(engine.Compiled): s = s + OPERATORS[unary.modifier] return s - def visit_binary(self, binary, **kwargs): - + def visit_binary(self, binary, **kw): + # don't allow "? = ?" to render + if self.ansi_bind_rules and \ + isinstance(binary.left, sql._BindParamClause) and \ + isinstance(binary.right, sql._BindParamClause): + kw['literal_binds'] = True + return self._operator_dispatch(binary.operator, binary, - lambda opstr: self.process(binary.left, **kwargs) + + lambda opstr: self.process(binary.left, **kw) + opstr + - self.process(binary.right, **kwargs), - **kwargs + self.process(binary.right, **kw), + **kw ) def visit_like_op(self, binary, **kw): @@ -517,8 +528,10 @@ class SQLCompiler(engine.Compiled): literal_binds=False, **kwargs): if literal_binds or \ (within_columns_clause and \ - not self.binds_in_columns_clause) and \ - bindparam.value is not None: + self.ansi_bind_rules): + if bindparam.value is None: + raise exc.CompileError("Bind parameter without a " + "renderable value not allowed here.") return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs) name = self._truncate_bindparam(bindparam) @@ -545,9 +558,9 @@ class SQLCompiler(engine.Compiled): processor = bindparam.bind_processor(self.dialect) if processor: value = processor(value) - return self.render_literal_value(value) + return self.render_literal_value(value, bindparam.type) - def render_literal_value(self, value): + def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. This is used for statement sections that do not accept bind paramters @@ -557,7 +570,17 @@ class SQLCompiler(engine.Compiled): of the DBAPI. """ - raise NotImplementedError() + if isinstance(value, basestring): + value = value.replace("'", "''") + return "'%s'" % value + elif value is None: + return "NULL" + elif isinstance(value, (float, int, long)): + return repr(value) + elif isinstance(value, decimal.Decimal): + return str(value) + else: + raise NotImplementedError("Don't know how to literal-quote value %r" % value) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: |
