diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 165 |
1 files changed, 127 insertions, 38 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 3e8e96a42..ed975b8cf 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1,5 +1,5 @@ # engine/default.py -# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file> +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -16,7 +16,8 @@ import re import random from . import reflection, interfaces, result from ..sql import compiler, expression -from .. import exc, types as sqltypes, util, pool, processors +from .. import types as sqltypes +from .. import exc, util, pool, processors import codecs import weakref from .. import event @@ -26,6 +27,7 @@ AUTOCOMMIT_REGEXP = re.compile( re.I | re.UNICODE) + class DefaultDialect(interfaces.Dialect): """Default implementation of Dialect""" @@ -57,6 +59,18 @@ class DefaultDialect(interfaces.Dialect): supports_simple_order_by_label = True + engine_config_types = util.immutabledict([ + ('convert_unicode', util.bool_or_str('force')), + ('pool_timeout', int), + ('echo', util.bool_or_str('debug')), + ('echo_pool', util.bool_or_str('debug')), + ('pool_recycle', int), + ('pool_size', int), + ('max_overflow', int), + ('pool_threadlocal', bool), + ('use_native_unicode', bool), + ]) + # if the NUMERIC type # returns decimal.Decimal. # *not* the FLOAT type however. @@ -97,6 +111,33 @@ class DefaultDialect(interfaces.Dialect): server_version_info = None + construct_arguments = None + """Optional set of argument specifiers for various SQLAlchemy + constructs, typically schema items. + + To + implement, establish as a series of tuples, as in:: + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) + ] + + If the above construct is established on the Postgresql dialect, + the ``Index`` construct will now accept additional keyword arguments + such as ``postgresql_using``, ``postgresql_where``, etc. Any kind of + ``postgresql_XYZ`` argument not corresponding to the above template will + be rejected with an ``ArgumentError`, for all those SQLAlchemy constructs + which implement the :class:`.DialectKWArgs` class. + + The default is ``None``; older dialects which don't implement the argument + will have the old behavior of un-validated kwargs to schema/SQL constructs. + + """ + # indicates symbol names are # UPPERCASEd if they are case insensitive # within the database. @@ -111,6 +152,7 @@ class DefaultDialect(interfaces.Dialect): implicit_returning=None, supports_right_nested_joins=None, case_sensitive=True, + supports_native_boolean=None, label_length=None, **kwargs): if not getattr(self, 'ported_sqla_06', True): @@ -136,7 +178,8 @@ class DefaultDialect(interfaces.Dialect): self.type_compiler = self.type_compiler(self) if supports_right_nested_joins is not None: self.supports_right_nested_joins = supports_right_nested_joins - + if supports_native_boolean is not None: + self.supports_native_boolean = supports_native_boolean self.case_sensitive = case_sensitive if label_length and label_length > self.max_identifier_length: @@ -159,6 +202,8 @@ class DefaultDialect(interfaces.Dialect): self._encoder = codecs.getencoder(self.encoding) self._decoder = processors.to_unicode_processor_factory(self.encoding) + + @util.memoized_property def _type_memos(self): return weakref.WeakKeyDictionary() @@ -191,6 +236,10 @@ class DefaultDialect(interfaces.Dialect): self.returns_unicode_strings = self._check_unicode_returns(connection) + if self.description_encoding is not None and \ + self._check_unicode_description(connection): + self._description_decoder = self.description_encoding = None + self.do_rollback(connection.connection) def on_connect(self): @@ -207,46 +256,78 @@ class DefaultDialect(interfaces.Dialect): """ return None - def _check_unicode_returns(self, connection): + def _check_unicode_returns(self, connection, additional_tests=None): if util.py2k and not self.supports_unicode_statements: cast_to = util.binary_type else: cast_to = util.text_type - def check_unicode(formatstr, type_): - cursor = connection.connection.cursor() + if self.positional: + parameters = self.execute_sequence_format() + else: + parameters = {} + + def check_unicode(test): + statement = cast_to(expression.select([test]).compile(dialect=self)) try: - try: - cursor.execute( - cast_to( - expression.select( - [expression.cast( - expression.literal_column( - "'test %s returns'" % formatstr), - type_) - ]).compile(dialect=self) - ) - ) - row = cursor.fetchone() - - return isinstance(row[0], util.text_type) - except self.dbapi.Error as de: - util.warn("Exception attempting to " - "detect unicode returns: %r" % de) - return False - finally: + cursor = connection.connection.cursor() + connection._cursor_execute(cursor, statement, parameters) + row = cursor.fetchone() cursor.close() + except exc.DBAPIError as de: + # note that _cursor_execute() will have closed the cursor + # if an exception is thrown. + util.warn("Exception attempting to " + "detect unicode returns: %r" % de) + return False + else: + return isinstance(row[0], util.text_type) + + tests = [ + # detect plain VARCHAR + expression.cast( + expression.literal_column("'test plain returns'"), + sqltypes.VARCHAR(60) + ), + # detect if there's an NVARCHAR type with different behavior available + expression.cast( + expression.literal_column("'test unicode returns'"), + sqltypes.Unicode(60) + ), + ] + + if additional_tests: + tests += additional_tests + + results = set([check_unicode(test) for test in tests]) + + if results.issuperset([True, False]): + return "conditional" + else: + return results == set([True]) - # detect plain VARCHAR - unicode_for_varchar = check_unicode("plain", sqltypes.VARCHAR(60)) - - # detect if there's an NVARCHAR type with different behavior available - unicode_for_unicode = check_unicode("unicode", sqltypes.Unicode(60)) + def _check_unicode_description(self, connection): + # all DBAPIs on Py2K return cursor.description as encoded, + # until pypy2.1beta2 with sqlite, so let's just check it - + # it's likely others will start doing this too in Py2k. - if unicode_for_unicode and not unicode_for_varchar: - return "conditional" + if util.py2k and not self.supports_unicode_statements: + cast_to = util.binary_type else: - return unicode_for_varchar + cast_to = util.text_type + + cursor = connection.connection.cursor() + try: + cursor.execute( + cast_to( + expression.select([ + expression.literal_column("'x'").label("some_label") + ]).compile(dialect=self) + ) + ) + return isinstance(cursor.description[0][0], util.text_type) + finally: + cursor.close() def type_descriptor(self, typeobj): """Provide a database-specific :class:`.TypeEngine` object, given @@ -259,8 +340,7 @@ class DefaultDialect(interfaces.Dialect): """ return sqltypes.adapt_type(typeobj, self.colspecs) - def reflecttable(self, connection, table, include_columns, - exclude_columns=None): + def reflecttable(self, connection, table, include_columns, exclude_columns): insp = reflection.Inspector.from_engine(connection) return insp.reflecttable(table, include_columns, exclude_columns) @@ -368,6 +448,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): statement = None postfetch_cols = None prefetch_cols = None + returning_cols = None _is_implicit_returning = False _is_explicit_returning = False @@ -464,6 +545,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if self.isinsert or self.isupdate: self.postfetch_cols = self.compiled.postfetch self.prefetch_cols = self.compiled.prefetch + self.returning_cols = self.compiled.returning self.__process_defaults() processors = compiled._bind_processors @@ -722,6 +804,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ipk.append(row[c]) self.inserted_primary_key = ipk + self.returned_defaults = row + + def _fetch_implicit_update_returning(self, resultproxy): + row = resultproxy.fetchone() + self.returned_defaults = row def lastrow_has_defaults(self): return (self.isinsert or self.isupdate) and \ @@ -808,6 +895,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): and generate inserted_primary_key collection. """ + key_getter = self.compiled._key_getters_for_crud_column[2] + if self.executemany: if len(self.compiled.prefetch): scalar_defaults = {} @@ -831,7 +920,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): else: val = self.get_update_default(c) if val is not None: - param[c.key] = val + param[key_getter(c)] = val del self.current_parameters else: self.current_parameters = compiled_parameters = \ @@ -844,12 +933,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): val = self.get_update_default(c) if val is not None: - compiled_parameters[c.key] = val + compiled_parameters[key_getter(c)] = val del self.current_parameters if self.isinsert: self.inserted_primary_key = [ - self.compiled_parameters[0].get(c.key, None) + self.compiled_parameters[0].get(key_getter(c), None) for c in self.compiled.\ statement.table.primary_key ] |