diff options
Diffstat (limited to 'lib/sqlalchemy/engine')
| -rw-r--r-- | lib/sqlalchemy/engine/__init__.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 445 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 416 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 51 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 436 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 384 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/strategies.py | 120 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/threadlocal.py | 50 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/url.py | 115 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/util.py | 16 |
10 files changed, 1204 insertions, 859 deletions
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 6342b3c21..590359c38 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -57,10 +57,9 @@ from .interfaces import ( Dialect, ExecutionContext, ExceptionContext, - # backwards compat Compiled, - TypeCompiler + TypeCompiler, ) from .base import ( @@ -82,9 +81,7 @@ from .result import ( RowProxy, ) -from .util import ( - connection_memoize -) +from .util import connection_memoize from . import util, strategies @@ -92,7 +89,7 @@ from . import util, strategies # backwards compat from ..sql import ddl -default_strategy = 'plain' +default_strategy = "plain" def create_engine(*args, **kwargs): @@ -460,12 +457,12 @@ def create_engine(*args, **kwargs): """ - strategy = kwargs.pop('strategy', default_strategy) + strategy = kwargs.pop("strategy", default_strategy) strategy = strategies.strategies[strategy] return strategy.create(*args, **kwargs) -def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): +def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): """Create a new Engine instance using a configuration dictionary. The dictionary is typically produced from a config file. @@ -497,16 +494,15 @@ def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): """ - options = dict((key[len(prefix):], configuration[key]) - for key in configuration - if key.startswith(prefix)) - options['_coerce_config'] = True + options = dict( + (key[len(prefix) :], configuration[key]) + for key in configuration + if key.startswith(prefix) + ) + options["_coerce_config"] = True options.update(kwargs) - url = options.pop('url') + url = options.pop("url") return create_engine(url, **options) -__all__ = ( - 'create_engine', - 'engine_from_config', -) +__all__ = ("create_engine", "engine_from_config") diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4a057ee59..75d03b744 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -61,10 +61,16 @@ class Connection(Connectable): """ - def __init__(self, engine, connection=None, close_with_result=False, - _branch_from=None, _execution_options=None, - _dispatch=None, - _has_events=None): + def __init__( + self, + engine, + connection=None, + close_with_result=False, + _branch_from=None, + _execution_options=None, + _dispatch=None, + _has_events=None, + ): """Construct a new Connection. The constructor here is not public and is only called only by an @@ -86,8 +92,11 @@ class Connection(Connectable): self._has_events = _branch_from._has_events self.schema_for_object = _branch_from.schema_for_object else: - self.__connection = connection \ - if connection is not None else engine.raw_connection() + self.__connection = ( + connection + if connection is not None + else engine.raw_connection() + ) self.__transaction = None self.__savepoint_seq = 0 self.should_close_with_result = close_with_result @@ -101,7 +110,8 @@ class Connection(Connectable): # want to handle any of the engine's events in that case. self.dispatch = self.dispatch._join(engine.dispatch) self._has_events = _has_events or ( - _has_events is None and engine._has_events) + _has_events is None and engine._has_events + ) assert not _execution_options self._execution_options = engine._execution_options @@ -134,7 +144,8 @@ class Connection(Connectable): _branch_from=self, _execution_options=self._execution_options, _has_events=self._has_events, - _dispatch=self.dispatch) + _dispatch=self.dispatch, + ) @property def _root(self): @@ -322,8 +333,10 @@ class Connection(Connectable): def closed(self): """Return True if this connection is closed.""" - return '_Connection__connection' not in self.__dict__ \ + return ( + "_Connection__connection" not in self.__dict__ and not self.__can_reconnect + ) @property def invalidated(self): @@ -425,7 +438,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " - "transaction is rolled back") + "transaction is rolled back" + ) self.__connection = self.engine.raw_connection(_connection=self) self.__invalid = False return self.__connection @@ -437,14 +451,15 @@ class Connection(Connectable): # dialect initializer, where the connection is not wrapped in # _ConnectionFairy - return getattr(self.__connection, 'is_valid', False) + return getattr(self.__connection, "is_valid", False) @property def _still_open_and_connection_is_valid(self): - return \ - not self.closed and \ - not self.invalidated and \ - getattr(self.__connection, 'is_valid', False) + return ( + not self.closed + and not self.invalidated + and getattr(self.__connection, "is_valid", False) + ) @property def info(self): @@ -656,7 +671,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " - "is already in progress.") + "is already in progress." + ) if xid is None: xid = self.engine.dialect.create_xid() self.__transaction = TwoPhaseTransaction(self, xid) @@ -705,8 +721,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None else: @@ -725,8 +743,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None @@ -738,7 +758,7 @@ class Connection(Connectable): if name is None: self.__savepoint_seq += 1 - name = 'sa_savepoint_%s' % self.__savepoint_seq + name = "sa_savepoint_%s" % self.__savepoint_seq if self._still_open_and_connection_is_valid: self.engine.dialect.do_savepoint(self, name) return name @@ -797,7 +817,8 @@ class Connection(Connectable): assert isinstance(self.__transaction, TwoPhaseTransaction) try: self.engine.dialect.do_rollback_twophase( - self, xid, is_prepared) + self, xid, is_prepared + ) finally: if self.connection._reset_agent is self.__transaction: self.connection._reset_agent = None @@ -950,16 +971,16 @@ class Connection(Connectable): def _execute_function(self, func, multiparams, params): """Execute a sql.FunctionElement object.""" - return self._execute_clauseelement(func.select(), - multiparams, params) + return self._execute_clauseelement(func.select(), multiparams, params) def _execute_default(self, default, multiparams, params): """Execute a schema.ColumnDefault object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - default, multiparams, params = \ - fn(self, default, multiparams, params) + default, multiparams, params = fn( + self, default, multiparams, params + ) try: try: @@ -972,8 +993,7 @@ class Connection(Connectable): conn = self._revalidate_connection() dialect = self.dialect - ctx = dialect.execution_ctx_cls._init_default( - dialect, self, conn) + ctx = dialect.execution_ctx_cls._init_default(dialect, self, conn) except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @@ -982,8 +1002,9 @@ class Connection(Connectable): self.close() if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - default, multiparams, params, ret) + self.dispatch.after_execute( + self, default, multiparams, params, ret + ) return ret @@ -992,25 +1013,25 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - ddl, multiparams, params = \ - fn(self, ddl, multiparams, params) + ddl, multiparams, params = fn(self, ddl, multiparams, params) dialect = self.dialect compiled = ddl.compile( dialect=dialect, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_ddl, compiled, None, - compiled + compiled, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - ddl, multiparams, params, ret) + self.dispatch.after_execute(self, ddl, multiparams, params, ret) return ret def _execute_clauseelement(self, elem, multiparams, params): @@ -1018,8 +1039,7 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - elem, multiparams, params = \ - fn(self, elem, multiparams, params) + elem, multiparams, params = fn(self, elem, multiparams, params) distilled_params = _distill_params(multiparams, params) if distilled_params: @@ -1030,38 +1050,45 @@ class Connection(Connectable): keys = [] dialect = self.dialect - if 'compiled_cache' in self._execution_options: + if "compiled_cache" in self._execution_options: key = ( - dialect, elem, tuple(sorted(keys)), + dialect, + elem, + tuple(sorted(keys)), self.schema_for_object.hash_key, - len(distilled_params) > 1 + len(distilled_params) > 1, ) - compiled_sql = self._execution_options['compiled_cache'].get(key) + compiled_sql = self._execution_options["compiled_cache"].get(key) if compiled_sql is None: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None + if not self.schema_for_object.is_default + else None, ) - self._execution_options['compiled_cache'][key] = compiled_sql + self._execution_options["compiled_cache"][key] = compiled_sql else: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_compiled, compiled_sql, distilled_params, - compiled_sql, distilled_params + compiled_sql, + distilled_params, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - elem, multiparams, params, ret) + self.dispatch.after_execute(self, elem, multiparams, params, ret) return ret def _execute_compiled(self, compiled, multiparams, params): @@ -1069,8 +1096,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - compiled, multiparams, params = \ - fn(self, compiled, multiparams, params) + compiled, multiparams, params = fn( + self, compiled, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1079,11 +1107,13 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_compiled, compiled, parameters, - compiled, parameters + compiled, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - compiled, multiparams, params, ret) + self.dispatch.after_execute( + self, compiled, multiparams, params, ret + ) return ret def _execute_text(self, statement, multiparams, params): @@ -1091,8 +1121,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - statement, multiparams, params = \ - fn(self, statement, multiparams, params) + statement, multiparams, params = fn( + self, statement, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1101,16 +1132,18 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_statement, statement, parameters, - statement, parameters + statement, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - statement, multiparams, params, ret) + self.dispatch.after_execute( + self, statement, multiparams, params, ret + ) return ret - def _execute_context(self, dialect, constructor, - statement, parameters, - *args): + def _execute_context( + self, dialect, constructor, statement, parameters, *args + ): """Create an :class:`.ExecutionContext` and execute, returning a :class:`.ResultProxy`.""" @@ -1127,31 +1160,36 @@ class Connection(Connectable): context = constructor(dialect, self, conn, *args) except BaseException as e: self._handle_dbapi_exception( - e, - util.text_type(statement), parameters, - None, None) + e, util.text_type(statement), parameters, None, None + ) if context.compiled: context.pre_exec() - cursor, statement, parameters = context.cursor, \ - context.statement, \ - context.parameters + cursor, statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) if not context.executemany: parameters = parameters[0] if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, context.executemany) + statement, parameters = fn( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info( - "%r", - sql_util._repr_params(parameters, batches=10) + "%r", sql_util._repr_params(parameters, batches=10) ) evt_handled = False @@ -1164,10 +1202,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_executemany( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) elif not parameters and context.no_parameters: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute_no_params: @@ -1176,9 +1212,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute_no_params( - cursor, - statement, - context) + cursor, statement, context + ) else: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute: @@ -1187,24 +1222,22 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - context.executemany) + self.dispatch.after_cursor_execute( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if context.compiled: context.post_exec() @@ -1245,39 +1278,32 @@ class Connection(Connectable): """ if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, - False) + statement, parameters = fn( + self, cursor, statement, parameters, context, False + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info("%r", parameters) try: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute: + for fn in ( + () + if not self.dialect._has_events + else self.dialect.dispatch.do_execute + ): if fn(cursor, statement, parameters, context): break else: - self.dialect.do_execute( - cursor, - statement, - parameters, - context) + self.dialect.do_execute(cursor, statement, parameters, context) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - False) + self.dispatch.after_cursor_execute( + self, cursor, statement, parameters, context, False + ) def _safe_close_cursor(self, cursor): """Close the given cursor, catching exceptions @@ -1289,17 +1315,15 @@ class Connection(Connectable): except Exception: # log the error through the connection pool's logger. self.engine.pool.logger.error( - "Error closing cursor", exc_info=True) + "Error closing cursor", exc_info=True + ) _reentrant_error = False _is_disconnect = False - def _handle_dbapi_exception(self, - e, - statement, - parameters, - cursor, - context): + def _handle_dbapi_exception( + self, e, statement, parameters, cursor, context + ): exc_info = sys.exc_info() if context and context.exception is None: @@ -1309,15 +1333,14 @@ class Connection(Connectable): if not self._is_disconnect: self._is_disconnect = ( - isinstance(e, self.dialect.dbapi.Error) and - not self.closed and - self.dialect.is_disconnect( + isinstance(e, self.dialect.dbapi.Error) + and not self.closed + and self.dialect.is_disconnect( e, self.__connection if not self.invalidated else None, - cursor) - ) or ( - is_exit_exception and not self.closed - ) + cursor, + ) + ) or (is_exit_exception and not self.closed) if context: context.is_disconnect = self._is_disconnect @@ -1326,20 +1349,24 @@ class Connection(Connectable): if self._reentrant_error: util.raise_from_cause( - exc.DBAPIError.instance(statement, - parameters, - e, - self.dialect.dbapi.Error, - dialect=self.dialect), - exc_info + exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.dbapi.Error, + dialect=self.dialect, + ), + exc_info, ) self._reentrant_error = True try: # non-DBAPI error - if we already got a context, # or there's no string statement, don't wrap it - should_wrap = isinstance(e, self.dialect.dbapi.Error) or \ - (statement is not None - and context is None and not is_exit_exception) + should_wrap = isinstance(e, self.dialect.dbapi.Error) or ( + statement is not None + and context is None + and not is_exit_exception + ) if should_wrap: sqlalchemy_exception = exc.DBAPIError.instance( @@ -1348,30 +1375,37 @@ class Connection(Connectable): e, self.dialect.dbapi.Error, connection_invalidated=self._is_disconnect, - dialect=self.dialect) + dialect=self.dialect, + ) else: sqlalchemy_exception = None newraise = None - if (self._has_events or self.engine._has_events) and \ - not self._execution_options.get( - 'skip_user_error_events', False): + if ( + self._has_events or self.engine._has_events + ) and not self._execution_options.get( + "skip_user_error_events", False + ): # legacy dbapi_error event if should_wrap and context: - self.dispatch.dbapi_error(self, - cursor, - statement, - parameters, - context, - e) + self.dispatch.dbapi_error( + self, cursor, statement, parameters, context, e + ) # new handle_error event ctx = ExceptionContextImpl( - e, sqlalchemy_exception, self.engine, - self, cursor, statement, - parameters, context, self._is_disconnect, - invalidate_pool_on_disconnect) + e, + sqlalchemy_exception, + self.engine, + self, + cursor, + statement, + parameters, + context, + self._is_disconnect, + invalidate_pool_on_disconnect, + ) for fn in self.dispatch.handle_error: try: @@ -1388,13 +1422,15 @@ class Connection(Connectable): if self._is_disconnect != ctx.is_disconnect: self._is_disconnect = ctx.is_disconnect if sqlalchemy_exception: - sqlalchemy_exception.connection_invalidated = \ + sqlalchemy_exception.connection_invalidated = ( ctx.is_disconnect + ) # set up potentially user-defined value for # invalidate pool. - invalidate_pool_on_disconnect = \ + invalidate_pool_on_disconnect = ( ctx.invalidate_pool_on_disconnect + ) if should_wrap and context: context.handle_dbapi_exception(e) @@ -1408,10 +1444,7 @@ class Connection(Connectable): if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1441,7 +1474,8 @@ class Connection(Connectable): None, e, dialect.dbapi.Error, - connection_invalidated=is_disconnect) + connection_invalidated=is_disconnect, + ) else: sqlalchemy_exception = None @@ -1449,8 +1483,17 @@ class Connection(Connectable): if engine._has_events: ctx = ExceptionContextImpl( - e, sqlalchemy_exception, engine, None, None, None, - None, None, is_disconnect, True) + e, + sqlalchemy_exception, + engine, + None, + None, + None, + None, + None, + is_disconnect, + True, + ) for fn in engine.dispatch.handle_error: try: # handler returns an exception; @@ -1463,18 +1506,15 @@ class Connection(Connectable): newraise = _raised break - if sqlalchemy_exception and \ - is_disconnect != ctx.is_disconnect: - sqlalchemy_exception.connection_invalidated = \ - is_disconnect = ctx.is_disconnect + if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = ( + is_disconnect + ) = ctx.is_disconnect if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1545,16 +1585,25 @@ class Connection(Connectable): return callable_(self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, **kwargs): - visitorcallable(self.dialect, self, - **kwargs).traverse_single(element) + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) class ExceptionContextImpl(ExceptionContext): """Implement the :class:`.ExceptionContext` interface.""" - def __init__(self, exception, sqlalchemy_exception, - engine, connection, cursor, statement, parameters, - context, is_disconnect, invalidate_pool_on_disconnect): + def __init__( + self, + exception, + sqlalchemy_exception, + engine, + connection, + cursor, + statement, + parameters, + context, + is_disconnect, + invalidate_pool_on_disconnect, + ): self.engine = engine self.connection = connection self.sqlalchemy_exception = sqlalchemy_exception @@ -1691,12 +1740,14 @@ class NestedTransaction(Transaction): def _do_rollback(self): if self.is_active: self.connection._rollback_to_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) def _do_commit(self): if self.is_active: self.connection._release_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) class TwoPhaseTransaction(Transaction): @@ -1771,10 +1822,16 @@ class Engine(Connectable, log.Identified): """ - def __init__(self, pool, dialect, url, - logging_name=None, echo=None, proxy=None, - execution_options=None - ): + def __init__( + self, + pool, + dialect, + url, + logging_name=None, + echo=None, + proxy=None, + execution_options=None, + ): self.pool = pool self.url = url self.dialect = dialect @@ -1805,8 +1862,7 @@ class Engine(Connectable, log.Identified): :meth:`.Engine.execution_options` """ - self._execution_options = \ - self._execution_options.union(opt) + self._execution_options = self._execution_options.union(opt) self.dispatch.set_engine_execution_options(self, opt) self.dialect.set_engine_execution_options(self, opt) @@ -1894,7 +1950,7 @@ class Engine(Connectable, log.Identified): echo = log.echo_property() def __repr__(self): - return 'Engine(%r)' % self.url + return "Engine(%r)" % self.url def dispose(self): """Dispose of the connection pool used by this :class:`.Engine`. @@ -1934,8 +1990,9 @@ class Engine(Connectable, log.Identified): else: yield connection - def _run_visitor(self, visitorcallable, element, - connection=None, **kwargs): + def _run_visitor( + self, visitorcallable, element, connection=None, **kwargs + ): with self._optional_conn_ctx_manager(connection) as conn: conn._run_visitor(visitorcallable, element, **kwargs) @@ -2122,7 +2179,8 @@ class Engine(Connectable, log.Identified): self, self._wrap_pool_connect(self.pool.connect, None), close_with_result=close_with_result, - **kwargs) + **kwargs + ) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. @@ -2159,7 +2217,8 @@ class Engine(Connectable, log.Identified): except dialect.dbapi.Error as e: if connection is None: Connection._handle_dbapi_exception_noconnection( - e, dialect, self) + e, dialect, self + ) else: util.reraise(*sys.exc_info()) @@ -2185,7 +2244,8 @@ class Engine(Connectable, log.Identified): """ return self._wrap_pool_connect( - self.pool.unique_connection, _connection) + self.pool.unique_connection, _connection + ) class OptionEngine(Engine): @@ -2225,10 +2285,11 @@ class OptionEngine(Engine): pool = property(_get_pool, _set_pool) def _get_has_events(self): - return self._proxied._has_events or \ - self.__dict__.get('_has_events', False) + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) def _set_has_events(self, value): - self.__dict__['_has_events'] = value + self.__dict__["_has_events"] = value _has_events = property(_get_has_events, _set_has_events) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 028abc4c2..d7c2518fe 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -24,13 +24,11 @@ import weakref from .. import event AUTOCOMMIT_REGEXP = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE +) # When we're handed literal SQL, ensure it's a SELECT query -SERVER_SIDE_CURSOR_RE = re.compile( - r'\s*SELECT', - re.I | re.UNICODE) +SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) class DefaultDialect(interfaces.Dialect): @@ -68,16 +66,18 @@ class DefaultDialect(interfaces.Dialect): supports_simple_order_by_label = True - engine_config_types = util.immutabledict([ - ('convert_unicode', util.bool_or_str('force')), - ('pool_timeout', util.asint), - ('echo', util.bool_or_str('debug')), - ('echo_pool', util.bool_or_str('debug')), - ('pool_recycle', util.asint), - ('pool_size', util.asint), - ('max_overflow', util.asint), - ('pool_threadlocal', util.asbool), - ]) + engine_config_types = util.immutabledict( + [ + ("convert_unicode", util.bool_or_str("force")), + ("pool_timeout", util.asint), + ("echo", util.bool_or_str("debug")), + ("echo_pool", util.bool_or_str("debug")), + ("pool_recycle", util.asint), + ("pool_size", util.asint), + ("max_overflow", util.asint), + ("pool_threadlocal", util.asbool), + ] + ) # if the NUMERIC type # returns decimal.Decimal. @@ -93,9 +93,9 @@ class DefaultDialect(interfaces.Dialect): supports_unicode_statements = False supports_unicode_binds = False returns_unicode_strings = False - description_encoding = 'use_encoding' + description_encoding = "use_encoding" - name = 'default' + name = "default" # length at which to truncate # any identifier. @@ -111,7 +111,7 @@ class DefaultDialect(interfaces.Dialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True colspecs = {} - default_paramstyle = 'named' + default_paramstyle = "named" supports_default_values = False supports_empty_insert = True supports_multivalues_insert = False @@ -175,19 +175,26 @@ class DefaultDialect(interfaces.Dialect): """ - def __init__(self, convert_unicode=False, - encoding='utf-8', paramstyle=None, dbapi=None, - implicit_returning=None, - supports_right_nested_joins=None, - case_sensitive=True, - supports_native_boolean=None, - empty_in_strategy='static', - label_length=None, **kwargs): - - if not getattr(self, 'ported_sqla_06', True): + def __init__( + self, + convert_unicode=False, + encoding="utf-8", + paramstyle=None, + dbapi=None, + implicit_returning=None, + supports_right_nested_joins=None, + case_sensitive=True, + supports_native_boolean=None, + empty_in_strategy="static", + label_length=None, + **kwargs + ): + + if not getattr(self, "ported_sqla_06", True): util.warn( - "The %s dialect is not yet ported to the 0.6 format" % - self.name) + "The %s dialect is not yet ported to the 0.6 format" + % self.name + ) self.convert_unicode = convert_unicode self.encoding = encoding @@ -202,7 +209,7 @@ class DefaultDialect(interfaces.Dialect): self.paramstyle = self.default_paramstyle if implicit_returning is not None: self.implicit_returning = implicit_returning - self.positional = self.paramstyle in ('qmark', 'format', 'numeric') + self.positional = self.paramstyle in ("qmark", "format", "numeric") self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) if supports_right_nested_joins is not None: @@ -212,33 +219,33 @@ class DefaultDialect(interfaces.Dialect): self.case_sensitive = case_sensitive self.empty_in_strategy = empty_in_strategy - if empty_in_strategy == 'static': + if empty_in_strategy == "static": self._use_static_in = True - elif empty_in_strategy in ('dynamic', 'dynamic_warn'): + elif empty_in_strategy in ("dynamic", "dynamic_warn"): self._use_static_in = False - self._warn_on_empty_in = empty_in_strategy == 'dynamic_warn' + self._warn_on_empty_in = empty_in_strategy == "dynamic_warn" else: raise exc.ArgumentError( "empty_in_strategy may be 'static', " - "'dynamic', or 'dynamic_warn'") + "'dynamic', or 'dynamic_warn'" + ) if label_length and label_length > self.max_identifier_length: raise exc.ArgumentError( "Label length of %d is greater than this dialect's" - " maximum identifier length of %d" % - (label_length, self.max_identifier_length)) + " maximum identifier length of %d" + % (label_length, self.max_identifier_length) + ) self.label_length = label_length - if self.description_encoding == 'use_encoding': - self._description_decoder = \ - processors.to_unicode_processor_factory( - encoding - ) + if self.description_encoding == "use_encoding": + self._description_decoder = processors.to_unicode_processor_factory( + encoding + ) elif self.description_encoding is not None: - self._description_decoder = \ - processors.to_unicode_processor_factory( - self.description_encoding - ) + self._description_decoder = processors.to_unicode_processor_factory( + self.description_encoding + ) self._encoder = codecs.getencoder(self.encoding) self._decoder = processors.to_unicode_processor_factory(self.encoding) @@ -256,30 +263,35 @@ class DefaultDialect(interfaces.Dialect): @classmethod def get_pool_class(cls, url): - return getattr(cls, 'poolclass', pool.QueuePool) + return getattr(cls, "poolclass", pool.QueuePool) def initialize(self, connection): try: - self.server_version_info = \ - self._get_server_version_info(connection) + self.server_version_info = self._get_server_version_info( + connection + ) except NotImplementedError: self.server_version_info = None try: - self.default_schema_name = \ - self._get_default_schema_name(connection) + self.default_schema_name = self._get_default_schema_name( + connection + ) except NotImplementedError: self.default_schema_name = None try: - self.default_isolation_level = \ - self.get_isolation_level(connection.connection) + self.default_isolation_level = self.get_isolation_level( + connection.connection + ) except NotImplementedError: self.default_isolation_level = None self.returns_unicode_strings = self._check_unicode_returns(connection) - if self.description_encoding is not None and \ - self._check_unicode_description(connection): + if ( + self.description_encoding is not None + and self._check_unicode_description(connection) + ): self._description_decoder = self.description_encoding = None self.do_rollback(connection.connection) @@ -311,7 +323,8 @@ class DefaultDialect(interfaces.Dialect): def check_unicode(test): statement = cast_to( - expression.select([test]).compile(dialect=self)) + expression.select([test]).compile(dialect=self) + ) try: cursor = connection.connection.cursor() connection._cursor_execute(cursor, statement, parameters) @@ -320,8 +333,10 @@ class DefaultDialect(interfaces.Dialect): except exc.DBAPIError as de: # note that _cursor_execute() will have closed the cursor # if an exception is thrown. - util.warn("Exception attempting to " - "detect unicode returns: %r" % de) + util.warn( + "Exception attempting to " + "detect unicode returns: %r" % de + ) return False else: return isinstance(row[0], util.text_type) @@ -330,13 +345,13 @@ class DefaultDialect(interfaces.Dialect): # detect plain VARCHAR expression.cast( expression.literal_column("'test plain returns'"), - sqltypes.VARCHAR(60) + sqltypes.VARCHAR(60), ), # detect if there's an NVARCHAR type with different behavior # available expression.cast( expression.literal_column("'test unicode returns'"), - sqltypes.Unicode(60) + sqltypes.Unicode(60), ), ] @@ -364,9 +379,9 @@ class DefaultDialect(interfaces.Dialect): try: cursor.execute( cast_to( - expression.select([ - expression.literal_column("'x'").label("some_label") - ]).compile(dialect=self) + expression.select( + [expression.literal_column("'x'").label("some_label")] + ).compile(dialect=self) ) ) return isinstance(cursor.description[0][0], util.text_type) @@ -385,10 +400,12 @@ class DefaultDialect(interfaces.Dialect): return sqltypes.adapt_type(typeobj, self.colspecs) def reflecttable( - self, connection, table, include_columns, exclude_columns, **opts): + self, connection, table, include_columns, exclude_columns, **opts + ): insp = reflection.Inspector.from_engine(connection) return insp.reflecttable( - table, include_columns, exclude_columns, **opts) + table, include_columns, exclude_columns, **opts + ) def get_pk_constraint(self, conn, table_name, schema=None, **kw): """Compatibility method, adapts the result of get_primary_keys() @@ -396,16 +413,16 @@ class DefaultDialect(interfaces.Dialect): """ return { - 'constrained_columns': - self.get_primary_keys(conn, table_name, - schema=schema, **kw) + "constrained_columns": self.get_primary_keys( + conn, table_name, schema=schema, **kw + ) } def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( - "Identifier '%s' exceeds maximum length of %d characters" % - (ident, self.max_identifier_length) + "Identifier '%s' exceeds maximum length of %d characters" + % (ident, self.max_identifier_length) ) def connect(self, *cargs, **cparams): @@ -417,16 +434,16 @@ class DefaultDialect(interfaces.Dialect): return [[], opts] def set_engine_execution_options(self, engine, opts): - if 'isolation_level' in opts: - isolation_level = opts['isolation_level'] + if "isolation_level" in opts: + isolation_level = opts["isolation_level"] @event.listens_for(engine, "engine_connect") def set_isolation(connection, branch): if not branch: self._set_connection_isolation(connection, isolation_level) - if 'schema_translate_map' in opts: - getter = schema._schema_getter(opts['schema_translate_map']) + if "schema_translate_map" in opts: + getter = schema._schema_getter(opts["schema_translate_map"]) engine.schema_for_object = getter @event.listens_for(engine, "engine_connect") @@ -434,11 +451,11 @@ class DefaultDialect(interfaces.Dialect): connection.schema_for_object = getter def set_connection_execution_options(self, connection, opts): - if 'isolation_level' in opts: - self._set_connection_isolation(connection, opts['isolation_level']) + if "isolation_level" in opts: + self._set_connection_isolation(connection, opts["isolation_level"]) - if 'schema_translate_map' in opts: - getter = schema._schema_getter(opts['schema_translate_map']) + if "schema_translate_map" in opts: + getter = schema._schema_getter(opts["schema_translate_map"]) connection.schema_for_object = getter def _set_connection_isolation(self, connection, level): @@ -447,10 +464,12 @@ class DefaultDialect(interfaces.Dialect): "Connection is already established with a Transaction; " "setting isolation_level may implicitly rollback or commit " "the existing transaction, or have no effect until " - "next transaction") + "next transaction" + ) self.set_isolation_level(connection.connection, level) - connection.connection._connection_record.\ - finalize_callback.append(self.reset_isolation_level) + connection.connection._connection_record.finalize_callback.append( + self.reset_isolation_level + ) def do_begin(self, dbapi_connection): pass @@ -593,8 +612,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self @classmethod - def _init_compiled(cls, dialect, connection, dbapi_connection, - compiled, parameters): + def _init_compiled( + cls, dialect, connection, dbapi_connection, compiled, parameters + ): """Initialize execution context for a Compiled construct.""" self = cls.__new__(cls) @@ -609,16 +629,20 @@ class DefaultExecutionContext(interfaces.ExecutionContext): assert compiled.can_execute self.execution_options = compiled.execution_options.union( - connection._execution_options) + connection._execution_options + ) self.result_column_struct = ( - compiled._result_columns, compiled._ordered_columns, - compiled._textual_ordered_columns) + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + ) self.unicode_statement = util.text_type(compiled) if not dialect.supports_unicode_statements: self.statement = self.unicode_statement.encode( - self.dialect.encoding) + self.dialect.encoding + ) else: self.statement = self.unicode_statement @@ -630,9 +654,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not parameters: self.compiled_parameters = [compiled.construct_params()] else: - self.compiled_parameters = \ - [compiled.construct_params(m, _group_number=grp) for - grp, m in enumerate(parameters)] + self.compiled_parameters = [ + compiled.construct_params(m, _group_number=grp) + for grp, m in enumerate(parameters) + ] self.executemany = len(parameters) > 1 @@ -642,7 +667,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( - compiled.returning and not compiled.statement._returning) + compiled.returning and not compiled.statement._returning + ) if self.compiled.insert_prefetch or self.compiled.update_prefetch: if self.executemany: @@ -680,7 +706,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect._encoder(key)[0], processors[key](compiled_params[key]) if key in processors - else compiled_params[key] + else compiled_params[key], ) for key in compiled_params ) @@ -690,7 +716,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): key, processors[key](compiled_params[key]) if key in processors - else compiled_params[key] + else compiled_params[key], ) for key in compiled_params ) @@ -708,14 +734,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ if self.executemany: raise exc.InvalidRequestError( - "'expanding' parameters can't be used with " - "executemany()") + "'expanding' parameters can't be used with " "executemany()" + ) if self.compiled.positional and self.compiled._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric' raise NotImplementedError( "'expanding' bind parameters not supported with " - "'numeric' paramstyle at this time.") + "'numeric' paramstyle at this time." + ) self._expanded_parameters = {} @@ -729,7 +756,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): to_update_sets = {} for name in ( - self.compiled.positiontup if compiled.positional + self.compiled.positiontup + if compiled.positional else self.compiled.binds ): parameter = self.compiled.binds[name] @@ -748,12 +776,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not values: to_update = to_update_sets[name] = [] - replacement_expressions[name] = ( - self.compiled.visit_empty_set_expr( - parameter._expanding_in_types - if parameter._expanding_in_types - else [parameter.type] - ) + replacement_expressions[ + name + ] = self.compiled.visit_empty_set_expr( + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] ) elif isinstance(values[0], (tuple, list)): @@ -763,15 +791,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for j, value in enumerate(tuple_element, 1) ] replacement_expressions[name] = ", ".join( - "(%s)" % ", ".join( - self.compiled.bindtemplate % { - "name": - to_update[i * len(tuple_element) + j][0] + "(%s)" + % ", ".join( + self.compiled.bindtemplate + % { + "name": to_update[ + i * len(tuple_element) + j + ][0] } for j, value in enumerate(tuple_element) ) for i, tuple_element in enumerate(values) - ) else: to_update = to_update_sets[name] = [ @@ -779,20 +809,21 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for i, value in enumerate(values, 1) ] replacement_expressions[name] = ", ".join( - self.compiled.bindtemplate % { - "name": key} + self.compiled.bindtemplate % {"name": key} for key, value in to_update ) compiled_params.update(to_update) processors.update( (key, processors[name]) - for key, value in to_update if name in processors + for key, value in to_update + if name in processors ) if compiled.positional: positiontup.extend(name for name, value in to_update) self._expanded_parameters[name] = [ - expand_key for expand_key, value in to_update] + expand_key for expand_key, value in to_update + ] elif compiled.positional: positiontup.append(name) @@ -800,15 +831,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return replacement_expressions[m.group(1)] self.statement = re.sub( - r"\[EXPANDING_(\S+)\]", - process_expanding, - self.statement + r"\[EXPANDING_(\S+)\]", process_expanding, self.statement ) return positiontup @classmethod - def _init_statement(cls, dialect, connection, dbapi_connection, - statement, parameters): + def _init_statement( + cls, dialect, connection, dbapi_connection, statement, parameters + ): """Initialize execution context for a string SQL statement.""" self = cls.__new__(cls) @@ -836,13 +866,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for d in parameters ] or [{}] else: - self.parameters = [dialect.execute_sequence_format(p) - for p in parameters] + self.parameters = [ + dialect.execute_sequence_format(p) for p in parameters + ] self.executemany = len(parameters) > 1 - if not dialect.supports_unicode_statements and \ - isinstance(statement, util.text_type): + if not dialect.supports_unicode_statements and isinstance( + statement, util.text_type + ): self.unicode_statement = statement self.statement = dialect._encoder(statement)[0] else: @@ -890,11 +922,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): @util.memoized_property def should_autocommit(self): - autocommit = self.execution_options.get('autocommit', - not self.compiled and - self.statement and - expression.PARSE_AUTOCOMMIT - or False) + autocommit = self.execution_options.get( + "autocommit", + not self.compiled + and self.statement + and expression.PARSE_AUTOCOMMIT + or False, + ) if autocommit is expression.PARSE_AUTOCOMMIT: return self.should_autocommit_text(self.unicode_statement) @@ -912,8 +946,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ conn = self.root_connection - if isinstance(stmt, util.text_type) and \ - not self.dialect.supports_unicode_statements: + if ( + isinstance(stmt, util.text_type) + and not self.dialect.supports_unicode_statements + ): stmt = self.dialect._encoder(stmt)[0] if self.dialect.positional: @@ -926,8 +962,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if type_ is not None: # apply type post processors to the result proc = type_._cached_result_processor( - self.dialect, - self.cursor.description[0][1] + self.dialect, self.cursor.description[0][1] ) if proc: return proc(r) @@ -945,22 +980,30 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return False if self.dialect.server_side_cursors: - use_server_side = \ - self.execution_options.get('stream_results', True) and ( - (self.compiled and isinstance(self.compiled.statement, - expression.Selectable) - or - ( - (not self.compiled or - isinstance(self.compiled.statement, - expression.TextClause)) - and self.statement and SERVER_SIDE_CURSOR_RE.match( - self.statement)) - ) + use_server_side = self.execution_options.get( + "stream_results", True + ) and ( + ( + self.compiled + and isinstance( + self.compiled.statement, expression.Selectable + ) + or ( + ( + not self.compiled + or isinstance( + self.compiled.statement, expression.TextClause + ) + ) + and self.statement + and SERVER_SIDE_CURSOR_RE.match(self.statement) + ) ) + ) else: - use_server_side = \ - self.execution_options.get('stream_results', False) + use_server_side = self.execution_options.get( + "stream_results", False + ) return use_server_side @@ -1039,11 +1082,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.dialect.supports_sane_multi_rowcount def _setup_crud_result_proxy(self): - if self.isinsert and \ - not self.executemany: - if not self._is_implicit_returning and \ - not self.compiled.inline and \ - self.dialect.postfetch_lastrowid: + if self.isinsert and not self.executemany: + if ( + not self._is_implicit_returning + and not self.compiled.inline + and self.dialect.postfetch_lastrowid + ): self._setup_ins_pk_from_lastrowid() @@ -1087,12 +1131,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if autoinc_col is not None: # apply type post processors to the lastrowid proc = autoinc_col.type._cached_result_processor( - self.dialect, None) + self.dialect, None + ) if proc is not None: lastrowid = proc(lastrowid) self.inserted_primary_key = [ - lastrowid if c is autoinc_col else - compiled_params.get(key_getter(c), None) + lastrowid + if c is autoinc_col + else compiled_params.get(key_getter(c), None) for c in table.primary_key ] else: @@ -1108,8 +1154,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): table = self.compiled.statement.table compiled_params = self.compiled_parameters[0] self.inserted_primary_key = [ - compiled_params.get(key_getter(c), None) - for c in table.primary_key + compiled_params.get(key_getter(c), None) for c in table.primary_key ] def _setup_ins_pk_from_implicit_returning(self, row): @@ -1129,11 +1174,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ] def lastrow_has_defaults(self): - return (self.isinsert or self.isupdate) and \ - bool(self.compiled.postfetch) + return (self.isinsert or self.isupdate) and bool( + self.compiled.postfetch + ) def set_input_sizes( - self, translate=None, include_types=None, exclude_types=None): + self, translate=None, include_types=None, exclude_types=None + ): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. @@ -1143,7 +1190,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ - if not hasattr(self.compiled, 'bind_names'): + if not hasattr(self.compiled, "bind_names"): return inputsizes = {} @@ -1153,12 +1200,18 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect_impl_cls = type(dialect_impl) dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi) - if dbtype is not None and ( - not exclude_types or dbtype not in exclude_types and - dialect_impl_cls not in exclude_types - ) and ( - not include_types or dbtype in include_types or - dialect_impl_cls in include_types + if ( + dbtype is not None + and ( + not exclude_types + or dbtype not in exclude_types + and dialect_impl_cls not in exclude_types + ) + and ( + not include_types + or dbtype in include_types + or dialect_impl_cls in include_types + ) ): inputsizes[bindparam] = dbtype else: @@ -1177,14 +1230,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if dbtype is not None: if key in self._expanded_parameters: positional_inputsizes.extend( - [dbtype] * len(self._expanded_parameters[key])) + [dbtype] * len(self._expanded_parameters[key]) + ) else: positional_inputsizes.append(dbtype) try: self.cursor.setinputsizes(*positional_inputsizes) except BaseException as e: self.root_connection._handle_dbapi_exception( - e, None, None, None, self) + e, None, None, None, self + ) else: keyword_inputsizes = {} for bindparam, key in self.compiled.bind_names.items(): @@ -1199,8 +1254,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): key = self.dialect._encoder(key)[0] if key in self._expanded_parameters: keyword_inputsizes.update( - (expand_key, dbtype) for expand_key - in self._expanded_parameters[key] + (expand_key, dbtype) + for expand_key in self._expanded_parameters[key] ) else: keyword_inputsizes[key] = dbtype @@ -1208,7 +1263,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor.setinputsizes(**keyword_inputsizes) except BaseException as e: self.root_connection._handle_dbapi_exception( - e, None, None, None, self) + e, None, None, None, self + ) def _exec_default(self, column, default, type_): if default.is_sequence: @@ -1290,10 +1346,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): except AttributeError: raise exc.InvalidRequestError( "get_current_parameters() can only be invoked in the " - "context of a Python side column default function") - if isolate_multiinsert_groups and \ - self.isinsert and \ - self.compiled.statement._has_multi_parameters: + "context of a Python side column default function" + ) + if ( + isolate_multiinsert_groups + and self.isinsert + and self.compiled.statement._has_multi_parameters + ): if column._is_multiparam_column: index = column.index + 1 d = {column.original.key: parameters[column.key]} @@ -1302,8 +1361,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): index = 0 keys = self.compiled.statement.parameters[0].keys() d.update( - (key, parameters["%s_m%d" % (key, index)]) - for key in keys + (key, parameters["%s_m%d" % (key, index)]) for key in keys ) return d else: @@ -1360,12 +1418,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - self.current_parameters = compiled_parameters = \ - self.compiled_parameters[0] + self.current_parameters = ( + compiled_parameters + ) = self.compiled_parameters[0] for c in self.compiled.insert_prefetch: - if c.default and \ - not c.default.is_sequence and c.default.is_scalar: + if c.default and not c.default.is_sequence and c.default.is_scalar: val = c.default.arg else: val = self.get_insert_default(c) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 9c3b24e9a..e10e6e884 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -198,7 +198,8 @@ class Dialect(object): pass def reflecttable( - self, connection, table, include_columns, exclude_columns): + self, connection, table, include_columns, exclude_columns + ): """Load table description from the database. Given a :class:`.Connection` and a @@ -367,7 +368,8 @@ class Dialect(object): raise NotImplementedError() def get_unique_constraints( - self, connection, table_name, schema=None, **kw): + self, connection, table_name, schema=None, **kw + ): r"""Return information about unique constraints in `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -389,8 +391,7 @@ class Dialect(object): raise NotImplementedError() - def get_check_constraints( - self, connection, table_name, schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): r"""Return information about check constraints in `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -412,8 +413,7 @@ class Dialect(object): raise NotImplementedError() - def get_table_comment( - self, connection, table_name, schema=None, **kw): + def get_table_comment(self, connection, table_name, schema=None, **kw): r"""Return the "comment" for the table identified by `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -613,8 +613,9 @@ class Dialect(object): raise NotImplementedError() - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): """Rollback a two phase transaction on the given connection. :param connection: a :class:`.Connection`. @@ -627,8 +628,9 @@ class Dialect(object): raise NotImplementedError() - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): """Commit a two phase transaction on the given connection. @@ -664,8 +666,9 @@ class Dialect(object): raise NotImplementedError() - def do_execute_no_params(self, cursor, statement, parameters, - context=None): + def do_execute_no_params( + self, cursor, statement, parameters, context=None + ): """Provide an implementation of ``cursor.execute(statement)``. The parameter collection should not be sent. @@ -899,6 +902,7 @@ class CreateEnginePlugin(object): .. versionadded:: 1.1 """ + def __init__(self, url, kwargs): """Contruct a new :class:`.CreateEnginePlugin`. @@ -1129,20 +1133,24 @@ class Connectable(object): raise NotImplementedError() - @util.deprecated("0.7", - "Use the create() method on the given schema " - "object directly, i.e. :meth:`.Table.create`, " - ":meth:`.Index.create`, :meth:`.MetaData.create_all`") + @util.deprecated( + "0.7", + "Use the create() method on the given schema " + "object directly, i.e. :meth:`.Table.create`, " + ":meth:`.Index.create`, :meth:`.MetaData.create_all`", + ) def create(self, entity, **kwargs): """Emit CREATE statements for the given schema entity. """ raise NotImplementedError() - @util.deprecated("0.7", - "Use the drop() method on the given schema " - "object directly, i.e. :meth:`.Table.drop`, " - ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`") + @util.deprecated( + "0.7", + "Use the drop() method on the given schema " + "object directly, i.e. :meth:`.Table.drop`, " + ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`", + ) def drop(self, entity, **kwargs): """Emit DROP statements for the given schema entity. """ @@ -1160,8 +1168,7 @@ class Connectable(object): """ raise NotImplementedError() - def _run_visitor(self, visitorcallable, element, - **kwargs): + def _run_visitor(self, visitorcallable, element, **kwargs): raise NotImplementedError() def _execute_clauseelement(self, elem, multiparams=None, params=None): diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 841bb4dfb..9b5fa2459 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -37,17 +37,17 @@ from .base import Connectable @util.decorator def cache(fn, self, con, *args, **kw): - info_cache = kw.get('info_cache', None) + info_cache = kw.get("info_cache", None) if info_cache is None: return fn(self, con, *args, **kw) key = ( fn.__name__, tuple(a for a in args if isinstance(a, util.string_types)), - tuple((k, v) for k, v in kw.items() if - isinstance(v, - util.string_types + util.int_types + (float, ) - ) - ) + tuple( + (k, v) + for k, v in kw.items() + if isinstance(v, util.string_types + util.int_types + (float,)) + ), ) ret = info_cache.get(key) if ret is None: @@ -99,7 +99,7 @@ class Inspector(object): self.bind = bind # set the engine - if hasattr(bind, 'engine'): + if hasattr(bind, "engine"): self.engine = bind.engine else: self.engine = bind @@ -130,7 +130,7 @@ class Inspector(object): See the example at :class:`.Inspector`. """ - if hasattr(bind.dialect, 'inspector'): + if hasattr(bind.dialect, "inspector"): return bind.dialect.inspector(bind) return Inspector(bind) @@ -153,9 +153,10 @@ class Inspector(object): """Return all schema names. """ - if hasattr(self.dialect, 'get_schema_names'): - return self.dialect.get_schema_names(self.bind, - info_cache=self.info_cache) + if hasattr(self.dialect, "get_schema_names"): + return self.dialect.get_schema_names( + self.bind, info_cache=self.info_cache + ) return [] def get_table_names(self, schema=None, order_by=None): @@ -196,17 +197,18 @@ class Inspector(object): """ - if hasattr(self.dialect, 'get_table_names'): + if hasattr(self.dialect, "get_table_names"): tnames = self.dialect.get_table_names( - self.bind, schema, info_cache=self.info_cache) + self.bind, schema, info_cache=self.info_cache + ) else: tnames = self.engine.table_names(schema) - if order_by == 'foreign_key': + if order_by == "foreign_key": tuples = [] for tname in tnames: for fkey in self.get_foreign_keys(tname, schema): - if tname != fkey['referred_table']: - tuples.append((fkey['referred_table'], tname)) + if tname != fkey["referred_table"]: + tuples.append((fkey["referred_table"], tname)) tnames = list(topological.sort(tuples, tnames)) return tnames @@ -234,9 +236,10 @@ class Inspector(object): with an already-given :class:`.MetaData`. """ - if hasattr(self.dialect, 'get_table_names'): + if hasattr(self.dialect, "get_table_names"): tnames = self.dialect.get_table_names( - self.bind, schema, info_cache=self.info_cache) + self.bind, schema, info_cache=self.info_cache + ) else: tnames = self.engine.table_names(schema) @@ -246,20 +249,17 @@ class Inspector(object): fknames_for_table = {} for tname in tnames: fkeys = self.get_foreign_keys(tname, schema) - fknames_for_table[tname] = set( - [fk['name'] for fk in fkeys] - ) + fknames_for_table[tname] = set([fk["name"] for fk in fkeys]) for fkey in fkeys: - if tname != fkey['referred_table']: - tuples.add((fkey['referred_table'], tname)) + if tname != fkey["referred_table"]: + tuples.add((fkey["referred_table"], tname)) try: candidate_sort = list(topological.sort(tuples, tnames)) except exc.CircularDependencyError as err: for edge in err.edges: tuples.remove(edge) remaining_fkcs.update( - (edge[1], fkc) - for fkc in fknames_for_table[edge[1]] + (edge[1], fkc) for fkc in fknames_for_table[edge[1]] ) candidate_sort = list(topological.sort(tuples, tnames)) @@ -278,7 +278,8 @@ class Inspector(object): """ return self.dialect.get_temp_table_names( - self.bind, info_cache=self.info_cache) + self.bind, info_cache=self.info_cache + ) def get_temp_view_names(self): """return a list of temporary view names for the current bind. @@ -290,7 +291,8 @@ class Inspector(object): """ return self.dialect.get_temp_view_names( - self.bind, info_cache=self.info_cache) + self.bind, info_cache=self.info_cache + ) def get_table_options(self, table_name, schema=None, **kw): """Return a dictionary of options specified when the table of the @@ -306,10 +308,10 @@ class Inspector(object): use :class:`.quoted_name`. """ - if hasattr(self.dialect, 'get_table_options'): + if hasattr(self.dialect, "get_table_options"): return self.dialect.get_table_options( - self.bind, table_name, schema, - info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) return {} def get_view_names(self, schema=None): @@ -320,8 +322,9 @@ class Inspector(object): """ - return self.dialect.get_view_names(self.bind, schema, - info_cache=self.info_cache) + return self.dialect.get_view_names( + self.bind, schema, info_cache=self.info_cache + ) def get_view_definition(self, view_name, schema=None): """Return definition for `view_name`. @@ -332,7 +335,8 @@ class Inspector(object): """ return self.dialect.get_view_definition( - self.bind, view_name, schema, info_cache=self.info_cache) + self.bind, view_name, schema, info_cache=self.info_cache + ) def get_columns(self, table_name, schema=None, **kw): """Return information about columns in `table_name`. @@ -364,18 +368,21 @@ class Inspector(object): """ - col_defs = self.dialect.get_columns(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + col_defs = self.dialect.get_columns( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) for col_def in col_defs: # make this easy and only return instances for coltype - coltype = col_def['type'] + coltype = col_def["type"] if not isinstance(coltype, TypeEngine): - col_def['type'] = coltype() + col_def["type"] = coltype() return col_defs - @deprecated('0.7', 'Call to deprecated method get_primary_keys.' - ' Use get_pk_constraint instead.') + @deprecated( + "0.7", + "Call to deprecated method get_primary_keys." + " Use get_pk_constraint instead.", + ) def get_primary_keys(self, table_name, schema=None, **kw): """Return information about primary keys in `table_name`. @@ -383,9 +390,9 @@ class Inspector(object): primary key information as a list of column names. """ - return self.dialect.get_pk_constraint(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw)['constrained_columns'] + return self.dialect.get_pk_constraint( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + )["constrained_columns"] def get_pk_constraint(self, table_name, schema=None, **kw): """Return information about primary key constraint on `table_name`. @@ -407,9 +414,9 @@ class Inspector(object): use :class:`.quoted_name`. """ - return self.dialect.get_pk_constraint(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + return self.dialect.get_pk_constraint( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_foreign_keys(self, table_name, schema=None, **kw): """Return information about foreign_keys in `table_name`. @@ -442,9 +449,9 @@ class Inspector(object): """ - return self.dialect.get_foreign_keys(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + return self.dialect.get_foreign_keys( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_indexes(self, table_name, schema=None, **kw): """Return information about indexes in `table_name`. @@ -476,9 +483,9 @@ class Inspector(object): """ - return self.dialect.get_indexes(self.bind, table_name, - schema, - info_cache=self.info_cache, **kw) + return self.dialect.get_indexes( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_unique_constraints(self, table_name, schema=None, **kw): """Return information about unique constraints in `table_name`. @@ -504,7 +511,8 @@ class Inspector(object): """ return self.dialect.get_unique_constraints( - self.bind, table_name, schema, info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_table_comment(self, table_name, schema=None, **kw): """Return information about the table comment for ``table_name``. @@ -523,8 +531,8 @@ class Inspector(object): """ return self.dialect.get_table_comment( - self.bind, table_name, schema, info_cache=self.info_cache, - **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_check_constraints(self, table_name, schema=None, **kw): """Return information about check constraints in `table_name`. @@ -550,10 +558,12 @@ class Inspector(object): """ return self.dialect.get_check_constraints( - self.bind, table_name, schema, info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) - def reflecttable(self, table, include_columns, exclude_columns=(), - _extend_on=None): + def reflecttable( + self, table, include_columns, exclude_columns=(), _extend_on=None + ): """Given a Table object, load its internal constructs based on introspection. @@ -599,7 +609,8 @@ class Inspector(object): # reflect table options, like mysql_engine tbl_opts = self.get_table_options( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) if tbl_opts: # add additional kwargs to the Table if the dialect # returned them @@ -615,185 +626,251 @@ class Inspector(object): cols_by_orig_name = {} for col_d in self.get_columns( - table_name, schema, **table.dialect_kwargs): + table_name, schema, **table.dialect_kwargs + ): found_table = True self._reflect_column( - table, col_d, include_columns, - exclude_columns, cols_by_orig_name) + table, + col_d, + include_columns, + exclude_columns, + cols_by_orig_name, + ) if not found_table: raise exc.NoSuchTableError(table.name) self._reflect_pk( - table_name, schema, table, cols_by_orig_name, exclude_columns) + table_name, schema, table, cols_by_orig_name, exclude_columns + ) self._reflect_fk( - table_name, schema, table, cols_by_orig_name, - exclude_columns, _extend_on, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + exclude_columns, + _extend_on, + reflection_options, + ) self._reflect_indexes( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_unique_constraints( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_check_constraints( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_table_comment( table_name, schema, table, reflection_options ) def _reflect_column( - self, table, col_d, include_columns, - exclude_columns, cols_by_orig_name): + self, table, col_d, include_columns, exclude_columns, cols_by_orig_name + ): - orig_name = col_d['name'] + orig_name = col_d["name"] table.dispatch.column_reflect(self, table, col_d) # fetch name again as column_reflect is allowed to # change it - name = col_d['name'] - if (include_columns and name not in include_columns) \ - or (exclude_columns and name in exclude_columns): + name = col_d["name"] + if (include_columns and name not in include_columns) or ( + exclude_columns and name in exclude_columns + ): return - coltype = col_d['type'] + coltype = col_d["type"] col_kw = dict( (k, col_d[k]) for k in [ - 'nullable', 'autoincrement', 'quote', 'info', 'key', - 'comment'] + "nullable", + "autoincrement", + "quote", + "info", + "key", + "comment", + ] if k in col_d ) - if 'dialect_options' in col_d: - col_kw.update(col_d['dialect_options']) + if "dialect_options" in col_d: + col_kw.update(col_d["dialect_options"]) colargs = [] - if col_d.get('default') is not None: - default = col_d['default'] + if col_d.get("default") is not None: + default = col_d["default"] if isinstance(default, sql.elements.TextClause): default = sa_schema.DefaultClause(default, _reflected=True) elif not isinstance(default, sa_schema.FetchedValue): default = sa_schema.DefaultClause( - sql.text(col_d['default']), _reflected=True) + sql.text(col_d["default"]), _reflected=True + ) colargs.append(default) - if 'sequence' in col_d: + if "sequence" in col_d: self._reflect_col_sequence(col_d, colargs) - cols_by_orig_name[orig_name] = col = \ - sa_schema.Column(name, coltype, *colargs, **col_kw) + cols_by_orig_name[orig_name] = col = sa_schema.Column( + name, coltype, *colargs, **col_kw + ) if col.key in table.primary_key: col.primary_key = True table.append_column(col) def _reflect_col_sequence(self, col_d, colargs): - if 'sequence' in col_d: + if "sequence" in col_d: # TODO: mssql and sybase are using this. - seq = col_d['sequence'] - sequence = sa_schema.Sequence(seq['name'], 1, 1) - if 'start' in seq: - sequence.start = seq['start'] - if 'increment' in seq: - sequence.increment = seq['increment'] + seq = col_d["sequence"] + sequence = sa_schema.Sequence(seq["name"], 1, 1) + if "start" in seq: + sequence.start = seq["start"] + if "increment" in seq: + sequence.increment = seq["increment"] colargs.append(sequence) def _reflect_pk( - self, table_name, schema, table, - cols_by_orig_name, exclude_columns): + self, table_name, schema, table, cols_by_orig_name, exclude_columns + ): pk_cons = self.get_pk_constraint( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) if pk_cons: pk_cols = [ cols_by_orig_name[pk] - for pk in pk_cons['constrained_columns'] + for pk in pk_cons["constrained_columns"] if pk in cols_by_orig_name and pk not in exclude_columns ] # update pk constraint name - table.primary_key.name = pk_cons.get('name') + table.primary_key.name = pk_cons.get("name") # tell the PKConstraint to re-initialize # its column collection table.primary_key._reload(pk_cols) def _reflect_fk( - self, table_name, schema, table, cols_by_orig_name, - exclude_columns, _extend_on, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + exclude_columns, + _extend_on, + reflection_options, + ): fkeys = self.get_foreign_keys( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) for fkey_d in fkeys: - conname = fkey_d['name'] + conname = fkey_d["name"] # look for columns by orig name in cols_by_orig_name, # but support columns that are in-Python only as fallback constrained_columns = [ - cols_by_orig_name[c].key - if c in cols_by_orig_name else c - for c in fkey_d['constrained_columns'] + cols_by_orig_name[c].key if c in cols_by_orig_name else c + for c in fkey_d["constrained_columns"] ] if exclude_columns and set(constrained_columns).intersection( - exclude_columns): + exclude_columns + ): continue - referred_schema = fkey_d['referred_schema'] - referred_table = fkey_d['referred_table'] - referred_columns = fkey_d['referred_columns'] + referred_schema = fkey_d["referred_schema"] + referred_table = fkey_d["referred_table"] + referred_columns = fkey_d["referred_columns"] refspec = [] if referred_schema is not None: - sa_schema.Table(referred_table, table.metadata, - autoload=True, schema=referred_schema, - autoload_with=self.bind, - _extend_on=_extend_on, - **reflection_options - ) + sa_schema.Table( + referred_table, + table.metadata, + autoload=True, + schema=referred_schema, + autoload_with=self.bind, + _extend_on=_extend_on, + **reflection_options + ) for column in referred_columns: - refspec.append(".".join( - [referred_schema, referred_table, column])) + refspec.append( + ".".join([referred_schema, referred_table, column]) + ) else: - sa_schema.Table(referred_table, table.metadata, autoload=True, - autoload_with=self.bind, - schema=sa_schema.BLANK_SCHEMA, - _extend_on=_extend_on, - **reflection_options - ) + sa_schema.Table( + referred_table, + table.metadata, + autoload=True, + autoload_with=self.bind, + schema=sa_schema.BLANK_SCHEMA, + _extend_on=_extend_on, + **reflection_options + ) for column in referred_columns: refspec.append(".".join([referred_table, column])) - if 'options' in fkey_d: - options = fkey_d['options'] + if "options" in fkey_d: + options = fkey_d["options"] else: options = {} table.append_constraint( - sa_schema.ForeignKeyConstraint(constrained_columns, refspec, - conname, link_to_name=True, - **options)) + sa_schema.ForeignKeyConstraint( + constrained_columns, + refspec, + conname, + link_to_name=True, + **options + ) + ) def _reflect_indexes( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): # Indexes indexes = self.get_indexes(table_name, schema) for index_d in indexes: - name = index_d['name'] - columns = index_d['column_names'] - unique = index_d['unique'] - flavor = index_d.get('type', 'index') - dialect_options = index_d.get('dialect_options', {}) - - duplicates = index_d.get('duplicates_constraint') - if include_columns and \ - not set(columns).issubset(include_columns): + name = index_d["name"] + columns = index_d["column_names"] + unique = index_d["unique"] + flavor = index_d.get("type", "index") + dialect_options = index_d.get("dialect_options", {}) + + duplicates = index_d.get("duplicates_constraint") + if include_columns and not set(columns).issubset(include_columns): util.warn( - "Omitting %s key for (%s), key covers omitted columns." % - (flavor, ', '.join(columns))) + "Omitting %s key for (%s), key covers omitted columns." + % (flavor, ", ".join(columns)) + ) continue if duplicates: continue @@ -802,26 +879,36 @@ class Inspector(object): idx_cols = [] for c in columns: try: - idx_col = cols_by_orig_name[c] \ - if c in cols_by_orig_name else table.c[c] + idx_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) except KeyError: util.warn( "%s key '%s' was not located in " - "columns for table '%s'" % ( - flavor, c, table_name - )) + "columns for table '%s'" % (flavor, c, table_name) + ) else: idx_cols.append(idx_col) sa_schema.Index( - name, *idx_cols, + name, + *idx_cols, _table=table, - **dict(list(dialect_options.items()) + [('unique', unique)]) + **dict(list(dialect_options.items()) + [("unique", unique)]) ) def _reflect_unique_constraints( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): # Unique Constraints try: @@ -831,15 +918,14 @@ class Inspector(object): return for const_d in constraints: - conname = const_d['name'] - columns = const_d['column_names'] - duplicates = const_d.get('duplicates_index') - if include_columns and \ - not set(columns).issubset(include_columns): + conname = const_d["name"] + columns = const_d["column_names"] + duplicates = const_d.get("duplicates_index") + if include_columns and not set(columns).issubset(include_columns): util.warn( "Omitting unique constraint key for (%s), " - "key covers omitted columns." % - ', '.join(columns)) + "key covers omitted columns." % ", ".join(columns) + ) continue if duplicates: continue @@ -848,20 +934,32 @@ class Inspector(object): constrained_cols = [] for c in columns: try: - constrained_col = cols_by_orig_name[c] \ - if c in cols_by_orig_name else table.c[c] + constrained_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) except KeyError: util.warn( "unique constraint key '%s' was not located in " - "columns for table '%s'" % (c, table_name)) + "columns for table '%s'" % (c, table_name) + ) else: constrained_cols.append(constrained_col) table.append_constraint( - sa_schema.UniqueConstraint(*constrained_cols, name=conname)) + sa_schema.UniqueConstraint(*constrained_cols, name=conname) + ) def _reflect_check_constraints( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): try: constraints = self.get_check_constraints(table_name, schema) except NotImplementedError: @@ -869,14 +967,14 @@ class Inspector(object): return for const_d in constraints: - table.append_constraint( - sa_schema.CheckConstraint(**const_d)) + table.append_constraint(sa_schema.CheckConstraint(**const_d)) def _reflect_table_comment( - self, table_name, schema, table, reflection_options): + self, table_name, schema, table, reflection_options + ): try: comment_dict = self.get_table_comment(table_name, schema) except NotImplementedError: return else: - table.comment = comment_dict.get('text', None) + table.comment = comment_dict.get("text", None) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index d4c862375..5ad0d2909 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -27,20 +27,25 @@ try: # the extension is present. def rowproxy_reconstructor(cls, state): return safe_rowproxy_reconstructor(cls, state) + + except ImportError: + def rowproxy_reconstructor(cls, state): obj = cls.__new__(cls) obj.__setstate__(state) return obj + try: from sqlalchemy.cresultproxy import BaseRowProxy + _baserowproxy_usecext = True except ImportError: _baserowproxy_usecext = False class BaseRowProxy(object): - __slots__ = ('_parent', '_row', '_processors', '_keymap') + __slots__ = ("_parent", "_row", "_processors", "_keymap") def __init__(self, parent, row, processors, keymap): """RowProxy objects are constructed by ResultProxy objects.""" @@ -51,8 +56,10 @@ except ImportError: self._keymap = keymap def __reduce__(self): - return (rowproxy_reconstructor, - (self.__class__, self.__getstate__())) + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) def values(self): """Return the values represented by this RowProxy as a list.""" @@ -76,8 +83,9 @@ except ImportError: except TypeError: if isinstance(key, slice): l = [] - for processor, value in zip(self._processors[key], - self._row[key]): + for processor, value in zip( + self._processors[key], self._row[key] + ): if processor is None: l.append(value) else: @@ -88,7 +96,8 @@ except ImportError: if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj) + "result set column descriptions" % obj + ) if processor is not None: return processor(self._row[index]) else: @@ -110,29 +119,29 @@ class RowProxy(BaseRowProxy): mapped to the original Columns that produced this result set (for results that correspond to constructed SQL expressions). """ + __slots__ = () def __contains__(self, key): return self._parent._has_key(key) def __getstate__(self): - return { - '_parent': self._parent, - '_row': tuple(self) - } + return {"_parent": self._parent, "_row": tuple(self)} def __setstate__(self, state): - self._parent = parent = state['_parent'] - self._row = state['_row'] + self._parent = parent = state["_parent"] + self._row = state["_row"] self._processors = parent._processors self._keymap = parent._keymap __hash__ = None def _op(self, other, op): - return op(tuple(self), tuple(other)) \ - if isinstance(other, RowProxy) \ + return ( + op(tuple(self), tuple(other)) + if isinstance(other, RowProxy) else op(tuple(self), other) + ) def __lt__(self, other): return self._op(other, operator.lt) @@ -176,6 +185,7 @@ class RowProxy(BaseRowProxy): def itervalues(self): return iter(self) + try: # Register RowProxy with Sequence, # so sequence protocol is implemented @@ -189,8 +199,13 @@ class ResultMetaData(object): context.""" __slots__ = ( - '_keymap', 'case_sensitive', 'matched_on_name', - '_processors', 'keys', '_orig_processors') + "_keymap", + "case_sensitive", + "matched_on_name", + "_processors", + "keys", + "_orig_processors", + ) def __init__(self, parent, cursor_description): context = parent.context @@ -200,18 +215,25 @@ class ResultMetaData(object): self._orig_processors = None if context.result_column_struct: - result_columns, cols_are_ordered, textual_ordered = \ + result_columns, cols_are_ordered, textual_ordered = ( context.result_column_struct + ) num_ctx_cols = len(result_columns) else: - result_columns = cols_are_ordered = \ - num_ctx_cols = textual_ordered = False + result_columns = ( + cols_are_ordered + ) = num_ctx_cols = textual_ordered = False # merge cursor.description with the column info # present in the compiled structure, if any raw = self._merge_cursor_description( - context, cursor_description, result_columns, - num_ctx_cols, cols_are_ordered, textual_ordered) + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ) self._keymap = {} if not _baserowproxy_usecext: @@ -223,23 +245,20 @@ class ResultMetaData(object): len_raw = len(raw) - self._keymap.update([ - (elem[0], (elem[3], elem[4], elem[0])) - for elem in raw - ] + [ - (elem[0] - len_raw, (elem[3], elem[4], elem[0])) - for elem in raw - ]) + self._keymap.update( + [(elem[0], (elem[3], elem[4], elem[0])) for elem in raw] + + [ + (elem[0] - len_raw, (elem[3], elem[4], elem[0])) + for elem in raw + ] + ) # processors in key order for certain per-row # views like __iter__ and slices self._processors = [elem[3] for elem in raw] # keymap by primary string... - by_key = dict([ - (elem[2], (elem[3], elem[4], elem[0])) - for elem in raw - ]) + by_key = dict([(elem[2], (elem[3], elem[4], elem[0])) for elem in raw]) # for compiled SQL constructs, copy additional lookup keys into # the key lookup map, such as Column objects, labels, @@ -264,29 +283,38 @@ class ResultMetaData(object): # copy secondary elements from compiled columns # into self._keymap, write in the potentially "ambiguous" # element - self._keymap.update([ - (obj_elem, by_key[elem[2]]) - for elem in raw if elem[4] - for obj_elem in elem[4] - ]) + self._keymap.update( + [ + (obj_elem, by_key[elem[2]]) + for elem in raw + if elem[4] + for obj_elem in elem[4] + ] + ) # if we did a pure positional match, then reset the # original "expression element" back to the "unambiguous" # entry. This is a new behavior in 1.1 which impacts # TextAsFrom but also straight compiled SQL constructs. if not self.matched_on_name: - self._keymap.update([ - (elem[4][0], (elem[3], elem[4], elem[0])) - for elem in raw if elem[4] - ]) + self._keymap.update( + [ + (elem[4][0], (elem[3], elem[4], elem[0])) + for elem in raw + if elem[4] + ] + ) else: # no dupes - copy secondary elements from compiled # columns into self._keymap - self._keymap.update([ - (obj_elem, (elem[3], elem[4], elem[0])) - for elem in raw if elem[4] - for obj_elem in elem[4] - ]) + self._keymap.update( + [ + (obj_elem, (elem[3], elem[4], elem[0])) + for elem in raw + if elem[4] + for obj_elem in elem[4] + ] + ) # update keymap with primary string names taking # precedence @@ -294,14 +322,19 @@ class ResultMetaData(object): # update keymap with "translated" names (sqlite-only thing) if not num_ctx_cols and context._translate_colname: - self._keymap.update([ - (elem[5], self._keymap[elem[2]]) - for elem in raw if elem[5] - ]) + self._keymap.update( + [(elem[5], self._keymap[elem[2]]) for elem in raw if elem[5]] + ) def _merge_cursor_description( - self, context, cursor_description, result_columns, - num_ctx_cols, cols_are_ordered, textual_ordered): + self, + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ): """Merge a cursor.description with compiled result column information. There are at least four separate strategies used here, selected @@ -357,10 +390,12 @@ class ResultMetaData(object): case_sensitive = context.dialect.case_sensitive - if num_ctx_cols and \ - cols_are_ordered and \ - not textual_ordered and \ - num_ctx_cols == len(cursor_description): + if ( + num_ctx_cols + and cols_are_ordered + and not textual_ordered + and num_ctx_cols == len(cursor_description) + ): self.keys = [elem[0] for elem in result_columns] # pure positional 1-1 case; doesn't need to read # the names from cursor.description @@ -373,9 +408,9 @@ class ResultMetaData(object): type_, key, cursor_description[idx][1] ), obj, - None - ) for idx, (key, name, obj, type_) - in enumerate(result_columns) + None, + ) + for idx, (key, name, obj, type_) in enumerate(result_columns) ] else: # name-based or text-positional cases, where we need @@ -383,26 +418,32 @@ class ResultMetaData(object): if textual_ordered: # textual positional case raw_iterator = self._merge_textual_cols_by_position( - context, cursor_description, result_columns) + context, cursor_description, result_columns + ) elif num_ctx_cols: # compiled SQL with a mismatch of description cols # vs. compiled cols, or textual w/ unordered columns raw_iterator = self._merge_cols_by_name( - context, cursor_description, result_columns) + context, cursor_description, result_columns + ) else: # no compiled SQL, just a raw string raw_iterator = self._merge_cols_by_none( - context, cursor_description) + context, cursor_description + ) return [ ( - idx, colname, colname, + idx, + colname, + colname, context.get_result_processor( - mapped_type, colname, coltype), - obj, untranslated) - - for idx, colname, mapped_type, coltype, obj, untranslated - in raw_iterator + mapped_type, colname, coltype + ), + obj, + untranslated, + ) + for idx, colname, mapped_type, coltype, obj, untranslated in raw_iterator ] def _colnames_from_description(self, context, cursor_description): @@ -416,10 +457,14 @@ class ResultMetaData(object): dialect = context.dialect case_sensitive = dialect.case_sensitive translate_colname = context._translate_colname - description_decoder = dialect._description_decoder \ - if dialect.description_encoding else None - normalize_name = dialect.normalize_name \ - if dialect.requires_name_normalize else None + description_decoder = ( + dialect._description_decoder + if dialect.description_encoding + else None + ) + normalize_name = ( + dialect.normalize_name if dialect.requires_name_normalize else None + ) untranslated = None self.keys = [] @@ -444,20 +489,25 @@ class ResultMetaData(object): yield idx, colname, untranslated, coltype def _merge_textual_cols_by_position( - self, context, cursor_description, result_columns): + self, context, cursor_description, result_columns + ): dialect = context.dialect num_ctx_cols = len(result_columns) if result_columns else None if num_ctx_cols > len(cursor_description): util.warn( "Number of columns in textual SQL (%d) is " - "smaller than number of columns requested (%d)" % ( - num_ctx_cols, len(cursor_description) - )) + "smaller than number of columns requested (%d)" + % (num_ctx_cols, len(cursor_description)) + ) seen = set() - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): if idx < num_ctx_cols: ctx_rec = result_columns[idx] obj = ctx_rec[2] @@ -465,7 +515,8 @@ class ResultMetaData(object): if obj[0] in seen: raise exc.InvalidRequestError( "Duplicate column expression requested " - "in textual SQL: %r" % obj[0]) + "in textual SQL: %r" % obj[0] + ) seen.add(obj[0]) else: mapped_type = sqltypes.NULLTYPE @@ -479,8 +530,12 @@ class ResultMetaData(object): result_map = self._create_result_map(result_columns, case_sensitive) self.matched_on_name = True - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): try: ctx_rec = result_map[colname] except KeyError: @@ -493,8 +548,12 @@ class ResultMetaData(object): def _merge_cols_by_none(self, context, cursor_description): dialect = context.dialect - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated @classmethod @@ -525,27 +584,28 @@ class ResultMetaData(object): # or colummn('name') constructs to ColumnElements, or after a # pickle/unpickle roundtrip elif isinstance(key, expression.ColumnElement): - if key._label and ( - key._label - if self.case_sensitive - else key._label.lower()) in map: - result = map[key._label - if self.case_sensitive - else key._label.lower()] - elif hasattr(key, 'name') and ( - key.name - if self.case_sensitive - else key.name.lower()) in map: + if ( + key._label + and (key._label if self.case_sensitive else key._label.lower()) + in map + ): + result = map[ + key._label if self.case_sensitive else key._label.lower() + ] + elif ( + hasattr(key, "name") + and (key.name if self.case_sensitive else key.name.lower()) + in map + ): # match is only on name. - result = map[key.name - if self.case_sensitive - else key.name.lower()] + result = map[ + key.name if self.case_sensitive else key.name.lower() + ] # search extra hard to make sure this # isn't a column/label name overlap. # this check isn't currently available if the row # was unpickled. - if result is not None and \ - result[1] is not None: + if result is not None and result[1] is not None: for obj in result[1]: if key._compare_name_for_result(obj): break @@ -554,8 +614,9 @@ class ResultMetaData(object): if result is None: if raiseerr: raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" % - expression._string_or_unprintable(key)) + "Could not locate column in row for column '%s'" + % expression._string_or_unprintable(key) + ) else: return None else: @@ -580,34 +641,35 @@ class ResultMetaData(object): if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj) + "result set column descriptions" % obj + ) return operator.itemgetter(index) def __getstate__(self): return { - '_pickled_keymap': dict( + "_pickled_keymap": dict( (key, index) for key, (processor, obj, index) in self._keymap.items() if isinstance(key, util.string_types + util.int_types) ), - 'keys': self.keys, + "keys": self.keys, "case_sensitive": self.case_sensitive, - "matched_on_name": self.matched_on_name + "matched_on_name": self.matched_on_name, } def __setstate__(self, state): # the row has been processed at pickling time so we don't need any # processor anymore - self._processors = [None for _ in range(len(state['keys']))] + self._processors = [None for _ in range(len(state["keys"]))] self._keymap = keymap = {} - for key, index in state['_pickled_keymap'].items(): + for key, index in state["_pickled_keymap"].items(): # not preserving "obj" here, unfortunately our # proxy comparison fails with the unpickle keymap[key] = (None, None, index) - self.keys = state['keys'] - self.case_sensitive = state['case_sensitive'] - self.matched_on_name = state['matched_on_name'] + self.keys = state["keys"] + self.case_sensitive = state["case_sensitive"] + self.matched_on_name = state["matched_on_name"] class ResultProxy(object): @@ -643,8 +705,9 @@ class ResultProxy(object): self.dialect = context.dialect self.cursor = self._saved_cursor = context.cursor self.connection = context.root_connection - self._echo = self.connection._echo and \ - context.engine._should_log_debug() + self._echo = ( + self.connection._echo and context.engine._should_log_debug() + ) self._init_metadata() def _getter(self, key, raiseerr=True): @@ -666,18 +729,22 @@ class ResultProxy(object): def _init_metadata(self): cursor_description = self._cursor_description() if cursor_description is not None: - if self.context.compiled and \ - 'compiled_cache' in self.context.execution_options: + if ( + self.context.compiled + and "compiled_cache" in self.context.execution_options + ): if self.context.compiled._cached_metadata: self._metadata = self.context.compiled._cached_metadata else: - self._metadata = self.context.compiled._cached_metadata = \ - ResultMetaData(self, cursor_description) + self._metadata = ( + self.context.compiled._cached_metadata + ) = ResultMetaData(self, cursor_description) else: self._metadata = ResultMetaData(self, cursor_description) if self._echo: self.context.engine.logger.debug( - "Col %r", tuple(x[0] for x in cursor_description)) + "Col %r", tuple(x[0] for x in cursor_description) + ) def keys(self): """Return the current set of string keys for rows.""" @@ -731,7 +798,8 @@ class ResultProxy(object): return self.context.rowcount except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, self.cursor, self.context) + e, None, None, self.cursor, self.context + ) @property def lastrowid(self): @@ -753,8 +821,8 @@ class ResultProxy(object): return self._saved_cursor.lastrowid except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self._saved_cursor, self.context) + e, None, None, self._saved_cursor, self.context + ) @property def returns_rows(self): @@ -913,17 +981,18 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " - "expression construct.") + "Statement is not an insert() " "expression construct." + ) elif self.context._is_explicit_returning: raise exc.InvalidRequestError( "Can't call inserted_primary_key " "when returning() " - "is used.") + "is used." + ) return self.context.inserted_primary_key @@ -938,12 +1007,12 @@ class ResultProxy(object): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isupdate: raise exc.InvalidRequestError( - "Statement is not an update() " - "expression construct.") + "Statement is not an update() " "expression construct." + ) elif self.context.executemany: return self.context.compiled_parameters else: @@ -960,12 +1029,12 @@ class ResultProxy(object): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " - "expression construct.") + "Statement is not an insert() " "expression construct." + ) elif self.context.executemany: return self.context.compiled_parameters else: @@ -1013,12 +1082,13 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( "Statement is not an insert() or update() " - "expression construct.") + "expression construct." + ) return self.context.postfetch_cols def prefetch_cols(self): @@ -1035,12 +1105,13 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( "Statement is not an insert() or update() " - "expression construct.") + "expression construct." + ) return self.context.prefetch_cols def supports_sane_rowcount(self): @@ -1086,7 +1157,7 @@ class ResultProxy(object): if self._metadata is None: raise exc.ResourceClosedError( "This result object does not return rows. " - "It has been closed automatically.", + "It has been closed automatically." ) elif self.closed: raise exc.ResourceClosedError("This result object is closed.") @@ -1106,8 +1177,9 @@ class ResultProxy(object): l.append(process_row(metadata, row, processors, keymap)) return l else: - return [process_row(metadata, row, processors, keymap) - for row in rows] + return [ + process_row(metadata, row, processors, keymap) for row in rows + ] def fetchall(self): """Fetch all rows, just like DB-API ``cursor.fetchall()``. @@ -1132,8 +1204,8 @@ class ResultProxy(object): return l except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def fetchmany(self, size=None): """Fetch many rows, just like DB-API @@ -1161,8 +1233,8 @@ class ResultProxy(object): return l except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def fetchone(self): """Fetch one row, just like DB-API ``cursor.fetchone()``. @@ -1190,8 +1262,8 @@ class ResultProxy(object): return None except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def first(self): """Fetch the first row and then close the result set unconditionally. @@ -1209,8 +1281,8 @@ class ResultProxy(object): row = self._fetchone_impl() except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) try: if row is not None: @@ -1268,7 +1340,8 @@ class BufferedRowResultProxy(ResultProxy): def _init_metadata(self): self._max_row_buffer = self.context.execution_options.get( - 'max_row_buffer', None) + "max_row_buffer", None + ) self.__buffer_rows() super(BufferedRowResultProxy, self)._init_metadata() @@ -1284,13 +1357,13 @@ class BufferedRowResultProxy(ResultProxy): 50: 100, 100: 250, 250: 500, - 500: 1000 + 500: 1000, } def __buffer_rows(self): if self.cursor is None: return - size = getattr(self, '_bufsize', 1) + size = getattr(self, "_bufsize", 1) self.__rowbuffer = collections.deque(self.cursor.fetchmany(size)) self._bufsize = self.size_growth.get(size, size) if self._max_row_buffer is not None: @@ -1385,8 +1458,9 @@ class BufferedColumnRow(RowProxy): row[index] = processor(row[index]) index += 1 row = tuple(row) - super(BufferedColumnRow, self).__init__(parent, row, - processors, keymap) + super(BufferedColumnRow, self).__init__( + parent, row, processors, keymap + ) class BufferedColumnResultProxy(ResultProxy): diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index d4f5185de..4aecb9537 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -51,18 +51,20 @@ class DefaultEngineStrategy(EngineStrategy): plugins = u._instantiate_plugins(kwargs) - u.query.pop('plugin', None) - kwargs.pop('plugins', None) + u.query.pop("plugin", None) + kwargs.pop("plugins", None) entrypoint = u._get_entrypoint() dialect_cls = entrypoint.get_dialect_cls(u) - if kwargs.pop('_coerce_config', False): + if kwargs.pop("_coerce_config", False): + def pop_kwarg(key, default=None): value = kwargs.pop(key, default) if key in dialect_cls.engine_config_types: value = dialect_cls.engine_config_types[key](value) return value + else: pop_kwarg = kwargs.pop @@ -72,7 +74,7 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: dialect_args[k] = pop_kwarg(k) - dbapi = kwargs.pop('module', None) + dbapi = kwargs.pop("module", None) if dbapi is None: dbapi_args = {} for k in util.get_func_kwargs(dialect_cls.dbapi): @@ -80,7 +82,7 @@ class DefaultEngineStrategy(EngineStrategy): dbapi_args[k] = pop_kwarg(k) dbapi = dialect_cls.dbapi(**dbapi_args) - dialect_args['dbapi'] = dbapi + dialect_args["dbapi"] = dbapi for plugin in plugins: plugin.handle_dialect_kwargs(dialect_cls, dialect_args) @@ -90,41 +92,43 @@ class DefaultEngineStrategy(EngineStrategy): # assemble connection arguments (cargs, cparams) = dialect.create_connect_args(u) - cparams.update(pop_kwarg('connect_args', {})) + cparams.update(pop_kwarg("connect_args", {})) cargs = list(cargs) # allow mutability # look for existing pool or create - pool = pop_kwarg('pool', None) + pool = pop_kwarg("pool", None) if pool is None: + def connect(connection_record=None): if dialect._has_events: for fn in dialect.dispatch.do_connect: connection = fn( - dialect, connection_record, cargs, cparams) + dialect, connection_record, cargs, cparams + ) if connection is not None: return connection return dialect.connect(*cargs, **cparams) - creator = pop_kwarg('creator', connect) + creator = pop_kwarg("creator", connect) - poolclass = pop_kwarg('poolclass', None) + poolclass = pop_kwarg("poolclass", None) if poolclass is None: poolclass = dialect_cls.get_pool_class(u) - pool_args = { - 'dialect': dialect - } + pool_args = {"dialect": dialect} # consume pool arguments from kwargs, translating a few of # the arguments - translate = {'logging_name': 'pool_logging_name', - 'echo': 'echo_pool', - 'timeout': 'pool_timeout', - 'recycle': 'pool_recycle', - 'events': 'pool_events', - 'use_threadlocal': 'pool_threadlocal', - 'reset_on_return': 'pool_reset_on_return', - 'pre_ping': 'pool_pre_ping', - 'use_lifo': 'pool_use_lifo'} + translate = { + "logging_name": "pool_logging_name", + "echo": "echo_pool", + "timeout": "pool_timeout", + "recycle": "pool_recycle", + "events": "pool_events", + "use_threadlocal": "pool_threadlocal", + "reset_on_return": "pool_reset_on_return", + "pre_ping": "pool_pre_ping", + "use_lifo": "pool_use_lifo", + } for k in util.get_cls_kwargs(poolclass): tk = translate.get(k, k) if tk in kwargs: @@ -149,7 +153,7 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: engine_args[k] = pop_kwarg(k) - _initialize = kwargs.pop('_initialize', True) + _initialize = kwargs.pop("_initialize", True) # all kwargs should be consumed if kwargs: @@ -157,32 +161,40 @@ class DefaultEngineStrategy(EngineStrategy): "Invalid argument(s) %s sent to create_engine(), " "using configuration %s/%s/%s. Please check that the " "keyword arguments are appropriate for this combination " - "of components." % (','.join("'%s'" % k for k in kwargs), - dialect.__class__.__name__, - pool.__class__.__name__, - engineclass.__name__)) + "of components." + % ( + ",".join("'%s'" % k for k in kwargs), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__, + ) + ) engine = engineclass(pool, dialect, u, **engine_args) if _initialize: do_on_connect = dialect.on_connect() if do_on_connect: + def on_connect(dbapi_connection, connection_record): conn = getattr( - dbapi_connection, '_sqla_unwrap', dbapi_connection) + dbapi_connection, "_sqla_unwrap", dbapi_connection + ) if conn is None: return do_on_connect(conn) - event.listen(pool, 'first_connect', on_connect) - event.listen(pool, 'connect', on_connect) + event.listen(pool, "first_connect", on_connect) + event.listen(pool, "connect", on_connect) def first_connect(dbapi_connection, connection_record): - c = base.Connection(engine, connection=dbapi_connection, - _has_events=False) + c = base.Connection( + engine, connection=dbapi_connection, _has_events=False + ) c._execution_options = util.immutabledict() dialect.initialize(c) - event.listen(pool, 'first_connect', first_connect, once=True) + + event.listen(pool, "first_connect", first_connect, once=True) dialect_cls.engine_created(engine) if entrypoint is not dialect_cls: @@ -197,18 +209,20 @@ class DefaultEngineStrategy(EngineStrategy): class PlainEngineStrategy(DefaultEngineStrategy): """Strategy for configuring a regular Engine.""" - name = 'plain' + name = "plain" engine_cls = base.Engine + PlainEngineStrategy() class ThreadLocalEngineStrategy(DefaultEngineStrategy): """Strategy for configuring an Engine with threadlocal behavior.""" - name = 'threadlocal' + name = "threadlocal" engine_cls = threadlocal.TLEngine + ThreadLocalEngineStrategy() @@ -220,7 +234,7 @@ class MockEngineStrategy(EngineStrategy): """ - name = 'mock' + name = "mock" def create(self, name_or_url, executor, **kwargs): # create url.URL object @@ -245,7 +259,7 @@ class MockEngineStrategy(EngineStrategy): self.execute = execute engine = property(lambda s: s) - dialect = property(attrgetter('_dialect')) + dialect = property(attrgetter("_dialect")) name = property(lambda s: s._dialect.name) schema_for_object = schema._schema_getter(None) @@ -258,29 +272,35 @@ class MockEngineStrategy(EngineStrategy): def compiler(self, statement, parameters, **kwargs): return self._dialect.compiler( - statement, parameters, engine=self, **kwargs) + statement, parameters, engine=self, **kwargs + ) def create(self, entity, **kwargs): - kwargs['checkfirst'] = False + kwargs["checkfirst"] = False from sqlalchemy.engine import ddl - ddl.SchemaGenerator( - self.dialect, self, **kwargs).traverse_single(entity) + ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single( + entity + ) def drop(self, entity, **kwargs): - kwargs['checkfirst'] = False + kwargs["checkfirst"] = False from sqlalchemy.engine import ddl - ddl.SchemaDropper( - self.dialect, self, **kwargs).traverse_single(entity) - def _run_visitor(self, visitorcallable, element, - connection=None, - **kwargs): - kwargs['checkfirst'] = False - visitorcallable(self.dialect, self, - **kwargs).traverse_single(element) + ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single( + entity + ) + + def _run_visitor( + self, visitorcallable, element, connection=None, **kwargs + ): + kwargs["checkfirst"] = False + visitorcallable(self.dialect, self, **kwargs).traverse_single( + element + ) def execute(self, object, *multiparams, **params): raise NotImplementedError() + MockEngineStrategy() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 0ec1f9613..5b2bdabc0 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -19,7 +19,6 @@ import weakref class TLConnection(base.Connection): - def __init__(self, *arg, **kw): super(TLConnection, self).__init__(*arg, **kw) self.__opencount = 0 @@ -43,6 +42,7 @@ class TLEngine(base.Engine): transactions. """ + _tl_connection_cls = TLConnection def __init__(self, *args, **kwargs): @@ -50,7 +50,7 @@ class TLEngine(base.Engine): self._connections = util.threading.local() def contextual_connect(self, **kw): - if not hasattr(self._connections, 'conn'): + if not hasattr(self._connections, "conn"): connection = None else: connection = self._connections.conn() @@ -60,29 +60,31 @@ class TLEngine(base.Engine): # or not connection.connection.is_valid: connection = self._tl_connection_cls( self, - self._wrap_pool_connect( - self.pool.connect, connection), - **kw) + self._wrap_pool_connect(self.pool.connect, connection), + **kw + ) self._connections.conn = weakref.ref(connection) return connection._increment_connect() def begin_twophase(self, xid=None): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_twophase(xid=xid)) + self.contextual_connect().begin_twophase(xid=xid) + ) return self def begin_nested(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_nested()) + self.contextual_connect().begin_nested() + ) return self def begin(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append(self.contextual_connect().begin()) return self @@ -97,21 +99,27 @@ class TLEngine(base.Engine): self.rollback() def prepare(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return self._connections.trans[-1].prepare() def commit(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.commit() def rollback(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.rollback() @@ -122,9 +130,11 @@ class TLEngine(base.Engine): @property def closed(self): - return not hasattr(self._connections, 'conn') or \ - self._connections.conn() is None or \ - self._connections.conn().closed + return ( + not hasattr(self._connections, "conn") + or self._connections.conn() is None + or self._connections.conn().closed + ) def close(self): if not self.closed: @@ -135,4 +145,4 @@ class TLEngine(base.Engine): self._connections.trans = [] def __repr__(self): - return 'TLEngine(%r)' % self.url + return "TLEngine(%r)" % self.url diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 1662efe20..e92e57b8e 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -50,8 +50,16 @@ class URL(object): """ - def __init__(self, drivername, username=None, password=None, - host=None, port=None, database=None, query=None): + def __init__( + self, + drivername, + username=None, + password=None, + host=None, + port=None, + database=None, + query=None, + ): self.drivername = drivername self.username = username self.password_original = password @@ -68,26 +76,26 @@ class URL(object): if self.username is not None: s += _rfc_1738_quote(self.username) if self.password is not None: - s += ':' + ('***' if hide_password - else _rfc_1738_quote(self.password)) + s += ":" + ( + "***" if hide_password else _rfc_1738_quote(self.password) + ) s += "@" if self.host is not None: - if ':' in self.host: + if ":" in self.host: s += "[%s]" % self.host else: s += self.host if self.port is not None: - s += ':' + str(self.port) + s += ":" + str(self.port) if self.database is not None: - s += '/' + self.database + s += "/" + self.database if self.query: keys = list(self.query) keys.sort() - s += '?' + "&".join( - "%s=%s" % ( - k, - element - ) for k in keys for element in util.to_list(self.query[k]) + s += "?" + "&".join( + "%s=%s" % (k, element) + for k in keys + for element in util.to_list(self.query[k]) ) return s @@ -101,14 +109,15 @@ class URL(object): return hash(str(self)) def __eq__(self, other): - return \ - isinstance(other, URL) and \ - self.drivername == other.drivername and \ - self.username == other.username and \ - self.password == other.password and \ - self.host == other.host and \ - self.database == other.database and \ - self.query == other.query + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + ) @property def password(self): @@ -122,20 +131,20 @@ class URL(object): self.password_original = password def get_backend_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.drivername else: - return self.drivername.split('+')[0] + return self.drivername.split("+")[0] def get_driver_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.get_dialect().driver else: - return self.drivername.split('+')[1] + return self.drivername.split("+")[1] def _instantiate_plugins(self, kwargs): - plugin_names = util.to_list(self.query.get('plugin', ())) - plugin_names += kwargs.get('plugins', []) + plugin_names = util.to_list(self.query.get("plugin", ())) + plugin_names += kwargs.get("plugins", []) return [ plugins.load(plugin_name)(self, kwargs) @@ -149,17 +158,19 @@ class URL(object): returned class implements the get_dialect_cls() method. """ - if '+' not in self.drivername: + if "+" not in self.drivername: name = self.drivername else: - name = self.drivername.replace('+', '.') + name = self.drivername.replace("+", ".") cls = registry.load(name) # check for legacy dialects that # would return a module with 'dialect' as the # actual class - if hasattr(cls, 'dialect') and \ - isinstance(cls.dialect, type) and \ - issubclass(cls.dialect, Dialect): + if ( + hasattr(cls, "dialect") + and isinstance(cls.dialect, type) + and issubclass(cls.dialect, Dialect) + ): return cls.dialect else: return cls @@ -187,7 +198,7 @@ class URL(object): """ translated = {} - attribute_names = ['host', 'database', 'username', 'password', 'port'] + attribute_names = ["host", "database", "username", "password", "port"] for sname in attribute_names: if names: name = names.pop(0) @@ -214,7 +225,8 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): - pattern = re.compile(r''' + pattern = re.compile( + r""" (?P<name>[\w\+]+):// (?: (?P<username>[^:/]*) @@ -228,21 +240,23 @@ def _parse_rfc1738_args(name): (?::(?P<port>[^/]*))? )? (?:/(?P<database>.*))? - ''', re.X) + """, + re.X, + ) m = pattern.match(name) if m is not None: components = m.groupdict() - if components['database'] is not None: - tokens = components['database'].split('?', 2) - components['database'] = tokens[0] + if components["database"] is not None: + tokens = components["database"].split("?", 2) + components["database"] = tokens[0] if len(tokens) > 1: query = {} for key, value in util.parse_qsl(tokens[1]): if util.py2k: - key = key.encode('ascii') + key = key.encode("ascii") if key in query: query[key] = util.to_list(query[key]) query[key].append(value) @@ -252,26 +266,27 @@ def _parse_rfc1738_args(name): query = None else: query = None - components['query'] = query + components["query"] = query - if components['username'] is not None: - components['username'] = _rfc_1738_unquote(components['username']) + if components["username"] is not None: + components["username"] = _rfc_1738_unquote(components["username"]) - if components['password'] is not None: - components['password'] = _rfc_1738_unquote(components['password']) + if components["password"] is not None: + components["password"] = _rfc_1738_unquote(components["password"]) - ipv4host = components.pop('ipv4host') - ipv6host = components.pop('ipv6host') - components['host'] = ipv4host or ipv6host - name = components.pop('name') + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") return URL(name, **components) else: raise exc.ArgumentError( - "Could not parse rfc1738 URL from string '%s'" % name) + "Could not parse rfc1738 URL from string '%s'" % name + ) def _rfc_1738_quote(text): - return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text) + return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) def _rfc_1738_unquote(text): @@ -279,7 +294,7 @@ def _rfc_1738_unquote(text): def _parse_keyvalue_args(name): - m = re.match(r'(\w+)://(.*)', name) + m = re.match(r"(\w+)://(.*)", name) if m is not None: (name, args) = m.group(1, 2) opts = dict(util.parse_qsl(args)) diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 17bc9a3b4..76bb8f4b5 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -46,28 +46,34 @@ def py_fallback(): elif len(multiparams) == 1: zero = multiparams[0] if isinstance(zero, (list, tuple)): - if not zero or hasattr(zero[0], '__iter__') and \ - not hasattr(zero[0], 'strip'): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): # execute(stmt, [{}, {}, {}, ...]) # execute(stmt, [(), (), (), ...]) return zero else: # execute(stmt, ("value", "value")) return [zero] - elif hasattr(zero, 'keys'): + elif hasattr(zero, "keys"): # execute(stmt, {"key":"value"}) return [zero] else: # execute(stmt, "value") return [[zero]] else: - if hasattr(multiparams[0], '__iter__') and \ - not hasattr(multiparams[0], 'strip'): + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): return multiparams else: return [multiparams] return locals() + + try: from sqlalchemy.cutils import _distill_params except ImportError: |
