diff options
author | Jonathan Beluch <jbeluch@squarespace.com> | 2016-08-23 15:10:53 -0400 |
---|---|---|
committer | Jonathan Beluch <jbeluch@squarespace.com> | 2016-08-24 14:16:10 -0400 |
commit | edf8380b90f60f8e7a4f7b0d5edef5c4aa563279 (patch) | |
tree | 67809ad2997f0fc874d1e7b62cef9d6a81dece9a | |
parent | f10eba00ea7c92315b4b39c69627058ad4931448 (diff) | |
download | sqlalchemy-pr/301.tar.gz |
Fix conn close behavior for thread local connectionspr/301
-rw-r--r-- | lib/sqlalchemy/engine/threadlocal.py | 10 | ||||
-rw-r--r-- | test/engine/test_bind.py | 17 |
2 files changed, 21 insertions, 6 deletions
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 505d1fadd..badfde8dd 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -64,6 +64,8 @@ class TLEngine(base.Engine): self.pool.connect, connection), **kw) self._connections.conn = weakref.ref(connection) + else: + connection.should_close_with_result = kw.get('close_with_result', False) return connection._increment_connect() @@ -81,11 +83,9 @@ class TLEngine(base.Engine): self.contextual_connect().begin_nested()) return self - def begin(self): - if not hasattr(self._connections, 'trans'): - self._connections.trans = [] - self._connections.trans.append(self.contextual_connect().begin()) - return self + def begin(self, close_with_result=False): + conn = self.contextual_connect() + return base.Engine._trans_ctx(conn, conn.begin(), close_with_result) def __enter__(self): return self diff --git a/test/engine/test_bind.py b/test/engine/test_bind.py index 69ab721c1..14c2e0162 100644 --- a/test/engine/test_bind.py +++ b/test/engine/test_bind.py @@ -2,7 +2,7 @@ including the deprecated versions of these arguments""" from sqlalchemy.testing import assert_raises, assert_raises_message -from sqlalchemy import engine, exc +from sqlalchemy import engine, exc, create_engine from sqlalchemy import MetaData, ThreadLocalMetaData from sqlalchemy import Integer, text from sqlalchemy.testing.schema import Table @@ -23,6 +23,21 @@ class BindTest(fixtures.TestBase): assert not conn.closed assert conn.closed + def test_tlbind_close_conn(self): + e = create_engine(testing.db.url, strategy='threadlocal') + conn = e.contextual_connect() + e.execute('select 1').fetchall() + conn.close() + assert conn.closed + + def test_tlbind_close_trans_conn(self): + e = create_engine(testing.db.url, strategy='threadlocal') + conn = e.contextual_connect() + with e.begin(): + pass + conn.close() + assert conn.closed + def test_bind_close_conn(self): e = testing.db conn = e.connect() |