summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2015-12-16 12:04:14 +0000
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2015-12-16 12:04:14 +0000
commit452fd56e044d49d7d0750f99bfa9381645af6bae (patch)
tree5f4bfd443908229bb5f9be67a9b9ac1db8f72ac7
parentfe4cb0d49353f56328b9981a5140ecda65e972b4 (diff)
parent5fd0f6c4eefecb0d6150179c32c43d16c11b173d (diff)
downloadpsycopg2-452fd56e044d49d7d0750f99bfa9381645af6bae.tar.gz
Merge branch 'bug-382'
-rw-r--r--NEWS1
-rw-r--r--lib/errorcodes.py10
-rwxr-xr-xtests/__init__.py2
-rwxr-xr-xtests/test_errcodes.py65
4 files changed, 76 insertions, 2 deletions
diff --git a/NEWS b/NEWS
index 5200c4d..c1e4152 100644
--- a/NEWS
+++ b/NEWS
@@ -27,6 +27,7 @@ What's new in psycopg 2.6.2
- Raise `!NotSupportedError` on unhandled server response status
(:ticket:`#352`).
- Fixed `!PersistentConnectionPool` on Python 3 (:ticket:`#348`).
+- Fixed `!errorcodes.lookup` initialization thread-safety (:ticket:`#382`).
What's new in psycopg 2.6.1
diff --git a/lib/errorcodes.py b/lib/errorcodes.py
index 12c300f..aa5a723 100644
--- a/lib/errorcodes.py
+++ b/lib/errorcodes.py
@@ -38,11 +38,17 @@ def lookup(code, _cache={}):
return _cache[code]
# Generate the lookup map at first usage.
+ tmp = {}
for k, v in globals().iteritems():
if isinstance(v, str) and len(v) in (2, 5):
- _cache[v] = k
+ tmp[v] = k
- return lookup(code)
+ assert tmp
+
+ # Atomic update, to avoid race condition on import (bug #382)
+ _cache.update(tmp)
+
+ return _cache[code]
# autogenerated data: do not edit below this point.
diff --git a/tests/__init__.py b/tests/__init__.py
index 3e677d8..3e0db77 100755
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -34,6 +34,7 @@ import test_connection
import test_copy
import test_cursor
import test_dates
+import test_errcodes
import test_extras_dictcursor
import test_green
import test_lobject
@@ -71,6 +72,7 @@ def test_suite():
suite.addTest(test_copy.test_suite())
suite.addTest(test_cursor.test_suite())
suite.addTest(test_dates.test_suite())
+ suite.addTest(test_errcodes.test_suite())
suite.addTest(test_extras_dictcursor.test_suite())
suite.addTest(test_green.test_suite())
suite.addTest(test_lobject.test_suite())
diff --git a/tests/test_errcodes.py b/tests/test_errcodes.py
new file mode 100755
index 0000000..6cf5ddb
--- /dev/null
+++ b/tests/test_errcodes.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python
+
+# test_errcodes.py - unit test for psycopg2.errcodes module
+#
+# Copyright (C) 2015 Daniele Varrazzo <daniele.varrazzo@gmail.com>
+#
+# psycopg2 is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# In addition, as a special exception, the copyright holders give
+# permission to link this program with the OpenSSL library (or with
+# modified versions of OpenSSL that use the same license as OpenSSL),
+# and distribute linked combinations including the two.
+#
+# You must obey the GNU Lesser General Public License in all respects for
+# all of the code used other than OpenSSL.
+#
+# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
+# License for more details.
+
+from testutils import unittest, ConnectingTestCase
+
+try:
+ reload
+except NameError:
+ from imp import reload
+
+from threading import Thread
+from psycopg2 import errorcodes
+
+class ErrocodeTests(ConnectingTestCase):
+ def test_lookup_threadsafe(self):
+
+ # Increase if it does not fail with KeyError
+ MAX_CYCLES = 2000
+
+ errs = []
+ def f(pg_code='40001'):
+ try:
+ errorcodes.lookup(pg_code)
+ except Exception, e:
+ errs.append(e)
+
+ for __ in xrange(MAX_CYCLES):
+ reload(errorcodes)
+ (t1, t2) = (Thread(target=f), Thread(target=f))
+ (t1.start(), t2.start())
+ (t1.join(), t2.join())
+
+ if errs:
+ self.fail(
+ "raised %s errors in %s cycles (first is %s %s)" % (
+ len(errs), MAX_CYCLES,
+ errs[0].__class__.__name__, errs[0]))
+
+
+def test_suite():
+ return unittest.TestLoader().loadTestsFromName(__name__)
+
+if __name__ == "__main__":
+ unittest.main()