diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2014-05-13 15:01:00 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2014-05-13 15:01:00 -0700 |
commit | 61030206f101f385b1b4e04569c57b0f81fda754 (patch) | |
tree | 0d60a4bed5bd69387278ea4c53452ab02fa7aadf | |
parent | 17ff3774a74a251f17ffb1235b8d38a17fa1b13f (diff) | |
download | redis-py-61030206f101f385b1b4e04569c57b0f81fda754.tar.gz |
allow cert_reqs to be a string and convert it to the appropriate SSL constant.
-rwxr-xr-x | redis/connection.py | 10 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 26 |
2 files changed, 36 insertions, 0 deletions
diff --git a/redis/connection.py b/redis/connection.py index 5607508..c209750 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -570,6 +570,16 @@ class SSLConnection(Connection): self.certfile = certfile if cert_reqs is None: cert_reqs = ssl.CERT_NONE + elif isinstance(cert_reqs, basestring): + CERT_REQS = { + 'none': ssl.CERT_NONE, + 'optional': ssl.CERT_OPTIONAL, + 'required': ssl.CERT_REQUIRED + } + if cert_reqs not in CERT_REQS: + raise RedisError( + "Invalid SSL Certificate Required Flag: %s" % cert_reqs) + cert_reqs = CERT_REQS[cert_reqs] self.cert_reqs = cert_reqs self.ca_certs = ca_certs diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 04d5e8c..934636b 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -6,6 +6,7 @@ import time import re from threading import Thread +from redis.connection import ssl_available from .conftest import skip_if_server_version_lt @@ -289,6 +290,31 @@ class TestConnectionPoolUnixSocketURLParsing(object): } +class TestSSLConnectionURLParsing(object): + @pytest.mark.skipif(not ssl_available, reason="SSL not installed") + def test_defaults(self): + pool = redis.ConnectionPool.from_url('rediss://localhost') + assert pool.connection_class == redis.SSLConnection + assert pool.connection_kwargs == { + 'host': 'localhost', + 'port': 6379, + 'db': 0, + 'password': None, + } + + @pytest.mark.skipif(not ssl_available, reason="SSL not installed") + def test_cert_reqs_options(self): + import ssl + pool = redis.ConnectionPool.from_url('rediss://?cert_reqs=none') + assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE + + pool = redis.ConnectionPool.from_url('rediss://?cert_reqs=optional') + assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL + + pool = redis.ConnectionPool.from_url('rediss://?cert_reqs=required') + assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED + + class TestConnection(object): def test_on_connect_error(self): """ |