diff options
author | Nick Pope <nick.pope@flightdataservices.com> | 2020-09-02 20:43:31 +0100 |
---|---|---|
committer | Nick Pope <nick.pope@flightdataservices.com> | 2020-09-03 20:14:44 +0100 |
commit | 9551dfd08e29728b5c77530536c54ad45d75a60e (patch) | |
tree | e60aed106410dbf636055be366b7ac384fe6b782 /pymemcache | |
parent | 73621c0feca2a636c5bda643485b8223948e0295 (diff) | |
download | pymemcache-9551dfd08e29728b5c77530536c54ad45d75a60e.tar.gz |
Added support for connections over IPv6.
Fixes #257.
Diffstat (limited to 'pymemcache')
-rw-r--r-- | pymemcache/client/base.py | 69 | ||||
-rw-r--r-- | pymemcache/client/hash.py | 9 | ||||
-rw-r--r-- | pymemcache/test/test_client.py | 67 | ||||
-rw-r--r-- | pymemcache/test/test_client_hash.py | 17 |
4 files changed, 135 insertions, 27 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 7b3ed6d..3d790bc 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -109,6 +109,27 @@ def check_key_helper(key, allow_unicode_keys, key_prefix=b''): return key +def normalize_server_spec(server): + if isinstance(server, tuple) or server is None: + return server + if isinstance(server, list): + return tuple(server) # Assume [host, port] provided. + if not isinstance(server, six.string_types): + raise ValueError('Unknown server provided: %r' % server) + if server.startswith('unix:'): + return server[5:] + if server.startswith('/'): + return server + if ':' not in server or server.endswith(']'): + host, port = server, 11211 + else: + host, port = server.rsplit(':', 1) + port = int(port) + if host.startswith('['): + host = host.strip('[]') + return (host, port) + + class Client(object): """ A client for a single memcached server. @@ -253,7 +274,7 @@ class Client(object): The constructor does not make a connection to memcached. The first call to a method on the object will do that. """ - self.server = server + self.server = normalize_server_spec(server) self.serde = serde or LegacyWrappingSerde(serializer, deserializer) self.connect_timeout = connect_timeout self.timeout = timeout @@ -279,25 +300,41 @@ class Client(object): def _connect(self): self.close() - if isinstance(self.server, (list, tuple)): - sock = self.socket_module.socket(self.socket_module.AF_INET, - self.socket_module.SOCK_STREAM) + s = self.socket_module + + if not isinstance(self.server, tuple): + sockaddr = self.server + sock = s.socket(s.AF_UNIX, s.SOCK_STREAM) - if self.tls_context: - sock = self.tls_context.wrap_socket( - sock, server_hostname=self.server[0] - ) else: - sock = self.socket_module.socket(self.socket_module.AF_UNIX, - self.socket_module.SOCK_STREAM) + sock = None + error = None + host, port = self.server + info = s.getaddrinfo(host, port, s.AF_UNSPEC, s.SOCK_STREAM, + s.IPPROTO_TCP) + for family, socktype, proto, _, sockaddr in info: + try: + sock = s.socket(family, socktype, proto) + if self.no_delay: + sock.setsockopt(s.IPPROTO_TCP, s.TCP_NODELAY, 1) + if self.tls_context: + context = self.tls_context + sock = context.wrap_socket(sock, server_hostname=host) + except Exception as e: + error = e + if sock is not None: + sock.close() + sock = None + else: + break + + if error is not None: + raise error + try: sock.settimeout(self.connect_timeout) - sock.connect(self.server) + sock.connect(sockaddr) sock.settimeout(self.timeout) - if self.no_delay and sock.family == self.socket_module.AF_INET: - sock.setsockopt(self.socket_module.IPPROTO_TCP, - self.socket_module.TCP_NODELAY, 1) - except Exception: sock.close() raise @@ -1030,7 +1067,7 @@ class PooledClient(object): allow_unicode_keys=False, encoding='ascii', tls_context=None): - self.server = server + self.server = normalize_server_spec(server) self.serde = serde or LegacyWrappingSerde(serializer, deserializer) self.connect_timeout = connect_timeout self.timeout = timeout diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 9c61cb6..9e273ee 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -4,7 +4,12 @@ import time import logging import six -from pymemcache.client.base import Client, PooledClient, check_key_helper +from pymemcache.client.base import ( + Client, + PooledClient, + check_key_helper, + normalize_server_spec, +) from pymemcache.client.rendezvous import RendezvousHash from pymemcache.exceptions import MemcacheError @@ -103,7 +108,7 @@ class HashClient(object): }) for server in servers: - self.add_server(server) + self.add_server(normalize_server_spec(server)) self.encoding = encoding self.tls_context = tls_context diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 96e6551..7b8ac3b 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -21,12 +21,13 @@ import functools import json import os import mock +import re import socket import unittest import pytest -from pymemcache.client.base import PooledClient, Client +from pymemcache.client.base import PooledClient, Client, normalize_server_spec from pymemcache.exceptions import ( MemcacheClientError, MemcacheServerError, @@ -52,7 +53,10 @@ class MockSocket(object): @property def family(self): - return socket.AF_INET + # TODO: Use ipaddress module when dropping support for Python < 3.3 + ipv6_re = re.compile(r'^[0-9a-f:]+$') + is_ipv6 = any(ipv6_re.match(c[0]) for c in self.connections) + return socket.AF_INET6 if is_ipv6 else socket.AF_INET def sendall(self, value): self.send_bufs.append(value) @@ -103,7 +107,7 @@ class MockSocketModule(object): self.close_failure = close_failure self.sockets = [] - def socket(self, family, type): + def socket(self, family, type, proto=0, fileno=None): socket = MockSocket( [], connect_failure=self.connect_failure, @@ -1075,12 +1079,40 @@ class TestClient(ClientTestMixin, unittest.TestCase): @pytest.mark.unit() class TestClientSocketConnect(unittest.TestCase): - def test_socket_connect(self): - server = ("example.com", 11211) + def test_socket_connect_ipv4(self): + server = ('127.0.0.1', 11211) client = Client(server, socket_module=MockSocketModule()) client._connect() + print(client.sock.connections) assert client.sock.connections == [server] + assert client.sock.family == socket.AF_INET + + timeout = 2 + connect_timeout = 3 + client = Client( + server, connect_timeout=connect_timeout, timeout=timeout, + socket_module=MockSocketModule()) + client._connect() + assert client.sock.timeouts == [connect_timeout, timeout] + + client = Client(server, socket_module=MockSocketModule()) + client._connect() + assert client.sock.socket_options == [] + + client = Client( + server, socket_module=MockSocketModule(), no_delay=True) + client._connect() + assert client.sock.socket_options == [(socket.IPPROTO_TCP, + socket.TCP_NODELAY, 1)] + + def test_socket_connect_ipv6(self): + server = ('::1', 11211) + + client = Client(server, socket_module=MockSocketModule()) + client._connect() + assert client.sock.connections == [server + (0, 0)] + assert client.sock.family == socket.AF_INET6 timeout = 2 connect_timeout = 3 @@ -1330,3 +1362,28 @@ class TestRetryOnEINTR(unittest.TestCase): b'ue1\r\nEND\r\n', ]) assert client[b'key1'] == b'value1' + + +@pytest.mark.unit() +class TestNormalizeServerSpec(unittest.TestCase): + def test_normalize_server_spec(self): + f = normalize_server_spec + assert f(None) is None + assert f(('127.0.0.1', 12345)) == ('127.0.0.1', 12345) + assert f(['127.0.0.1', 12345]) == ('127.0.0.1', 12345) + assert f('unix:/run/memcached/socket') == '/run/memcached/socket' + assert f('/run/memcached/socket') == '/run/memcached/socket' + assert f('localhost') == ('localhost', 11211) + assert f('localhost:12345') == ('localhost', 12345) + assert f('[::1]') == ('::1', 11211) + assert f('[::1]:12345') == ('::1', 12345) + assert f('127.0.0.1') == ('127.0.0.1', 11211) + assert f('127.0.0.1:12345') == ('127.0.0.1', 12345) + + with pytest.raises(ValueError) as excinfo: + f({'host': 12345}) + assert str(excinfo.value) == "Unknown server provided: {'host': 12345}" + + with pytest.raises(ValueError) as excinfo: + f(12345) + assert str(excinfo.value) == "Unknown server provided: 12345" diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py index 4ffad5c..5dd4ec4 100644 --- a/pymemcache/test/test_client_hash.py +++ b/pymemcache/test/test_client_hash.py @@ -372,12 +372,21 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): assert isinstance(c, MyClient) def test_mixed_inet_and_unix_sockets(self): - servers = [ + expected = { '/tmp/pymemcache.{pid}'.format(pid=os.getpid()), ('127.0.0.1', 11211), - ] - client = HashClient(servers) - assert set(servers) == {c.server for c in client.clients.values()} + ('::1', 11211), + } + client = HashClient([ + '/tmp/pymemcache.{pid}'.format(pid=os.getpid()), + '127.0.0.1', + '127.0.0.1:11211', + '[::1]', + '[::1]:11211', + ('127.0.0.1', 11211), + ('::1', 11211), + ]) + assert expected == {c.server for c in client.clients.values()} def test_legacy_add_remove_server_signature(self): server = ('127.0.0.1', 11211) |