diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-04-07 14:15:43 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-04-16 13:35:55 -0400 |
| commit | 2f617f56f2acdce00b88f746c403cf5ed66d4d27 (patch) | |
| tree | 0962f2c43c1a361135ecdab933167fa0963ae58a /lib/sqlalchemy/engine/base.py | |
| parent | bd303b10e2bf69169f07447c7272fc71ac931f10 (diff) | |
| download | sqlalchemy-2f617f56f2acdce00b88f746c403cf5ed66d4d27.tar.gz | |
Create initial 2.0 engine implementation
Implemented the SQLAlchemy 2 :func:`.future.create_engine` function which
is used for forwards compatibility with SQLAlchemy 2. This engine
features always-transactional behavior with autobegin.
Allow execution options per statement execution. This includes
that the before_execute() and after_execute() events now accept
an additional dictionary with these options, empty if not
passed; a legacy event decorator is added for backwards compatibility
which now also emits a deprecation warning.
Add some basic tests for execution, transactions, and
the new result object. Build out on a new testing fixture
that swaps in the future engine completely to start with.
Change-Id: I70e7338bb3f0ce22d2f702537d94bb249bd9fb0a
Fixes: #4644
Diffstat (limited to 'lib/sqlalchemy/engine/base.py')
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 316 |
1 files changed, 226 insertions, 90 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 8a340d9ce..09e700b5c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -12,6 +12,7 @@ import sys from .interfaces import Connectable from .interfaces import ExceptionContext from .util import _distill_params +from .util import _distill_params_20 from .. import exc from .. import inspection from .. import log @@ -52,6 +53,8 @@ class Connection(Connectable): """ _schema_translate_map = None + _is_future = False + _sqla_logger_namespace = "sqlalchemy.engine.Connection" def __init__( self, @@ -85,7 +88,7 @@ class Connection(Connectable): if connection is not None else engine.raw_connection() ) - self.__transaction = None + self._transaction = None self.__savepoint_seq = 0 self.should_close_with_result = close_with_result @@ -168,13 +171,15 @@ class Connection(Connectable): else: return self - def _clone(self): - """Create a shallow copy of this Connection. + def _generate_for_options(self): + """define connection method chaining behavior for execution_options""" - """ - c = self.__class__.__new__(self.__class__) - c.__dict__ = self.__dict__.copy() - return c + if self._is_future: + return self + else: + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + return c def __enter__(self): return self @@ -340,7 +345,7 @@ class Connection(Connectable): """ # noqa - c = self._clone() + c = self._generate_for_options() c._execution_options = c._execution_options.union(opt) if self._has_events or self.engine._has_events: self.dispatch.set_connection_execution_options(c, opt) @@ -469,7 +474,7 @@ class Connection(Connectable): if self.__branch_from: return self.__branch_from._revalidate_connection() if self.__can_reconnect and self.__invalid: - if self.__transaction is not None: + if self._transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " "transaction is rolled back" @@ -640,14 +645,21 @@ class Connection(Connectable): :class:`_engine.Engine` """ - if self.__branch_from: + if self._is_future: + assert not self.__branch_from + elif self.__branch_from: return self.__branch_from.begin() - if self.__transaction is None: - self.__transaction = RootTransaction(self) - return self.__transaction + if self._transaction is None: + self._transaction = RootTransaction(self) + return self._transaction else: - return Transaction(self, self.__transaction) + if self._is_future: + raise exc.InvalidRequestError( + "a transaction is already begun for this connection" + ) + else: + return Transaction(self, self._transaction) def begin_nested(self): """Begin a nested transaction and return a transaction handle. @@ -667,14 +679,22 @@ class Connection(Connectable): :meth:`_engine.Connection.begin_twophase` """ - if self.__branch_from: + if self._is_future: + assert not self.__branch_from + elif self.__branch_from: return self.__branch_from.begin_nested() - if self.__transaction is None: - self.__transaction = RootTransaction(self) - else: - self.__transaction = NestedTransaction(self, self.__transaction) - return self.__transaction + if self._transaction is None: + if self._is_future: + self._autobegin() + else: + self._transaction = RootTransaction(self) + return self._transaction + + trans = NestedTransaction(self, self._transaction) + if not self._is_future: + self._transaction = trans + return trans def begin_twophase(self, xid=None): """Begin a two-phase or XA transaction and return a transaction @@ -699,15 +719,15 @@ class Connection(Connectable): if self.__branch_from: return self.__branch_from.begin_twophase(xid=xid) - if self.__transaction is not None: + if self._transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " "is already in progress." ) if xid is None: xid = self.engine.dialect.create_xid() - self.__transaction = TwoPhaseTransaction(self, xid) - return self.__transaction + self._transaction = TwoPhaseTransaction(self, xid) + return self._transaction def recover_twophase(self): return self.engine.dialect.do_recover_twophase(self) @@ -721,8 +741,8 @@ class Connection(Connectable): def in_transaction(self): """Return True if a transaction is in progress.""" return ( - self._root.__transaction is not None - and self._root.__transaction.is_active + self._root._transaction is not None + and self._root._transaction.is_active ) def _begin_impl(self, transaction): @@ -736,7 +756,7 @@ class Connection(Connectable): try: self.engine.dialect.do_begin(self.connection) - if self.connection._reset_agent is None: + if not self._is_future and self.connection._reset_agent is None: self.connection._reset_agent = transaction except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @@ -757,7 +777,7 @@ class Connection(Connectable): finally: if ( not self.__invalid - and self.connection._reset_agent is self.__transaction + and self.connection._reset_agent is self._transaction ): self.connection._reset_agent = None @@ -776,10 +796,10 @@ class Connection(Connectable): finally: if ( not self.__invalid - and self.connection._reset_agent is self.__transaction + and self.connection._reset_agent is self._transaction ): self.connection._reset_agent = None - self.__transaction = None + self._transaction = None def _savepoint_impl(self, name=None): assert not self.__branch_from @@ -795,13 +815,13 @@ class Connection(Connectable): return name def _discard_transaction(self, trans): - if trans is self.__transaction: + if trans is self._transaction: if trans._is_root: assert trans._parent is trans - self.__transaction = None + self._transaction = None else: assert trans._parent is not trans - self.__transaction = trans._parent + self._transaction = trans._parent def _rollback_to_savepoint_impl( self, name, context, deactivate_only=False @@ -822,7 +842,7 @@ class Connection(Connectable): if self._still_open_and_connection_is_valid: self.engine.dialect.do_release_savepoint(self, name) - self.__transaction = context + self._transaction = context def _begin_twophase_impl(self, transaction): assert not self.__branch_from @@ -835,7 +855,7 @@ class Connection(Connectable): if self._still_open_and_connection_is_valid: self.engine.dialect.do_begin_twophase(self, transaction.xid) - if self.connection._reset_agent is None: + if not self._is_future and self.connection._reset_agent is None: self.connection._reset_agent = transaction def _prepare_twophase_impl(self, xid): @@ -845,7 +865,7 @@ class Connection(Connectable): self.dispatch.prepare_twophase(self, xid) if self._still_open_and_connection_is_valid: - assert isinstance(self.__transaction, TwoPhaseTransaction) + assert isinstance(self._transaction, TwoPhaseTransaction) self.engine.dialect.do_prepare_twophase(self, xid) def _rollback_twophase_impl(self, xid, is_prepared): @@ -855,17 +875,17 @@ class Connection(Connectable): self.dispatch.rollback_twophase(self, xid, is_prepared) if self._still_open_and_connection_is_valid: - assert isinstance(self.__transaction, TwoPhaseTransaction) + assert isinstance(self._transaction, TwoPhaseTransaction) try: self.engine.dialect.do_rollback_twophase( self, xid, is_prepared ) finally: - if self.connection._reset_agent is self.__transaction: + if self.connection._reset_agent is self._transaction: self.connection._reset_agent = None - self.__transaction = None + self._transaction = None else: - self.__transaction = None + self._transaction = None def _commit_twophase_impl(self, xid, is_prepared): assert not self.__branch_from @@ -874,15 +894,20 @@ class Connection(Connectable): self.dispatch.commit_twophase(self, xid, is_prepared) if self._still_open_and_connection_is_valid: - assert isinstance(self.__transaction, TwoPhaseTransaction) + assert isinstance(self._transaction, TwoPhaseTransaction) try: self.engine.dialect.do_commit_twophase(self, xid, is_prepared) finally: - if self.connection._reset_agent is self.__transaction: + if self.connection._reset_agent is self._transaction: self.connection._reset_agent = None - self.__transaction = None + self._transaction = None else: - self.__transaction = None + self._transaction = None + + def _autobegin(self): + assert self._is_future + + return self.begin() def _autorollback(self): if not self._root.in_transaction(): @@ -907,6 +932,8 @@ class Connection(Connectable): and will allow no further operations. """ + assert not self._is_future + if self.__branch_from: util.warn_deprecated_20( "The .close() method on a so-called 'branched' connection is " @@ -929,7 +956,7 @@ class Connection(Connectable): else: conn.close() - if conn._reset_agent is self.__transaction: + if conn._reset_agent is self._transaction: conn._reset_agent = None # the close() process can end up invalidating us, @@ -938,7 +965,7 @@ class Connection(Connectable): if not self.__invalid: del self.__connection self.__can_reconnect = False - self.__transaction = None + self._transaction = None def scalar(self, object_, *multiparams, **params): """Executes and returns the first column of the first row. @@ -1030,8 +1057,11 @@ class Connection(Connectable): "or the Connection.exec_driver_sql() method to invoke a " "driver-level SQL string." ) - distilled_params = _distill_params(multiparams, params) - return self._exec_driver_sql_distilled(object_, distilled_params) + distilled_parameters = _distill_params(multiparams, params) + + return self._exec_driver_sql( + object_, multiparams, params, distilled_parameters + ) try: meth = object_._execute_on_connection except AttributeError as err: @@ -1039,20 +1069,28 @@ class Connection(Connectable): exc.ObjectNotExecutableError(object_), replace_context=err ) else: - return meth(self, multiparams, params) + return meth(self, multiparams, params, util.immutabledict()) - def _execute_function(self, func, multiparams, params): + def _execute_function( + self, func, multiparams, params, execution_options=util.immutabledict() + ): """Execute a sql.FunctionElement object.""" return self._execute_clauseelement(func.select(), multiparams, params) - def _execute_default(self, default, multiparams, params): + def _execute_default( + self, + default, + multiparams, + params, + execution_options=util.immutabledict(), + ): """Execute a schema.ColumnDefault object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: default, multiparams, params = fn( - self, default, multiparams, params + self, default, multiparams, params, execution_options ) try: @@ -1066,7 +1104,9 @@ class Connection(Connectable): conn = self._revalidate_connection() dialect = self.dialect - ctx = dialect.execution_ctx_cls._init_default(dialect, self, conn) + ctx = dialect.execution_ctx_cls._init_default( + dialect, self, conn, execution_options + ) except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @@ -1076,17 +1116,21 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: self.dispatch.after_execute( - self, default, multiparams, params, ret + self, default, multiparams, params, execution_options, ret ) return ret - def _execute_ddl(self, ddl, multiparams, params): + def _execute_ddl( + self, ddl, multiparams, params, execution_options=util.immutabledict() + ): """Execute a schema.DDL object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - ddl, multiparams, params = fn(self, ddl, multiparams, params) + ddl, multiparams, params = fn( + self, ddl, multiparams, params, execution_options + ) dialect = self.dialect @@ -1098,18 +1142,25 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_ddl, compiled, None, + execution_options, compiled, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, ddl, multiparams, params, ret) + self.dispatch.after_execute( + self, ddl, multiparams, params, execution_options, ret + ) return ret - def _execute_clauseelement(self, elem, multiparams, params): + def _execute_clauseelement( + self, elem, multiparams, params, execution_options=util.immutabledict() + ): """Execute a sql.ClauseElement object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - elem, multiparams, params = fn(self, elem, multiparams, params) + elem, multiparams, params = fn( + self, elem, multiparams, params, execution_options + ) distilled_params = _distill_params(multiparams, params) if distilled_params: @@ -1121,22 +1172,31 @@ class Connection(Connectable): dialect = self.dialect - if "compiled_cache" in self._execution_options: - elem_cache_key, extracted_params = elem._generate_cache_key() + exec_opts = self._execution_options + if execution_options: + exec_opts = exec_opts.union(execution_options) + + if "compiled_cache" in exec_opts: + elem_cache_key = elem._generate_cache_key() + else: + elem_cache_key = None + + if elem_cache_key: + cache_key, extracted_params = elem_cache_key key = ( dialect, - elem_cache_key, + cache_key, tuple(sorted(keys)), bool(self._schema_translate_map), len(distilled_params) > 1, ) - cache = self._execution_options["compiled_cache"] + cache = exec_opts["compiled_cache"] compiled_sql = cache.get(key) if compiled_sql is None: compiled_sql = elem.compile( dialect=dialect, - cache_key=(elem_cache_key, extracted_params), + cache_key=elem_cache_key, column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self._schema_translate_map, @@ -1160,22 +1220,31 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_compiled, compiled_sql, distilled_params, + execution_options, compiled_sql, distilled_params, elem, extracted_params, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, elem, multiparams, params, ret) + self.dispatch.after_execute( + self, elem, multiparams, params, execution_options, ret + ) return ret - def _execute_compiled(self, compiled, multiparams, params): + def _execute_compiled( + self, + compiled, + multiparams, + params, + execution_options=util.immutabledict(), + ): """Execute a sql.Compiled object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: compiled, multiparams, params = fn( - self, compiled, multiparams, params + self, compiled, multiparams, params, execution_options ) dialect = self.dialect @@ -1185,6 +1254,7 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_compiled, compiled, parameters, + execution_options, compiled, parameters, None, @@ -1192,16 +1262,23 @@ class Connection(Connectable): ) if self._has_events or self.engine._has_events: self.dispatch.after_execute( - self, compiled, multiparams, params, ret + self, compiled, multiparams, params, execution_options, ret ) return ret - def _exec_driver_sql_distilled(self, statement, parameters): + def _exec_driver_sql( + self, + statement, + multiparams, + params, + distilled_parameters, + execution_options=util.immutabledict(), + ): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: statement, multiparams, params = fn( - self, statement, parameters, {} + self, statement, multiparams, params, execution_options ) dialect = self.dialect @@ -1209,15 +1286,38 @@ class Connection(Connectable): dialect, dialect.execution_ctx_cls._init_statement, statement, - parameters, + distilled_parameters, + execution_options, statement, - parameters, + distilled_parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, statement, parameters, {}) + self.dispatch.after_execute( + self, statement, multiparams, params, execution_options, ret + ) return ret - def exec_driver_sql(self, statement, parameters=None): + def _execute_20( + self, + statement, + parameters=None, + execution_options=util.immutabledict(), + ): + multiparams, params, distilled_parameters = _distill_params_20( + parameters + ) + try: + meth = statement._execute_on_connection + except AttributeError as err: + util.raise_( + exc.ObjectNotExecutableError(statement), replace_context=err + ) + else: + return meth(self, multiparams, params, execution_options) + + def exec_driver_sql( + self, statement, parameters=None, execution_options=None + ): r"""Executes a SQL statement construct and returns a :class:`_engine.ResultProxy`. @@ -1258,22 +1358,33 @@ class Connection(Connectable): """ - if isinstance(parameters, list) and parameters: - if not isinstance(parameters[0], (dict, tuple)): - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) - elif isinstance(parameters, (dict, tuple)): - parameters = [parameters] + multiparams, params, distilled_parameters = _distill_params_20( + parameters + ) - return self._exec_driver_sql_distilled(statement, parameters or ()) + return self._exec_driver_sql( + statement, + multiparams, + params, + distilled_parameters, + execution_options, + ) def _execute_context( - self, dialect, constructor, statement, parameters, *args + self, + dialect, + constructor, + statement, + parameters, + execution_options, + *args ): """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.ResultProxy`.""" + if execution_options: + dialect.set_exec_execution_options(self, execution_options) + try: try: conn = self.__connection @@ -1284,23 +1395,29 @@ class Connection(Connectable): if conn is None: conn = self._revalidate_connection() - context = constructor(dialect, self, conn, *args) + context = constructor( + dialect, self, conn, execution_options, *args + ) except BaseException as e: self._handle_dbapi_exception( e, util.text_type(statement), parameters, None, None ) - if self._root.__transaction and not self._root.__transaction.is_active: + if self._root._transaction and not self._root._transaction.is_active: raise exc.InvalidRequestError( "This connection is on an inactive %stransaction. " "Please rollback() fully before proceeding." % ( "savepoint " - if isinstance(self.__transaction, NestedTransaction) + if isinstance(self._transaction, NestedTransaction) else "" ), code="8s2a", ) + + if self._is_future and self._root._transaction is None: + self._autobegin() + if context.compiled: context.pre_exec() @@ -1386,12 +1503,17 @@ class Connection(Connectable): result = context._setup_result_proxy() - if context.should_autocommit and self._root.__transaction is None: + if ( + not self._is_future + and context.should_autocommit + and self._root._transaction is None + ): self._root._commit_impl(autocommit=True) # for "connectionless" execution, we have to close this # Connection after the statement is complete. if self.should_close_with_result: + assert not self._is_future assert not context._is_future_result # ResultProxy already exhausted rows / has no rows. @@ -1600,6 +1722,7 @@ class Connection(Connectable): self.engine.pool._invalidate(dbapi_conn_wrapper, e) self.invalidate(e) if self.should_close_with_result: + assert not self._is_future self.close() @classmethod @@ -1991,6 +2114,8 @@ class Engine(Connectable, log.Identified): _execution_options = util.immutabledict() _has_events = False _connection_cls = Connection + _sqla_logger_namespace = "sqlalchemy.engine.Engine" + _is_future = False _schema_translate_map = None @@ -2114,7 +2239,7 @@ class Engine(Connectable, log.Identified): """ - return OptionEngine(self, opt) + return self._option_cls(self, opt) def get_execution_options(self): """ Get the non-SQL options which will take effect during execution. @@ -2200,7 +2325,8 @@ class Engine(Connectable, log.Identified): if type_ is not None: self.transaction.rollback() else: - self.transaction.commit() + if self.transaction.is_active: + self.transaction.commit() if not self.close_with_result: self.conn.close() @@ -2239,7 +2365,10 @@ class Engine(Connectable, log.Identified): for a particular :class:`_engine.Connection`. """ - conn = self.connect(close_with_result=close_with_result) + if self._connection_cls._is_future: + conn = self.connect() + else: + conn = self.connect(close_with_result=close_with_result) try: trans = conn.begin() except: @@ -2477,7 +2606,7 @@ class Engine(Connectable, log.Identified): return self._wrap_pool_connect(self.pool.connect, _connection) -class OptionEngine(Engine): +class OptionEngineMixin(object): _sa_propagate_class_events = False def __init__(self, proxied, execution_options): @@ -2523,3 +2652,10 @@ class OptionEngine(Engine): self.__dict__["_has_events"] = value _has_events = property(_get_has_events, _set_has_events) + + +class OptionEngine(OptionEngineMixin, Engine): + pass + + +Engine._option_cls = OptionEngine |
