summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-04-02 21:36:11 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-04-02 21:36:11 +0000
commitcdceb3c3714af707bfe3ede10af6536eaf529ca8 (patch)
tree2ccbfb60cd10d995c0309801b0adc4fc3a1f0a44 /lib/sqlalchemy
parent8607de3159fd37923ae99118c499935c4a54d0e2 (diff)
downloadsqlalchemy-cdceb3c3714af707bfe3ede10af6536eaf529ca8.tar.gz
- merged the "execcontext" branch, refactors engine/dialect codepaths
- much more functionality moved into ExecutionContext, which impacted the API used by dialects to some degree - ResultProxy and subclasses now designed sanely - merged patch for #522, Unicode subclasses String directly, MSNVarchar implements for MS-SQL, removed MSUnicode. - String moves its "VARCHAR"/"TEXT" switchy thing into "get_search_list()" function, which VARCHAR and CHAR can override to not return TEXT in any case (didnt do the latter yet) - implements server side cursors for postgres, unit tests, #514 - includes overhaul of dbapi import strategy #480, all dbapi importing happens in dialect method "dbapi()", is only called inside of create_engine() for default and threadlocal strategies. Dialect subclasses have a datamember "dbapi" referencing the loaded module which may be None. - added "mock" engine strategy, doesnt require DBAPI module and gives you a "Connecition" which just sends all executes to a callable. can be used to create string output of create_all()/drop_all().
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ansisql.py48
-rw-r--r--lib/sqlalchemy/databases/firebird.py73
-rw-r--r--lib/sqlalchemy/databases/mssql.py159
-rw-r--r--lib/sqlalchemy/databases/mysql.py50
-rw-r--r--lib/sqlalchemy/databases/oracle.py75
-rw-r--r--lib/sqlalchemy/databases/postgres.py138
-rw-r--r--lib/sqlalchemy/databases/sqlite.py60
-rw-r--r--lib/sqlalchemy/engine/base.py518
-rw-r--r--lib/sqlalchemy/engine/default.py126
-rw-r--r--lib/sqlalchemy/engine/strategies.py64
-rw-r--r--lib/sqlalchemy/engine/url.py4
-rw-r--r--lib/sqlalchemy/logging.py4
-rw-r--r--lib/sqlalchemy/pool.py14
-rw-r--r--lib/sqlalchemy/sql.py2
-rw-r--r--lib/sqlalchemy/types.py82
-rw-r--r--lib/sqlalchemy/util.py4
16 files changed, 709 insertions, 712 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index a75263d91..03053b998 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -49,14 +49,11 @@ class ANSIDialect(default.DefaultDialect):
def create_connect_args(self):
return ([],{})
- def dbapi(self):
- return None
+ def schemagenerator(self, *args, **kwargs):
+ return ANSISchemaGenerator(self, *args, **kwargs)
- def schemagenerator(self, *args, **params):
- return ANSISchemaGenerator(*args, **params)
-
- def schemadropper(self, *args, **params):
- return ANSISchemaDropper(*args, **params)
+ def schemadropper(self, *args, **kwargs):
+ return ANSISchemaDropper(self, *args, **kwargs)
def compiler(self, statement, parameters, **kwargs):
return ANSICompiler(self, statement, parameters, **kwargs)
@@ -97,6 +94,9 @@ class ANSICompiler(sql.Compiled):
sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+ # if we are insert/update. set to true when we visit an INSERT or UPDATE
+ self.isinsert = self.isupdate = False
+
# a dictionary of bind parameter keys to _BindParamClause instances.
self.binds = {}
@@ -789,13 +789,12 @@ class ANSISchemaBase(engine.SchemaIterator):
return alterables
class ANSISchemaGenerator(ANSISchemaBase):
- def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(ANSISchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables and util.Set(tables) or None
- self.connection = connection
- self.preparer = self.engine.dialect.preparer()
- self.dialect = self.engine.dialect
+ self.preparer = dialect.preparer()
+ self.dialect = dialect
def get_column_specification(self, column, first_pk=False):
raise NotImplementedError()
@@ -804,7 +803,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
table.accept_visitor(self)
- if self.supports_alter():
+ if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
@@ -857,7 +856,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def _compile(self, tocompile, parameters):
"""compile the given string/parameters using this SchemaGenerator's dialect."""
- compiler = self.engine.dialect.compiler(tocompile, parameters)
+ compiler = self.dialect.compiler(tocompile, parameters)
compiler.compile()
return compiler
@@ -880,11 +879,8 @@ class ANSISchemaGenerator(ANSISchemaBase):
self.append("PRIMARY KEY ")
self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
- def supports_alter(self):
- return True
-
def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and self.supports_alter():
+ if constraint.use_alter and self.dialect.supports_alter():
return
self.append(", \n\t ")
self.define_foreign_key(constraint)
@@ -927,25 +923,21 @@ class ANSISchemaGenerator(ANSISchemaBase):
self.execute()
class ANSISchemaDropper(ANSISchemaBase):
- def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
- super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(ANSISchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
- self.connection = connection
- self.preparer = self.engine.dialect.preparer()
- self.dialect = self.engine.dialect
+ self.preparer = dialect.preparer()
+ self.dialect = dialect
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))]
- if self.supports_alter():
+ if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.drop_foreignkey(alterable)
for table in collection:
table.accept_visitor(self)
- def supports_alter(self):
- return True
-
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
self.execute()
@@ -1099,3 +1091,5 @@ class ANSIIdentifierPreparer(object):
"""Prepare a quoted column name with table name."""
return self.format_column(column, use_table=True, name=column_name)
+
+dialect = ANSIDialect
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index 91a0869c6..2ab88101a 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -15,12 +15,9 @@ import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
-try:
+def dbapi():
import kinterbasdb
-except:
- kinterbasdb = None
-
-dbmodule = kinterbasdb
+ return kinterbasdb
_initialized_kb = False
@@ -33,7 +30,6 @@ class FBNumeric(sqltypes.Numeric):
return "NUMERIC(%(precision)s, %(length)s)" % { 'precision': self.precision,
'length' : self.length }
-
class FBInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
@@ -111,24 +107,11 @@ class FBExecutionContext(default.DefaultExecutionContext):
class FBDialect(ansisql.ANSIDialect):
- def __init__(self, module = None, **params):
- global _initialized_kb
- self.module = module or dbmodule
- self.opts = {}
-
- if not _initialized_kb:
- _initialized_kb = True
- type_conv = params.get('type_conv', 200) or 200
- if isinstance(type_conv, types.StringTypes):
- type_conv = int(type_conv)
-
- concurrency_level = params.get('concurrency_level', 1) or 1
- if isinstance(concurrency_level, types.StringTypes):
- concurrency_level = int(concurrency_level)
+ def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
+ ansisql.ANSIDialect.__init__(self, **kwargs)
- if kinterbasdb is not None:
- kinterbasdb.init(type_conv=type_conv, concurrency_level=concurrency_level)
- ansisql.ANSIDialect.__init__(self, **params)
+ self.type_conv = type_conv
+ self.concurrency_level= concurrency_level
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
@@ -136,15 +119,17 @@ class FBDialect(ansisql.ANSIDialect):
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
- # pop arguments that we took at the module level
- opts.pop('type_conv', None)
- opts.pop('concurrency_level', None)
- self.opts = opts
- return ([], self.opts)
+ type_conv = opts.pop('type_conv', self.type_conv)
+ concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
+ global _initialized_kb
+ if not _initialized_kb and self.dbapi is not None:
+ _initialized_kb = True
+ self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
+ return ([], opts)
- def create_execution_context(self):
- return FBExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return FBExecutionContext(self, *args, **kwargs)
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
@@ -156,13 +141,13 @@ class FBDialect(ansisql.ANSIDialect):
return FBCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return FBSchemaGenerator(*args, **kwargs)
+ return FBSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return FBSchemaDropper(*args, **kwargs)
+ return FBSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return FBDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection):
+ return FBDefaultRunner(connection)
def preparer(self):
return FBIdentifierPreparer(self)
@@ -292,9 +277,6 @@ class FBDialect(ansisql.ANSIDialect):
for name,value in fks.iteritems():
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
- def last_inserted_ids(self):
- return self.context.last_inserted_ids
-
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters or [])
@@ -304,15 +286,6 @@ class FBDialect(ansisql.ANSIDialect):
def do_commit(self, connection):
connection.commit(True)
- def connection(self):
- """Returns a managed DBAPI connection from this SQLEngine's connection pool."""
- c = self._pool.connect()
- c.supportsTransactions = 0
- return c
-
- def dbapi(self):
- return self.module
-
class FBCompiler(ansisql.ANSICompiler):
"""Firebird specific idiosincrasies"""
@@ -364,7 +337,7 @@ class FBCompiler(ansisql.ANSICompiler):
class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
- colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
@@ -388,11 +361,11 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper):
class FBDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["rdb$database"], engine=self.engine).compile()
- return self.proxy(str(c), c.get_params()).fetchone()[0]
+ c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.engine)
+ return self.connection.execute_compiled(c).scalar()
def visit_sequence(self, seq):
- return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0]
+ return self.connection.execute_text("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").scalar()
RESERVED_WORDS = util.Set(
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 1852edefb..6d2ff66cd 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -52,7 +52,22 @@ import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
-
+def dbapi(module_name=None):
+ if module_name:
+ try:
+ dialect_cls = dialect_mapping[module_name]
+ return dialect_cls.import_dbapi()
+ except KeyError:
+ raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+ else:
+ for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]:
+ try:
+ return dialect_cls.import_dbapi()
+ except ImportError, e:
+ pass
+ else:
+ raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
+
class MSNumeric(sqltypes.Numeric):
def convert_result_value(self, value, dialect):
return value
@@ -142,9 +157,6 @@ class MSString(sqltypes.String):
return "VARCHAR(%(length)s)" % {'length' : self.length}
class MSNVarchar(MSString):
- """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. """
- impl = sqltypes.Unicode
-
def get_col_spec(self):
if self.length:
return "NVARCHAR(%(length)s)" % {'length' : self.length}
@@ -154,19 +166,7 @@ class MSNVarchar(MSString):
return "NTEXT"
class AdoMSNVarchar(MSNVarchar):
- def convert_bind_param(self, value, dialect):
- return value
-
- def convert_result_value(self, value, dialect):
- return value
-
-class MSUnicode(sqltypes.Unicode):
- """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl."""
- impl = MSNVarchar
-
-class AdoMSUnicode(MSUnicode):
- impl = AdoMSNVarchar
-
+ """overrides bindparam/result processing to not convert any unicode strings"""
def convert_bind_param(self, value, dialect):
return value
@@ -215,9 +215,9 @@ def descriptor():
]}
class MSSQLExecutionContext(default.DefaultExecutionContext):
- def __init__(self, dialect):
+ def __init__(self, *args, **kwargs):
self.IINSERT = self.HASIDENT = False
- super(MSSQLExecutionContext, self).__init__(dialect)
+ super(MSSQLExecutionContext, self).__init__(*args, **kwargs)
def _has_implicit_sequence(self, column):
if column.primary_key and column.autoincrement:
@@ -227,14 +227,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
return True
return False
- def pre_exec(self, engine, proxy, compiled, parameters, **kwargs):
+ def pre_exec(self):
"""MS-SQL has a special mode for inserting non-NULL values
into IDENTITY columns.
Activate it if the feature is turned on and needed.
"""
- if getattr(compiled, "isinsert", False):
- tbl = compiled.statement.table
+ if self.compiled.isinsert:
+ tbl = self.compiled.statement.table
if not hasattr(tbl, 'has_sequence'):
tbl.has_sequence = None
for column in tbl.c:
@@ -243,39 +243,43 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
break
self.HASIDENT = bool(tbl.has_sequence)
- if engine.dialect.auto_identity_insert and self.HASIDENT:
- if isinstance(parameters, list):
- self.IINSERT = tbl.has_sequence.key in parameters[0]
+ if self.dialect.auto_identity_insert and self.HASIDENT:
+ if isinstance(self.compiled_parameters, list):
+ self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0]
else:
- self.IINSERT = tbl.has_sequence.key in parameters
+ self.IINSERT = tbl.has_sequence.key in self.compiled_parameters
else:
self.IINSERT = False
if self.IINSERT:
- proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
+ # TODO: quoting rules for table name here ?
+ self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name)
- super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs)
+ super(MSSQLExecutionContext, self).pre_exec()
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
+ def post_exec(self):
"""Turn off the INDENTITY_INSERT mode if it's been activated,
and fetch recently inserted IDENTIFY values (works only for
one column).
"""
- if getattr(compiled, "isinsert", False):
+ if self.compiled.isinsert:
if self.IINSERT:
- proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name)
+ # TODO: quoting rules for table name here ?
+ self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name)
self.IINSERT = False
elif self.HASIDENT:
- cursor = proxy("SELECT @@IDENTITY AS lastrowid")
- row = cursor.fetchone()
+ self.cursor.execute("SELECT @@IDENTITY AS lastrowid")
+ row = self.cursor.fetchone()
self._last_inserted_ids = [int(row[0])]
# print "LAST ROW ID", self._last_inserted_ids
self.HASIDENT = False
+ super(MSSQLExecutionContext, self).post_exec()
class MSSQLDialect(ansisql.ANSIDialect):
colspecs = {
+ sqltypes.Unicode : MSNVarchar,
sqltypes.Integer : MSInteger,
sqltypes.Smallinteger: MSSmallInteger,
sqltypes.Numeric : MSNumeric,
@@ -283,7 +287,6 @@ class MSSQLDialect(ansisql.ANSIDialect):
sqltypes.DateTime : MSDateTime,
sqltypes.Date : MSDate,
sqltypes.String : MSString,
- sqltypes.Unicode : MSUnicode,
sqltypes.Binary : MSBinary,
sqltypes.Boolean : MSBoolean,
sqltypes.TEXT : MSText,
@@ -296,7 +299,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
'smallint' : MSSmallInteger,
'tinyint' : MSTinyInteger,
'varchar' : MSString,
- 'nvarchar' : MSUnicode,
+ 'nvarchar' : MSNVarchar,
'char' : MSChar,
'nchar' : MSNChar,
'text' : MSText,
@@ -312,30 +315,16 @@ class MSSQLDialect(ansisql.ANSIDialect):
'image' : MSBinary
}
- def __new__(cls, module_name=None, *args, **kwargs):
- module = kwargs.get('module', None)
+ def __new__(cls, dbapi=None, *args, **kwargs):
if cls != MSSQLDialect:
return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
- if module_name:
- dialect = dialect_mapping.get(module_name)
- if not dialect:
- raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name)
- if not hasattr(dialect, 'module'):
- raise dialect.saved_import_error
+ if dbapi:
+ dialect = dialect_mapping.get(dbapi.__name__)
return dialect(*args, **kwargs)
- elif module:
- return object.__new__(cls, *args, **kwargs)
else:
- for dialect in dialect_preference:
- if hasattr(dialect, 'module'):
- return dialect(*args, **kwargs)
- #raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
- else:
- return object.__new__(cls, *args, **kwargs)
+ return object.__new__(cls, *args, **kwargs)
- def __init__(self, module_name=None, module=None, auto_identity_insert=True, **params):
- if not hasattr(self, 'module'):
- self.module = module
+ def __init__(self, auto_identity_insert=True, **params):
super(MSSQLDialect, self).__init__(**params)
self.auto_identity_insert = auto_identity_insert
self.text_as_varchar = False
@@ -352,8 +341,8 @@ class MSSQLDialect(ansisql.ANSIDialect):
self.text_as_varchar = bool(opts.pop('text_as_varchar'))
return self.make_connect_string(opts)
- def create_execution_context(self):
- return MSSQLExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return MSSQLExecutionContext(self, *args, **kwargs)
def type_descriptor(self, typeobj):
newobj = sqltypes.adapt_type(typeobj, self.colspecs)
@@ -373,13 +362,13 @@ class MSSQLDialect(ansisql.ANSIDialect):
return MSSQLCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return MSSQLSchemaGenerator(*args, **kwargs)
+ return MSSQLSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return MSSQLSchemaDropper(*args, **kwargs)
+ return MSSQLSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return MSSQLDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection, **kwargs):
+ return MSSQLDefaultRunner(connection, **kwargs)
def preparer(self):
return MSSQLIdentifierPreparer(self)
@@ -411,19 +400,12 @@ class MSSQLDialect(ansisql.ANSIDialect):
def raw_connection(self, connection):
"""Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
try:
+ # TODO: probably want to move this to individual dialect subclasses to
+ # save on the exception throw + simplify
return connection.connection.__dict__['_pymssqlCnx__cnx']
except:
return connection.connection.adoConn
- def connection(self):
- """returns a managed DBAPI connection from this SQLEngine's connection pool."""
- c = self._pool.connect()
- c.supportsTransactions = 0
- return c
-
- def dbapi(self):
- return self.module
-
def uppercase_table(self, t):
# convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
t.name = t.name.upper()
@@ -558,13 +540,14 @@ class MSSQLDialect(ansisql.ANSIDialect):
table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
class MSSQLDialect_pymssql(MSSQLDialect):
- try:
+ def import_dbapi(cls):
import pymssql as module
# pymmsql doesn't have a Binary method. we use string
+ # TODO: monkeypatching here is less than ideal
module.Binary = lambda st: str(st)
- except ImportError, e:
- saved_import_error = e
-
+ return module
+ import_dbapi = classmethod(import_dbapi)
+
def supports_sane_rowcount(self):
return True
@@ -578,7 +561,7 @@ class MSSQLDialect_pymssql(MSSQLDialect):
def create_connect_args(self, url):
r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
if hasattr(self, 'query_timeout'):
- self.module._mssql.set_query_timeout(self.query_timeout)
+ self.dbapi._mssql.set_query_timeout(self.query_timeout)
return r
def make_connect_string(self, keys):
@@ -621,15 +604,16 @@ class MSSQLDialect_pymssql(MSSQLDialect):
## r.fetch_array()
class MSSQLDialect_pyodbc(MSSQLDialect):
- try:
+
+ def import_dbapi(cls):
import pyodbc as module
- except ImportError, e:
- saved_import_error = e
-
+ return module
+ import_dbapi = classmethod(import_dbapi)
+
colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.Unicode] = AdoMSUnicode
+ colspecs[sqltypes.Unicode] = AdoMSNVarchar
ischema_names = MSSQLDialect.ischema_names.copy()
- ischema_names['nvarchar'] = AdoMSUnicode
+ ischema_names['nvarchar'] = AdoMSNVarchar
def supports_sane_rowcount(self):
return False
@@ -648,15 +632,15 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
class MSSQLDialect_adodbapi(MSSQLDialect):
- try:
+ def import_dbapi(cls):
import adodbapi as module
- except ImportError, e:
- saved_import_error = e
+ return module
+ import_dbapi = classmethod(import_dbapi)
colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.Unicode] = AdoMSUnicode
+ colspecs[sqltypes.Unicode] = AdoMSNVarchar
ischema_names = MSSQLDialect.ischema_names.copy()
- ischema_names['nvarchar'] = AdoMSUnicode
+ ischema_names['nvarchar'] = AdoMSNVarchar
def supports_sane_rowcount(self):
return True
@@ -676,13 +660,11 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
connectors.append("Integrated Security=SSPI")
return [[";".join (connectors)], {}]
-
dialect_mapping = {
'pymssql': MSSQLDialect_pymssql,
'pyodbc': MSSQLDialect_pyodbc,
'adodbapi': MSSQLDialect_adodbapi
}
-dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]
class MSSQLCompiler(ansisql.ANSICompiler):
@@ -770,7 +752,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
@@ -797,6 +779,7 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
self.execute()
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
+ # TODO: does ms-sql have standalone sequences ?
pass
class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index 5fc63234a..65ccb6af1 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -12,12 +12,9 @@ import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
from array import array
-try:
+def dbapi():
import MySQLdb as mysql
- import MySQLdb.constants.CLIENT as CLIENT_FLAGS
-except:
- mysql = None
- CLIENT_FLAGS = None
+ return mysql
def kw_colspec(self, spec):
if self.unsigned:
@@ -158,8 +155,6 @@ class MSLongText(MSText):
return "LONGTEXT"
class MSString(sqltypes.String):
- def __init__(self, length=None, *extra, **kwargs):
- sqltypes.String.__init__(self, length=length)
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
@@ -277,16 +272,12 @@ def descriptor():
]}
class MySQLExecutionContext(default.DefaultExecutionContext):
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- if getattr(compiled, "isinsert", False):
- self._last_inserted_ids = [proxy().lastrowid]
+ def post_exec(self):
+ if self.compiled.isinsert:
+ self._last_inserted_ids = [self.cursor.lastrowid]
class MySQLDialect(ansisql.ANSIDialect):
- def __init__(self, module = None, **kwargs):
- if module is None:
- self.module = mysql
- else:
- self.module = module
+ def __init__(self, **kwargs):
ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
def create_connect_args(self, url):
@@ -305,14 +296,18 @@ class MySQLDialect(ansisql.ANSIDialect):
# TODO: what about options like "ssl", "cursorclass" and "conv" ?
client_flag = opts.get('client_flag', 0)
- if CLIENT_FLAGS is not None:
- client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ if self.dbapi is not None:
+ try:
+ import MySQLdb.constants.CLIENT as CLIENT_FLAGS
+ client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ except:
+ pass
opts['client_flag'] = client_flag
return [[], opts]
- def create_execution_context(self):
- return MySQLExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return MySQLExecutionContext(self, *args, **kwargs)
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
@@ -324,10 +319,10 @@ class MySQLDialect(ansisql.ANSIDialect):
return MySQLCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return MySQLSchemaGenerator(*args, **kwargs)
+ return MySQLSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return MySQLSchemaDropper(*args, **kwargs)
+ return MySQLSchemaDropper(self, *args, **kwargs)
def preparer(self):
return MySQLIdentifierPreparer(self)
@@ -337,14 +332,14 @@ class MySQLDialect(ansisql.ANSIDialect):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
- except mysql.OperationalError, o:
+ except self.dbapi.OperationalError, o:
if o.args[0] == 2006 or o.args[0] == 2014:
cursor.invalidate()
raise o
def do_execute(self, cursor, statement, parameters, **kwargs):
try:
cursor.execute(statement, parameters)
- except mysql.OperationalError, o:
+ except self.dbapi.OperationalError, o:
if o.args[0] == 2006 or o.args[0] == 2014:
cursor.invalidate()
raise o
@@ -361,11 +356,9 @@ class MySQLDialect(ansisql.ANSIDialect):
self._default_schema_name = text("select database()", self).scalar()
return self._default_schema_name
- def dbapi(self):
- return self.module
-
def has_table(self, connection, table_name, schema=None):
- cursor = connection.execute("show table status like '" + table_name + "'")
+ cursor = connection.execute("show table status like %s", [table_name])
+ print "CURSOR", cursor, "ROWCOUNT", cursor.rowcount, "REAL RC", cursor.cursor.rowcount
return bool( not not cursor.rowcount )
def reflecttable(self, connection, table):
@@ -492,8 +485,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
- t = column.type.engine_impl(self.engine)
- colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index adea127bf..5377759a2 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -8,15 +8,13 @@
import sys, StringIO, string, re
from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
-import sqlalchemy.engine.default as default
+from sqlalchemy.engine import default, base
import sqlalchemy.types as sqltypes
-try:
+def dbapi():
import cx_Oracle
-except:
- cx_Oracle = None
+ return cx_Oracle
-ORACLE_BINARY_TYPES = [getattr(cx_Oracle, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(cx_Oracle, k)]
class OracleNumeric(sqltypes.Numeric):
def get_col_spec(self):
@@ -149,26 +147,32 @@ def descriptor():
]}
class OracleExecutionContext(default.DefaultExecutionContext):
- def pre_exec(self, engine, proxy, compiled, parameters):
- super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters)
+ def pre_exec(self):
+ super(OracleExecutionContext, self).pre_exec()
if self.dialect.auto_setinputsizes:
- self.set_input_sizes(proxy(), parameters)
+ self.set_input_sizes()
+
+ def get_result_proxy(self):
+ if self.cursor.description is not None:
+ for column in self.cursor.description:
+ type_code = column[1]
+ if type_code in self.dialect.ORACLE_BINARY_TYPES:
+ return base.BufferedColumnResultProxy(self)
+
+ return base.ResultProxy(self)
class OracleDialect(ansisql.ANSIDialect):
- def __init__(self, use_ansi=True, auto_setinputsizes=True, module=None, threaded=True, **kwargs):
+ def __init__(self, use_ansi=True, auto_setinputsizes=True, threaded=True, **kwargs):
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs)
self.use_ansi = use_ansi
self.threaded = threaded
- if module is None:
- self.module = cx_Oracle
- else:
- self.module = module
- self.supports_timestamp = hasattr(self.module, 'TIMESTAMP' )
+ self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
- ansisql.ANSIDialect.__init__(self, **kwargs)
-
- def dbapi(self):
- return self.module
-
+ if self.dbapi is not None:
+ self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
+ else:
+ self.ORACLE_BINARY_TYPES = []
+
def create_connect_args(self, url):
if url.database:
# if we have a database, then we have a remote host
@@ -177,7 +181,7 @@ class OracleDialect(ansisql.ANSIDialect):
port = int(port)
else:
port = 1521
- dsn = self.module.makedsn(url.host,port,url.database)
+ dsn = self.dbapi.makedsn(url.host,port,url.database)
else:
# we have a local tnsname
dsn = url.host
@@ -206,20 +210,20 @@ class OracleDialect(ansisql.ANSIDialect):
else:
return "rowid"
- def create_execution_context(self):
- return OracleExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return OracleExecutionContext(self, *args, **kwargs)
def compiler(self, statement, bindparams, **kwargs):
return OracleCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return OracleSchemaGenerator(*args, **kwargs)
+ return OracleSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return OracleSchemaDropper(*args, **kwargs)
+ return OracleSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return OracleDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection, **kwargs):
+ return OracleDefaultRunner(connection, **kwargs)
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()})
@@ -405,15 +409,6 @@ class OracleDialect(ansisql.ANSIDialect):
if context is not None:
context._rowcount = rowcount
- def create_result_proxy_args(self, connection, cursor):
- args = super(OracleDialect, self).create_result_proxy_args(connection, cursor)
- if cursor and cursor.description:
- for column in cursor.description:
- type_code = column[1]
- if type_code in ORACLE_BINARY_TYPES:
- args['should_prefetch'] = True
- break
- return args
OracleDialect.logger = logging.class_logger(OracleDialect)
@@ -569,7 +564,7 @@ class OracleCompiler(ansisql.ANSICompiler):
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
- colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -579,22 +574,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not self.engine.dialect.has_sequence(self.connection, sequence.name):
+ if not self.dialect.has_sequence(self.connection, sequence.name):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class OracleSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if self.engine.dialect.has_sequence(self.connection, sequence.name):
+ if self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile()
- return self.proxy(str(c), c.get_params()).fetchone()[0]
+ return self.connection.execute_compiled(c).scalar()
def visit_sequence(self, seq):
- return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0]
+ return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
dialect = OracleDialect
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index d83607793..2943d163e 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,33 +4,28 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, sys, StringIO, string, types, re
-
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+import datetime, string, types, re, random
+
+from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy.engine import base, default
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
from sqlalchemy.databases import information_schema as ischema
-import re
try:
import mx.DateTime.DateTime as mxDateTime
except:
mxDateTime = None
-try:
- import psycopg2 as psycopg
- #import psycopg2.psycopg1 as psycopg
-except:
+def dbapi():
try:
- import psycopg
- except:
- psycopg = None
-
+ import psycopg2 as psycopg
+ except ImportError, e:
+ try:
+ import psycopg
+ except ImportError, e2:
+ raise e
+ return psycopg
+
class PGInet(sqltypes.TypeEngine):
def get_col_spec(self):
return "INET"
@@ -74,8 +69,8 @@ class PG1DateTime(sqltypes.DateTime):
mx_datetime = mxDateTime(value.year, value.month, value.day,
value.hour, value.minute,
seconds)
- return psycopg.TimestampFromMx(mx_datetime)
- return psycopg.TimestampFromMx(value)
+ return dialect.dbapi.TimestampFromMx(mx_datetime)
+ return dialect.dbapi.TimestampFromMx(value)
else:
return None
@@ -101,7 +96,7 @@ class PG1Date(sqltypes.Date):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
# this one doesnt seem to work with the "emulation" mode
if value is not None:
- return psycopg.DateFromMx(value)
+ return dialect.dbapi.DateFromMx(value)
else:
return None
@@ -219,44 +214,49 @@ def descriptor():
]}
class PGExecutionContext(default.DefaultExecutionContext):
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None:
- if not engine.dialect.use_oids:
+
+ def is_select(self):
+ return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
+
+ def create_cursor(self):
+ if self.dialect.server_side_cursors and self.is_select():
+ # use server-side cursors:
+ # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+ ident = "c" + hex(random.randint(0, 65535))[2:]
+ return self.connection.connection.cursor(ident)
+ else:
+ return self.connection.connection.cursor()
+
+ def get_result_proxy(self):
+ if self.dialect.server_side_cursors and self.is_select():
+ return base.BufferedRowResultProxy(self)
+ else:
+ return base.ResultProxy(self)
+
+ def post_exec(self):
+ if self.compiled.isinsert and self.last_inserted_ids is None:
+ if not self.dialect.use_oids:
pass
# will raise invalid error when they go to get them
else:
- table = compiled.statement.table
- cursor = proxy()
- if cursor.lastrowid is not None and table is not None and len(table.primary_key):
- s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid)
- c = s.compile(engine=engine)
- cursor = proxy(str(c), c.get_params())
- row = cursor.fetchone()
+ table = self.compiled.statement.table
+ if self.cursor.lastrowid is not None and table is not None and len(table.primary_key):
+ s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid)
+ row = self.connection.execute(s).fetchone()
self._last_inserted_ids = [v for v in row]
-
+ super(PGExecutionContext, self).post_exec()
+
class PGDialect(ansisql.ANSIDialect):
- def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params):
+ def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs):
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
- if module is None:
- #if psycopg is None:
- # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument")
- self.module = psycopg
+ if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
+ self.version = 2
else:
- self.module = module
- # figure psycopg version 1 or 2
- try:
- if self.module.__version__.startswith('2'):
- self.version = 2
- else:
- self.version = 1
- except:
self.version = 1
- ansisql.ANSIDialect.__init__(self, **params)
self.use_information_schema = use_information_schema
- # produce consistent paramstyle even if psycopg2 module not present
- if self.module is None:
- self.paramstyle = 'pyformat'
+ self.paramstyle = 'pyformat'
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
@@ -268,16 +268,9 @@ class PGDialect(ansisql.ANSIDialect):
opts.update(url.query)
return ([], opts)
- def create_cursor(self, connection):
- if self.server_side_cursors:
- # use server-side cursors:
- # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- return connection.cursor('x')
- else:
- return connection.cursor()
- def create_execution_context(self):
- return PGExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return PGExecutionContext(self, *args, **kwargs)
def max_identifier_length(self):
return 68
@@ -292,13 +285,13 @@ class PGDialect(ansisql.ANSIDialect):
return PGCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return PGSchemaGenerator(*args, **kwargs)
+ return PGSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return PGSchemaDropper(*args, **kwargs)
+ return PGSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return PGDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection, **kwargs):
+ return PGDefaultRunner(connection, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
@@ -326,7 +319,6 @@ class PGDialect(ansisql.ANSIDialect):
``psycopg2`` is not nice enough to produce this correctly for
an executemany, so we do our own executemany here.
"""
-
rowcount = 0
for param in parameters:
c.execute(statement, param)
@@ -334,9 +326,6 @@ class PGDialect(ansisql.ANSIDialect):
if context is not None:
context._rowcount = rowcount
- def dbapi(self):
- return self.module
-
def has_table(self, connection, table_name, schema=None):
# seems like case gets folded in pg_class...
if schema is None:
@@ -542,7 +531,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
else:
colspec += " SERIAL"
else:
- colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -567,8 +556,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
- c = self.proxy("select %s" % column.default.arg)
- return c.fetchone()[0]
+ return self.connection.execute_text("select %s" % column.default.arg).scalar()
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
@@ -577,17 +565,13 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- c = self.proxy(exc)
- return c.fetchone()[0]
- else:
- return ansisql.ANSIDefaultRunner.get_column_default(self, column)
- else:
- return ansisql.ANSIDefaultRunner.get_column_default(self, column)
+ return self.connection.execute_text(exc).scalar()
+
+ return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
def visit_sequence(self, seq):
if not seq.optional:
- c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq))
- return c.fetchone()[0]
+ return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()
else:
return None
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index b29be9eed..9270f2a5f 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -12,19 +12,19 @@ import sqlalchemy.engine.default as default
import sqlalchemy.types as sqltypes
import datetime,time
-pysqlite2_timesupport = False # Change this if the init.d guys ever get around to supporting time cols
-
-try:
- from pysqlite2 import dbapi2 as sqlite
-except ImportError:
+def dbapi():
try:
- from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
- except ImportError:
+ from pysqlite2 import dbapi2 as sqlite
+ except ImportError, e:
try:
- sqlite = __import__('sqlite') # skip ourselves
- except:
- sqlite = None
-
+ from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+ except ImportError:
+ try:
+ sqlite = __import__('sqlite') # skip ourselves
+ except ImportError:
+ raise e
+ return sqlite
+
class SLNumeric(sqltypes.Numeric):
def get_col_spec(self):
if self.precision is None:
@@ -140,10 +140,6 @@ pragma_names = {
'BLOB' : SLBinary,
}
-if pysqlite2_timesupport:
- colspecs.update({sqltypes.Time : SLTime})
- pragma_names.update({'TIME' : SLTime})
-
def descriptor():
return {'name':'sqlite',
'description':'SQLite',
@@ -152,25 +148,29 @@ def descriptor():
]}
class SQLiteExecutionContext(default.DefaultExecutionContext):
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- if getattr(compiled, "isinsert", False):
- self._last_inserted_ids = [proxy().lastrowid]
-
+ def post_exec(self):
+ if self.compiled.isinsert:
+ self._last_inserted_ids = [self.cursor.lastrowid]
+ super(SQLiteExecutionContext, self).post_exec()
+
class SQLiteDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
def vers(num):
return tuple([int(x) for x in num.split('.')])
- self.supports_cast = (sqlite is not None and vers(sqlite.sqlite_version) >= vers("3.2.3"))
- ansisql.ANSIDialect.__init__(self, **kwargs)
+ self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
def compiler(self, statement, bindparams, **kwargs):
return SQLiteCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return SQLiteSchemaGenerator(*args, **kwargs)
+ return SQLiteSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return SQLiteSchemaDropper(*args, **kwargs)
+ return SQLiteSchemaDropper(self, *args, **kwargs)
+
+ def supports_alter(self):
+ return False
def preparer(self):
return SQLiteIdentifierPreparer(self)
@@ -182,8 +182,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
- def create_execution_context(self):
- return SQLiteExecutionContext(self)
+ def create_execution_context(self, **kwargs):
+ return SQLiteExecutionContext(self, **kwargs)
def last_inserted_ids(self):
return self.context.last_inserted_ids
@@ -191,9 +191,6 @@ class SQLiteDialect(ansisql.ANSIDialect):
def oid_column_name(self, column):
return "oid"
- def dbapi(self):
- return sqlite
-
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {})
row = cursor.fetchone()
@@ -321,11 +318,9 @@ class SQLiteCompiler(ansisql.ANSICompiler):
return ansisql.ANSICompiler.binary_operator_string(self, binary)
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
- def supports_alter(self):
- return False
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -345,8 +340,7 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
- def supports_alter(self):
- return False
+ pass
class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 0baaeb826..d8a9c5299 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -83,7 +83,7 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
def type_descriptor(self, typeobj):
- """Trasform the type from generic to database-specific.
+ """Transform the type from generic to database-specific.
Provides a database-specific TypeEngine object, given the
generic object which comes from the types module. Subclasses
@@ -105,6 +105,10 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
+ def supports_alter(self):
+ """return True if the database supports ALTER TABLE."""
+ raise NotImplementedError()
+
def max_identifier_length(self):
"""Return the maximum length of identifier names.
@@ -118,32 +122,43 @@ class Dialect(sql.AbstractDialect):
def supports_sane_rowcount(self):
"""Indicate whether the dialect properly implements statements rowcount.
- Provided to indicate when MySQL is being used, which does not
- have standard behavior for the "rowcount" function on a statement handle.
+ This was needed for MySQL which had non-standard behavior of rowcount,
+ but this issue has since been resolved.
"""
raise NotImplementedError()
- def schemagenerator(self, engine, proxy, **params):
+ def schemagenerator(self, connection, **kwargs):
"""Return a ``schema.SchemaVisitor`` instance that can generate schemas.
+ connection
+ a Connection to use for statement execution
+
`schemagenerator()` is called via the `create()` method on Table,
Index, and others.
"""
raise NotImplementedError()
- def schemadropper(self, engine, proxy, **params):
+ def schemadropper(self, connection, **kwargs):
"""Return a ``schema.SchemaVisitor`` instance that can drop schemas.
+ connection
+ a Connection to use for statement execution
+
`schemadropper()` is called via the `drop()` method on Table,
Index, and others.
"""
raise NotImplementedError()
- def defaultrunner(self, engine, proxy, **params):
- """Return a ``schema.SchemaVisitor`` instance that can execute defaults."""
+ def defaultrunner(self, connection, **kwargs):
+ """Return a ``schema.SchemaVisitor`` instance that can execute defaults.
+
+ connection
+ a Connection to use for statement execution
+
+ """
raise NotImplementedError()
@@ -154,7 +169,6 @@ class Dialect(sql.AbstractDialect):
ansisql.ANSICompiler, and will produce a string representation
of the given ClauseElement and `parameters` dictionary.
- `compiler()` is called within the context of the compile() method.
"""
raise NotImplementedError()
@@ -188,23 +202,13 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def dbapi(self):
- """Establish a connection to the database.
-
- Subclasses override this method to provide the DBAPI module
- used to establish connections.
- """
-
- raise NotImplementedError()
-
def get_default_schema_name(self, connection):
"""Return the currently selected schema given a connection"""
raise NotImplementedError()
- def execution_context(self):
+ def create_execution_context(self, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
"""Return a new ExecutionContext object."""
-
raise NotImplementedError()
def do_begin(self, connection):
@@ -232,15 +236,6 @@ class Dialect(sql.AbstractDialect):
raise NotImplementedError()
- def create_cursor(self, connection):
- """Return a new cursor generated from the given connection."""
-
- raise NotImplementedError()
-
- def create_result_proxy_args(self, connection, cursor):
- """Return a dictionary of arguments that should be passed to ResultProxy()."""
-
- raise NotImplementedError()
def compile(self, clauseelement, parameters=None):
"""Compile the given ClauseElement using this Dialect.
@@ -255,42 +250,74 @@ class Dialect(sql.AbstractDialect):
class ExecutionContext(object):
"""A messenger object for a Dialect that corresponds to a single execution.
+ ExecutionContext should have these datamembers:
+
+ connection
+ Connection object which initiated the call to the
+ dialect to create this ExecutionContext.
+
+ dialect
+ dialect which created this ExecutionContext.
+
+ cursor
+ DBAPI cursor procured from the connection
+
+ compiled
+ if passed to constructor, sql.Compiled object being executed
+
+ compiled_parameters
+ if passed to constructor, sql.ClauseParameters object
+
+ statement
+ string version of the statement to be executed. Is either
+ passed to the constructor, or must be created from the
+ sql.Compiled object by the time pre_exec() has completed.
+
+ parameters
+ "raw" parameters suitable for direct execution by the
+ dialect. Either passed to the constructor, or must be
+ created from the sql.ClauseParameters object by the time
+ pre_exec() has completed.
+
+
The Dialect should provide an ExecutionContext via the
create_execution_context() method. The `pre_exec` and `post_exec`
- methods will be called for compiled statements, afterwhich it is
- expected that the various methods `last_inserted_ids`,
- `last_inserted_params`, etc. will contain appropriate values, if
- applicable.
+ methods will be called for compiled statements.
+
"""
- def pre_exec(self, engine, proxy, compiled, parameters):
- """Called before an execution of a compiled statement.
+ def create_cursor(self):
+ """Return a new cursor generated this ExecutionContext's connection."""
- `proxy` is a callable that takes a string statement and a bind
- parameter list/dictionary.
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ """Called before an execution of a compiled statement.
+
+ If compiled and compiled_parameters were passed to this
+ ExecutionContext, the `statement` and `parameters` datamembers
+ must be initialized after this statement is complete.
"""
raise NotImplementedError()
- def post_exec(self, engine, proxy, compiled, parameters):
+ def post_exec(self):
"""Called after the execution of a compiled statement.
-
- `proxy` is a callable that takes a string statement and a bind
- parameter list/dictionary.
+
+ If compiled was passed to this ExecutionContext,
+ the `last_insert_ids`, `last_inserted_params`, etc.
+ datamembers should be available after this method
+ completes.
"""
raise NotImplementedError()
-
- def get_rowcount(self, cursor):
- """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
-
+
+ def get_result_proxy(self):
+ """return a ResultProxy corresponding to this ExecutionContext."""
raise NotImplementedError()
-
- def supports_sane_rowcount(self):
- """Indicate if the "rowcount" DBAPI cursor function works properly.
-
- Currently, MySQLDB does not properly implement this function.
- """
+
+ def get_rowcount(self):
+ """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
raise NotImplementedError()
@@ -299,7 +326,7 @@ class ExecutionContext(object):
This does not apply to straight textual clauses; only to
``sql.Insert`` objects compiled against a ``schema.Table`` object,
- which are executed via `statement.execute()`. The order of
+ which are executed via `execute()`. The order of
items in the list is the same as that of the Table's
'primary_key' attribute.
@@ -337,7 +364,7 @@ class ExecutionContext(object):
raise NotImplementedError()
-class Connectable(object):
+class Connectable(sql.Executor):
"""Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
def contextual_connect(self):
@@ -362,6 +389,7 @@ class Connectable(object):
raise NotImplementedError()
engine = property(_not_impl, doc="The Engine which this Connectable is associated with.")
+ dialect = property(_not_impl, doc="Dialect which this Connectable is associated with.")
class Connection(Connectable):
"""Represent a single DBAPI connection returned from the underlying connection pool.
@@ -385,7 +413,8 @@ class Connection(Connectable):
except AttributeError:
raise exceptions.InvalidRequestError("This Connection is closed")
- engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated (read only)")
+ engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
+ dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
@@ -429,7 +458,7 @@ class Connection(Connectable):
"""When no Transaction is present, this is called after executions to provide "autocommit" behavior."""
# TODO: have the dialect determine if autocommit can be set on the connection directly without this
# extra step
- if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()):
+ if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip(), re.I):
self._commit_impl()
def _autorollback(self):
@@ -448,6 +477,9 @@ class Connection(Connectable):
def scalar(self, object, *multiparams, **params):
return self.execute(object, *multiparams, **params).scalar()
+ def compiler(self, statement, parameters, **kwargs):
+ return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs)
+
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
if c in Connection.executors:
@@ -456,7 +488,7 @@ class Connection(Connectable):
raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
def execute_default(self, default, **kwargs):
- return default.accept_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs))
+ return default.accept_visitor(self.__engine.dialect.defaultrunner(self))
def execute_text(self, statement, *multiparams, **params):
if len(multiparams) == 0:
@@ -465,9 +497,9 @@ class Connection(Connectable):
parameters = multiparams[0]
else:
parameters = list(multiparams)
- cursor = self._execute_raw(statement, parameters)
- rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
- return ResultProxy(self.__engine, self, cursor, **rpargs)
+ context = self._create_execution_context(statement=statement, parameters=parameters)
+ self._execute_raw(context)
+ return context.get_result_proxy()
def _params_to_listofdicts(self, *multiparams, **params):
if len(multiparams) == 0:
@@ -491,29 +523,57 @@ class Connection(Connectable):
param = multiparams[0]
else:
param = params
- return self.execute_compiled(elem.compile(engine=self.__engine, parameters=param), *multiparams, **params)
+ return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params)
def execute_compiled(self, compiled, *multiparams, **params):
"""Execute a sql.Compiled object."""
if not compiled.can_execute:
raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
- cursor = self.__engine.dialect.create_cursor(self.connection)
parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
if len(parameters) == 1:
parameters = parameters[0]
- def proxy(statement=None, parameters=None):
- if statement is None:
- return cursor
-
- parameters = self.__engine.dialect.convert_compiled_params(parameters)
- self._execute_raw(statement, parameters, cursor=cursor, context=context)
- return cursor
- context = self.__engine.dialect.create_execution_context()
- context.pre_exec(self.__engine, proxy, compiled, parameters)
- proxy(unicode(compiled), parameters)
- context.post_exec(self.__engine, proxy, compiled, parameters)
- rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
- return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs)
+ context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters)
+ context.pre_exec()
+ self._execute_raw(context)
+ context.post_exec()
+ return context.get_result_proxy()
+
+ def _create_execution_context(self, **kwargs):
+ return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
+
+ def _execute_raw(self, context):
+ self.__engine.logger.info(context.statement)
+ self.__engine.logger.info(repr(context.parameters))
+ if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], dict)):
+ self._executemany(context)
+ else:
+ self._execute(context)
+ self._autocommit(context.statement)
+
+ def _execute(self, context):
+ if context.parameters is None:
+ if context.dialect.positional:
+ context.parameters = ()
+ else:
+ context.parameters = {}
+ try:
+ context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
+ except Exception, e:
+ self._autorollback()
+ #self._rollback_impl()
+ if self.__close_with_result:
+ self.close()
+ raise exceptions.SQLError(context.statement, context.parameters, e)
+
+ def _executemany(self, context):
+ try:
+ context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
+ except Exception, e:
+ self._autorollback()
+ #self._rollback_impl()
+ if self.__close_with_result:
+ self.close()
+ raise exceptions.SQLError(context.statement, context.parameters, e)
# poor man's multimethod/generic function thingy
executors = {
@@ -525,17 +585,17 @@ class Connection(Connectable):
}
def create(self, entity, **kwargs):
- """Create a table or index given an appropriate schema object."""
+ """Create a Table or Index given an appropriate Schema object."""
return self.__engine.create(entity, connection=self, **kwargs)
def drop(self, entity, **kwargs):
- """Drop a table or index given an appropriate schema object."""
+ """Drop a Table or Index given an appropriate Schema object."""
return self.__engine.drop(entity, connection=self, **kwargs)
def reflecttable(self, table, **kwargs):
- """Reflect the columns in the given table from the database."""
+ """Reflect the columns in the given string table name from the database."""
return self.__engine.reflecttable(table, connection=self, **kwargs)
@@ -545,59 +605,6 @@ class Connection(Connectable):
def run_callable(self, callable_):
return callable_(self)
- def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs):
- if cursor is None:
- cursor = self.__engine.dialect.create_cursor(self.connection)
- if not self.__engine.dialect.supports_unicode_statements():
- # encode to ascii, with full error handling
- statement = statement.encode('ascii')
- self.__engine.logger.info(statement)
- self.__engine.logger.info(repr(parameters))
- if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
- self._executemany(cursor, statement, parameters, context=context)
- else:
- self._execute(cursor, statement, parameters, context=context)
- self._autocommit(statement)
- return cursor
-
- def _execute(self, c, statement, parameters, context=None):
- if parameters is None:
- if self.__engine.dialect.positional:
- parameters = ()
- else:
- parameters = {}
- try:
- self.__engine.dialect.do_execute(c, statement, parameters, context=context)
- except Exception, e:
- self._autorollback()
- #self._rollback_impl()
- if self.__close_with_result:
- self.close()
- raise exceptions.SQLError(statement, parameters, e)
-
- def _executemany(self, c, statement, parameters, context=None):
- try:
- self.__engine.dialect.do_executemany(c, statement, parameters, context=context)
- except Exception, e:
- self._autorollback()
- #self._rollback_impl()
- if self.__close_with_result:
- self.close()
- raise exceptions.SQLError(statement, parameters, e)
-
- def proxy(self, statement=None, parameters=None):
- """Execute the given statement string and parameter object.
-
- The parameter object is expected to be the result of a call to
- ``compiled.get_params()``. This callable is a generic version
- of a connection/cursor-specific callable that is produced
- within the execute_compiled method, and is used for objects
- that require this style of proxy when outside of an
- execute_compiled method, primarily the DefaultRunner.
- """
- parameters = self.__engine.dialect.convert_compiled_params(parameters)
- return self._execute_raw(statement, parameters)
-
class Transaction(object):
"""Represent a Transaction in progress.
@@ -630,7 +637,7 @@ class Transaction(object):
self.__connection._commit_impl()
self.__is_active = False
-class Engine(sql.Executor, Connectable):
+class Engine(Connectable):
"""
Connects a ConnectionProvider, a Dialect and a CompilerFactory together to
provide a default implementation of SchemaEngine.
@@ -638,12 +645,13 @@ class Engine(sql.Executor, Connectable):
def __init__(self, connection_provider, dialect, echo=None):
self.connection_provider = connection_provider
- self.dialect=dialect
+ self._dialect=dialect
self.echo = echo
self.logger = logging.instance_logger(self)
name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'])
engine = property(lambda s:s)
+ dialect = property(lambda s:s._dialect)
echo = logging.echo_property()
def dispose(self):
@@ -678,11 +686,11 @@ class Engine(sql.Executor, Connectable):
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
if connection is None:
- conn = self.contextual_connect()
+ conn = self.contextual_connect(close_with_result=False)
else:
conn = connection
try:
- element.accept_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs))
+ element.accept_visitor(visitorcallable(conn, **kwargs))
finally:
if connection is None:
conn.close()
@@ -807,55 +815,39 @@ class ResultProxy(object):
def convert_result_value(self, arg, engine):
raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
- def __new__(cls, *args, **kwargs):
- if cls is ResultProxy and kwargs.has_key('should_prefetch') and kwargs['should_prefetch']:
- return PrefetchingResultProxy(*args, **kwargs)
- else:
- return object.__new__(cls, *args, **kwargs)
-
- def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, column_labels=None, should_prefetch=None):
+ def __init__(self, context):
"""ResultProxy objects are constructed via the execute() method on SQLEngine."""
-
- self.connection = connection
- self.dialect = engine.dialect
- self.cursor = cursor
- self.engine = engine
+ self.context = context
self.closed = False
- self.column_labels = column_labels
- if executioncontext is not None:
- self.__executioncontext = executioncontext
- self.rowcount = executioncontext.get_rowcount(cursor)
- else:
- self.rowcount = cursor.rowcount
- self.__key_cache = {}
- self.__echo = engine.echo == 'debug'
- metadata = cursor.description
- self.props = {}
- self.keys = []
- i = 0
+ self.cursor = context.cursor
+ self.__echo = logging.is_debug_enabled(context.engine.logger)
+ self._init_metadata()
+ dialect = property(lambda s:s.context.dialect)
+ rowcount = property(lambda s:s.context.get_rowcount())
+ connection = property(lambda s:s.context.connection)
+
+ def _init_metadata(self):
+ if hasattr(self, '_ResultProxy__props'):
+ return
+ self.__key_cache = {}
+ self.__props = {}
+ self.__keys = []
+ metadata = self.cursor.description
if metadata is not None:
- for item in metadata:
+ for i, item in enumerate(metadata):
# sqlite possibly prepending table name to colnames so strip
- colname = item[0].split('.')[-1].lower()
- if typemap is not None:
- rec = (typemap.get(colname, types.NULLTYPE), i)
+ colname = item[0].split('.')[-1]
+ if self.context.typemap is not None:
+ rec = (self.context.typemap.get(colname.lower(), types.NULLTYPE), i)
else:
rec = (types.NULLTYPE, i)
if rec[0] is None:
raise DBAPIError("None for metadata " + colname)
- if self.props.setdefault(colname, rec) is not rec:
- self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0)
- self.keys.append(colname)
- self.props[i] = rec
- i+=1
-
- def _executioncontext(self):
- try:
- return self.__executioncontext
- except AttributeError:
- raise exceptions.InvalidRequestError("This ResultProxy does not have an execution context with which to complete this operation. Execution contexts are not generated for literal SQL execution.")
- executioncontext = property(_executioncontext)
+ if self.__props.setdefault(colname.lower(), rec) is not rec:
+ self.__props[colname.lower()] = (ResultProxy.AmbiguousColumn(colname), 0)
+ self.__keys.append(colname)
+ self.__props[i] = rec
def close(self):
"""Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
@@ -867,13 +859,12 @@ class ResultProxy(object):
This method is also called automatically when all result rows
are exhausted.
"""
-
if not self.closed:
self.closed = True
self.cursor.close()
if self.connection.should_close_with_result and self.dialect.supports_autoclose_results:
self.connection.close()
-
+
def _convert_key(self, key):
"""Convert and cache a key.
@@ -882,25 +873,26 @@ class ResultProxy(object):
metadata; then cache it locally for quick re-access.
"""
- try:
+ if key in self.__key_cache:
return self.__key_cache[key]
- except KeyError:
- if isinstance(key, int) and key in self.props:
- rec = self.props[key]
- elif isinstance(key, basestring) and key.lower() in self.props:
- rec = self.props[key.lower()]
+ else:
+ if isinstance(key, int) and key in self.__props:
+ rec = self.__props[key]
+ elif isinstance(key, basestring) and key.lower() in self.__props:
+ rec = self.__props[key.lower()]
elif isinstance(key, sql.ColumnElement):
- label = self.column_labels.get(key._label, key.name).lower()
- if label in self.props:
- rec = self.props[label]
+ label = self.context.column_labels.get(key._label, key.name).lower()
+ if label in self.__props:
+ rec = self.__props[label]
if not "rec" in locals():
raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (repr(key)))
self.__key_cache[key] = rec
return rec
-
-
+
+ keys = property(lambda s:s.__keys)
+
def _has_key(self, row, key):
try:
self._convert_key(key)
@@ -908,10 +900,6 @@ class ResultProxy(object):
except KeyError:
return False
- def _get_col(self, row, key):
- rec = self._convert_key(key)
- return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect)
-
def __iter__(self):
while True:
row = self.fetchone()
@@ -926,7 +914,7 @@ class ResultProxy(object):
See ExecutionContext for details.
"""
- return self.executioncontext.last_inserted_ids()
+ return self.context.last_inserted_ids()
def last_updated_params(self):
"""Return ``last_updated_params()`` from the underlying ExecutionContext.
@@ -934,7 +922,7 @@ class ResultProxy(object):
See ExecutionContext for details.
"""
- return self.executioncontext.last_updated_params()
+ return self.context.last_updated_params()
def last_inserted_params(self):
"""Return ``last_inserted_params()`` from the underlying ExecutionContext.
@@ -942,7 +930,7 @@ class ResultProxy(object):
See ExecutionContext for details.
"""
- return self.executioncontext.last_inserted_params()
+ return self.context.last_inserted_params()
def lastrow_has_defaults(self):
"""Return ``lastrow_has_defaults()`` from the underlying ExecutionContext.
@@ -950,7 +938,7 @@ class ResultProxy(object):
See ExecutionContext for details.
"""
- return self.executioncontext.lastrow_has_defaults()
+ return self.context.lastrow_has_defaults()
def supports_sane_rowcount(self):
"""Return ``supports_sane_rowcount()`` from the underlying ExecutionContext.
@@ -958,71 +946,122 @@ class ResultProxy(object):
See ExecutionContext for details.
"""
- return self.executioncontext.supports_sane_rowcount()
+ return self.context.supports_sane_rowcount()
+ def _get_col(self, row, key):
+ rec = self._convert_key(key)
+ return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect)
+
+ def _fetchone_impl(self):
+ return self.cursor.fetchone()
+ def _fetchmany_impl(self, size=None):
+ return self.cursor.fetchmany(size)
+ def _fetchall_impl(self):
+ return self.cursor.fetchall()
+
+ def _process_row(self, row):
+ return RowProxy(self, row)
+
def fetchall(self):
"""Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
- l = []
- for row in self.cursor.fetchall():
- l.append(RowProxy(self, row))
+ l = [self._process_row(row) for row in self._fetchall_impl()]
self.close()
return l
def fetchmany(self, size=None):
"""Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
- if size is None:
- rows = self.cursor.fetchmany()
- else:
- rows = self.cursor.fetchmany(size)
- l = []
- for row in rows:
- l.append(RowProxy(self, row))
+ l = [self._process_row(row) for row in self._fetchmany_impl(size)]
if len(l) == 0:
self.close()
return l
def fetchone(self):
"""Fetch one row, just like DBAPI ``cursor.fetchone()``."""
-
- row = self.cursor.fetchone()
+ row = self._fetchone_impl()
if row is not None:
- return RowProxy(self, row)
+ return self._process_row(row)
else:
self.close()
return None
def scalar(self):
"""Fetch the first column of the first row, and close the result set."""
-
- row = self.cursor.fetchone()
+ row = self._fetchone_impl()
try:
if row is not None:
- return RowProxy(self, row)[0]
+ return self._process_row(row)[0]
else:
return None
finally:
self.close()
-class PrefetchingResultProxy(ResultProxy):
+class BufferedRowResultProxy(ResultProxy):
+ def _init_metadata(self):
+ self.__buffer_rows()
+ super(BufferedRowResultProxy, self)._init_metadata()
+
+ # this is a "growth chart" for the buffering of rows.
+ # each successive __buffer_rows call will use the next
+ # value in the list for the buffer size until the max
+ # is reached
+ size_growth = {
+ 1 : 5,
+ 5 : 10,
+ 10 : 20,
+ 20 : 50,
+ 50 : 100
+ }
+
+ def __buffer_rows(self):
+ size = getattr(self, '_bufsize', 1)
+ self.__rowbuffer = self.cursor.fetchmany(size)
+ #self.context.engine.logger.debug("Buffered %d rows" % size)
+ self._bufsize = self.size_growth.get(size, size)
+
+ def _fetchone_impl(self):
+ if self.closed:
+ return None
+ if len(self.__rowbuffer) == 0:
+ self.__buffer_rows()
+ if len(self.__rowbuffer) == 0:
+ return None
+ return self.__rowbuffer.pop(0)
+
+ def _fetchmany_impl(self, size=None):
+ result = []
+ for x in range(0, size):
+ row = self._fetchone_impl()
+ if row is None:
+ break
+ result.append(row)
+ return result
+
+ def _fetchall_impl(self):
+ return self.__rowbuffer + list(self.cursor.fetchall())
+
+class BufferedColumnResultProxy(ResultProxy):
"""ResultProxy that loads all columns into memory each time fetchone() is
called. If fetchmany() or fetchall() are called, the full grid of results
is fetched.
"""
-
def _get_col(self, row, key):
rec = self._convert_key(key)
return row[rec[1]]
+
+ def _process_row(self, row):
+ sup = super(BufferedColumnResultProxy, self)
+ row = [sup._get_col(row, i) for i in xrange(len(row))]
+ return RowProxy(self, row)
def fetchall(self):
l = []
while True:
row = self.fetchone()
- if row is not None:
- l.append(row)
- else:
+ if row is None:
break
+ l.append(row)
return l
def fetchmany(self, size=None):
@@ -1031,24 +1070,13 @@ class PrefetchingResultProxy(ResultProxy):
l = []
for i in xrange(size):
row = self.fetchone()
- if row is not None:
- l.append(row)
- else:
+ if row is None:
break
+ l.append(row)
return l
- def fetchone(self):
- sup = super(PrefetchingResultProxy, self)
- row = self.cursor.fetchone()
- if row is not None:
- row = [sup._get_col(row, i) for i in xrange(len(row))]
- return RowProxy(self, row)
- else:
- self.close()
- return None
-
class RowProxy(object):
- """Proxie a single cursor row for a parent ResultProxy.
+ """Proxy a single cursor row for a parent ResultProxy.
Mostly follows "ordered dictionary" behavior, mapping result
values to the string-based column name, the integer position of
@@ -1063,7 +1091,7 @@ class RowProxy(object):
self.__parent = parent
self.__row = row
if self.__parent._ResultProxy__echo:
- self.__parent.engine.logger.debug("Row " + repr(row))
+ self.__parent.context.engine.logger.debug("Row " + repr(row))
def close(self):
"""Close the parent ResultProxy."""
@@ -1115,20 +1143,10 @@ class RowProxy(object):
class SchemaIterator(schema.SchemaVisitor):
"""A visitor that can gather text into a buffer and execute the contents of the buffer."""
- def __init__(self, engine, proxy, **params):
+ def __init__(self, connection):
"""Construct a new SchemaIterator.
-
- engine
- the Engine used by this SchemaIterator
-
- proxy
- a callable which takes a statement and bind parameters and
- executes it, returning the cursor (the actual DBAPI cursor).
- The callable should use the same cursor repeatedly.
"""
-
- self.proxy = proxy
- self.engine = engine
+ self.connection = connection
self.buffer = StringIO.StringIO()
def append(self, s):
@@ -1140,7 +1158,7 @@ class SchemaIterator(schema.SchemaVisitor):
"""Execute the contents of the SchemaIterator's buffer."""
try:
- return self.proxy(self.buffer.getvalue(), None)
+ return self.connection.execute(self.buffer.getvalue())
finally:
self.buffer.truncate(0)
@@ -1154,10 +1172,10 @@ class DefaultRunner(schema.SchemaVisitor):
DefaultRunner to allow database-specific behavior.
"""
- def __init__(self, engine, proxy):
- self.proxy = proxy
- self.engine = engine
-
+ def __init__(self, connection):
+ self.connection = connection
+ self.dialect = connection.dialect
+
def get_column_default(self, column):
if column.default is not None:
return column.default.accept_visitor(self)
@@ -1188,8 +1206,8 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def exec_default_sql(self, default):
- c = sql.select([default.arg], engine=self.engine).compile()
- return self.proxy(str(c), c.get_params()).fetchone()[0]
+ c = sql.select([default.arg]).compile(engine=self.connection)
+ return self.connection.execute_compiled(c).scalar()
def visit_column_onupdate(self, onupdate):
if isinstance(onupdate.arg, sql.ClauseElement):
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 86563cd7c..ceecee364 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -26,16 +26,17 @@ class PoolConnectionProvider(base.ConnectionProvider):
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
- def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs):
+ def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
self.convert_unicode = convert_unicode
self.supports_autoclose_results = True
self.encoding = encoding
self.positional = False
self._ischema = None
- self._figure_paramstyle(default=default_paramstyle)
+ self.dbapi = dbapi
+ self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
- def create_execution_context(self):
- return DefaultExecutionContext(self)
+ def create_execution_context(self, **kwargs):
+ return DefaultExecutionContext(self, **kwargs)
def type_descriptor(self, typeobj):
"""Provide a database-specific ``TypeEngine`` object, given
@@ -56,6 +57,9 @@ class DefaultDialect(base.Dialect):
# TODO: probably raise this and fill out
# db modules better
return 30
+
+ def supports_alter(self):
+ return True
def oid_column_name(self, column):
return None
@@ -92,14 +96,8 @@ class DefaultDialect(base.Dialect):
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, engine, proxy):
- return base.DefaultRunner(engine, proxy)
-
- def create_cursor(self, connection):
- return connection.cursor()
-
- def create_result_proxy_args(self, connection, cursor):
- return dict(should_prefetch=False)
+ def defaultrunner(self, connection):
+ return base.DefaultRunner(connection)
def _set_paramstyle(self, style):
self._paramstyle = style
@@ -126,11 +124,10 @@ class DefaultDialect(base.Dialect):
return parameters
def _figure_paramstyle(self, paramstyle=None, default='named'):
- db = self.dbapi()
if paramstyle is not None:
self._paramstyle = paramstyle
- elif db is not None:
- self._paramstyle = db.paramstyle
+ elif self.dbapi is not None:
+ self._paramstyle = self.dbapi.paramstyle
else:
self._paramstyle = default
@@ -146,10 +143,6 @@ class DefaultDialect(base.Dialect):
raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
def _get_ischema(self):
- # We use a property for ischema so that the accessor
- # creation only happens as needed, since otherwise we
- # have a circularity problem with the generic
- # ansisql.engine()
if self._ischema is None:
import sqlalchemy.databases.information_schema as ischema
self._ischema = ischema.ISchema(self)
@@ -157,20 +150,49 @@ class DefaultDialect(base.Dialect):
ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect):
+ def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None):
self.dialect = dialect
+ self.connection = connection
+ self.compiled = compiled
+ self.compiled_parameters = compiled_parameters
+
+ if compiled is not None:
+ self.typemap = compiled.typemap
+ self.column_labels = compiled.column_labels
+ self.statement = unicode(compiled)
+ else:
+ self.typemap = self.column_labels = None
+ self.parameters = parameters
+ self.statement = statement
- def pre_exec(self, engine, proxy, compiled, parameters):
- self._process_defaults(engine, proxy, compiled, parameters)
+ if not dialect.supports_unicode_statements():
+ self.statement = self.statement.encode('ascii')
+
+ self.cursor = self.create_cursor()
+
+ engine = property(lambda s:s.connection.engine)
+
+ def is_select(self):
+ return re.match(r'SELECT', self.statement.lstrip(), re.I)
+
+ def create_cursor(self):
+ return self.connection.connection.cursor()
+
+ def pre_exec(self):
+ self._process_defaults()
+ self.parameters = self.dialect.convert_compiled_params(self.compiled_parameters)
- def post_exec(self, engine, proxy, compiled, parameters):
+ def post_exec(self):
pass
- def get_rowcount(self, cursor):
+ def get_result_proxy(self):
+ return base.ResultProxy(self)
+
+ def get_rowcount(self):
if hasattr(self, '_rowcount'):
return self._rowcount
else:
- return cursor.rowcount
+ return self.cursor.rowcount
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount()
@@ -187,44 +209,44 @@ class DefaultExecutionContext(base.ExecutionContext):
def lastrow_has_defaults(self):
return self._lastrow_has_defaults
- def set_input_sizes(self, cursor, parameters):
+ def set_input_sizes(self):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DBAPI types
from the bind parameter's ``TypeEngine`` objects.
"""
- if isinstance(parameters, list):
- plist = parameters
+ if isinstance(self.compiled_parameters, list):
+ plist = self.compiled_parameters
else:
- plist = [parameters]
+ plist = [self.compiled_parameters]
if self.dialect.positional:
inputsizes = []
for params in plist[0:1]:
for key in params.positional:
typeengine = params.binds[key].type
- dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
+ dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes.append(dbtype)
- cursor.setinputsizes(*inputsizes)
+ self.cursor.setinputsizes(*inputsizes)
else:
inputsizes = {}
for params in plist[0:1]:
for key in params.keys():
typeengine = params.binds[key].type
- dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
+ dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
if dbtype is not None:
inputsizes[key] = dbtype
- cursor.setinputsizes(**inputsizes)
+ self.cursor.setinputsizes(**inputsizes)
- def _process_defaults(self, engine, proxy, compiled, parameters):
+ def _process_defaults(self):
"""``INSERT`` and ``UPDATE`` statements, when compiled, may
have additional columns added to their ``VALUES`` and ``SET``
lists corresponding to column defaults/onupdates that are
present on the ``Table`` object (i.e. ``ColumnDefault``,
``Sequence``, ``PassiveDefault``). This method pre-execs
those ``DefaultGenerator`` objects that require pre-execution
- and sets their values within the parameter list, and flags the
- thread-local state about ``PassiveDefault`` objects that may
+ and sets their values within the parameter list, and flags this
+ ExecutionContext about ``PassiveDefault`` objects that may
require post-fetching the row after it is inserted/updated.
This method relies upon logic within the ``ANSISQLCompiler``
@@ -234,30 +256,28 @@ class DefaultExecutionContext(base.ExecutionContext):
statement.
"""
- if compiled is None: return
-
- if getattr(compiled, "isinsert", False):
- if isinstance(parameters, list):
- plist = parameters
+ if self.compiled.isinsert:
+ if isinstance(self.compiled_parameters, list):
+ plist = self.compiled_parameters
else:
- plist = [parameters]
- drunner = self.dialect.defaultrunner(engine, proxy)
+ plist = [self.compiled_parameters]
+ drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
self._lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
need_lastrowid=False
# check the "default" status of each column in the table
- for c in compiled.statement.table.c:
+ for c in self.compiled.statement.table.c:
# check if it will be populated by a SQL clause - we'll need that
# after execution.
- if c in compiled.inline_params:
+ if c in self.compiled.inline_params:
self._lastrow_has_defaults = True
if c.primary_key:
need_lastrowid = True
# check if its not present at all. see if theres a default
# and fire it off, and add to bind parameters. if
# its a pk, add the value to our last_inserted_ids list,
- # or, if its a SQL-side default, dont do any of that, but we'll need
+ # or, if its a SQL-side default, let it fire off on the DB side, but we'll need
# the SQL-generated value after execution.
elif not c.key in param or param.get_original(c.key) is None:
if isinstance(c.default, schema.PassiveDefault):
@@ -278,19 +298,19 @@ class DefaultExecutionContext(base.ExecutionContext):
else:
self._last_inserted_ids = last_inserted_ids
self._last_inserted_params = param
- elif getattr(compiled, 'isupdate', False):
- if isinstance(parameters, list):
- plist = parameters
+ elif self.compiled.isupdate:
+ if isinstance(self.compiled_parameters, list):
+ plist = self.compiled_parameters
else:
- plist = [parameters]
- drunner = self.dialect.defaultrunner(engine, proxy)
+ plist = [self.compiled_parameters]
+ drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
self._lastrow_has_defaults = False
for param in plist:
# check the "onupdate" status of each column in the table
- for c in compiled.statement.table.c:
+ for c in self.compiled.statement.table.c:
# it will be populated by a SQL clause - we'll need that
# after execution.
- if c in compiled.inline_params:
+ if c in self.compiled.inline_params:
pass
# its not in the bind parameters, and theres an "onupdate" defined for the column;
# execute it and add to bind params
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index 8ac721b77..1b760fca8 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -50,6 +50,16 @@ class DefaultEngineStrategy(EngineStrategy):
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
+ dbapi = kwargs.pop('module', None)
+ if dbapi is None:
+ dbapi_args = {}
+ for k in util.get_func_kwargs(module.dbapi):
+ if k in kwargs:
+ dbapi_args[k] = kwargs.pop(k)
+ dbapi = module.dbapi(**dbapi_args)
+
+ dialect_args['dbapi'] = dbapi
+
# create dialect
dialect = module.dialect(**dialect_args)
@@ -60,10 +70,6 @@ class DefaultEngineStrategy(EngineStrategy):
# look for existing pool or create
pool = kwargs.pop('pool', None)
if pool is None:
- dbapi = kwargs.pop('module', dialect.dbapi())
- if dbapi is None:
- raise exceptions.InvalidRequestError("Can't get DBAPI module for dialect '%s'" % dialect)
-
def connect():
try:
return dbapi.connect(*cargs, **cparams)
@@ -73,6 +79,7 @@ class DefaultEngineStrategy(EngineStrategy):
poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool))
pool_args = {}
+
# consume pool arguments from kwargs, translating a few of the arguments
for k in util.get_cls_kwargs(poolclass):
tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k)
@@ -139,3 +146,52 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy):
return threadlocal.TLEngine
ThreadLocalEngineStrategy()
+
+
+class MockEngineStrategy(EngineStrategy):
+ """Produces a single Connection object which dispatches statement executions
+ to a passed-in function"""
+ def __init__(self):
+ EngineStrategy.__init__(self, 'mock')
+
+ def create(self, name_or_url, executor, **kwargs):
+ # create url.URL object
+ u = url.make_url(name_or_url)
+
+ # get module from sqlalchemy.databases
+ module = u.get_module()
+
+ dialect_args = {}
+ # consume dialect arguments from kwargs
+ for k in util.get_cls_kwargs(module.dialect):
+ if k in kwargs:
+ dialect_args[k] = kwargs.pop(k)
+
+ # create dialect
+ dialect = module.dialect(**dialect_args)
+
+ return MockEngineStrategy.MockConnection(dialect, executor)
+
+ class MockConnection(base.Connectable):
+ def __init__(self, dialect, execute):
+ self._dialect = dialect
+ self.execute = execute
+
+ engine = property(lambda s: s)
+ dialect = property(lambda s:s._dialect)
+
+ def contextual_connect(self):
+ return self
+
+ def create(self, entity, **kwargs):
+ kwargs['checkfirst'] = False
+ entity.accept_visitor(self.dialect.schemagenerator(self, **kwargs))
+
+ def drop(self, entity, **kwargs):
+ kwargs['checkfirst'] = False
+ entity.accept_visitor(self.dialect.schemadropper(self, **kwargs))
+
+ def execute(self, object, *multiparams, **params):
+ raise NotImplementedError()
+
+MockEngineStrategy() \ No newline at end of file
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index edb8cf32e..faa0ffc11 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -71,6 +71,10 @@ class URL(object):
def get_module(self):
"""Return the SQLAlchemy database module corresponding to this URL's driver name."""
+ if self.drivername == 'ansi':
+ import sqlalchemy.ansisql
+ return sqlalchemy.ansisql
+
try:
return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
except ImportError:
diff --git a/lib/sqlalchemy/logging.py b/lib/sqlalchemy/logging.py
index 6f4368707..91326233a 100644
--- a/lib/sqlalchemy/logging.py
+++ b/lib/sqlalchemy/logging.py
@@ -31,8 +31,8 @@ import sys
# py2.5 absolute imports will fix....
logging = __import__('logging')
-# turn off logging at the root sqlalchemy level
-logging.getLogger('sqlalchemy').setLevel(logging.ERROR)
+
+logging.getLogger('sqlalchemy').setLevel(logging.WARN)
default_enabled = False
def default_logging(name):
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index 787fd059f..8d559aff5 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -237,7 +237,9 @@ class _ConnectionFairy(object):
raise
if self.__pool.echo:
self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
-
+
+ _logger = property(lambda self: self.__pool.logger)
+
def invalidate(self):
if self.connection is None:
raise exceptions.InvalidRequestError("This connection is closed")
@@ -248,7 +250,8 @@ class _ConnectionFairy(object):
def cursor(self, *args, **kwargs):
try:
- return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+ c = self.connection.cursor(*args, **kwargs)
+ return _CursorFairy(self, c)
except Exception, e:
self.invalidate()
raise
@@ -307,11 +310,14 @@ class _CursorFairy(object):
def invalidate(self):
self.__parent.invalidate()
-
+
def close(self):
if self in self.__parent._cursors:
del self.__parent._cursors[self]
- self.cursor.close()
+ try:
+ self.cursor.close()
+ except Exception, e:
+ self.__parent._logger.warn("Error closing cursor: " + str(e))
def __getattr__(self, key):
return getattr(self.cursor, key)
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 87cbdaf0c..f6c2315ae 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -508,7 +508,7 @@ class ClauseParameters(object):
return d
def __repr__(self):
- return repr(self.get_original_dict())
+ return self.__class__.__name__ + ":" + repr(self.get_original_dict())
class ClauseVisitor(object):
"""A class that knows how to traverse and visit
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 86e323c6e..7d7dbeeed 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -53,28 +53,12 @@ class TypeEngine(AbstractType):
def __init__(self, *args, **params):
pass
- def engine_impl(self, engine):
- """Deprecated; call dialect_impl with a dialect directly."""
-
- return self.dialect_impl(engine.dialect)
-
def dialect_impl(self, dialect):
try:
return self.impl_dict[dialect]
except KeyError:
return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self))
- def _get_impl(self):
- if hasattr(self, '_impl'):
- return self._impl
- else:
- return NULLTYPE
-
- def _set_impl(self, impl):
- self._impl = impl
-
- impl = property(_get_impl, _set_impl)
-
def get_col_spec(self):
raise NotImplementedError()
@@ -86,26 +70,25 @@ class TypeEngine(AbstractType):
def adapt(self, cls):
return cls()
-
+
+ def get_search_list(self):
+ """return a list of classes to test for a match
+ when adapting this type to a dialect-specific type.
+
+ """
+
+ return self.__class__.__mro__[0:-1]
+
class TypeDecorator(AbstractType):
def __init__(self, *args, **kwargs):
if not hasattr(self.__class__, 'impl'):
raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
self.impl = self.__class__.impl(*args, **kwargs)
- def engine_impl(self, engine):
- return self.dialect_impl(engine.dialect)
-
def dialect_impl(self, dialect):
try:
return self.impl_dict[dialect]
except:
- # see if the dialect has an adaptation of the TypeDecorator itself
- adapted_decorator = dialect.type_descriptor(self)
- if adapted_decorator is not self:
- result = adapted_decorator.dialect_impl(dialect)
- self.impl_dict[dialect] = result
- return result
typedesc = dialect.type_descriptor(self.impl)
tt = self.copy()
if not isinstance(tt, self.__class__):
@@ -168,8 +151,7 @@ def to_instance(typeobj):
def adapt_type(typeobj, colspecs):
if isinstance(typeobj, type):
typeobj = typeobj()
-
- for t in typeobj.__class__.__mro__[0:-1]:
+ for t in typeobj.get_search_list():
try:
impltype = colspecs[t]
break
@@ -198,26 +180,28 @@ class NullTypeEngine(TypeEngine):
return value
class String(TypeEngine):
- def __new__(cls, *args, **kwargs):
- if cls is not String or len(args) > 0 or kwargs.has_key('length'):
- return super(String, cls).__new__(cls, *args, **kwargs)
- else:
- return super(String, TEXT).__new__(TEXT, *args, **kwargs)
-
- def __init__(self, length = None):
+ def __init__(self, length=None, convert_unicode=False):
self.length = length
+ self.convert_unicode = convert_unicode
def adapt(self, impltype):
- return impltype(length=self.length)
+ return impltype(length=self.length, convert_unicode=self.convert_unicode)
def convert_bind_param(self, value, dialect):
- if not dialect.convert_unicode or value is None or not isinstance(value, unicode):
+ if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode):
return value
else:
return value.encode(dialect.encoding)
+ def get_search_list(self):
+ l = super(String, self).get_search_list()
+ if self.length is None:
+ return (TEXT,) + l
+ else:
+ return l
+
def convert_result_value(self, value, dialect):
- if not dialect.convert_unicode or value is None or isinstance(value, unicode):
+ if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode):
return value
else:
return value.decode(dialect.encoding)
@@ -228,21 +212,11 @@ class String(TypeEngine):
def compare_values(self, x, y):
return x == y
-class Unicode(TypeDecorator):
- impl = String
-
- def convert_bind_param(self, value, dialect):
- if value is not None and isinstance(value, unicode):
- return value.encode(dialect.encoding)
- else:
- return value
-
- def convert_result_value(self, value, dialect):
- if value is not None and not isinstance(value, unicode):
- return value.decode(dialect.encoding)
- else:
- return value
-
+class Unicode(String):
+ def __init__(self, length=None, **kwargs):
+ kwargs['convert_unicode'] = True
+ super(Unicode, self).__init__(length=length, **kwargs)
+
class Integer(TypeEngine):
"""Integer datatype."""
@@ -310,7 +284,7 @@ class Binary(TypeEngine):
def convert_bind_param(self, value, dialect):
if value is not None:
- return dialect.dbapi().Binary(value)
+ return dialect.dbapi.Binary(value)
else:
return None
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index dadcf0dde..238f12493 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -94,6 +94,10 @@ def get_cls_kwargs(cls):
kw.append(vn)
return kw
+def get_func_kwargs(func):
+ """Return the full set of legal kwargs for the given `func`."""
+ return [vn for vn in func.func_code.co_varnames]
+
class SimpleProperty(object):
"""A *default* property accessor."""