summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine')
-rw-r--r--lib/sqlalchemy/engine/__init__.py30
-rw-r--r--lib/sqlalchemy/engine/base.py445
-rw-r--r--lib/sqlalchemy/engine/default.py416
-rw-r--r--lib/sqlalchemy/engine/interfaces.py51
-rw-r--r--lib/sqlalchemy/engine/reflection.py436
-rw-r--r--lib/sqlalchemy/engine/result.py384
-rw-r--r--lib/sqlalchemy/engine/strategies.py120
-rw-r--r--lib/sqlalchemy/engine/threadlocal.py50
-rw-r--r--lib/sqlalchemy/engine/url.py115
-rw-r--r--lib/sqlalchemy/engine/util.py16
10 files changed, 1204 insertions, 859 deletions
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 6342b3c21..590359c38 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -57,10 +57,9 @@ from .interfaces import (
Dialect,
ExecutionContext,
ExceptionContext,
-
# backwards compat
Compiled,
- TypeCompiler
+ TypeCompiler,
)
from .base import (
@@ -82,9 +81,7 @@ from .result import (
RowProxy,
)
-from .util import (
- connection_memoize
-)
+from .util import connection_memoize
from . import util, strategies
@@ -92,7 +89,7 @@ from . import util, strategies
# backwards compat
from ..sql import ddl
-default_strategy = 'plain'
+default_strategy = "plain"
def create_engine(*args, **kwargs):
@@ -460,12 +457,12 @@ def create_engine(*args, **kwargs):
"""
- strategy = kwargs.pop('strategy', default_strategy)
+ strategy = kwargs.pop("strategy", default_strategy)
strategy = strategies.strategies[strategy]
return strategy.create(*args, **kwargs)
-def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
+def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
"""Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file.
@@ -497,16 +494,15 @@ def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
"""
- options = dict((key[len(prefix):], configuration[key])
- for key in configuration
- if key.startswith(prefix))
- options['_coerce_config'] = True
+ options = dict(
+ (key[len(prefix) :], configuration[key])
+ for key in configuration
+ if key.startswith(prefix)
+ )
+ options["_coerce_config"] = True
options.update(kwargs)
- url = options.pop('url')
+ url = options.pop("url")
return create_engine(url, **options)
-__all__ = (
- 'create_engine',
- 'engine_from_config',
-)
+__all__ = ("create_engine", "engine_from_config")
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 4a057ee59..75d03b744 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -61,10 +61,16 @@ class Connection(Connectable):
"""
- def __init__(self, engine, connection=None, close_with_result=False,
- _branch_from=None, _execution_options=None,
- _dispatch=None,
- _has_events=None):
+ def __init__(
+ self,
+ engine,
+ connection=None,
+ close_with_result=False,
+ _branch_from=None,
+ _execution_options=None,
+ _dispatch=None,
+ _has_events=None,
+ ):
"""Construct a new Connection.
The constructor here is not public and is only called only by an
@@ -86,8 +92,11 @@ class Connection(Connectable):
self._has_events = _branch_from._has_events
self.schema_for_object = _branch_from.schema_for_object
else:
- self.__connection = connection \
- if connection is not None else engine.raw_connection()
+ self.__connection = (
+ connection
+ if connection is not None
+ else engine.raw_connection()
+ )
self.__transaction = None
self.__savepoint_seq = 0
self.should_close_with_result = close_with_result
@@ -101,7 +110,8 @@ class Connection(Connectable):
# want to handle any of the engine's events in that case.
self.dispatch = self.dispatch._join(engine.dispatch)
self._has_events = _has_events or (
- _has_events is None and engine._has_events)
+ _has_events is None and engine._has_events
+ )
assert not _execution_options
self._execution_options = engine._execution_options
@@ -134,7 +144,8 @@ class Connection(Connectable):
_branch_from=self,
_execution_options=self._execution_options,
_has_events=self._has_events,
- _dispatch=self.dispatch)
+ _dispatch=self.dispatch,
+ )
@property
def _root(self):
@@ -322,8 +333,10 @@ class Connection(Connectable):
def closed(self):
"""Return True if this connection is closed."""
- return '_Connection__connection' not in self.__dict__ \
+ return (
+ "_Connection__connection" not in self.__dict__
and not self.__can_reconnect
+ )
@property
def invalidated(self):
@@ -425,7 +438,8 @@ class Connection(Connectable):
if self.__transaction is not None:
raise exc.InvalidRequestError(
"Can't reconnect until invalid "
- "transaction is rolled back")
+ "transaction is rolled back"
+ )
self.__connection = self.engine.raw_connection(_connection=self)
self.__invalid = False
return self.__connection
@@ -437,14 +451,15 @@ class Connection(Connectable):
# dialect initializer, where the connection is not wrapped in
# _ConnectionFairy
- return getattr(self.__connection, 'is_valid', False)
+ return getattr(self.__connection, "is_valid", False)
@property
def _still_open_and_connection_is_valid(self):
- return \
- not self.closed and \
- not self.invalidated and \
- getattr(self.__connection, 'is_valid', False)
+ return (
+ not self.closed
+ and not self.invalidated
+ and getattr(self.__connection, "is_valid", False)
+ )
@property
def info(self):
@@ -656,7 +671,8 @@ class Connection(Connectable):
if self.__transaction is not None:
raise exc.InvalidRequestError(
"Cannot start a two phase transaction when a transaction "
- "is already in progress.")
+ "is already in progress."
+ )
if xid is None:
xid = self.engine.dialect.create_xid()
self.__transaction = TwoPhaseTransaction(self, xid)
@@ -705,8 +721,10 @@ class Connection(Connectable):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
finally:
- if not self.__invalid and \
- self.connection._reset_agent is self.__transaction:
+ if (
+ not self.__invalid
+ and self.connection._reset_agent is self.__transaction
+ ):
self.connection._reset_agent = None
self.__transaction = None
else:
@@ -725,8 +743,10 @@ class Connection(Connectable):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
finally:
- if not self.__invalid and \
- self.connection._reset_agent is self.__transaction:
+ if (
+ not self.__invalid
+ and self.connection._reset_agent is self.__transaction
+ ):
self.connection._reset_agent = None
self.__transaction = None
@@ -738,7 +758,7 @@ class Connection(Connectable):
if name is None:
self.__savepoint_seq += 1
- name = 'sa_savepoint_%s' % self.__savepoint_seq
+ name = "sa_savepoint_%s" % self.__savepoint_seq
if self._still_open_and_connection_is_valid:
self.engine.dialect.do_savepoint(self, name)
return name
@@ -797,7 +817,8 @@ class Connection(Connectable):
assert isinstance(self.__transaction, TwoPhaseTransaction)
try:
self.engine.dialect.do_rollback_twophase(
- self, xid, is_prepared)
+ self, xid, is_prepared
+ )
finally:
if self.connection._reset_agent is self.__transaction:
self.connection._reset_agent = None
@@ -950,16 +971,16 @@ class Connection(Connectable):
def _execute_function(self, func, multiparams, params):
"""Execute a sql.FunctionElement object."""
- return self._execute_clauseelement(func.select(),
- multiparams, params)
+ return self._execute_clauseelement(func.select(), multiparams, params)
def _execute_default(self, default, multiparams, params):
"""Execute a schema.ColumnDefault object."""
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- default, multiparams, params = \
- fn(self, default, multiparams, params)
+ default, multiparams, params = fn(
+ self, default, multiparams, params
+ )
try:
try:
@@ -972,8 +993,7 @@ class Connection(Connectable):
conn = self._revalidate_connection()
dialect = self.dialect
- ctx = dialect.execution_ctx_cls._init_default(
- dialect, self, conn)
+ ctx = dialect.execution_ctx_cls._init_default(dialect, self, conn)
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
@@ -982,8 +1002,9 @@ class Connection(Connectable):
self.close()
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- default, multiparams, params, ret)
+ self.dispatch.after_execute(
+ self, default, multiparams, params, ret
+ )
return ret
@@ -992,25 +1013,25 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- ddl, multiparams, params = \
- fn(self, ddl, multiparams, params)
+ ddl, multiparams, params = fn(self, ddl, multiparams, params)
dialect = self.dialect
compiled = ddl.compile(
dialect=dialect,
schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default else None)
+ if not self.schema_for_object.is_default
+ else None,
+ )
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_ddl,
compiled,
None,
- compiled
+ compiled,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- ddl, multiparams, params, ret)
+ self.dispatch.after_execute(self, ddl, multiparams, params, ret)
return ret
def _execute_clauseelement(self, elem, multiparams, params):
@@ -1018,8 +1039,7 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- elem, multiparams, params = \
- fn(self, elem, multiparams, params)
+ elem, multiparams, params = fn(self, elem, multiparams, params)
distilled_params = _distill_params(multiparams, params)
if distilled_params:
@@ -1030,38 +1050,45 @@ class Connection(Connectable):
keys = []
dialect = self.dialect
- if 'compiled_cache' in self._execution_options:
+ if "compiled_cache" in self._execution_options:
key = (
- dialect, elem, tuple(sorted(keys)),
+ dialect,
+ elem,
+ tuple(sorted(keys)),
self.schema_for_object.hash_key,
- len(distilled_params) > 1
+ len(distilled_params) > 1,
)
- compiled_sql = self._execution_options['compiled_cache'].get(key)
+ compiled_sql = self._execution_options["compiled_cache"].get(key)
if compiled_sql is None:
compiled_sql = elem.compile(
- dialect=dialect, column_keys=keys,
+ dialect=dialect,
+ column_keys=keys,
inline=len(distilled_params) > 1,
schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default else None
+ if not self.schema_for_object.is_default
+ else None,
)
- self._execution_options['compiled_cache'][key] = compiled_sql
+ self._execution_options["compiled_cache"][key] = compiled_sql
else:
compiled_sql = elem.compile(
- dialect=dialect, column_keys=keys,
+ dialect=dialect,
+ column_keys=keys,
inline=len(distilled_params) > 1,
schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default else None)
+ if not self.schema_for_object.is_default
+ else None,
+ )
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_compiled,
compiled_sql,
distilled_params,
- compiled_sql, distilled_params
+ compiled_sql,
+ distilled_params,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- elem, multiparams, params, ret)
+ self.dispatch.after_execute(self, elem, multiparams, params, ret)
return ret
def _execute_compiled(self, compiled, multiparams, params):
@@ -1069,8 +1096,9 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- compiled, multiparams, params = \
- fn(self, compiled, multiparams, params)
+ compiled, multiparams, params = fn(
+ self, compiled, multiparams, params
+ )
dialect = self.dialect
parameters = _distill_params(multiparams, params)
@@ -1079,11 +1107,13 @@ class Connection(Connectable):
dialect.execution_ctx_cls._init_compiled,
compiled,
parameters,
- compiled, parameters
+ compiled,
+ parameters,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- compiled, multiparams, params, ret)
+ self.dispatch.after_execute(
+ self, compiled, multiparams, params, ret
+ )
return ret
def _execute_text(self, statement, multiparams, params):
@@ -1091,8 +1121,9 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- statement, multiparams, params = \
- fn(self, statement, multiparams, params)
+ statement, multiparams, params = fn(
+ self, statement, multiparams, params
+ )
dialect = self.dialect
parameters = _distill_params(multiparams, params)
@@ -1101,16 +1132,18 @@ class Connection(Connectable):
dialect.execution_ctx_cls._init_statement,
statement,
parameters,
- statement, parameters
+ statement,
+ parameters,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- statement, multiparams, params, ret)
+ self.dispatch.after_execute(
+ self, statement, multiparams, params, ret
+ )
return ret
- def _execute_context(self, dialect, constructor,
- statement, parameters,
- *args):
+ def _execute_context(
+ self, dialect, constructor, statement, parameters, *args
+ ):
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`.ResultProxy`."""
@@ -1127,31 +1160,36 @@ class Connection(Connectable):
context = constructor(dialect, self, conn, *args)
except BaseException as e:
self._handle_dbapi_exception(
- e,
- util.text_type(statement), parameters,
- None, None)
+ e, util.text_type(statement), parameters, None, None
+ )
if context.compiled:
context.pre_exec()
- cursor, statement, parameters = context.cursor, \
- context.statement, \
- context.parameters
+ cursor, statement, parameters = (
+ context.cursor,
+ context.statement,
+ context.parameters,
+ )
if not context.executemany:
parameters = parameters[0]
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_cursor_execute:
- statement, parameters = \
- fn(self, cursor, statement, parameters,
- context, context.executemany)
+ statement, parameters = fn(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
if self._echo:
self.engine.logger.info(statement)
self.engine.logger.info(
- "%r",
- sql_util._repr_params(parameters, batches=10)
+ "%r", sql_util._repr_params(parameters, batches=10)
)
evt_handled = False
@@ -1164,10 +1202,8 @@ class Connection(Connectable):
break
if not evt_handled:
self.dialect.do_executemany(
- cursor,
- statement,
- parameters,
- context)
+ cursor, statement, parameters, context
+ )
elif not parameters and context.no_parameters:
if self.dialect._has_events:
for fn in self.dialect.dispatch.do_execute_no_params:
@@ -1176,9 +1212,8 @@ class Connection(Connectable):
break
if not evt_handled:
self.dialect.do_execute_no_params(
- cursor,
- statement,
- context)
+ cursor, statement, context
+ )
else:
if self.dialect._has_events:
for fn in self.dialect.dispatch.do_execute:
@@ -1187,24 +1222,22 @@ class Connection(Connectable):
break
if not evt_handled:
self.dialect.do_execute(
- cursor,
- statement,
- parameters,
- context)
+ cursor, statement, parameters, context
+ )
except BaseException as e:
self._handle_dbapi_exception(
- e,
- statement,
- parameters,
- cursor,
- context)
+ e, statement, parameters, cursor, context
+ )
if self._has_events or self.engine._has_events:
- self.dispatch.after_cursor_execute(self, cursor,
- statement,
- parameters,
- context,
- context.executemany)
+ self.dispatch.after_cursor_execute(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
if context.compiled:
context.post_exec()
@@ -1245,39 +1278,32 @@ class Connection(Connectable):
"""
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_cursor_execute:
- statement, parameters = \
- fn(self, cursor, statement, parameters,
- context,
- False)
+ statement, parameters = fn(
+ self, cursor, statement, parameters, context, False
+ )
if self._echo:
self.engine.logger.info(statement)
self.engine.logger.info("%r", parameters)
try:
- for fn in () if not self.dialect._has_events \
- else self.dialect.dispatch.do_execute:
+ for fn in (
+ ()
+ if not self.dialect._has_events
+ else self.dialect.dispatch.do_execute
+ ):
if fn(cursor, statement, parameters, context):
break
else:
- self.dialect.do_execute(
- cursor,
- statement,
- parameters,
- context)
+ self.dialect.do_execute(cursor, statement, parameters, context)
except BaseException as e:
self._handle_dbapi_exception(
- e,
- statement,
- parameters,
- cursor,
- context)
+ e, statement, parameters, cursor, context
+ )
if self._has_events or self.engine._has_events:
- self.dispatch.after_cursor_execute(self, cursor,
- statement,
- parameters,
- context,
- False)
+ self.dispatch.after_cursor_execute(
+ self, cursor, statement, parameters, context, False
+ )
def _safe_close_cursor(self, cursor):
"""Close the given cursor, catching exceptions
@@ -1289,17 +1315,15 @@ class Connection(Connectable):
except Exception:
# log the error through the connection pool's logger.
self.engine.pool.logger.error(
- "Error closing cursor", exc_info=True)
+ "Error closing cursor", exc_info=True
+ )
_reentrant_error = False
_is_disconnect = False
- def _handle_dbapi_exception(self,
- e,
- statement,
- parameters,
- cursor,
- context):
+ def _handle_dbapi_exception(
+ self, e, statement, parameters, cursor, context
+ ):
exc_info = sys.exc_info()
if context and context.exception is None:
@@ -1309,15 +1333,14 @@ class Connection(Connectable):
if not self._is_disconnect:
self._is_disconnect = (
- isinstance(e, self.dialect.dbapi.Error) and
- not self.closed and
- self.dialect.is_disconnect(
+ isinstance(e, self.dialect.dbapi.Error)
+ and not self.closed
+ and self.dialect.is_disconnect(
e,
self.__connection if not self.invalidated else None,
- cursor)
- ) or (
- is_exit_exception and not self.closed
- )
+ cursor,
+ )
+ ) or (is_exit_exception and not self.closed)
if context:
context.is_disconnect = self._is_disconnect
@@ -1326,20 +1349,24 @@ class Connection(Connectable):
if self._reentrant_error:
util.raise_from_cause(
- exc.DBAPIError.instance(statement,
- parameters,
- e,
- self.dialect.dbapi.Error,
- dialect=self.dialect),
- exc_info
+ exc.DBAPIError.instance(
+ statement,
+ parameters,
+ e,
+ self.dialect.dbapi.Error,
+ dialect=self.dialect,
+ ),
+ exc_info,
)
self._reentrant_error = True
try:
# non-DBAPI error - if we already got a context,
# or there's no string statement, don't wrap it
- should_wrap = isinstance(e, self.dialect.dbapi.Error) or \
- (statement is not None
- and context is None and not is_exit_exception)
+ should_wrap = isinstance(e, self.dialect.dbapi.Error) or (
+ statement is not None
+ and context is None
+ and not is_exit_exception
+ )
if should_wrap:
sqlalchemy_exception = exc.DBAPIError.instance(
@@ -1348,30 +1375,37 @@ class Connection(Connectable):
e,
self.dialect.dbapi.Error,
connection_invalidated=self._is_disconnect,
- dialect=self.dialect)
+ dialect=self.dialect,
+ )
else:
sqlalchemy_exception = None
newraise = None
- if (self._has_events or self.engine._has_events) and \
- not self._execution_options.get(
- 'skip_user_error_events', False):
+ if (
+ self._has_events or self.engine._has_events
+ ) and not self._execution_options.get(
+ "skip_user_error_events", False
+ ):
# legacy dbapi_error event
if should_wrap and context:
- self.dispatch.dbapi_error(self,
- cursor,
- statement,
- parameters,
- context,
- e)
+ self.dispatch.dbapi_error(
+ self, cursor, statement, parameters, context, e
+ )
# new handle_error event
ctx = ExceptionContextImpl(
- e, sqlalchemy_exception, self.engine,
- self, cursor, statement,
- parameters, context, self._is_disconnect,
- invalidate_pool_on_disconnect)
+ e,
+ sqlalchemy_exception,
+ self.engine,
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ self._is_disconnect,
+ invalidate_pool_on_disconnect,
+ )
for fn in self.dispatch.handle_error:
try:
@@ -1388,13 +1422,15 @@ class Connection(Connectable):
if self._is_disconnect != ctx.is_disconnect:
self._is_disconnect = ctx.is_disconnect
if sqlalchemy_exception:
- sqlalchemy_exception.connection_invalidated = \
+ sqlalchemy_exception.connection_invalidated = (
ctx.is_disconnect
+ )
# set up potentially user-defined value for
# invalidate pool.
- invalidate_pool_on_disconnect = \
+ invalidate_pool_on_disconnect = (
ctx.invalidate_pool_on_disconnect
+ )
if should_wrap and context:
context.handle_dbapi_exception(e)
@@ -1408,10 +1444,7 @@ class Connection(Connectable):
if newraise:
util.raise_from_cause(newraise, exc_info)
elif should_wrap:
- util.raise_from_cause(
- sqlalchemy_exception,
- exc_info
- )
+ util.raise_from_cause(sqlalchemy_exception, exc_info)
else:
util.reraise(*exc_info)
@@ -1441,7 +1474,8 @@ class Connection(Connectable):
None,
e,
dialect.dbapi.Error,
- connection_invalidated=is_disconnect)
+ connection_invalidated=is_disconnect,
+ )
else:
sqlalchemy_exception = None
@@ -1449,8 +1483,17 @@ class Connection(Connectable):
if engine._has_events:
ctx = ExceptionContextImpl(
- e, sqlalchemy_exception, engine, None, None, None,
- None, None, is_disconnect, True)
+ e,
+ sqlalchemy_exception,
+ engine,
+ None,
+ None,
+ None,
+ None,
+ None,
+ is_disconnect,
+ True,
+ )
for fn in engine.dispatch.handle_error:
try:
# handler returns an exception;
@@ -1463,18 +1506,15 @@ class Connection(Connectable):
newraise = _raised
break
- if sqlalchemy_exception and \
- is_disconnect != ctx.is_disconnect:
- sqlalchemy_exception.connection_invalidated = \
- is_disconnect = ctx.is_disconnect
+ if sqlalchemy_exception and is_disconnect != ctx.is_disconnect:
+ sqlalchemy_exception.connection_invalidated = (
+ is_disconnect
+ ) = ctx.is_disconnect
if newraise:
util.raise_from_cause(newraise, exc_info)
elif should_wrap:
- util.raise_from_cause(
- sqlalchemy_exception,
- exc_info
- )
+ util.raise_from_cause(sqlalchemy_exception, exc_info)
else:
util.reraise(*exc_info)
@@ -1545,16 +1585,25 @@ class Connection(Connectable):
return callable_(self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, **kwargs):
- visitorcallable(self.dialect, self,
- **kwargs).traverse_single(element)
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
class ExceptionContextImpl(ExceptionContext):
"""Implement the :class:`.ExceptionContext` interface."""
- def __init__(self, exception, sqlalchemy_exception,
- engine, connection, cursor, statement, parameters,
- context, is_disconnect, invalidate_pool_on_disconnect):
+ def __init__(
+ self,
+ exception,
+ sqlalchemy_exception,
+ engine,
+ connection,
+ cursor,
+ statement,
+ parameters,
+ context,
+ is_disconnect,
+ invalidate_pool_on_disconnect,
+ ):
self.engine = engine
self.connection = connection
self.sqlalchemy_exception = sqlalchemy_exception
@@ -1691,12 +1740,14 @@ class NestedTransaction(Transaction):
def _do_rollback(self):
if self.is_active:
self.connection._rollback_to_savepoint_impl(
- self._savepoint, self._parent)
+ self._savepoint, self._parent
+ )
def _do_commit(self):
if self.is_active:
self.connection._release_savepoint_impl(
- self._savepoint, self._parent)
+ self._savepoint, self._parent
+ )
class TwoPhaseTransaction(Transaction):
@@ -1771,10 +1822,16 @@ class Engine(Connectable, log.Identified):
"""
- def __init__(self, pool, dialect, url,
- logging_name=None, echo=None, proxy=None,
- execution_options=None
- ):
+ def __init__(
+ self,
+ pool,
+ dialect,
+ url,
+ logging_name=None,
+ echo=None,
+ proxy=None,
+ execution_options=None,
+ ):
self.pool = pool
self.url = url
self.dialect = dialect
@@ -1805,8 +1862,7 @@ class Engine(Connectable, log.Identified):
:meth:`.Engine.execution_options`
"""
- self._execution_options = \
- self._execution_options.union(opt)
+ self._execution_options = self._execution_options.union(opt)
self.dispatch.set_engine_execution_options(self, opt)
self.dialect.set_engine_execution_options(self, opt)
@@ -1894,7 +1950,7 @@ class Engine(Connectable, log.Identified):
echo = log.echo_property()
def __repr__(self):
- return 'Engine(%r)' % self.url
+ return "Engine(%r)" % self.url
def dispose(self):
"""Dispose of the connection pool used by this :class:`.Engine`.
@@ -1934,8 +1990,9 @@ class Engine(Connectable, log.Identified):
else:
yield connection
- def _run_visitor(self, visitorcallable, element,
- connection=None, **kwargs):
+ def _run_visitor(
+ self, visitorcallable, element, connection=None, **kwargs
+ ):
with self._optional_conn_ctx_manager(connection) as conn:
conn._run_visitor(visitorcallable, element, **kwargs)
@@ -2122,7 +2179,8 @@ class Engine(Connectable, log.Identified):
self,
self._wrap_pool_connect(self.pool.connect, None),
close_with_result=close_with_result,
- **kwargs)
+ **kwargs
+ )
def table_names(self, schema=None, connection=None):
"""Return a list of all table names available in the database.
@@ -2159,7 +2217,8 @@ class Engine(Connectable, log.Identified):
except dialect.dbapi.Error as e:
if connection is None:
Connection._handle_dbapi_exception_noconnection(
- e, dialect, self)
+ e, dialect, self
+ )
else:
util.reraise(*sys.exc_info())
@@ -2185,7 +2244,8 @@ class Engine(Connectable, log.Identified):
"""
return self._wrap_pool_connect(
- self.pool.unique_connection, _connection)
+ self.pool.unique_connection, _connection
+ )
class OptionEngine(Engine):
@@ -2225,10 +2285,11 @@ class OptionEngine(Engine):
pool = property(_get_pool, _set_pool)
def _get_has_events(self):
- return self._proxied._has_events or \
- self.__dict__.get('_has_events', False)
+ return self._proxied._has_events or self.__dict__.get(
+ "_has_events", False
+ )
def _set_has_events(self, value):
- self.__dict__['_has_events'] = value
+ self.__dict__["_has_events"] = value
_has_events = property(_get_has_events, _set_has_events)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 028abc4c2..d7c2518fe 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -24,13 +24,11 @@ import weakref
from .. import event
AUTOCOMMIT_REGEXP = re.compile(
- r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
- re.I | re.UNICODE)
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE
+)
# When we're handed literal SQL, ensure it's a SELECT query
-SERVER_SIDE_CURSOR_RE = re.compile(
- r'\s*SELECT',
- re.I | re.UNICODE)
+SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
class DefaultDialect(interfaces.Dialect):
@@ -68,16 +66,18 @@ class DefaultDialect(interfaces.Dialect):
supports_simple_order_by_label = True
- engine_config_types = util.immutabledict([
- ('convert_unicode', util.bool_or_str('force')),
- ('pool_timeout', util.asint),
- ('echo', util.bool_or_str('debug')),
- ('echo_pool', util.bool_or_str('debug')),
- ('pool_recycle', util.asint),
- ('pool_size', util.asint),
- ('max_overflow', util.asint),
- ('pool_threadlocal', util.asbool),
- ])
+ engine_config_types = util.immutabledict(
+ [
+ ("convert_unicode", util.bool_or_str("force")),
+ ("pool_timeout", util.asint),
+ ("echo", util.bool_or_str("debug")),
+ ("echo_pool", util.bool_or_str("debug")),
+ ("pool_recycle", util.asint),
+ ("pool_size", util.asint),
+ ("max_overflow", util.asint),
+ ("pool_threadlocal", util.asbool),
+ ]
+ )
# if the NUMERIC type
# returns decimal.Decimal.
@@ -93,9 +93,9 @@ class DefaultDialect(interfaces.Dialect):
supports_unicode_statements = False
supports_unicode_binds = False
returns_unicode_strings = False
- description_encoding = 'use_encoding'
+ description_encoding = "use_encoding"
- name = 'default'
+ name = "default"
# length at which to truncate
# any identifier.
@@ -111,7 +111,7 @@ class DefaultDialect(interfaces.Dialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
colspecs = {}
- default_paramstyle = 'named'
+ default_paramstyle = "named"
supports_default_values = False
supports_empty_insert = True
supports_multivalues_insert = False
@@ -175,19 +175,26 @@ class DefaultDialect(interfaces.Dialect):
"""
- def __init__(self, convert_unicode=False,
- encoding='utf-8', paramstyle=None, dbapi=None,
- implicit_returning=None,
- supports_right_nested_joins=None,
- case_sensitive=True,
- supports_native_boolean=None,
- empty_in_strategy='static',
- label_length=None, **kwargs):
-
- if not getattr(self, 'ported_sqla_06', True):
+ def __init__(
+ self,
+ convert_unicode=False,
+ encoding="utf-8",
+ paramstyle=None,
+ dbapi=None,
+ implicit_returning=None,
+ supports_right_nested_joins=None,
+ case_sensitive=True,
+ supports_native_boolean=None,
+ empty_in_strategy="static",
+ label_length=None,
+ **kwargs
+ ):
+
+ if not getattr(self, "ported_sqla_06", True):
util.warn(
- "The %s dialect is not yet ported to the 0.6 format" %
- self.name)
+ "The %s dialect is not yet ported to the 0.6 format"
+ % self.name
+ )
self.convert_unicode = convert_unicode
self.encoding = encoding
@@ -202,7 +209,7 @@ class DefaultDialect(interfaces.Dialect):
self.paramstyle = self.default_paramstyle
if implicit_returning is not None:
self.implicit_returning = implicit_returning
- self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
+ self.positional = self.paramstyle in ("qmark", "format", "numeric")
self.identifier_preparer = self.preparer(self)
self.type_compiler = self.type_compiler(self)
if supports_right_nested_joins is not None:
@@ -212,33 +219,33 @@ class DefaultDialect(interfaces.Dialect):
self.case_sensitive = case_sensitive
self.empty_in_strategy = empty_in_strategy
- if empty_in_strategy == 'static':
+ if empty_in_strategy == "static":
self._use_static_in = True
- elif empty_in_strategy in ('dynamic', 'dynamic_warn'):
+ elif empty_in_strategy in ("dynamic", "dynamic_warn"):
self._use_static_in = False
- self._warn_on_empty_in = empty_in_strategy == 'dynamic_warn'
+ self._warn_on_empty_in = empty_in_strategy == "dynamic_warn"
else:
raise exc.ArgumentError(
"empty_in_strategy may be 'static', "
- "'dynamic', or 'dynamic_warn'")
+ "'dynamic', or 'dynamic_warn'"
+ )
if label_length and label_length > self.max_identifier_length:
raise exc.ArgumentError(
"Label length of %d is greater than this dialect's"
- " maximum identifier length of %d" %
- (label_length, self.max_identifier_length))
+ " maximum identifier length of %d"
+ % (label_length, self.max_identifier_length)
+ )
self.label_length = label_length
- if self.description_encoding == 'use_encoding':
- self._description_decoder = \
- processors.to_unicode_processor_factory(
- encoding
- )
+ if self.description_encoding == "use_encoding":
+ self._description_decoder = processors.to_unicode_processor_factory(
+ encoding
+ )
elif self.description_encoding is not None:
- self._description_decoder = \
- processors.to_unicode_processor_factory(
- self.description_encoding
- )
+ self._description_decoder = processors.to_unicode_processor_factory(
+ self.description_encoding
+ )
self._encoder = codecs.getencoder(self.encoding)
self._decoder = processors.to_unicode_processor_factory(self.encoding)
@@ -256,30 +263,35 @@ class DefaultDialect(interfaces.Dialect):
@classmethod
def get_pool_class(cls, url):
- return getattr(cls, 'poolclass', pool.QueuePool)
+ return getattr(cls, "poolclass", pool.QueuePool)
def initialize(self, connection):
try:
- self.server_version_info = \
- self._get_server_version_info(connection)
+ self.server_version_info = self._get_server_version_info(
+ connection
+ )
except NotImplementedError:
self.server_version_info = None
try:
- self.default_schema_name = \
- self._get_default_schema_name(connection)
+ self.default_schema_name = self._get_default_schema_name(
+ connection
+ )
except NotImplementedError:
self.default_schema_name = None
try:
- self.default_isolation_level = \
- self.get_isolation_level(connection.connection)
+ self.default_isolation_level = self.get_isolation_level(
+ connection.connection
+ )
except NotImplementedError:
self.default_isolation_level = None
self.returns_unicode_strings = self._check_unicode_returns(connection)
- if self.description_encoding is not None and \
- self._check_unicode_description(connection):
+ if (
+ self.description_encoding is not None
+ and self._check_unicode_description(connection)
+ ):
self._description_decoder = self.description_encoding = None
self.do_rollback(connection.connection)
@@ -311,7 +323,8 @@ class DefaultDialect(interfaces.Dialect):
def check_unicode(test):
statement = cast_to(
- expression.select([test]).compile(dialect=self))
+ expression.select([test]).compile(dialect=self)
+ )
try:
cursor = connection.connection.cursor()
connection._cursor_execute(cursor, statement, parameters)
@@ -320,8 +333,10 @@ class DefaultDialect(interfaces.Dialect):
except exc.DBAPIError as de:
# note that _cursor_execute() will have closed the cursor
# if an exception is thrown.
- util.warn("Exception attempting to "
- "detect unicode returns: %r" % de)
+ util.warn(
+ "Exception attempting to "
+ "detect unicode returns: %r" % de
+ )
return False
else:
return isinstance(row[0], util.text_type)
@@ -330,13 +345,13 @@ class DefaultDialect(interfaces.Dialect):
# detect plain VARCHAR
expression.cast(
expression.literal_column("'test plain returns'"),
- sqltypes.VARCHAR(60)
+ sqltypes.VARCHAR(60),
),
# detect if there's an NVARCHAR type with different behavior
# available
expression.cast(
expression.literal_column("'test unicode returns'"),
- sqltypes.Unicode(60)
+ sqltypes.Unicode(60),
),
]
@@ -364,9 +379,9 @@ class DefaultDialect(interfaces.Dialect):
try:
cursor.execute(
cast_to(
- expression.select([
- expression.literal_column("'x'").label("some_label")
- ]).compile(dialect=self)
+ expression.select(
+ [expression.literal_column("'x'").label("some_label")]
+ ).compile(dialect=self)
)
)
return isinstance(cursor.description[0][0], util.text_type)
@@ -385,10 +400,12 @@ class DefaultDialect(interfaces.Dialect):
return sqltypes.adapt_type(typeobj, self.colspecs)
def reflecttable(
- self, connection, table, include_columns, exclude_columns, **opts):
+ self, connection, table, include_columns, exclude_columns, **opts
+ ):
insp = reflection.Inspector.from_engine(connection)
return insp.reflecttable(
- table, include_columns, exclude_columns, **opts)
+ table, include_columns, exclude_columns, **opts
+ )
def get_pk_constraint(self, conn, table_name, schema=None, **kw):
"""Compatibility method, adapts the result of get_primary_keys()
@@ -396,16 +413,16 @@ class DefaultDialect(interfaces.Dialect):
"""
return {
- 'constrained_columns':
- self.get_primary_keys(conn, table_name,
- schema=schema, **kw)
+ "constrained_columns": self.get_primary_keys(
+ conn, table_name, schema=schema, **kw
+ )
}
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
- "Identifier '%s' exceeds maximum length of %d characters" %
- (ident, self.max_identifier_length)
+ "Identifier '%s' exceeds maximum length of %d characters"
+ % (ident, self.max_identifier_length)
)
def connect(self, *cargs, **cparams):
@@ -417,16 +434,16 @@ class DefaultDialect(interfaces.Dialect):
return [[], opts]
def set_engine_execution_options(self, engine, opts):
- if 'isolation_level' in opts:
- isolation_level = opts['isolation_level']
+ if "isolation_level" in opts:
+ isolation_level = opts["isolation_level"]
@event.listens_for(engine, "engine_connect")
def set_isolation(connection, branch):
if not branch:
self._set_connection_isolation(connection, isolation_level)
- if 'schema_translate_map' in opts:
- getter = schema._schema_getter(opts['schema_translate_map'])
+ if "schema_translate_map" in opts:
+ getter = schema._schema_getter(opts["schema_translate_map"])
engine.schema_for_object = getter
@event.listens_for(engine, "engine_connect")
@@ -434,11 +451,11 @@ class DefaultDialect(interfaces.Dialect):
connection.schema_for_object = getter
def set_connection_execution_options(self, connection, opts):
- if 'isolation_level' in opts:
- self._set_connection_isolation(connection, opts['isolation_level'])
+ if "isolation_level" in opts:
+ self._set_connection_isolation(connection, opts["isolation_level"])
- if 'schema_translate_map' in opts:
- getter = schema._schema_getter(opts['schema_translate_map'])
+ if "schema_translate_map" in opts:
+ getter = schema._schema_getter(opts["schema_translate_map"])
connection.schema_for_object = getter
def _set_connection_isolation(self, connection, level):
@@ -447,10 +464,12 @@ class DefaultDialect(interfaces.Dialect):
"Connection is already established with a Transaction; "
"setting isolation_level may implicitly rollback or commit "
"the existing transaction, or have no effect until "
- "next transaction")
+ "next transaction"
+ )
self.set_isolation_level(connection.connection, level)
- connection.connection._connection_record.\
- finalize_callback.append(self.reset_isolation_level)
+ connection.connection._connection_record.finalize_callback.append(
+ self.reset_isolation_level
+ )
def do_begin(self, dbapi_connection):
pass
@@ -593,8 +612,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self
@classmethod
- def _init_compiled(cls, dialect, connection, dbapi_connection,
- compiled, parameters):
+ def _init_compiled(
+ cls, dialect, connection, dbapi_connection, compiled, parameters
+ ):
"""Initialize execution context for a Compiled construct."""
self = cls.__new__(cls)
@@ -609,16 +629,20 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
assert compiled.can_execute
self.execution_options = compiled.execution_options.union(
- connection._execution_options)
+ connection._execution_options
+ )
self.result_column_struct = (
- compiled._result_columns, compiled._ordered_columns,
- compiled._textual_ordered_columns)
+ compiled._result_columns,
+ compiled._ordered_columns,
+ compiled._textual_ordered_columns,
+ )
self.unicode_statement = util.text_type(compiled)
if not dialect.supports_unicode_statements:
self.statement = self.unicode_statement.encode(
- self.dialect.encoding)
+ self.dialect.encoding
+ )
else:
self.statement = self.unicode_statement
@@ -630,9 +654,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if not parameters:
self.compiled_parameters = [compiled.construct_params()]
else:
- self.compiled_parameters = \
- [compiled.construct_params(m, _group_number=grp) for
- grp, m in enumerate(parameters)]
+ self.compiled_parameters = [
+ compiled.construct_params(m, _group_number=grp)
+ for grp, m in enumerate(parameters)
+ ]
self.executemany = len(parameters) > 1
@@ -642,7 +667,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.is_crud = True
self._is_explicit_returning = bool(compiled.statement._returning)
self._is_implicit_returning = bool(
- compiled.returning and not compiled.statement._returning)
+ compiled.returning and not compiled.statement._returning
+ )
if self.compiled.insert_prefetch or self.compiled.update_prefetch:
if self.executemany:
@@ -680,7 +706,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
dialect._encoder(key)[0],
processors[key](compiled_params[key])
if key in processors
- else compiled_params[key]
+ else compiled_params[key],
)
for key in compiled_params
)
@@ -690,7 +716,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
key,
processors[key](compiled_params[key])
if key in processors
- else compiled_params[key]
+ else compiled_params[key],
)
for key in compiled_params
)
@@ -708,14 +734,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
if self.executemany:
raise exc.InvalidRequestError(
- "'expanding' parameters can't be used with "
- "executemany()")
+ "'expanding' parameters can't be used with " "executemany()"
+ )
if self.compiled.positional and self.compiled._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'
raise NotImplementedError(
"'expanding' bind parameters not supported with "
- "'numeric' paramstyle at this time.")
+ "'numeric' paramstyle at this time."
+ )
self._expanded_parameters = {}
@@ -729,7 +756,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
to_update_sets = {}
for name in (
- self.compiled.positiontup if compiled.positional
+ self.compiled.positiontup
+ if compiled.positional
else self.compiled.binds
):
parameter = self.compiled.binds[name]
@@ -748,12 +776,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if not values:
to_update = to_update_sets[name] = []
- replacement_expressions[name] = (
- self.compiled.visit_empty_set_expr(
- parameter._expanding_in_types
- if parameter._expanding_in_types
- else [parameter.type]
- )
+ replacement_expressions[
+ name
+ ] = self.compiled.visit_empty_set_expr(
+ parameter._expanding_in_types
+ if parameter._expanding_in_types
+ else [parameter.type]
)
elif isinstance(values[0], (tuple, list)):
@@ -763,15 +791,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for j, value in enumerate(tuple_element, 1)
]
replacement_expressions[name] = ", ".join(
- "(%s)" % ", ".join(
- self.compiled.bindtemplate % {
- "name":
- to_update[i * len(tuple_element) + j][0]
+ "(%s)"
+ % ", ".join(
+ self.compiled.bindtemplate
+ % {
+ "name": to_update[
+ i * len(tuple_element) + j
+ ][0]
}
for j, value in enumerate(tuple_element)
)
for i, tuple_element in enumerate(values)
-
)
else:
to_update = to_update_sets[name] = [
@@ -779,20 +809,21 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for i, value in enumerate(values, 1)
]
replacement_expressions[name] = ", ".join(
- self.compiled.bindtemplate % {
- "name": key}
+ self.compiled.bindtemplate % {"name": key}
for key, value in to_update
)
compiled_params.update(to_update)
processors.update(
(key, processors[name])
- for key, value in to_update if name in processors
+ for key, value in to_update
+ if name in processors
)
if compiled.positional:
positiontup.extend(name for name, value in to_update)
self._expanded_parameters[name] = [
- expand_key for expand_key, value in to_update]
+ expand_key for expand_key, value in to_update
+ ]
elif compiled.positional:
positiontup.append(name)
@@ -800,15 +831,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return replacement_expressions[m.group(1)]
self.statement = re.sub(
- r"\[EXPANDING_(\S+)\]",
- process_expanding,
- self.statement
+ r"\[EXPANDING_(\S+)\]", process_expanding, self.statement
)
return positiontup
@classmethod
- def _init_statement(cls, dialect, connection, dbapi_connection,
- statement, parameters):
+ def _init_statement(
+ cls, dialect, connection, dbapi_connection, statement, parameters
+ ):
"""Initialize execution context for a string SQL statement."""
self = cls.__new__(cls)
@@ -836,13 +866,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for d in parameters
] or [{}]
else:
- self.parameters = [dialect.execute_sequence_format(p)
- for p in parameters]
+ self.parameters = [
+ dialect.execute_sequence_format(p) for p in parameters
+ ]
self.executemany = len(parameters) > 1
- if not dialect.supports_unicode_statements and \
- isinstance(statement, util.text_type):
+ if not dialect.supports_unicode_statements and isinstance(
+ statement, util.text_type
+ ):
self.unicode_statement = statement
self.statement = dialect._encoder(statement)[0]
else:
@@ -890,11 +922,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
@util.memoized_property
def should_autocommit(self):
- autocommit = self.execution_options.get('autocommit',
- not self.compiled and
- self.statement and
- expression.PARSE_AUTOCOMMIT
- or False)
+ autocommit = self.execution_options.get(
+ "autocommit",
+ not self.compiled
+ and self.statement
+ and expression.PARSE_AUTOCOMMIT
+ or False,
+ )
if autocommit is expression.PARSE_AUTOCOMMIT:
return self.should_autocommit_text(self.unicode_statement)
@@ -912,8 +946,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
conn = self.root_connection
- if isinstance(stmt, util.text_type) and \
- not self.dialect.supports_unicode_statements:
+ if (
+ isinstance(stmt, util.text_type)
+ and not self.dialect.supports_unicode_statements
+ ):
stmt = self.dialect._encoder(stmt)[0]
if self.dialect.positional:
@@ -926,8 +962,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if type_ is not None:
# apply type post processors to the result
proc = type_._cached_result_processor(
- self.dialect,
- self.cursor.description[0][1]
+ self.dialect, self.cursor.description[0][1]
)
if proc:
return proc(r)
@@ -945,22 +980,30 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return False
if self.dialect.server_side_cursors:
- use_server_side = \
- self.execution_options.get('stream_results', True) and (
- (self.compiled and isinstance(self.compiled.statement,
- expression.Selectable)
- or
- (
- (not self.compiled or
- isinstance(self.compiled.statement,
- expression.TextClause))
- and self.statement and SERVER_SIDE_CURSOR_RE.match(
- self.statement))
- )
+ use_server_side = self.execution_options.get(
+ "stream_results", True
+ ) and (
+ (
+ self.compiled
+ and isinstance(
+ self.compiled.statement, expression.Selectable
+ )
+ or (
+ (
+ not self.compiled
+ or isinstance(
+ self.compiled.statement, expression.TextClause
+ )
+ )
+ and self.statement
+ and SERVER_SIDE_CURSOR_RE.match(self.statement)
+ )
)
+ )
else:
- use_server_side = \
- self.execution_options.get('stream_results', False)
+ use_server_side = self.execution_options.get(
+ "stream_results", False
+ )
return use_server_side
@@ -1039,11 +1082,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self.dialect.supports_sane_multi_rowcount
def _setup_crud_result_proxy(self):
- if self.isinsert and \
- not self.executemany:
- if not self._is_implicit_returning and \
- not self.compiled.inline and \
- self.dialect.postfetch_lastrowid:
+ if self.isinsert and not self.executemany:
+ if (
+ not self._is_implicit_returning
+ and not self.compiled.inline
+ and self.dialect.postfetch_lastrowid
+ ):
self._setup_ins_pk_from_lastrowid()
@@ -1087,12 +1131,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if autoinc_col is not None:
# apply type post processors to the lastrowid
proc = autoinc_col.type._cached_result_processor(
- self.dialect, None)
+ self.dialect, None
+ )
if proc is not None:
lastrowid = proc(lastrowid)
self.inserted_primary_key = [
- lastrowid if c is autoinc_col else
- compiled_params.get(key_getter(c), None)
+ lastrowid
+ if c is autoinc_col
+ else compiled_params.get(key_getter(c), None)
for c in table.primary_key
]
else:
@@ -1108,8 +1154,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
table = self.compiled.statement.table
compiled_params = self.compiled_parameters[0]
self.inserted_primary_key = [
- compiled_params.get(key_getter(c), None)
- for c in table.primary_key
+ compiled_params.get(key_getter(c), None) for c in table.primary_key
]
def _setup_ins_pk_from_implicit_returning(self, row):
@@ -1129,11 +1174,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
]
def lastrow_has_defaults(self):
- return (self.isinsert or self.isupdate) and \
- bool(self.compiled.postfetch)
+ return (self.isinsert or self.isupdate) and bool(
+ self.compiled.postfetch
+ )
def set_input_sizes(
- self, translate=None, include_types=None, exclude_types=None):
+ self, translate=None, include_types=None, exclude_types=None
+ ):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
@@ -1143,7 +1190,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
- if not hasattr(self.compiled, 'bind_names'):
+ if not hasattr(self.compiled, "bind_names"):
return
inputsizes = {}
@@ -1153,12 +1200,18 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
dialect_impl_cls = type(dialect_impl)
dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None and (
- not exclude_types or dbtype not in exclude_types and
- dialect_impl_cls not in exclude_types
- ) and (
- not include_types or dbtype in include_types or
- dialect_impl_cls in include_types
+ if (
+ dbtype is not None
+ and (
+ not exclude_types
+ or dbtype not in exclude_types
+ and dialect_impl_cls not in exclude_types
+ )
+ and (
+ not include_types
+ or dbtype in include_types
+ or dialect_impl_cls in include_types
+ )
):
inputsizes[bindparam] = dbtype
else:
@@ -1177,14 +1230,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if dbtype is not None:
if key in self._expanded_parameters:
positional_inputsizes.extend(
- [dbtype] * len(self._expanded_parameters[key]))
+ [dbtype] * len(self._expanded_parameters[key])
+ )
else:
positional_inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*positional_inputsizes)
except BaseException as e:
self.root_connection._handle_dbapi_exception(
- e, None, None, None, self)
+ e, None, None, None, self
+ )
else:
keyword_inputsizes = {}
for bindparam, key in self.compiled.bind_names.items():
@@ -1199,8 +1254,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
key = self.dialect._encoder(key)[0]
if key in self._expanded_parameters:
keyword_inputsizes.update(
- (expand_key, dbtype) for expand_key
- in self._expanded_parameters[key]
+ (expand_key, dbtype)
+ for expand_key in self._expanded_parameters[key]
)
else:
keyword_inputsizes[key] = dbtype
@@ -1208,7 +1263,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.cursor.setinputsizes(**keyword_inputsizes)
except BaseException as e:
self.root_connection._handle_dbapi_exception(
- e, None, None, None, self)
+ e, None, None, None, self
+ )
def _exec_default(self, column, default, type_):
if default.is_sequence:
@@ -1290,10 +1346,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
except AttributeError:
raise exc.InvalidRequestError(
"get_current_parameters() can only be invoked in the "
- "context of a Python side column default function")
- if isolate_multiinsert_groups and \
- self.isinsert and \
- self.compiled.statement._has_multi_parameters:
+ "context of a Python side column default function"
+ )
+ if (
+ isolate_multiinsert_groups
+ and self.isinsert
+ and self.compiled.statement._has_multi_parameters
+ ):
if column._is_multiparam_column:
index = column.index + 1
d = {column.original.key: parameters[column.key]}
@@ -1302,8 +1361,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
index = 0
keys = self.compiled.statement.parameters[0].keys()
d.update(
- (key, parameters["%s_m%d" % (key, index)])
- for key in keys
+ (key, parameters["%s_m%d" % (key, index)]) for key in keys
)
return d
else:
@@ -1360,12 +1418,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
def _process_executesingle_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
- self.current_parameters = compiled_parameters = \
- self.compiled_parameters[0]
+ self.current_parameters = (
+ compiled_parameters
+ ) = self.compiled_parameters[0]
for c in self.compiled.insert_prefetch:
- if c.default and \
- not c.default.is_sequence and c.default.is_scalar:
+ if c.default and not c.default.is_sequence and c.default.is_scalar:
val = c.default.arg
else:
val = self.get_insert_default(c)
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 9c3b24e9a..e10e6e884 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -198,7 +198,8 @@ class Dialect(object):
pass
def reflecttable(
- self, connection, table, include_columns, exclude_columns):
+ self, connection, table, include_columns, exclude_columns
+ ):
"""Load table description from the database.
Given a :class:`.Connection` and a
@@ -367,7 +368,8 @@ class Dialect(object):
raise NotImplementedError()
def get_unique_constraints(
- self, connection, table_name, schema=None, **kw):
+ self, connection, table_name, schema=None, **kw
+ ):
r"""Return information about unique constraints in `table_name`.
Given a string `table_name` and an optional string `schema`, return
@@ -389,8 +391,7 @@ class Dialect(object):
raise NotImplementedError()
- def get_check_constraints(
- self, connection, table_name, schema=None, **kw):
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
r"""Return information about check constraints in `table_name`.
Given a string `table_name` and an optional string `schema`, return
@@ -412,8 +413,7 @@ class Dialect(object):
raise NotImplementedError()
- def get_table_comment(
- self, connection, table_name, schema=None, **kw):
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
r"""Return the "comment" for the table identified by `table_name`.
Given a string `table_name` and an optional string `schema`, return
@@ -613,8 +613,9 @@ class Dialect(object):
raise NotImplementedError()
- def do_rollback_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
"""Rollback a two phase transaction on the given connection.
:param connection: a :class:`.Connection`.
@@ -627,8 +628,9 @@ class Dialect(object):
raise NotImplementedError()
- def do_commit_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
"""Commit a two phase transaction on the given connection.
@@ -664,8 +666,9 @@ class Dialect(object):
raise NotImplementedError()
- def do_execute_no_params(self, cursor, statement, parameters,
- context=None):
+ def do_execute_no_params(
+ self, cursor, statement, parameters, context=None
+ ):
"""Provide an implementation of ``cursor.execute(statement)``.
The parameter collection should not be sent.
@@ -899,6 +902,7 @@ class CreateEnginePlugin(object):
.. versionadded:: 1.1
"""
+
def __init__(self, url, kwargs):
"""Contruct a new :class:`.CreateEnginePlugin`.
@@ -1129,20 +1133,24 @@ class Connectable(object):
raise NotImplementedError()
- @util.deprecated("0.7",
- "Use the create() method on the given schema "
- "object directly, i.e. :meth:`.Table.create`, "
- ":meth:`.Index.create`, :meth:`.MetaData.create_all`")
+ @util.deprecated(
+ "0.7",
+ "Use the create() method on the given schema "
+ "object directly, i.e. :meth:`.Table.create`, "
+ ":meth:`.Index.create`, :meth:`.MetaData.create_all`",
+ )
def create(self, entity, **kwargs):
"""Emit CREATE statements for the given schema entity.
"""
raise NotImplementedError()
- @util.deprecated("0.7",
- "Use the drop() method on the given schema "
- "object directly, i.e. :meth:`.Table.drop`, "
- ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`")
+ @util.deprecated(
+ "0.7",
+ "Use the drop() method on the given schema "
+ "object directly, i.e. :meth:`.Table.drop`, "
+ ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`",
+ )
def drop(self, entity, **kwargs):
"""Emit DROP statements for the given schema entity.
"""
@@ -1160,8 +1168,7 @@ class Connectable(object):
"""
raise NotImplementedError()
- def _run_visitor(self, visitorcallable, element,
- **kwargs):
+ def _run_visitor(self, visitorcallable, element, **kwargs):
raise NotImplementedError()
def _execute_clauseelement(self, elem, multiparams=None, params=None):
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
index 841bb4dfb..9b5fa2459 100644
--- a/lib/sqlalchemy/engine/reflection.py
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -37,17 +37,17 @@ from .base import Connectable
@util.decorator
def cache(fn, self, con, *args, **kw):
- info_cache = kw.get('info_cache', None)
+ info_cache = kw.get("info_cache", None)
if info_cache is None:
return fn(self, con, *args, **kw)
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, util.string_types)),
- tuple((k, v) for k, v in kw.items() if
- isinstance(v,
- util.string_types + util.int_types + (float, )
- )
- )
+ tuple(
+ (k, v)
+ for k, v in kw.items()
+ if isinstance(v, util.string_types + util.int_types + (float,))
+ ),
)
ret = info_cache.get(key)
if ret is None:
@@ -99,7 +99,7 @@ class Inspector(object):
self.bind = bind
# set the engine
- if hasattr(bind, 'engine'):
+ if hasattr(bind, "engine"):
self.engine = bind.engine
else:
self.engine = bind
@@ -130,7 +130,7 @@ class Inspector(object):
See the example at :class:`.Inspector`.
"""
- if hasattr(bind.dialect, 'inspector'):
+ if hasattr(bind.dialect, "inspector"):
return bind.dialect.inspector(bind)
return Inspector(bind)
@@ -153,9 +153,10 @@ class Inspector(object):
"""Return all schema names.
"""
- if hasattr(self.dialect, 'get_schema_names'):
- return self.dialect.get_schema_names(self.bind,
- info_cache=self.info_cache)
+ if hasattr(self.dialect, "get_schema_names"):
+ return self.dialect.get_schema_names(
+ self.bind, info_cache=self.info_cache
+ )
return []
def get_table_names(self, schema=None, order_by=None):
@@ -196,17 +197,18 @@ class Inspector(object):
"""
- if hasattr(self.dialect, 'get_table_names'):
+ if hasattr(self.dialect, "get_table_names"):
tnames = self.dialect.get_table_names(
- self.bind, schema, info_cache=self.info_cache)
+ self.bind, schema, info_cache=self.info_cache
+ )
else:
tnames = self.engine.table_names(schema)
- if order_by == 'foreign_key':
+ if order_by == "foreign_key":
tuples = []
for tname in tnames:
for fkey in self.get_foreign_keys(tname, schema):
- if tname != fkey['referred_table']:
- tuples.append((fkey['referred_table'], tname))
+ if tname != fkey["referred_table"]:
+ tuples.append((fkey["referred_table"], tname))
tnames = list(topological.sort(tuples, tnames))
return tnames
@@ -234,9 +236,10 @@ class Inspector(object):
with an already-given :class:`.MetaData`.
"""
- if hasattr(self.dialect, 'get_table_names'):
+ if hasattr(self.dialect, "get_table_names"):
tnames = self.dialect.get_table_names(
- self.bind, schema, info_cache=self.info_cache)
+ self.bind, schema, info_cache=self.info_cache
+ )
else:
tnames = self.engine.table_names(schema)
@@ -246,20 +249,17 @@ class Inspector(object):
fknames_for_table = {}
for tname in tnames:
fkeys = self.get_foreign_keys(tname, schema)
- fknames_for_table[tname] = set(
- [fk['name'] for fk in fkeys]
- )
+ fknames_for_table[tname] = set([fk["name"] for fk in fkeys])
for fkey in fkeys:
- if tname != fkey['referred_table']:
- tuples.add((fkey['referred_table'], tname))
+ if tname != fkey["referred_table"]:
+ tuples.add((fkey["referred_table"], tname))
try:
candidate_sort = list(topological.sort(tuples, tnames))
except exc.CircularDependencyError as err:
for edge in err.edges:
tuples.remove(edge)
remaining_fkcs.update(
- (edge[1], fkc)
- for fkc in fknames_for_table[edge[1]]
+ (edge[1], fkc) for fkc in fknames_for_table[edge[1]]
)
candidate_sort = list(topological.sort(tuples, tnames))
@@ -278,7 +278,8 @@ class Inspector(object):
"""
return self.dialect.get_temp_table_names(
- self.bind, info_cache=self.info_cache)
+ self.bind, info_cache=self.info_cache
+ )
def get_temp_view_names(self):
"""return a list of temporary view names for the current bind.
@@ -290,7 +291,8 @@ class Inspector(object):
"""
return self.dialect.get_temp_view_names(
- self.bind, info_cache=self.info_cache)
+ self.bind, info_cache=self.info_cache
+ )
def get_table_options(self, table_name, schema=None, **kw):
"""Return a dictionary of options specified when the table of the
@@ -306,10 +308,10 @@ class Inspector(object):
use :class:`.quoted_name`.
"""
- if hasattr(self.dialect, 'get_table_options'):
+ if hasattr(self.dialect, "get_table_options"):
return self.dialect.get_table_options(
- self.bind, table_name, schema,
- info_cache=self.info_cache, **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
return {}
def get_view_names(self, schema=None):
@@ -320,8 +322,9 @@ class Inspector(object):
"""
- return self.dialect.get_view_names(self.bind, schema,
- info_cache=self.info_cache)
+ return self.dialect.get_view_names(
+ self.bind, schema, info_cache=self.info_cache
+ )
def get_view_definition(self, view_name, schema=None):
"""Return definition for `view_name`.
@@ -332,7 +335,8 @@ class Inspector(object):
"""
return self.dialect.get_view_definition(
- self.bind, view_name, schema, info_cache=self.info_cache)
+ self.bind, view_name, schema, info_cache=self.info_cache
+ )
def get_columns(self, table_name, schema=None, **kw):
"""Return information about columns in `table_name`.
@@ -364,18 +368,21 @@ class Inspector(object):
"""
- col_defs = self.dialect.get_columns(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)
+ col_defs = self.dialect.get_columns(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
for col_def in col_defs:
# make this easy and only return instances for coltype
- coltype = col_def['type']
+ coltype = col_def["type"]
if not isinstance(coltype, TypeEngine):
- col_def['type'] = coltype()
+ col_def["type"] = coltype()
return col_defs
- @deprecated('0.7', 'Call to deprecated method get_primary_keys.'
- ' Use get_pk_constraint instead.')
+ @deprecated(
+ "0.7",
+ "Call to deprecated method get_primary_keys."
+ " Use get_pk_constraint instead.",
+ )
def get_primary_keys(self, table_name, schema=None, **kw):
"""Return information about primary keys in `table_name`.
@@ -383,9 +390,9 @@ class Inspector(object):
primary key information as a list of column names.
"""
- return self.dialect.get_pk_constraint(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)['constrained_columns']
+ return self.dialect.get_pk_constraint(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )["constrained_columns"]
def get_pk_constraint(self, table_name, schema=None, **kw):
"""Return information about primary key constraint on `table_name`.
@@ -407,9 +414,9 @@ class Inspector(object):
use :class:`.quoted_name`.
"""
- return self.dialect.get_pk_constraint(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)
+ return self.dialect.get_pk_constraint(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_foreign_keys(self, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`.
@@ -442,9 +449,9 @@ class Inspector(object):
"""
- return self.dialect.get_foreign_keys(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)
+ return self.dialect.get_foreign_keys(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_indexes(self, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`.
@@ -476,9 +483,9 @@ class Inspector(object):
"""
- return self.dialect.get_indexes(self.bind, table_name,
- schema,
- info_cache=self.info_cache, **kw)
+ return self.dialect.get_indexes(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_unique_constraints(self, table_name, schema=None, **kw):
"""Return information about unique constraints in `table_name`.
@@ -504,7 +511,8 @@ class Inspector(object):
"""
return self.dialect.get_unique_constraints(
- self.bind, table_name, schema, info_cache=self.info_cache, **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_table_comment(self, table_name, schema=None, **kw):
"""Return information about the table comment for ``table_name``.
@@ -523,8 +531,8 @@ class Inspector(object):
"""
return self.dialect.get_table_comment(
- self.bind, table_name, schema, info_cache=self.info_cache,
- **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_check_constraints(self, table_name, schema=None, **kw):
"""Return information about check constraints in `table_name`.
@@ -550,10 +558,12 @@ class Inspector(object):
"""
return self.dialect.get_check_constraints(
- self.bind, table_name, schema, info_cache=self.info_cache, **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
- def reflecttable(self, table, include_columns, exclude_columns=(),
- _extend_on=None):
+ def reflecttable(
+ self, table, include_columns, exclude_columns=(), _extend_on=None
+ ):
"""Given a Table object, load its internal constructs based on
introspection.
@@ -599,7 +609,8 @@ class Inspector(object):
# reflect table options, like mysql_engine
tbl_opts = self.get_table_options(
- table_name, schema, **table.dialect_kwargs)
+ table_name, schema, **table.dialect_kwargs
+ )
if tbl_opts:
# add additional kwargs to the Table if the dialect
# returned them
@@ -615,185 +626,251 @@ class Inspector(object):
cols_by_orig_name = {}
for col_d in self.get_columns(
- table_name, schema, **table.dialect_kwargs):
+ table_name, schema, **table.dialect_kwargs
+ ):
found_table = True
self._reflect_column(
- table, col_d, include_columns,
- exclude_columns, cols_by_orig_name)
+ table,
+ col_d,
+ include_columns,
+ exclude_columns,
+ cols_by_orig_name,
+ )
if not found_table:
raise exc.NoSuchTableError(table.name)
self._reflect_pk(
- table_name, schema, table, cols_by_orig_name, exclude_columns)
+ table_name, schema, table, cols_by_orig_name, exclude_columns
+ )
self._reflect_fk(
- table_name, schema, table, cols_by_orig_name,
- exclude_columns, _extend_on, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ exclude_columns,
+ _extend_on,
+ reflection_options,
+ )
self._reflect_indexes(
- table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
self._reflect_unique_constraints(
- table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
self._reflect_check_constraints(
- table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
self._reflect_table_comment(
table_name, schema, table, reflection_options
)
def _reflect_column(
- self, table, col_d, include_columns,
- exclude_columns, cols_by_orig_name):
+ self, table, col_d, include_columns, exclude_columns, cols_by_orig_name
+ ):
- orig_name = col_d['name']
+ orig_name = col_d["name"]
table.dispatch.column_reflect(self, table, col_d)
# fetch name again as column_reflect is allowed to
# change it
- name = col_d['name']
- if (include_columns and name not in include_columns) \
- or (exclude_columns and name in exclude_columns):
+ name = col_d["name"]
+ if (include_columns and name not in include_columns) or (
+ exclude_columns and name in exclude_columns
+ ):
return
- coltype = col_d['type']
+ coltype = col_d["type"]
col_kw = dict(
(k, col_d[k])
for k in [
- 'nullable', 'autoincrement', 'quote', 'info', 'key',
- 'comment']
+ "nullable",
+ "autoincrement",
+ "quote",
+ "info",
+ "key",
+ "comment",
+ ]
if k in col_d
)
- if 'dialect_options' in col_d:
- col_kw.update(col_d['dialect_options'])
+ if "dialect_options" in col_d:
+ col_kw.update(col_d["dialect_options"])
colargs = []
- if col_d.get('default') is not None:
- default = col_d['default']
+ if col_d.get("default") is not None:
+ default = col_d["default"]
if isinstance(default, sql.elements.TextClause):
default = sa_schema.DefaultClause(default, _reflected=True)
elif not isinstance(default, sa_schema.FetchedValue):
default = sa_schema.DefaultClause(
- sql.text(col_d['default']), _reflected=True)
+ sql.text(col_d["default"]), _reflected=True
+ )
colargs.append(default)
- if 'sequence' in col_d:
+ if "sequence" in col_d:
self._reflect_col_sequence(col_d, colargs)
- cols_by_orig_name[orig_name] = col = \
- sa_schema.Column(name, coltype, *colargs, **col_kw)
+ cols_by_orig_name[orig_name] = col = sa_schema.Column(
+ name, coltype, *colargs, **col_kw
+ )
if col.key in table.primary_key:
col.primary_key = True
table.append_column(col)
def _reflect_col_sequence(self, col_d, colargs):
- if 'sequence' in col_d:
+ if "sequence" in col_d:
# TODO: mssql and sybase are using this.
- seq = col_d['sequence']
- sequence = sa_schema.Sequence(seq['name'], 1, 1)
- if 'start' in seq:
- sequence.start = seq['start']
- if 'increment' in seq:
- sequence.increment = seq['increment']
+ seq = col_d["sequence"]
+ sequence = sa_schema.Sequence(seq["name"], 1, 1)
+ if "start" in seq:
+ sequence.start = seq["start"]
+ if "increment" in seq:
+ sequence.increment = seq["increment"]
colargs.append(sequence)
def _reflect_pk(
- self, table_name, schema, table,
- cols_by_orig_name, exclude_columns):
+ self, table_name, schema, table, cols_by_orig_name, exclude_columns
+ ):
pk_cons = self.get_pk_constraint(
- table_name, schema, **table.dialect_kwargs)
+ table_name, schema, **table.dialect_kwargs
+ )
if pk_cons:
pk_cols = [
cols_by_orig_name[pk]
- for pk in pk_cons['constrained_columns']
+ for pk in pk_cons["constrained_columns"]
if pk in cols_by_orig_name and pk not in exclude_columns
]
# update pk constraint name
- table.primary_key.name = pk_cons.get('name')
+ table.primary_key.name = pk_cons.get("name")
# tell the PKConstraint to re-initialize
# its column collection
table.primary_key._reload(pk_cols)
def _reflect_fk(
- self, table_name, schema, table, cols_by_orig_name,
- exclude_columns, _extend_on, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ exclude_columns,
+ _extend_on,
+ reflection_options,
+ ):
fkeys = self.get_foreign_keys(
- table_name, schema, **table.dialect_kwargs)
+ table_name, schema, **table.dialect_kwargs
+ )
for fkey_d in fkeys:
- conname = fkey_d['name']
+ conname = fkey_d["name"]
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
constrained_columns = [
- cols_by_orig_name[c].key
- if c in cols_by_orig_name else c
- for c in fkey_d['constrained_columns']
+ cols_by_orig_name[c].key if c in cols_by_orig_name else c
+ for c in fkey_d["constrained_columns"]
]
if exclude_columns and set(constrained_columns).intersection(
- exclude_columns):
+ exclude_columns
+ ):
continue
- referred_schema = fkey_d['referred_schema']
- referred_table = fkey_d['referred_table']
- referred_columns = fkey_d['referred_columns']
+ referred_schema = fkey_d["referred_schema"]
+ referred_table = fkey_d["referred_table"]
+ referred_columns = fkey_d["referred_columns"]
refspec = []
if referred_schema is not None:
- sa_schema.Table(referred_table, table.metadata,
- autoload=True, schema=referred_schema,
- autoload_with=self.bind,
- _extend_on=_extend_on,
- **reflection_options
- )
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ autoload=True,
+ schema=referred_schema,
+ autoload_with=self.bind,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
for column in referred_columns:
- refspec.append(".".join(
- [referred_schema, referred_table, column]))
+ refspec.append(
+ ".".join([referred_schema, referred_table, column])
+ )
else:
- sa_schema.Table(referred_table, table.metadata, autoload=True,
- autoload_with=self.bind,
- schema=sa_schema.BLANK_SCHEMA,
- _extend_on=_extend_on,
- **reflection_options
- )
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ autoload=True,
+ autoload_with=self.bind,
+ schema=sa_schema.BLANK_SCHEMA,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
- if 'options' in fkey_d:
- options = fkey_d['options']
+ if "options" in fkey_d:
+ options = fkey_d["options"]
else:
options = {}
table.append_constraint(
- sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
- conname, link_to_name=True,
- **options))
+ sa_schema.ForeignKeyConstraint(
+ constrained_columns,
+ refspec,
+ conname,
+ link_to_name=True,
+ **options
+ )
+ )
def _reflect_indexes(
- self, table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
# Indexes
indexes = self.get_indexes(table_name, schema)
for index_d in indexes:
- name = index_d['name']
- columns = index_d['column_names']
- unique = index_d['unique']
- flavor = index_d.get('type', 'index')
- dialect_options = index_d.get('dialect_options', {})
-
- duplicates = index_d.get('duplicates_constraint')
- if include_columns and \
- not set(columns).issubset(include_columns):
+ name = index_d["name"]
+ columns = index_d["column_names"]
+ unique = index_d["unique"]
+ flavor = index_d.get("type", "index")
+ dialect_options = index_d.get("dialect_options", {})
+
+ duplicates = index_d.get("duplicates_constraint")
+ if include_columns and not set(columns).issubset(include_columns):
util.warn(
- "Omitting %s key for (%s), key covers omitted columns." %
- (flavor, ', '.join(columns)))
+ "Omitting %s key for (%s), key covers omitted columns."
+ % (flavor, ", ".join(columns))
+ )
continue
if duplicates:
continue
@@ -802,26 +879,36 @@ class Inspector(object):
idx_cols = []
for c in columns:
try:
- idx_col = cols_by_orig_name[c] \
- if c in cols_by_orig_name else table.c[c]
+ idx_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
except KeyError:
util.warn(
"%s key '%s' was not located in "
- "columns for table '%s'" % (
- flavor, c, table_name
- ))
+ "columns for table '%s'" % (flavor, c, table_name)
+ )
else:
idx_cols.append(idx_col)
sa_schema.Index(
- name, *idx_cols,
+ name,
+ *idx_cols,
_table=table,
- **dict(list(dialect_options.items()) + [('unique', unique)])
+ **dict(list(dialect_options.items()) + [("unique", unique)])
)
def _reflect_unique_constraints(
- self, table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
# Unique Constraints
try:
@@ -831,15 +918,14 @@ class Inspector(object):
return
for const_d in constraints:
- conname = const_d['name']
- columns = const_d['column_names']
- duplicates = const_d.get('duplicates_index')
- if include_columns and \
- not set(columns).issubset(include_columns):
+ conname = const_d["name"]
+ columns = const_d["column_names"]
+ duplicates = const_d.get("duplicates_index")
+ if include_columns and not set(columns).issubset(include_columns):
util.warn(
"Omitting unique constraint key for (%s), "
- "key covers omitted columns." %
- ', '.join(columns))
+ "key covers omitted columns." % ", ".join(columns)
+ )
continue
if duplicates:
continue
@@ -848,20 +934,32 @@ class Inspector(object):
constrained_cols = []
for c in columns:
try:
- constrained_col = cols_by_orig_name[c] \
- if c in cols_by_orig_name else table.c[c]
+ constrained_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
except KeyError:
util.warn(
"unique constraint key '%s' was not located in "
- "columns for table '%s'" % (c, table_name))
+ "columns for table '%s'" % (c, table_name)
+ )
else:
constrained_cols.append(constrained_col)
table.append_constraint(
- sa_schema.UniqueConstraint(*constrained_cols, name=conname))
+ sa_schema.UniqueConstraint(*constrained_cols, name=conname)
+ )
def _reflect_check_constraints(
- self, table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
try:
constraints = self.get_check_constraints(table_name, schema)
except NotImplementedError:
@@ -869,14 +967,14 @@ class Inspector(object):
return
for const_d in constraints:
- table.append_constraint(
- sa_schema.CheckConstraint(**const_d))
+ table.append_constraint(sa_schema.CheckConstraint(**const_d))
def _reflect_table_comment(
- self, table_name, schema, table, reflection_options):
+ self, table_name, schema, table, reflection_options
+ ):
try:
comment_dict = self.get_table_comment(table_name, schema)
except NotImplementedError:
return
else:
- table.comment = comment_dict.get('text', None)
+ table.comment = comment_dict.get("text", None)
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index d4c862375..5ad0d2909 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -27,20 +27,25 @@ try:
# the extension is present.
def rowproxy_reconstructor(cls, state):
return safe_rowproxy_reconstructor(cls, state)
+
+
except ImportError:
+
def rowproxy_reconstructor(cls, state):
obj = cls.__new__(cls)
obj.__setstate__(state)
return obj
+
try:
from sqlalchemy.cresultproxy import BaseRowProxy
+
_baserowproxy_usecext = True
except ImportError:
_baserowproxy_usecext = False
class BaseRowProxy(object):
- __slots__ = ('_parent', '_row', '_processors', '_keymap')
+ __slots__ = ("_parent", "_row", "_processors", "_keymap")
def __init__(self, parent, row, processors, keymap):
"""RowProxy objects are constructed by ResultProxy objects."""
@@ -51,8 +56,10 @@ except ImportError:
self._keymap = keymap
def __reduce__(self):
- return (rowproxy_reconstructor,
- (self.__class__, self.__getstate__()))
+ return (
+ rowproxy_reconstructor,
+ (self.__class__, self.__getstate__()),
+ )
def values(self):
"""Return the values represented by this RowProxy as a list."""
@@ -76,8 +83,9 @@ except ImportError:
except TypeError:
if isinstance(key, slice):
l = []
- for processor, value in zip(self._processors[key],
- self._row[key]):
+ for processor, value in zip(
+ self._processors[key], self._row[key]
+ ):
if processor is None:
l.append(value)
else:
@@ -88,7 +96,8 @@ except ImportError:
if index is None:
raise exc.InvalidRequestError(
"Ambiguous column name '%s' in "
- "result set column descriptions" % obj)
+ "result set column descriptions" % obj
+ )
if processor is not None:
return processor(self._row[index])
else:
@@ -110,29 +119,29 @@ class RowProxy(BaseRowProxy):
mapped to the original Columns that produced this result set (for
results that correspond to constructed SQL expressions).
"""
+
__slots__ = ()
def __contains__(self, key):
return self._parent._has_key(key)
def __getstate__(self):
- return {
- '_parent': self._parent,
- '_row': tuple(self)
- }
+ return {"_parent": self._parent, "_row": tuple(self)}
def __setstate__(self, state):
- self._parent = parent = state['_parent']
- self._row = state['_row']
+ self._parent = parent = state["_parent"]
+ self._row = state["_row"]
self._processors = parent._processors
self._keymap = parent._keymap
__hash__ = None
def _op(self, other, op):
- return op(tuple(self), tuple(other)) \
- if isinstance(other, RowProxy) \
+ return (
+ op(tuple(self), tuple(other))
+ if isinstance(other, RowProxy)
else op(tuple(self), other)
+ )
def __lt__(self, other):
return self._op(other, operator.lt)
@@ -176,6 +185,7 @@ class RowProxy(BaseRowProxy):
def itervalues(self):
return iter(self)
+
try:
# Register RowProxy with Sequence,
# so sequence protocol is implemented
@@ -189,8 +199,13 @@ class ResultMetaData(object):
context."""
__slots__ = (
- '_keymap', 'case_sensitive', 'matched_on_name',
- '_processors', 'keys', '_orig_processors')
+ "_keymap",
+ "case_sensitive",
+ "matched_on_name",
+ "_processors",
+ "keys",
+ "_orig_processors",
+ )
def __init__(self, parent, cursor_description):
context = parent.context
@@ -200,18 +215,25 @@ class ResultMetaData(object):
self._orig_processors = None
if context.result_column_struct:
- result_columns, cols_are_ordered, textual_ordered = \
+ result_columns, cols_are_ordered, textual_ordered = (
context.result_column_struct
+ )
num_ctx_cols = len(result_columns)
else:
- result_columns = cols_are_ordered = \
- num_ctx_cols = textual_ordered = False
+ result_columns = (
+ cols_are_ordered
+ ) = num_ctx_cols = textual_ordered = False
# merge cursor.description with the column info
# present in the compiled structure, if any
raw = self._merge_cursor_description(
- context, cursor_description, result_columns,
- num_ctx_cols, cols_are_ordered, textual_ordered)
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ )
self._keymap = {}
if not _baserowproxy_usecext:
@@ -223,23 +245,20 @@ class ResultMetaData(object):
len_raw = len(raw)
- self._keymap.update([
- (elem[0], (elem[3], elem[4], elem[0]))
- for elem in raw
- ] + [
- (elem[0] - len_raw, (elem[3], elem[4], elem[0]))
- for elem in raw
- ])
+ self._keymap.update(
+ [(elem[0], (elem[3], elem[4], elem[0])) for elem in raw]
+ + [
+ (elem[0] - len_raw, (elem[3], elem[4], elem[0]))
+ for elem in raw
+ ]
+ )
# processors in key order for certain per-row
# views like __iter__ and slices
self._processors = [elem[3] for elem in raw]
# keymap by primary string...
- by_key = dict([
- (elem[2], (elem[3], elem[4], elem[0]))
- for elem in raw
- ])
+ by_key = dict([(elem[2], (elem[3], elem[4], elem[0])) for elem in raw])
# for compiled SQL constructs, copy additional lookup keys into
# the key lookup map, such as Column objects, labels,
@@ -264,29 +283,38 @@ class ResultMetaData(object):
# copy secondary elements from compiled columns
# into self._keymap, write in the potentially "ambiguous"
# element
- self._keymap.update([
- (obj_elem, by_key[elem[2]])
- for elem in raw if elem[4]
- for obj_elem in elem[4]
- ])
+ self._keymap.update(
+ [
+ (obj_elem, by_key[elem[2]])
+ for elem in raw
+ if elem[4]
+ for obj_elem in elem[4]
+ ]
+ )
# if we did a pure positional match, then reset the
# original "expression element" back to the "unambiguous"
# entry. This is a new behavior in 1.1 which impacts
# TextAsFrom but also straight compiled SQL constructs.
if not self.matched_on_name:
- self._keymap.update([
- (elem[4][0], (elem[3], elem[4], elem[0]))
- for elem in raw if elem[4]
- ])
+ self._keymap.update(
+ [
+ (elem[4][0], (elem[3], elem[4], elem[0]))
+ for elem in raw
+ if elem[4]
+ ]
+ )
else:
# no dupes - copy secondary elements from compiled
# columns into self._keymap
- self._keymap.update([
- (obj_elem, (elem[3], elem[4], elem[0]))
- for elem in raw if elem[4]
- for obj_elem in elem[4]
- ])
+ self._keymap.update(
+ [
+ (obj_elem, (elem[3], elem[4], elem[0]))
+ for elem in raw
+ if elem[4]
+ for obj_elem in elem[4]
+ ]
+ )
# update keymap with primary string names taking
# precedence
@@ -294,14 +322,19 @@ class ResultMetaData(object):
# update keymap with "translated" names (sqlite-only thing)
if not num_ctx_cols and context._translate_colname:
- self._keymap.update([
- (elem[5], self._keymap[elem[2]])
- for elem in raw if elem[5]
- ])
+ self._keymap.update(
+ [(elem[5], self._keymap[elem[2]]) for elem in raw if elem[5]]
+ )
def _merge_cursor_description(
- self, context, cursor_description, result_columns,
- num_ctx_cols, cols_are_ordered, textual_ordered):
+ self,
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ ):
"""Merge a cursor.description with compiled result column information.
There are at least four separate strategies used here, selected
@@ -357,10 +390,12 @@ class ResultMetaData(object):
case_sensitive = context.dialect.case_sensitive
- if num_ctx_cols and \
- cols_are_ordered and \
- not textual_ordered and \
- num_ctx_cols == len(cursor_description):
+ if (
+ num_ctx_cols
+ and cols_are_ordered
+ and not textual_ordered
+ and num_ctx_cols == len(cursor_description)
+ ):
self.keys = [elem[0] for elem in result_columns]
# pure positional 1-1 case; doesn't need to read
# the names from cursor.description
@@ -373,9 +408,9 @@ class ResultMetaData(object):
type_, key, cursor_description[idx][1]
),
obj,
- None
- ) for idx, (key, name, obj, type_)
- in enumerate(result_columns)
+ None,
+ )
+ for idx, (key, name, obj, type_) in enumerate(result_columns)
]
else:
# name-based or text-positional cases, where we need
@@ -383,26 +418,32 @@ class ResultMetaData(object):
if textual_ordered:
# textual positional case
raw_iterator = self._merge_textual_cols_by_position(
- context, cursor_description, result_columns)
+ context, cursor_description, result_columns
+ )
elif num_ctx_cols:
# compiled SQL with a mismatch of description cols
# vs. compiled cols, or textual w/ unordered columns
raw_iterator = self._merge_cols_by_name(
- context, cursor_description, result_columns)
+ context, cursor_description, result_columns
+ )
else:
# no compiled SQL, just a raw string
raw_iterator = self._merge_cols_by_none(
- context, cursor_description)
+ context, cursor_description
+ )
return [
(
- idx, colname, colname,
+ idx,
+ colname,
+ colname,
context.get_result_processor(
- mapped_type, colname, coltype),
- obj, untranslated)
-
- for idx, colname, mapped_type, coltype, obj, untranslated
- in raw_iterator
+ mapped_type, colname, coltype
+ ),
+ obj,
+ untranslated,
+ )
+ for idx, colname, mapped_type, coltype, obj, untranslated in raw_iterator
]
def _colnames_from_description(self, context, cursor_description):
@@ -416,10 +457,14 @@ class ResultMetaData(object):
dialect = context.dialect
case_sensitive = dialect.case_sensitive
translate_colname = context._translate_colname
- description_decoder = dialect._description_decoder \
- if dialect.description_encoding else None
- normalize_name = dialect.normalize_name \
- if dialect.requires_name_normalize else None
+ description_decoder = (
+ dialect._description_decoder
+ if dialect.description_encoding
+ else None
+ )
+ normalize_name = (
+ dialect.normalize_name if dialect.requires_name_normalize else None
+ )
untranslated = None
self.keys = []
@@ -444,20 +489,25 @@ class ResultMetaData(object):
yield idx, colname, untranslated, coltype
def _merge_textual_cols_by_position(
- self, context, cursor_description, result_columns):
+ self, context, cursor_description, result_columns
+ ):
dialect = context.dialect
num_ctx_cols = len(result_columns) if result_columns else None
if num_ctx_cols > len(cursor_description):
util.warn(
"Number of columns in textual SQL (%d) is "
- "smaller than number of columns requested (%d)" % (
- num_ctx_cols, len(cursor_description)
- ))
+ "smaller than number of columns requested (%d)"
+ % (num_ctx_cols, len(cursor_description))
+ )
seen = set()
- for idx, colname, untranslated, coltype in \
- self._colnames_from_description(context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
if idx < num_ctx_cols:
ctx_rec = result_columns[idx]
obj = ctx_rec[2]
@@ -465,7 +515,8 @@ class ResultMetaData(object):
if obj[0] in seen:
raise exc.InvalidRequestError(
"Duplicate column expression requested "
- "in textual SQL: %r" % obj[0])
+ "in textual SQL: %r" % obj[0]
+ )
seen.add(obj[0])
else:
mapped_type = sqltypes.NULLTYPE
@@ -479,8 +530,12 @@ class ResultMetaData(object):
result_map = self._create_result_map(result_columns, case_sensitive)
self.matched_on_name = True
- for idx, colname, untranslated, coltype in \
- self._colnames_from_description(context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
try:
ctx_rec = result_map[colname]
except KeyError:
@@ -493,8 +548,12 @@ class ResultMetaData(object):
def _merge_cols_by_none(self, context, cursor_description):
dialect = context.dialect
- for idx, colname, untranslated, coltype in \
- self._colnames_from_description(context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated
@classmethod
@@ -525,27 +584,28 @@ class ResultMetaData(object):
# or colummn('name') constructs to ColumnElements, or after a
# pickle/unpickle roundtrip
elif isinstance(key, expression.ColumnElement):
- if key._label and (
- key._label
- if self.case_sensitive
- else key._label.lower()) in map:
- result = map[key._label
- if self.case_sensitive
- else key._label.lower()]
- elif hasattr(key, 'name') and (
- key.name
- if self.case_sensitive
- else key.name.lower()) in map:
+ if (
+ key._label
+ and (key._label if self.case_sensitive else key._label.lower())
+ in map
+ ):
+ result = map[
+ key._label if self.case_sensitive else key._label.lower()
+ ]
+ elif (
+ hasattr(key, "name")
+ and (key.name if self.case_sensitive else key.name.lower())
+ in map
+ ):
# match is only on name.
- result = map[key.name
- if self.case_sensitive
- else key.name.lower()]
+ result = map[
+ key.name if self.case_sensitive else key.name.lower()
+ ]
# search extra hard to make sure this
# isn't a column/label name overlap.
# this check isn't currently available if the row
# was unpickled.
- if result is not None and \
- result[1] is not None:
+ if result is not None and result[1] is not None:
for obj in result[1]:
if key._compare_name_for_result(obj):
break
@@ -554,8 +614,9 @@ class ResultMetaData(object):
if result is None:
if raiseerr:
raise exc.NoSuchColumnError(
- "Could not locate column in row for column '%s'" %
- expression._string_or_unprintable(key))
+ "Could not locate column in row for column '%s'"
+ % expression._string_or_unprintable(key)
+ )
else:
return None
else:
@@ -580,34 +641,35 @@ class ResultMetaData(object):
if index is None:
raise exc.InvalidRequestError(
"Ambiguous column name '%s' in "
- "result set column descriptions" % obj)
+ "result set column descriptions" % obj
+ )
return operator.itemgetter(index)
def __getstate__(self):
return {
- '_pickled_keymap': dict(
+ "_pickled_keymap": dict(
(key, index)
for key, (processor, obj, index) in self._keymap.items()
if isinstance(key, util.string_types + util.int_types)
),
- 'keys': self.keys,
+ "keys": self.keys,
"case_sensitive": self.case_sensitive,
- "matched_on_name": self.matched_on_name
+ "matched_on_name": self.matched_on_name,
}
def __setstate__(self, state):
# the row has been processed at pickling time so we don't need any
# processor anymore
- self._processors = [None for _ in range(len(state['keys']))]
+ self._processors = [None for _ in range(len(state["keys"]))]
self._keymap = keymap = {}
- for key, index in state['_pickled_keymap'].items():
+ for key, index in state["_pickled_keymap"].items():
# not preserving "obj" here, unfortunately our
# proxy comparison fails with the unpickle
keymap[key] = (None, None, index)
- self.keys = state['keys']
- self.case_sensitive = state['case_sensitive']
- self.matched_on_name = state['matched_on_name']
+ self.keys = state["keys"]
+ self.case_sensitive = state["case_sensitive"]
+ self.matched_on_name = state["matched_on_name"]
class ResultProxy(object):
@@ -643,8 +705,9 @@ class ResultProxy(object):
self.dialect = context.dialect
self.cursor = self._saved_cursor = context.cursor
self.connection = context.root_connection
- self._echo = self.connection._echo and \
- context.engine._should_log_debug()
+ self._echo = (
+ self.connection._echo and context.engine._should_log_debug()
+ )
self._init_metadata()
def _getter(self, key, raiseerr=True):
@@ -666,18 +729,22 @@ class ResultProxy(object):
def _init_metadata(self):
cursor_description = self._cursor_description()
if cursor_description is not None:
- if self.context.compiled and \
- 'compiled_cache' in self.context.execution_options:
+ if (
+ self.context.compiled
+ and "compiled_cache" in self.context.execution_options
+ ):
if self.context.compiled._cached_metadata:
self._metadata = self.context.compiled._cached_metadata
else:
- self._metadata = self.context.compiled._cached_metadata = \
- ResultMetaData(self, cursor_description)
+ self._metadata = (
+ self.context.compiled._cached_metadata
+ ) = ResultMetaData(self, cursor_description)
else:
self._metadata = ResultMetaData(self, cursor_description)
if self._echo:
self.context.engine.logger.debug(
- "Col %r", tuple(x[0] for x in cursor_description))
+ "Col %r", tuple(x[0] for x in cursor_description)
+ )
def keys(self):
"""Return the current set of string keys for rows."""
@@ -731,7 +798,8 @@ class ResultProxy(object):
return self.context.rowcount
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None, self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
@property
def lastrowid(self):
@@ -753,8 +821,8 @@ class ResultProxy(object):
return self._saved_cursor.lastrowid
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self._saved_cursor, self.context)
+ e, None, None, self._saved_cursor, self.context
+ )
@property
def returns_rows(self):
@@ -913,17 +981,18 @@ class ResultProxy(object):
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert:
raise exc.InvalidRequestError(
- "Statement is not an insert() "
- "expression construct.")
+ "Statement is not an insert() " "expression construct."
+ )
elif self.context._is_explicit_returning:
raise exc.InvalidRequestError(
"Can't call inserted_primary_key "
"when returning() "
- "is used.")
+ "is used."
+ )
return self.context.inserted_primary_key
@@ -938,12 +1007,12 @@ class ResultProxy(object):
"""
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isupdate:
raise exc.InvalidRequestError(
- "Statement is not an update() "
- "expression construct.")
+ "Statement is not an update() " "expression construct."
+ )
elif self.context.executemany:
return self.context.compiled_parameters
else:
@@ -960,12 +1029,12 @@ class ResultProxy(object):
"""
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert:
raise exc.InvalidRequestError(
- "Statement is not an insert() "
- "expression construct.")
+ "Statement is not an insert() " "expression construct."
+ )
elif self.context.executemany:
return self.context.compiled_parameters
else:
@@ -1013,12 +1082,13 @@ class ResultProxy(object):
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert and not self.context.isupdate:
raise exc.InvalidRequestError(
"Statement is not an insert() or update() "
- "expression construct.")
+ "expression construct."
+ )
return self.context.postfetch_cols
def prefetch_cols(self):
@@ -1035,12 +1105,13 @@ class ResultProxy(object):
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert and not self.context.isupdate:
raise exc.InvalidRequestError(
"Statement is not an insert() or update() "
- "expression construct.")
+ "expression construct."
+ )
return self.context.prefetch_cols
def supports_sane_rowcount(self):
@@ -1086,7 +1157,7 @@ class ResultProxy(object):
if self._metadata is None:
raise exc.ResourceClosedError(
"This result object does not return rows. "
- "It has been closed automatically.",
+ "It has been closed automatically."
)
elif self.closed:
raise exc.ResourceClosedError("This result object is closed.")
@@ -1106,8 +1177,9 @@ class ResultProxy(object):
l.append(process_row(metadata, row, processors, keymap))
return l
else:
- return [process_row(metadata, row, processors, keymap)
- for row in rows]
+ return [
+ process_row(metadata, row, processors, keymap) for row in rows
+ ]
def fetchall(self):
"""Fetch all rows, just like DB-API ``cursor.fetchall()``.
@@ -1132,8 +1204,8 @@ class ResultProxy(object):
return l
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
def fetchmany(self, size=None):
"""Fetch many rows, just like DB-API
@@ -1161,8 +1233,8 @@ class ResultProxy(object):
return l
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
def fetchone(self):
"""Fetch one row, just like DB-API ``cursor.fetchone()``.
@@ -1190,8 +1262,8 @@ class ResultProxy(object):
return None
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
def first(self):
"""Fetch the first row and then close the result set unconditionally.
@@ -1209,8 +1281,8 @@ class ResultProxy(object):
row = self._fetchone_impl()
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
try:
if row is not None:
@@ -1268,7 +1340,8 @@ class BufferedRowResultProxy(ResultProxy):
def _init_metadata(self):
self._max_row_buffer = self.context.execution_options.get(
- 'max_row_buffer', None)
+ "max_row_buffer", None
+ )
self.__buffer_rows()
super(BufferedRowResultProxy, self)._init_metadata()
@@ -1284,13 +1357,13 @@ class BufferedRowResultProxy(ResultProxy):
50: 100,
100: 250,
250: 500,
- 500: 1000
+ 500: 1000,
}
def __buffer_rows(self):
if self.cursor is None:
return
- size = getattr(self, '_bufsize', 1)
+ size = getattr(self, "_bufsize", 1)
self.__rowbuffer = collections.deque(self.cursor.fetchmany(size))
self._bufsize = self.size_growth.get(size, size)
if self._max_row_buffer is not None:
@@ -1385,8 +1458,9 @@ class BufferedColumnRow(RowProxy):
row[index] = processor(row[index])
index += 1
row = tuple(row)
- super(BufferedColumnRow, self).__init__(parent, row,
- processors, keymap)
+ super(BufferedColumnRow, self).__init__(
+ parent, row, processors, keymap
+ )
class BufferedColumnResultProxy(ResultProxy):
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index d4f5185de..4aecb9537 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -51,18 +51,20 @@ class DefaultEngineStrategy(EngineStrategy):
plugins = u._instantiate_plugins(kwargs)
- u.query.pop('plugin', None)
- kwargs.pop('plugins', None)
+ u.query.pop("plugin", None)
+ kwargs.pop("plugins", None)
entrypoint = u._get_entrypoint()
dialect_cls = entrypoint.get_dialect_cls(u)
- if kwargs.pop('_coerce_config', False):
+ if kwargs.pop("_coerce_config", False):
+
def pop_kwarg(key, default=None):
value = kwargs.pop(key, default)
if key in dialect_cls.engine_config_types:
value = dialect_cls.engine_config_types[key](value)
return value
+
else:
pop_kwarg = kwargs.pop
@@ -72,7 +74,7 @@ class DefaultEngineStrategy(EngineStrategy):
if k in kwargs:
dialect_args[k] = pop_kwarg(k)
- dbapi = kwargs.pop('module', None)
+ dbapi = kwargs.pop("module", None)
if dbapi is None:
dbapi_args = {}
for k in util.get_func_kwargs(dialect_cls.dbapi):
@@ -80,7 +82,7 @@ class DefaultEngineStrategy(EngineStrategy):
dbapi_args[k] = pop_kwarg(k)
dbapi = dialect_cls.dbapi(**dbapi_args)
- dialect_args['dbapi'] = dbapi
+ dialect_args["dbapi"] = dbapi
for plugin in plugins:
plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
@@ -90,41 +92,43 @@ class DefaultEngineStrategy(EngineStrategy):
# assemble connection arguments
(cargs, cparams) = dialect.create_connect_args(u)
- cparams.update(pop_kwarg('connect_args', {}))
+ cparams.update(pop_kwarg("connect_args", {}))
cargs = list(cargs) # allow mutability
# look for existing pool or create
- pool = pop_kwarg('pool', None)
+ pool = pop_kwarg("pool", None)
if pool is None:
+
def connect(connection_record=None):
if dialect._has_events:
for fn in dialect.dispatch.do_connect:
connection = fn(
- dialect, connection_record, cargs, cparams)
+ dialect, connection_record, cargs, cparams
+ )
if connection is not None:
return connection
return dialect.connect(*cargs, **cparams)
- creator = pop_kwarg('creator', connect)
+ creator = pop_kwarg("creator", connect)
- poolclass = pop_kwarg('poolclass', None)
+ poolclass = pop_kwarg("poolclass", None)
if poolclass is None:
poolclass = dialect_cls.get_pool_class(u)
- pool_args = {
- 'dialect': dialect
- }
+ pool_args = {"dialect": dialect}
# consume pool arguments from kwargs, translating a few of
# the arguments
- translate = {'logging_name': 'pool_logging_name',
- 'echo': 'echo_pool',
- 'timeout': 'pool_timeout',
- 'recycle': 'pool_recycle',
- 'events': 'pool_events',
- 'use_threadlocal': 'pool_threadlocal',
- 'reset_on_return': 'pool_reset_on_return',
- 'pre_ping': 'pool_pre_ping',
- 'use_lifo': 'pool_use_lifo'}
+ translate = {
+ "logging_name": "pool_logging_name",
+ "echo": "echo_pool",
+ "timeout": "pool_timeout",
+ "recycle": "pool_recycle",
+ "events": "pool_events",
+ "use_threadlocal": "pool_threadlocal",
+ "reset_on_return": "pool_reset_on_return",
+ "pre_ping": "pool_pre_ping",
+ "use_lifo": "pool_use_lifo",
+ }
for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k)
if tk in kwargs:
@@ -149,7 +153,7 @@ class DefaultEngineStrategy(EngineStrategy):
if k in kwargs:
engine_args[k] = pop_kwarg(k)
- _initialize = kwargs.pop('_initialize', True)
+ _initialize = kwargs.pop("_initialize", True)
# all kwargs should be consumed
if kwargs:
@@ -157,32 +161,40 @@ class DefaultEngineStrategy(EngineStrategy):
"Invalid argument(s) %s sent to create_engine(), "
"using configuration %s/%s/%s. Please check that the "
"keyword arguments are appropriate for this combination "
- "of components." % (','.join("'%s'" % k for k in kwargs),
- dialect.__class__.__name__,
- pool.__class__.__name__,
- engineclass.__name__))
+ "of components."
+ % (
+ ",".join("'%s'" % k for k in kwargs),
+ dialect.__class__.__name__,
+ pool.__class__.__name__,
+ engineclass.__name__,
+ )
+ )
engine = engineclass(pool, dialect, u, **engine_args)
if _initialize:
do_on_connect = dialect.on_connect()
if do_on_connect:
+
def on_connect(dbapi_connection, connection_record):
conn = getattr(
- dbapi_connection, '_sqla_unwrap', dbapi_connection)
+ dbapi_connection, "_sqla_unwrap", dbapi_connection
+ )
if conn is None:
return
do_on_connect(conn)
- event.listen(pool, 'first_connect', on_connect)
- event.listen(pool, 'connect', on_connect)
+ event.listen(pool, "first_connect", on_connect)
+ event.listen(pool, "connect", on_connect)
def first_connect(dbapi_connection, connection_record):
- c = base.Connection(engine, connection=dbapi_connection,
- _has_events=False)
+ c = base.Connection(
+ engine, connection=dbapi_connection, _has_events=False
+ )
c._execution_options = util.immutabledict()
dialect.initialize(c)
- event.listen(pool, 'first_connect', first_connect, once=True)
+
+ event.listen(pool, "first_connect", first_connect, once=True)
dialect_cls.engine_created(engine)
if entrypoint is not dialect_cls:
@@ -197,18 +209,20 @@ class DefaultEngineStrategy(EngineStrategy):
class PlainEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring a regular Engine."""
- name = 'plain'
+ name = "plain"
engine_cls = base.Engine
+
PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring an Engine with threadlocal behavior."""
- name = 'threadlocal'
+ name = "threadlocal"
engine_cls = threadlocal.TLEngine
+
ThreadLocalEngineStrategy()
@@ -220,7 +234,7 @@ class MockEngineStrategy(EngineStrategy):
"""
- name = 'mock'
+ name = "mock"
def create(self, name_or_url, executor, **kwargs):
# create url.URL object
@@ -245,7 +259,7 @@ class MockEngineStrategy(EngineStrategy):
self.execute = execute
engine = property(lambda s: s)
- dialect = property(attrgetter('_dialect'))
+ dialect = property(attrgetter("_dialect"))
name = property(lambda s: s._dialect.name)
schema_for_object = schema._schema_getter(None)
@@ -258,29 +272,35 @@ class MockEngineStrategy(EngineStrategy):
def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler(
- statement, parameters, engine=self, **kwargs)
+ statement, parameters, engine=self, **kwargs
+ )
def create(self, entity, **kwargs):
- kwargs['checkfirst'] = False
+ kwargs["checkfirst"] = False
from sqlalchemy.engine import ddl
- ddl.SchemaGenerator(
- self.dialect, self, **kwargs).traverse_single(entity)
+ ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single(
+ entity
+ )
def drop(self, entity, **kwargs):
- kwargs['checkfirst'] = False
+ kwargs["checkfirst"] = False
from sqlalchemy.engine import ddl
- ddl.SchemaDropper(
- self.dialect, self, **kwargs).traverse_single(entity)
- def _run_visitor(self, visitorcallable, element,
- connection=None,
- **kwargs):
- kwargs['checkfirst'] = False
- visitorcallable(self.dialect, self,
- **kwargs).traverse_single(element)
+ ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single(
+ entity
+ )
+
+ def _run_visitor(
+ self, visitorcallable, element, connection=None, **kwargs
+ ):
+ kwargs["checkfirst"] = False
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(
+ element
+ )
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
+
MockEngineStrategy()
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py
index 0ec1f9613..5b2bdabc0 100644
--- a/lib/sqlalchemy/engine/threadlocal.py
+++ b/lib/sqlalchemy/engine/threadlocal.py
@@ -19,7 +19,6 @@ import weakref
class TLConnection(base.Connection):
-
def __init__(self, *arg, **kw):
super(TLConnection, self).__init__(*arg, **kw)
self.__opencount = 0
@@ -43,6 +42,7 @@ class TLEngine(base.Engine):
transactions.
"""
+
_tl_connection_cls = TLConnection
def __init__(self, *args, **kwargs):
@@ -50,7 +50,7 @@ class TLEngine(base.Engine):
self._connections = util.threading.local()
def contextual_connect(self, **kw):
- if not hasattr(self._connections, 'conn'):
+ if not hasattr(self._connections, "conn"):
connection = None
else:
connection = self._connections.conn()
@@ -60,29 +60,31 @@ class TLEngine(base.Engine):
# or not connection.connection.is_valid:
connection = self._tl_connection_cls(
self,
- self._wrap_pool_connect(
- self.pool.connect, connection),
- **kw)
+ self._wrap_pool_connect(self.pool.connect, connection),
+ **kw
+ )
self._connections.conn = weakref.ref(connection)
return connection._increment_connect()
def begin_twophase(self, xid=None):
- if not hasattr(self._connections, 'trans'):
+ if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(
- self.contextual_connect().begin_twophase(xid=xid))
+ self.contextual_connect().begin_twophase(xid=xid)
+ )
return self
def begin_nested(self):
- if not hasattr(self._connections, 'trans'):
+ if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(
- self.contextual_connect().begin_nested())
+ self.contextual_connect().begin_nested()
+ )
return self
def begin(self):
- if not hasattr(self._connections, 'trans'):
+ if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin())
return self
@@ -97,21 +99,27 @@ class TLEngine(base.Engine):
self.rollback()
def prepare(self):
- if not hasattr(self._connections, 'trans') or \
- not self._connections.trans:
+ if (
+ not hasattr(self._connections, "trans")
+ or not self._connections.trans
+ ):
return
self._connections.trans[-1].prepare()
def commit(self):
- if not hasattr(self._connections, 'trans') or \
- not self._connections.trans:
+ if (
+ not hasattr(self._connections, "trans")
+ or not self._connections.trans
+ ):
return
trans = self._connections.trans.pop(-1)
trans.commit()
def rollback(self):
- if not hasattr(self._connections, 'trans') or \
- not self._connections.trans:
+ if (
+ not hasattr(self._connections, "trans")
+ or not self._connections.trans
+ ):
return
trans = self._connections.trans.pop(-1)
trans.rollback()
@@ -122,9 +130,11 @@ class TLEngine(base.Engine):
@property
def closed(self):
- return not hasattr(self._connections, 'conn') or \
- self._connections.conn() is None or \
- self._connections.conn().closed
+ return (
+ not hasattr(self._connections, "conn")
+ or self._connections.conn() is None
+ or self._connections.conn().closed
+ )
def close(self):
if not self.closed:
@@ -135,4 +145,4 @@ class TLEngine(base.Engine):
self._connections.trans = []
def __repr__(self):
- return 'TLEngine(%r)' % self.url
+ return "TLEngine(%r)" % self.url
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index 1662efe20..e92e57b8e 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -50,8 +50,16 @@ class URL(object):
"""
- def __init__(self, drivername, username=None, password=None,
- host=None, port=None, database=None, query=None):
+ def __init__(
+ self,
+ drivername,
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ database=None,
+ query=None,
+ ):
self.drivername = drivername
self.username = username
self.password_original = password
@@ -68,26 +76,26 @@ class URL(object):
if self.username is not None:
s += _rfc_1738_quote(self.username)
if self.password is not None:
- s += ':' + ('***' if hide_password
- else _rfc_1738_quote(self.password))
+ s += ":" + (
+ "***" if hide_password else _rfc_1738_quote(self.password)
+ )
s += "@"
if self.host is not None:
- if ':' in self.host:
+ if ":" in self.host:
s += "[%s]" % self.host
else:
s += self.host
if self.port is not None:
- s += ':' + str(self.port)
+ s += ":" + str(self.port)
if self.database is not None:
- s += '/' + self.database
+ s += "/" + self.database
if self.query:
keys = list(self.query)
keys.sort()
- s += '?' + "&".join(
- "%s=%s" % (
- k,
- element
- ) for k in keys for element in util.to_list(self.query[k])
+ s += "?" + "&".join(
+ "%s=%s" % (k, element)
+ for k in keys
+ for element in util.to_list(self.query[k])
)
return s
@@ -101,14 +109,15 @@ class URL(object):
return hash(str(self))
def __eq__(self, other):
- return \
- isinstance(other, URL) and \
- self.drivername == other.drivername and \
- self.username == other.username and \
- self.password == other.password and \
- self.host == other.host and \
- self.database == other.database and \
- self.query == other.query
+ return (
+ isinstance(other, URL)
+ and self.drivername == other.drivername
+ and self.username == other.username
+ and self.password == other.password
+ and self.host == other.host
+ and self.database == other.database
+ and self.query == other.query
+ )
@property
def password(self):
@@ -122,20 +131,20 @@ class URL(object):
self.password_original = password
def get_backend_name(self):
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
return self.drivername
else:
- return self.drivername.split('+')[0]
+ return self.drivername.split("+")[0]
def get_driver_name(self):
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
return self.get_dialect().driver
else:
- return self.drivername.split('+')[1]
+ return self.drivername.split("+")[1]
def _instantiate_plugins(self, kwargs):
- plugin_names = util.to_list(self.query.get('plugin', ()))
- plugin_names += kwargs.get('plugins', [])
+ plugin_names = util.to_list(self.query.get("plugin", ()))
+ plugin_names += kwargs.get("plugins", [])
return [
plugins.load(plugin_name)(self, kwargs)
@@ -149,17 +158,19 @@ class URL(object):
returned class implements the get_dialect_cls() method.
"""
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
name = self.drivername
else:
- name = self.drivername.replace('+', '.')
+ name = self.drivername.replace("+", ".")
cls = registry.load(name)
# check for legacy dialects that
# would return a module with 'dialect' as the
# actual class
- if hasattr(cls, 'dialect') and \
- isinstance(cls.dialect, type) and \
- issubclass(cls.dialect, Dialect):
+ if (
+ hasattr(cls, "dialect")
+ and isinstance(cls.dialect, type)
+ and issubclass(cls.dialect, Dialect)
+ ):
return cls.dialect
else:
return cls
@@ -187,7 +198,7 @@ class URL(object):
"""
translated = {}
- attribute_names = ['host', 'database', 'username', 'password', 'port']
+ attribute_names = ["host", "database", "username", "password", "port"]
for sname in attribute_names:
if names:
name = names.pop(0)
@@ -214,7 +225,8 @@ def make_url(name_or_url):
def _parse_rfc1738_args(name):
- pattern = re.compile(r'''
+ pattern = re.compile(
+ r"""
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
@@ -228,21 +240,23 @@ def _parse_rfc1738_args(name):
(?::(?P<port>[^/]*))?
)?
(?:/(?P<database>.*))?
- ''', re.X)
+ """,
+ re.X,
+ )
m = pattern.match(name)
if m is not None:
components = m.groupdict()
- if components['database'] is not None:
- tokens = components['database'].split('?', 2)
- components['database'] = tokens[0]
+ if components["database"] is not None:
+ tokens = components["database"].split("?", 2)
+ components["database"] = tokens[0]
if len(tokens) > 1:
query = {}
for key, value in util.parse_qsl(tokens[1]):
if util.py2k:
- key = key.encode('ascii')
+ key = key.encode("ascii")
if key in query:
query[key] = util.to_list(query[key])
query[key].append(value)
@@ -252,26 +266,27 @@ def _parse_rfc1738_args(name):
query = None
else:
query = None
- components['query'] = query
+ components["query"] = query
- if components['username'] is not None:
- components['username'] = _rfc_1738_unquote(components['username'])
+ if components["username"] is not None:
+ components["username"] = _rfc_1738_unquote(components["username"])
- if components['password'] is not None:
- components['password'] = _rfc_1738_unquote(components['password'])
+ if components["password"] is not None:
+ components["password"] = _rfc_1738_unquote(components["password"])
- ipv4host = components.pop('ipv4host')
- ipv6host = components.pop('ipv6host')
- components['host'] = ipv4host or ipv6host
- name = components.pop('name')
+ ipv4host = components.pop("ipv4host")
+ ipv6host = components.pop("ipv6host")
+ components["host"] = ipv4host or ipv6host
+ name = components.pop("name")
return URL(name, **components)
else:
raise exc.ArgumentError(
- "Could not parse rfc1738 URL from string '%s'" % name)
+ "Could not parse rfc1738 URL from string '%s'" % name
+ )
def _rfc_1738_quote(text):
- return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text)
+ return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text)
def _rfc_1738_unquote(text):
@@ -279,7 +294,7 @@ def _rfc_1738_unquote(text):
def _parse_keyvalue_args(name):
- m = re.match(r'(\w+)://(.*)', name)
+ m = re.match(r"(\w+)://(.*)", name)
if m is not None:
(name, args) = m.group(1, 2)
opts = dict(util.parse_qsl(args))
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py
index 17bc9a3b4..76bb8f4b5 100644
--- a/lib/sqlalchemy/engine/util.py
+++ b/lib/sqlalchemy/engine/util.py
@@ -46,28 +46,34 @@ def py_fallback():
elif len(multiparams) == 1:
zero = multiparams[0]
if isinstance(zero, (list, tuple)):
- if not zero or hasattr(zero[0], '__iter__') and \
- not hasattr(zero[0], 'strip'):
+ if (
+ not zero
+ or hasattr(zero[0], "__iter__")
+ and not hasattr(zero[0], "strip")
+ ):
# execute(stmt, [{}, {}, {}, ...])
# execute(stmt, [(), (), (), ...])
return zero
else:
# execute(stmt, ("value", "value"))
return [zero]
- elif hasattr(zero, 'keys'):
+ elif hasattr(zero, "keys"):
# execute(stmt, {"key":"value"})
return [zero]
else:
# execute(stmt, "value")
return [[zero]]
else:
- if hasattr(multiparams[0], '__iter__') and \
- not hasattr(multiparams[0], 'strip'):
+ if hasattr(multiparams[0], "__iter__") and not hasattr(
+ multiparams[0], "strip"
+ ):
return multiparams
else:
return [multiparams]
return locals()
+
+
try:
from sqlalchemy.cutils import _distill_params
except ImportError: