summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/engine.py
blob: 619cf85086318b51eba0ce36d03971dbc36fb07f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
from typing import Any
from typing import Callable
from typing import Mapping
from typing import Optional

from . import exc as async_exc
from .base import StartableContext
from .result import AsyncResult
from ... import exc
from ... import util
from ...engine import Connection
from ...engine import create_engine as _create_engine
from ...engine import Engine
from ...engine import Result
from ...engine import Transaction
from ...engine.base import OptionEngineMixin
from ...sql import Executable
from ...util.concurrency import greenlet_spawn


def create_async_engine(*arg, **kw):
    """Create a new async engine instance.

    Arguments passed to :func:`_asyncio.create_async_engine` are mostly
    identical to those passed to the :func:`_sa.create_engine` function.
    The specified dialect must be an asyncio-compatible dialect
    such as :ref:`dialect-postgresql-asyncpg`.

    .. versionadded:: 1.4

    """

    if kw.get("server_side_cursors", False):
        raise async_exc.AsyncMethodRequired(
            "Can't set server_side_cursors for async engine globally; "
            "use the connection.stream() method for an async "
            "streaming result set"
        )
    kw["future"] = True
    sync_engine = _create_engine(*arg, **kw)
    return AsyncEngine(sync_engine)


class AsyncConnection(StartableContext):
    """An asyncio proxy for a :class:`_engine.Connection`.

    :class:`_asyncio.AsyncConnection` is acquired using the
    :meth:`_asyncio.AsyncEngine.connect`
    method of :class:`_asyncio.AsyncEngine`::

        from sqlalchemy.ext.asyncio import create_async_engine
        engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")

        async with engine.connect() as conn:
            result = await conn.execute(select(table))

    .. versionadded:: 1.4

    """  # noqa

    __slots__ = (
        "sync_engine",
        "sync_connection",
    )

    def __init__(
        self, sync_engine: Engine, sync_connection: Optional[Connection] = None
    ):
        self.sync_engine = sync_engine
        self.sync_connection = sync_connection

    async def start(self):
        """Start this :class:`_asyncio.AsyncConnection` object's context
        outside of using a Python ``with:`` block.

        """
        if self.sync_connection:
            raise exc.InvalidRequestError("connection is already started")
        self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
        return self

    def _sync_connection(self):
        if not self.sync_connection:
            self._raise_for_not_started()
        return self.sync_connection

    def begin(self) -> "AsyncTransaction":
        """Begin a transaction prior to autobegin occurring.

        """
        self._sync_connection()
        return AsyncTransaction(self)

    def begin_nested(self) -> "AsyncTransaction":
        """Begin a nested transaction and return a transaction handle.

        """
        self._sync_connection()
        return AsyncTransaction(self, nested=True)

    async def commit(self):
        """Commit the transaction that is currently in progress.

        This method commits the current transaction if one has been started.
        If no transaction was started, the method has no effect, assuming
        the connection is in a non-invalidated state.

        A transaction is begun on a :class:`_future.Connection` automatically
        whenever a statement is first executed, or when the
        :meth:`_future.Connection.begin` method is called.

        """
        conn = self._sync_connection()
        await greenlet_spawn(conn.commit)

    async def rollback(self):
        """Roll back the transaction that is currently in progress.

        This method rolls back the current transaction if one has been started.
        If no transaction was started, the method has no effect.  If a
        transaction was started and the connection is in an invalidated state,
        the transaction is cleared using this method.

        A transaction is begun on a :class:`_future.Connection` automatically
        whenever a statement is first executed, or when the
        :meth:`_future.Connection.begin` method is called.


        """
        conn = self._sync_connection()
        await greenlet_spawn(conn.rollback)

    async def close(self):
        """Close this :class:`_asyncio.AsyncConnection`.

        This has the effect of also rolling back the transaction if one
        is in place.

        """
        conn = self._sync_connection()
        await greenlet_spawn(conn.close)

    async def exec_driver_sql(
        self,
        statement: Executable,
        parameters: Optional[Mapping] = None,
        execution_options: Mapping = util.EMPTY_DICT,
    ) -> Result:
        r"""Executes a driver-level SQL string and return buffered
        :class:`_engine.Result`.

        """

        conn = self._sync_connection()

        result = await greenlet_spawn(
            conn.exec_driver_sql, statement, parameters, execution_options,
        )
        if result.context._is_server_side:
            raise async_exc.AsyncMethodRequired(
                "Can't use the connection.exec_driver_sql() method with a "
                "server-side cursor."
                "Use the connection.stream() method for an async "
                "streaming result set."
            )

        return result

    async def stream(
        self,
        statement: Executable,
        parameters: Optional[Mapping] = None,
        execution_options: Mapping = util.EMPTY_DICT,
    ) -> AsyncResult:
        """Execute a statement and return a streaming
        :class:`_asyncio.AsyncResult` object."""

        conn = self._sync_connection()

        result = await greenlet_spawn(
            conn._execute_20,
            statement,
            parameters,
            util.EMPTY_DICT.merge_with(
                execution_options, {"stream_results": True}
            ),
        )
        if not result.context._is_server_side:
            # TODO: real exception here
            assert False, "server side result expected"
        return AsyncResult(result)

    async def execute(
        self,
        statement: Executable,
        parameters: Optional[Mapping] = None,
        execution_options: Mapping = util.EMPTY_DICT,
    ) -> Result:
        r"""Executes a SQL statement construct and return a buffered
        :class:`_engine.Result`.

        :param object: The statement to be executed.  This is always
         an object that is in both the :class:`_expression.ClauseElement` and
         :class:`_expression.Executable` hierarchies, including:

         * :class:`_expression.Select`
         * :class:`_expression.Insert`, :class:`_expression.Update`,
           :class:`_expression.Delete`
         * :class:`_expression.TextClause` and
           :class:`_expression.TextualSelect`
         * :class:`_schema.DDL` and objects which inherit from
           :class:`_schema.DDLElement`

        :param parameters: parameters which will be bound into the statement.
         This may be either a dictionary of parameter names to values,
         or a mutable sequence (e.g. a list) of dictionaries.  When a
         list of dictionaries is passed, the underlying statement execution
         will make use of the DBAPI ``cursor.executemany()`` method.
         When a single dictionary is passed, the DBAPI ``cursor.execute()``
         method will be used.

        :param execution_options: optional dictionary of execution options,
         which will be associated with the statement execution.  This
         dictionary can provide a subset of the options that are accepted
         by :meth:`_future.Connection.execution_options`.

        :return: a :class:`_engine.Result` object.

        """
        conn = self._sync_connection()

        result = await greenlet_spawn(
            conn._execute_20, statement, parameters, execution_options,
        )
        if result.context._is_server_side:
            raise async_exc.AsyncMethodRequired(
                "Can't use the connection.execute() method with a "
                "server-side cursor."
                "Use the connection.stream() method for an async "
                "streaming result set."
            )
        return result

    async def scalar(
        self,
        statement: Executable,
        parameters: Optional[Mapping] = None,
        execution_options: Mapping = util.EMPTY_DICT,
    ) -> Any:
        r"""Executes a SQL statement construct and returns a scalar object.

        This method is shorthand for invoking the
        :meth:`_engine.Result.scalar` method after invoking the
        :meth:`_future.Connection.execute` method.  Parameters are equivalent.

        :return: a scalar Python value representing the first column of the
         first row returned.

        """
        result = await self.execute(statement, parameters, execution_options)
        return result.scalar()

    async def run_sync(self, fn: Callable, *arg, **kw) -> Any:
        """"Invoke the given sync callable passing self as the first argument.

        This method maintains the asyncio event loop all the way through
        to the database connection by running the given callable in a
        specially instrumented greenlet.

        E.g.::

            with async_engine.begin() as conn:
                await conn.run_sync(metadata.create_all)

        """

        conn = self._sync_connection()

        return await greenlet_spawn(fn, conn, *arg, **kw)

    def __await__(self):
        return self.start().__await__()

    async def __aexit__(self, type_, value, traceback):
        await self.close()


class AsyncEngine:
    """An asyncio proxy for a :class:`_engine.Engine`.

    :class:`_asyncio.AsyncEngine` is acquired using the
    :func:`_asyncio.create_async_engine` function::

        from sqlalchemy.ext.asyncio import create_async_engine
        engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")

    .. versionadded:: 1.4


    """  # noqa

    __slots__ = ("sync_engine",)

    _connection_cls = AsyncConnection

    _option_cls: type

    class _trans_ctx(StartableContext):
        def __init__(self, conn):
            self.conn = conn

        async def start(self):
            await self.conn.start()
            self.transaction = self.conn.begin()
            await self.transaction.__aenter__()

            return self.conn

        async def __aexit__(self, type_, value, traceback):
            if type_ is not None:
                await self.transaction.rollback()
            else:
                if self.transaction.is_active:
                    await self.transaction.commit()
            await self.conn.close()

    def __init__(self, sync_engine: Engine):
        self.sync_engine = sync_engine

    def begin(self):
        """Return a context manager which when entered will deliver an
        :class:`_asyncio.AsyncConnection` with an
        :class:`_asyncio.AsyncTransaction` established.

        E.g.::

            async with async_engine.begin() as conn:
                await conn.execute(
                    text("insert into table (x, y, z) values (1, 2, 3)")
                )
                await conn.execute(text("my_special_procedure(5)"))


        """
        conn = self.connect()
        return self._trans_ctx(conn)

    def connect(self) -> AsyncConnection:
        """Return an :class:`_asyncio.AsyncConnection` object.

        The :class:`_asyncio.AsyncConnection` will procure a database
        connection from the underlying connection pool when it is entered
        as an async context manager::

            async with async_engine.connect() as conn:
                result = await conn.execute(select(user_table))

        The :class:`_asyncio.AsyncConnection` may also be started outside of a
        context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
        method.

        """

        return self._connection_cls(self.sync_engine)

    async def raw_connection(self) -> Any:
        """Return a "raw" DBAPI connection from the connection pool.

        .. seealso::

            :ref:`dbapi_connections`

        """
        return await greenlet_spawn(self.sync_engine.raw_connection)


class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
    pass


AsyncEngine._option_cls = AsyncOptionEngine


class AsyncTransaction(StartableContext):
    """An asyncio proxy for a :class:`_engine.Transaction`."""

    __slots__ = ("connection", "sync_transaction", "nested")

    def __init__(self, connection: AsyncConnection, nested: bool = False):
        self.connection = connection
        self.sync_transaction: Optional[Transaction] = None
        self.nested = nested

    def _sync_transaction(self):
        if not self.sync_transaction:
            self._raise_for_not_started()
        return self.sync_transaction

    @property
    def is_valid(self) -> bool:
        return self._sync_transaction().is_valid

    @property
    def is_active(self) -> bool:
        return self._sync_transaction().is_active

    async def close(self):
        """Close this :class:`.Transaction`.

        If this transaction is the base transaction in a begin/commit
        nesting, the transaction will rollback().  Otherwise, the
        method returns.

        This is used to cancel a Transaction without affecting the scope of
        an enclosing transaction.

        """
        await greenlet_spawn(self._sync_transaction().close)

    async def rollback(self):
        """Roll back this :class:`.Transaction`.

        """
        await greenlet_spawn(self._sync_transaction().rollback)

    async def commit(self):
        """Commit this :class:`.Transaction`."""

        await greenlet_spawn(self._sync_transaction().commit)

    async def start(self):
        """Start this :class:`_asyncio.AsyncTransaction` object's context
        outside of using a Python ``with:`` block.

        """

        self.sync_transaction = await greenlet_spawn(
            self.connection._sync_connection().begin_nested
            if self.nested
            else self.connection._sync_connection().begin
        )
        return self

    async def __aexit__(self, type_, value, traceback):
        if type_ is None and self.is_active:
            try:
                await self.commit()
            except:
                with util.safe_reraise():
                    await self.rollback()
        else:
            await self.rollback()


def _get_sync_engine(async_engine):
    try:
        return async_engine.sync_engine
    except AttributeError as e:
        raise exc.ArgumentError(
            "AsyncEngine expected, got %r" % async_engine
        ) from e