summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/__init__.py21
-rwxr-xr-xtests/test_module.py20
2 files changed, 40 insertions, 1 deletions
diff --git a/lib/__init__.py b/lib/__init__.py
index 7676f3c..f42d081 100644
--- a/lib/__init__.py
+++ b/lib/__init__.py
@@ -97,6 +97,24 @@ else:
_ext.register_adapter(Decimal, Adapter)
del Decimal, Adapter
+import re
+
+def _param_escape(s,
+ re_escape=re.compile(r"([\\'])"),
+ re_space=re.compile(r'\s')):
+ """
+ Apply the escaping rule required by PQconnectdb
+ """
+ if not s: return "''"
+
+ s = re_escape.sub(r'\\\1', s)
+ if re_space.search(s):
+ s = "'" + s + "'"
+
+ return s
+
+del re
+
def connect(dsn=None,
database=None, user=None, password=None, host=None, port=None,
@@ -147,7 +165,8 @@ def connect(dsn=None,
items.extend(
[(k, v) for (k, v) in kwargs.iteritems() if v is not None])
- dsn = " ".join(["%s=%s" % item for item in items])
+ dsn = " ".join(["%s=%s" % (k, _param_escape(str(v)))
+ for (k, v) in items])
if not dsn:
raise InterfaceError('missing dsn and no parameters')
diff --git a/tests/test_module.py b/tests/test_module.py
index 66eeccf..5d45187 100755
--- a/tests/test_module.py
+++ b/tests/test_module.py
@@ -99,6 +99,26 @@ class ConnectTestCase(unittest.TestCase):
self.assertEqual(self.args[1], None)
self.assert_(self.args[2])
+ def test_empty_param(self):
+ psycopg2.connect(database='sony', password='')
+ self.assertEqual(self.args[0], "dbname=sony password=''")
+
+ def test_escape(self):
+ psycopg2.connect(database='hello world')
+ self.assertEqual(self.args[0], "dbname='hello world'")
+
+ psycopg2.connect(database=r'back\slash')
+ self.assertEqual(self.args[0], r"dbname=back\\slash")
+
+ psycopg2.connect(database="quo'te")
+ self.assertEqual(self.args[0], r"dbname=quo\'te")
+
+ psycopg2.connect(database="with\ttab")
+ self.assertEqual(self.args[0], "dbname='with\ttab'")
+
+ psycopg2.connect(database=r"\every thing'")
+ self.assertEqual(self.args[0], r"dbname='\\every thing\''")
+
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)