summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py24
-rw-r--r--lib/sqlalchemy/dialects/access/base.py4
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py40
-rw-r--r--lib/sqlalchemy/dialects/mssql/mxodbc.py40
-rw-r--r--lib/sqlalchemy/dialects/mssql/pyodbc.py28
-rw-r--r--lib/sqlalchemy/dialects/mysql/oursql.py2
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py4
-rw-r--r--lib/sqlalchemy/dialects/sybase/__init__.py2
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py57
-rw-r--r--lib/sqlalchemy/dialects/sybase/pyodbc.py13
-rw-r--r--lib/sqlalchemy/dialects/sybase/pysybase.py20
-rw-r--r--lib/sqlalchemy/engine/default.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py53
13 files changed, 174 insertions, 114 deletions
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index ce8e84c33..e503135f7 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -19,6 +19,10 @@ class PyODBCConnector(Connector):
# hold the desired driver name
pyodbc_driver_name = None
+ # will be set to True after initialize()
+ # if the freetds.so is detected
+ freetds = False
+
@classmethod
def dbapi(cls):
return __import__('pyodbc')
@@ -76,6 +80,26 @@ class PyODBCConnector(Connector):
else:
return False
+ def initialize(self, connection):
+ # determine FreeTDS first. can't issue SQL easily
+ # without getting unicode_statements/binds set up.
+
+ pyodbc = self.dbapi
+
+ dbapi_con = connection.connection
+
+ self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)))
+
+ # the "Py2K only" part here is theoretical.
+ # have not tried pyodbc + python3.1 yet.
+ # Py2K
+ self.supports_unicode_statements = not self.freetds
+ self.supports_unicode_binds = not self.freetds
+ # end Py2K
+
+ # run other initialization which asks for user name, etc.
+ super(PyODBCConnector, self).initialize(connection)
+
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py
index c10e77011..7dfb3153e 100644
--- a/lib/sqlalchemy/dialects/access/base.py
+++ b/lib/sqlalchemy/dialects/access/base.py
@@ -371,9 +371,9 @@ class AccessCompiler(compiler.SQLCompiler):
return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \
self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
- def visit_extract(self, extract):
+ def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
class AccessDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index eb4073b94..4d697854f 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -843,8 +843,9 @@ class MSExecutionContext(default.DefaultExecutionContext):
class MSSQLCompiler(compiler.SQLCompiler):
returning_precedes_values = True
- extract_map = compiler.SQLCompiler.extract_map.copy()
- extract_map.update ({
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {
'doy': 'dayofyear',
'dow': 'weekday',
'milliseconds': 'millisecond',
@@ -937,9 +938,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
kwargs['mssql_aliased'] = True
return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
- def visit_extract(self, extract):
+ def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
def visit_rollback_to_savepoint(self, savepoint_stmt):
return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
@@ -1020,6 +1021,37 @@ class MSSQLCompiler(compiler.SQLCompiler):
else:
return ""
+class MSSQLStrictCompiler(MSSQLCompiler):
+ """A subclass of MSSQLCompiler which disables the usage of bind
+ parameters where not allowed natively by MS-SQL.
+
+ A dialect may use this compiler on a platform where native
+ binds are used.
+
+ """
+ ansi_bind_rules = True
+
+ def visit_in_op(self, binary, **kw):
+ kw['literal_binds'] = True
+ return "%s IN %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)
+ )
+
+ def visit_notin_op(self, binary, **kw):
+ kw['literal_binds'] = True
+ return "%s NOT IN %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)
+ )
+
+ def visit_function(self, func, **kw):
+ kw['literal_binds'] = True
+ return super(MSSQLStrictCompiler, self).visit_function(func, **kw)
+
+ #def render_literal_value(self, value):
+ # TODO! use mxODBC's literal quoting services here
+
class MSDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py
index 3dcc78b8c..59cf65d63 100644
--- a/lib/sqlalchemy/dialects/mssql/mxodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py
@@ -4,43 +4,9 @@ import sys
from sqlalchemy import types as sqltypes
from sqlalchemy.connectors.mxodbc import MxODBCConnector
from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc
-from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSSQLCompiler
+from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, \
+ MSSQLCompiler, MSSQLStrictCompiler
-# TODO: does Pyodbc on windows have the same limitations ?
-# if so this compiler can be moved to a common "odbc.py" module
-# here
-# *or* - should we implement this for MS-SQL across the board
-# since its technically MS-SQL's behavior ?
-# perhaps yes, with a dialect flag "strict_binds" to turn it off
-class MSSQLCompiler_mxodbc(MSSQLCompiler):
- binds_in_columns_clause = False
-
- def visit_in_op(self, binary, **kw):
- kw['literal_binds'] = True
- return "%s IN %s" % (
- self.process(binary.left, **kw),
- self.process(binary.right, **kw)
- )
-
- def visit_notin_op(self, binary, **kw):
- kw['literal_binds'] = True
- return "%s NOT IN %s" % (
- self.process(binary.left, **kw),
- self.process(binary.right, **kw)
- )
-
- def visit_function(self, func, **kw):
- kw['literal_binds'] = True
- return super(MSSQLCompiler_mxodbc, self).visit_function(func, **kw)
-
- def render_literal_value(self, value):
- # TODO! use mxODBC's literal quoting services here
- if isinstance(value, basestring):
- value = value.replace("'", "''")
- return "'%s'" % value
- else:
- return repr(value)
-
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
"""
@@ -58,7 +24,7 @@ class MSDialect_mxodbc(MxODBCConnector, MSDialect):
# TODO: may want to use this only if FreeTDS is not in use,
# since FreeTDS doesn't seem to use native binds.
- statement_compiler = MSSQLCompiler_mxodbc
+ statement_compiler = MSSQLStrictCompiler
def __init__(self, description_encoding='latin-1', **params):
super(MSDialect_mxodbc, self).__init__(**params)
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
index 9ef065b1a..34050271f 100644
--- a/lib/sqlalchemy/dialects/mssql/pyodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py
@@ -1,3 +1,16 @@
+"""
+Support for MS-SQL via pyodbc.
+
+http://pypi.python.org/pypi/pyodbc/
+
+Connect strings are of the form::
+
+ mssql+pyodbc://<username>:<password>@<dsn>/
+ mssql+pyodbc://<username>:<password>@<host>/<database>
+
+
+"""
+
from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy import types as sqltypes
@@ -60,19 +73,4 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
self.description_encoding = description_encoding
self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
- def initialize(self, connection):
- super(MSDialect_pyodbc, self).initialize(connection)
- pyodbc = self.dbapi
-
- dbapi_con = connection.connection
-
- self.freetds = re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME))
-
- # the "Py2K only" part here is theoretical.
- # have not tried pyodbc + python3.1 yet.
- # Py2K
- self.supports_unicode_statements = not self.freetds
- self.supports_unicode_binds = not self.freetds
- # end Py2K
-
dialect = MSDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py
index 6c8bbcac4..605b39760 100644
--- a/lib/sqlalchemy/dialects/mysql/oursql.py
+++ b/lib/sqlalchemy/dialects/mysql/oursql.py
@@ -56,8 +56,6 @@ class MySQLDialect_oursql(MySQLDialect):
driver = 'oursql'
# Py3K
# description_encoding = None
-# supports_unicode_binds = False
-# supports_unicode_statements = False
# Py2K
supports_unicode_binds = True
supports_unicode_statements = True
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 37e63fbc1..98df8d0cb 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -218,10 +218,10 @@ class SQLiteCompiler(compiler.SQLCompiler):
else:
return self.process(cast.clause)
- def visit_extract(self, extract):
+ def visit_extract(self, extract, **kw):
try:
return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
- self.extract_map[extract.field], self.process(extract.expr))
+ self.extract_map[extract.field], self.process(extract.expr, **kw))
except KeyError:
raise exc.ArgumentError(
"%s is not a valid extract argument." % extract.field)
diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py
index 573aedde3..4d9b07007 100644
--- a/lib/sqlalchemy/dialects/sybase/__init__.py
+++ b/lib/sqlalchemy/dialects/sybase/__init__.py
@@ -1,4 +1,4 @@
-from sqlalchemy.dialects.sybase import base, pysybase
+from sqlalchemy.dialects.sybase import base, pysybase, pyodbc
from base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
index 2addba2f8..c440015d0 100644
--- a/lib/sqlalchemy/dialects/sybase/base.py
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -176,7 +176,19 @@ ischema_names = {
class SybaseExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False
-
+
+ def set_ddl_autocommit(self, connection, value):
+ """Must be implemented by subclasses to accommodate DDL executions.
+
+ "connection" is the raw unwrapped DBAPI connection. "value"
+ is True or False. when True, the connection should be configured
+ such that a DDL can take place subsequently. when False,
+ a DDL has taken place and the connection should be resumed
+ into non-autocommit mode.
+
+ """
+ raise NotImplementedError()
+
def pre_exec(self):
if self.isinsert:
tbl = self.compiled.statement.table
@@ -192,7 +204,22 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
self.cursor.execute("SET IDENTITY_INSERT %s ON" %
self.dialect.identifier_preparer.format_table(tbl))
+ if self.isddl:
+ # TODO: to enhance this, we can detect "ddl in tran" on the
+ # database settings. this error message should be improved to
+ # include a note about that.
+ if not self.should_autocommit:
+ raise exc.InvalidRequestError("The Sybase dialect only supports "
+ "DDL in 'autocommit' mode at this time.")
+
+ self.root_connection.engine.logger.info("AUTOCOMMIT (Assuming no Sybase 'ddl in tran')")
+
+ self.set_ddl_autocommit(self.root_connection.connection.connection, True)
+
+
def post_exec(self):
+ if self.isddl:
+ self.set_ddl_autocommit(self.root_connection, False)
if self._enable_identity_insert:
self.cursor.execute(
@@ -209,9 +236,11 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
return lastrowid
class SybaseSQLCompiler(compiler.SQLCompiler):
+ ansi_bind_rules = True
- extract_map = compiler.SQLCompiler.extract_map.copy()
- extract_map.update ({
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {
'doy': 'dayofyear',
'dow': 'weekday',
'milliseconds': 'millisecond'
@@ -240,32 +269,16 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
# Limit in sybase is after the select keyword
return ""
- def dont_visit_binary(self, binary):
- """Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
- return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
- else:
- return super(SybaseSQLCompiler, self).visit_binary(binary)
-
- def dont_label_select_column(self, select, column, asfrom):
- if isinstance(column, expression.Function):
- return column.label(None)
- else:
- return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
-
-# def visit_getdate_func(self, fn, **kw):
- # TODO: need to cast? something ?
-# pass
-
- def visit_extract(self, extract):
+ def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
return ''
def order_by_clause(self, select, **kw):
+ kw['literal_binds'] = True
order_by = self.process(select._order_by_clause, **kw)
# SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py
index 642ae3219..1bfdb6151 100644
--- a/lib/sqlalchemy/dialects/sybase/pyodbc.py
+++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py
@@ -1,7 +1,12 @@
"""
Support for Sybase via pyodbc.
-This dialect is a stub only and is likely non functional at this time.
+http://pypi.python.org/pypi/pyodbc/
+
+Connect strings are of the form::
+
+ sybase+pyodbc://<username>:<password>@<dsn>/
+ sybase+pyodbc://<username>:<password>@<host>/<database>
"""
@@ -10,7 +15,11 @@ from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContex
from sqlalchemy.connectors.pyodbc import PyODBCConnector
class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
- pass
+ def set_ddl_autocommit(self, connection, value):
+ if value:
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py
index 195407384..200ce11a2 100644
--- a/lib/sqlalchemy/dialects/sybase/pysybase.py
+++ b/lib/sqlalchemy/dialects/sybase/pysybase.py
@@ -20,6 +20,14 @@ from sqlalchemy.dialects.sybase.base import SybaseDialect, \
class SybaseExecutionContext_pysybase(SybaseExecutionContext):
+
+ def set_ddl_autocommit(self, dbapi_connection, value):
+ if value:
+ # call commit() on the Sybase connection directly,
+ # to avoid any side effects of calling a Connection
+ # transactional method inside of pre_exec()
+ dbapi_connection.commit()
+
def pre_exec(self):
SybaseExecutionContext.pre_exec(self)
@@ -28,18 +36,6 @@ class SybaseExecutionContext_pysybase(SybaseExecutionContext):
param["@" + key] = param[key]
del param[key]
- if self.isddl:
- # TODO: to enhance this, we can detect "ddl in tran" on the
- # database settings. this error message should be improved to
- # include a note about that.
- if not self.should_autocommit:
- raise exc.InvalidRequestError("The Sybase dialect only supports "
- "DDL in 'autocommit' mode at this time.")
- # call commit() on the Sybase connection directly,
- # to avoid any side effects of calling a Connection
- # transactional method inside of pre_exec()
- self.root_connection.engine.logger.info("COMMIT (Assuming no Sybase 'ddl in tran')")
- self.root_connection.connection.commit()
class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
def bindparam_string(self, name):
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index ce24a9ae4..2ef8fd104 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -62,6 +62,7 @@ class DefaultDialect(base.Dialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
dbapi_type_map = {}
+ colspecs = {}
default_paramstyle = 'named'
supports_default_values = False
supports_empty_insert = True
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index a3008d085..4e9175ae8 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -26,6 +26,7 @@ import re
from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions, util as sql_util, visitors
from sqlalchemy.sql import expression as sql
+import decimal
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -184,10 +185,11 @@ class SQLCompiler(engine.Compiled):
returning_precedes_values = False
# SQL 92 doesn't allow bind parameters to be used
- # in the columns clause of a SELECT. A compiler
+ # in the columns clause of a SELECT, nor does it allow
+ # ambiguous expressions like "? = ?". A compiler
# subclass can set this flag to False if the target
# driver/DB enforces this
- binds_in_columns_clause = True
+ ansi_bind_rules = False
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -303,7 +305,7 @@ class SQLCompiler(engine.Compiled):
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + self.process(grouping.element, **kwargs) + ")"
- def visit_label(self, label, result_map=None, within_columns_clause=False):
+ def visit_label(self, label, result_map=None, within_columns_clause=False, **kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
@@ -315,11 +317,15 @@ class SQLCompiler(engine.Compiled):
result_map[labelname.lower()] = \
(label.name, (label, label.element, labelname), label.element.type)
- return self.process(label.element) + \
+ return self.process(label.element,
+ within_columns_clause=within_columns_clause,
+ **kw) + \
OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
- return self.process(label.element)
+ return self.process(label.element,
+ within_columns_clause=within_columns_clause,
+ **kw)
def visit_column(self, column, result_map=None, **kwargs):
name = column.name
@@ -465,14 +471,19 @@ class SQLCompiler(engine.Compiled):
s = s + OPERATORS[unary.modifier]
return s
- def visit_binary(self, binary, **kwargs):
-
+ def visit_binary(self, binary, **kw):
+ # don't allow "? = ?" to render
+ if self.ansi_bind_rules and \
+ isinstance(binary.left, sql._BindParamClause) and \
+ isinstance(binary.right, sql._BindParamClause):
+ kw['literal_binds'] = True
+
return self._operator_dispatch(binary.operator,
binary,
- lambda opstr: self.process(binary.left, **kwargs) +
+ lambda opstr: self.process(binary.left, **kw) +
opstr +
- self.process(binary.right, **kwargs),
- **kwargs
+ self.process(binary.right, **kw),
+ **kw
)
def visit_like_op(self, binary, **kw):
@@ -517,8 +528,10 @@ class SQLCompiler(engine.Compiled):
literal_binds=False, **kwargs):
if literal_binds or \
(within_columns_clause and \
- not self.binds_in_columns_clause) and \
- bindparam.value is not None:
+ self.ansi_bind_rules):
+ if bindparam.value is None:
+ raise exc.CompileError("Bind parameter without a "
+ "renderable value not allowed here.")
return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
name = self._truncate_bindparam(bindparam)
@@ -545,9 +558,9 @@ class SQLCompiler(engine.Compiled):
processor = bindparam.bind_processor(self.dialect)
if processor:
value = processor(value)
- return self.render_literal_value(value)
+ return self.render_literal_value(value, bindparam.type)
- def render_literal_value(self, value):
+ def render_literal_value(self, value, type_):
"""Render the value of a bind parameter as a quoted literal.
This is used for statement sections that do not accept bind paramters
@@ -557,7 +570,17 @@ class SQLCompiler(engine.Compiled):
of the DBAPI.
"""
- raise NotImplementedError()
+ if isinstance(value, basestring):
+ value = value.replace("'", "''")
+ return "'%s'" % value
+ elif value is None:
+ return "NULL"
+ elif isinstance(value, (float, int, long)):
+ return repr(value)
+ elif isinstance(value, decimal.Decimal):
+ return str(value)
+ else:
+ raise NotImplementedError("Don't know how to literal-quote value %r" % value)
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names: