summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py165
1 files changed, 127 insertions, 38 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 3e8e96a42..ed975b8cf 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1,5 +1,5 @@
# engine/default.py
-# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -16,7 +16,8 @@ import re
import random
from . import reflection, interfaces, result
from ..sql import compiler, expression
-from .. import exc, types as sqltypes, util, pool, processors
+from .. import types as sqltypes
+from .. import exc, util, pool, processors
import codecs
import weakref
from .. import event
@@ -26,6 +27,7 @@ AUTOCOMMIT_REGEXP = re.compile(
re.I | re.UNICODE)
+
class DefaultDialect(interfaces.Dialect):
"""Default implementation of Dialect"""
@@ -57,6 +59,18 @@ class DefaultDialect(interfaces.Dialect):
supports_simple_order_by_label = True
+ engine_config_types = util.immutabledict([
+ ('convert_unicode', util.bool_or_str('force')),
+ ('pool_timeout', int),
+ ('echo', util.bool_or_str('debug')),
+ ('echo_pool', util.bool_or_str('debug')),
+ ('pool_recycle', int),
+ ('pool_size', int),
+ ('max_overflow', int),
+ ('pool_threadlocal', bool),
+ ('use_native_unicode', bool),
+ ])
+
# if the NUMERIC type
# returns decimal.Decimal.
# *not* the FLOAT type however.
@@ -97,6 +111,33 @@ class DefaultDialect(interfaces.Dialect):
server_version_info = None
+ construct_arguments = None
+ """Optional set of argument specifiers for various SQLAlchemy
+ constructs, typically schema items.
+
+ To
+ implement, establish as a series of tuples, as in::
+
+ construct_arguments = [
+ (schema.Index, {
+ "using": False,
+ "where": None,
+ "ops": None
+ })
+ ]
+
+ If the above construct is established on the Postgresql dialect,
+ the ``Index`` construct will now accept additional keyword arguments
+ such as ``postgresql_using``, ``postgresql_where``, etc. Any kind of
+ ``postgresql_XYZ`` argument not corresponding to the above template will
+ be rejected with an ``ArgumentError`, for all those SQLAlchemy constructs
+ which implement the :class:`.DialectKWArgs` class.
+
+ The default is ``None``; older dialects which don't implement the argument
+ will have the old behavior of un-validated kwargs to schema/SQL constructs.
+
+ """
+
# indicates symbol names are
# UPPERCASEd if they are case insensitive
# within the database.
@@ -111,6 +152,7 @@ class DefaultDialect(interfaces.Dialect):
implicit_returning=None,
supports_right_nested_joins=None,
case_sensitive=True,
+ supports_native_boolean=None,
label_length=None, **kwargs):
if not getattr(self, 'ported_sqla_06', True):
@@ -136,7 +178,8 @@ class DefaultDialect(interfaces.Dialect):
self.type_compiler = self.type_compiler(self)
if supports_right_nested_joins is not None:
self.supports_right_nested_joins = supports_right_nested_joins
-
+ if supports_native_boolean is not None:
+ self.supports_native_boolean = supports_native_boolean
self.case_sensitive = case_sensitive
if label_length and label_length > self.max_identifier_length:
@@ -159,6 +202,8 @@ class DefaultDialect(interfaces.Dialect):
self._encoder = codecs.getencoder(self.encoding)
self._decoder = processors.to_unicode_processor_factory(self.encoding)
+
+
@util.memoized_property
def _type_memos(self):
return weakref.WeakKeyDictionary()
@@ -191,6 +236,10 @@ class DefaultDialect(interfaces.Dialect):
self.returns_unicode_strings = self._check_unicode_returns(connection)
+ if self.description_encoding is not None and \
+ self._check_unicode_description(connection):
+ self._description_decoder = self.description_encoding = None
+
self.do_rollback(connection.connection)
def on_connect(self):
@@ -207,46 +256,78 @@ class DefaultDialect(interfaces.Dialect):
"""
return None
- def _check_unicode_returns(self, connection):
+ def _check_unicode_returns(self, connection, additional_tests=None):
if util.py2k and not self.supports_unicode_statements:
cast_to = util.binary_type
else:
cast_to = util.text_type
- def check_unicode(formatstr, type_):
- cursor = connection.connection.cursor()
+ if self.positional:
+ parameters = self.execute_sequence_format()
+ else:
+ parameters = {}
+
+ def check_unicode(test):
+ statement = cast_to(expression.select([test]).compile(dialect=self))
try:
- try:
- cursor.execute(
- cast_to(
- expression.select(
- [expression.cast(
- expression.literal_column(
- "'test %s returns'" % formatstr),
- type_)
- ]).compile(dialect=self)
- )
- )
- row = cursor.fetchone()
-
- return isinstance(row[0], util.text_type)
- except self.dbapi.Error as de:
- util.warn("Exception attempting to "
- "detect unicode returns: %r" % de)
- return False
- finally:
+ cursor = connection.connection.cursor()
+ connection._cursor_execute(cursor, statement, parameters)
+ row = cursor.fetchone()
cursor.close()
+ except exc.DBAPIError as de:
+ # note that _cursor_execute() will have closed the cursor
+ # if an exception is thrown.
+ util.warn("Exception attempting to "
+ "detect unicode returns: %r" % de)
+ return False
+ else:
+ return isinstance(row[0], util.text_type)
+
+ tests = [
+ # detect plain VARCHAR
+ expression.cast(
+ expression.literal_column("'test plain returns'"),
+ sqltypes.VARCHAR(60)
+ ),
+ # detect if there's an NVARCHAR type with different behavior available
+ expression.cast(
+ expression.literal_column("'test unicode returns'"),
+ sqltypes.Unicode(60)
+ ),
+ ]
+
+ if additional_tests:
+ tests += additional_tests
+
+ results = set([check_unicode(test) for test in tests])
+
+ if results.issuperset([True, False]):
+ return "conditional"
+ else:
+ return results == set([True])
- # detect plain VARCHAR
- unicode_for_varchar = check_unicode("plain", sqltypes.VARCHAR(60))
-
- # detect if there's an NVARCHAR type with different behavior available
- unicode_for_unicode = check_unicode("unicode", sqltypes.Unicode(60))
+ def _check_unicode_description(self, connection):
+ # all DBAPIs on Py2K return cursor.description as encoded,
+ # until pypy2.1beta2 with sqlite, so let's just check it -
+ # it's likely others will start doing this too in Py2k.
- if unicode_for_unicode and not unicode_for_varchar:
- return "conditional"
+ if util.py2k and not self.supports_unicode_statements:
+ cast_to = util.binary_type
else:
- return unicode_for_varchar
+ cast_to = util.text_type
+
+ cursor = connection.connection.cursor()
+ try:
+ cursor.execute(
+ cast_to(
+ expression.select([
+ expression.literal_column("'x'").label("some_label")
+ ]).compile(dialect=self)
+ )
+ )
+ return isinstance(cursor.description[0][0], util.text_type)
+ finally:
+ cursor.close()
def type_descriptor(self, typeobj):
"""Provide a database-specific :class:`.TypeEngine` object, given
@@ -259,8 +340,7 @@ class DefaultDialect(interfaces.Dialect):
"""
return sqltypes.adapt_type(typeobj, self.colspecs)
- def reflecttable(self, connection, table, include_columns,
- exclude_columns=None):
+ def reflecttable(self, connection, table, include_columns, exclude_columns):
insp = reflection.Inspector.from_engine(connection)
return insp.reflecttable(table, include_columns, exclude_columns)
@@ -368,6 +448,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
statement = None
postfetch_cols = None
prefetch_cols = None
+ returning_cols = None
_is_implicit_returning = False
_is_explicit_returning = False
@@ -464,6 +545,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if self.isinsert or self.isupdate:
self.postfetch_cols = self.compiled.postfetch
self.prefetch_cols = self.compiled.prefetch
+ self.returning_cols = self.compiled.returning
self.__process_defaults()
processors = compiled._bind_processors
@@ -722,6 +804,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
ipk.append(row[c])
self.inserted_primary_key = ipk
+ self.returned_defaults = row
+
+ def _fetch_implicit_update_returning(self, resultproxy):
+ row = resultproxy.fetchone()
+ self.returned_defaults = row
def lastrow_has_defaults(self):
return (self.isinsert or self.isupdate) and \
@@ -808,6 +895,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
and generate inserted_primary_key collection.
"""
+ key_getter = self.compiled._key_getters_for_crud_column[2]
+
if self.executemany:
if len(self.compiled.prefetch):
scalar_defaults = {}
@@ -831,7 +920,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
else:
val = self.get_update_default(c)
if val is not None:
- param[c.key] = val
+ param[key_getter(c)] = val
del self.current_parameters
else:
self.current_parameters = compiled_parameters = \
@@ -844,12 +933,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
val = self.get_update_default(c)
if val is not None:
- compiled_parameters[c.key] = val
+ compiled_parameters[key_getter(c)] = val
del self.current_parameters
if self.isinsert:
self.inserted_primary_key = [
- self.compiled_parameters[0].get(c.key, None)
+ self.compiled_parameters[0].get(key_getter(c), None)
for c in self.compiled.\
statement.table.primary_key
]