1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
|
# engine/threadlocal.py
# Copyright (C) 2005-2018 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides a thread-local transactional wrapper around the root Engine class.
The ``threadlocal`` module is invoked when using the
``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`.
This module is semi-private and is invoked automatically when the threadlocal
engine strategy is used.
"""
import weakref
from . import base
from .. import util
class TLConnection(base.Connection):
def __init__(self, *arg, **kw):
super(TLConnection, self).__init__(*arg, **kw)
self.__opencount = 0
def _increment_connect(self):
self.__opencount += 1
return self
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
self.__opencount -= 1
def _force_close(self):
self.__opencount = 0
base.Connection.close(self)
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed
transactions.
"""
_tl_connection_cls = TLConnection
def __init__(self, *args, **kwargs):
super(TLEngine, self).__init__(*args, **kwargs)
self._connections = util.threading.local()
def contextual_connect(self, **kw):
if not hasattr(self._connections, "conn"):
connection = None
else:
connection = self._connections.conn()
if connection is None or connection.closed:
# guards against pool-level reapers, if desired.
# or not connection.connection.is_valid:
connection = self._tl_connection_cls(
self,
self._wrap_pool_connect(self.pool.connect, connection),
**kw
)
self._connections.conn = weakref.ref(connection)
return connection._increment_connect()
def begin_twophase(self, xid=None):
if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(
self.contextual_connect().begin_twophase(xid=xid)
)
return self
def begin_nested(self):
if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(
self.contextual_connect().begin_nested()
)
return self
def begin(self):
if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin())
return self
def __enter__(self):
return self
def __exit__(self, type_, value, traceback):
if type_ is None:
self.commit()
else:
self.rollback()
def prepare(self):
if (
not hasattr(self._connections, "trans")
or not self._connections.trans
):
return
self._connections.trans[-1].prepare()
def commit(self):
if (
not hasattr(self._connections, "trans")
or not self._connections.trans
):
return
trans = self._connections.trans.pop(-1)
trans.commit()
def rollback(self):
if (
not hasattr(self._connections, "trans")
or not self._connections.trans
):
return
trans = self._connections.trans.pop(-1)
trans.rollback()
def dispose(self):
self._connections = util.threading.local()
super(TLEngine, self).dispose()
@property
def closed(self):
return (
not hasattr(self._connections, "conn")
or self._connections.conn() is None
or self._connections.conn().closed
)
def close(self):
if not self.closed:
self.contextual_connect().close()
connection = self._connections.conn()
connection._force_close()
del self._connections.conn
self._connections.trans = []
def __repr__(self):
return "TLEngine(%r)" % self.url
|