summaryrefslogtreecommitdiff
path: root/pymemcache
diff options
context:
space:
mode:
authorNick Pope <nick.pope@flightdataservices.com>2020-09-02 20:43:31 +0100
committerNick Pope <nick.pope@flightdataservices.com>2020-09-03 20:14:44 +0100
commit9551dfd08e29728b5c77530536c54ad45d75a60e (patch)
treee60aed106410dbf636055be366b7ac384fe6b782 /pymemcache
parent73621c0feca2a636c5bda643485b8223948e0295 (diff)
downloadpymemcache-9551dfd08e29728b5c77530536c54ad45d75a60e.tar.gz
Added support for connections over IPv6.
Fixes #257.
Diffstat (limited to 'pymemcache')
-rw-r--r--pymemcache/client/base.py69
-rw-r--r--pymemcache/client/hash.py9
-rw-r--r--pymemcache/test/test_client.py67
-rw-r--r--pymemcache/test/test_client_hash.py17
4 files changed, 135 insertions, 27 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index 7b3ed6d..3d790bc 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -109,6 +109,27 @@ def check_key_helper(key, allow_unicode_keys, key_prefix=b''):
return key
+def normalize_server_spec(server):
+ if isinstance(server, tuple) or server is None:
+ return server
+ if isinstance(server, list):
+ return tuple(server) # Assume [host, port] provided.
+ if not isinstance(server, six.string_types):
+ raise ValueError('Unknown server provided: %r' % server)
+ if server.startswith('unix:'):
+ return server[5:]
+ if server.startswith('/'):
+ return server
+ if ':' not in server or server.endswith(']'):
+ host, port = server, 11211
+ else:
+ host, port = server.rsplit(':', 1)
+ port = int(port)
+ if host.startswith('['):
+ host = host.strip('[]')
+ return (host, port)
+
+
class Client(object):
"""
A client for a single memcached server.
@@ -253,7 +274,7 @@ class Client(object):
The constructor does not make a connection to memcached. The first
call to a method on the object will do that.
"""
- self.server = server
+ self.server = normalize_server_spec(server)
self.serde = serde or LegacyWrappingSerde(serializer, deserializer)
self.connect_timeout = connect_timeout
self.timeout = timeout
@@ -279,25 +300,41 @@ class Client(object):
def _connect(self):
self.close()
- if isinstance(self.server, (list, tuple)):
- sock = self.socket_module.socket(self.socket_module.AF_INET,
- self.socket_module.SOCK_STREAM)
+ s = self.socket_module
+
+ if not isinstance(self.server, tuple):
+ sockaddr = self.server
+ sock = s.socket(s.AF_UNIX, s.SOCK_STREAM)
- if self.tls_context:
- sock = self.tls_context.wrap_socket(
- sock, server_hostname=self.server[0]
- )
else:
- sock = self.socket_module.socket(self.socket_module.AF_UNIX,
- self.socket_module.SOCK_STREAM)
+ sock = None
+ error = None
+ host, port = self.server
+ info = s.getaddrinfo(host, port, s.AF_UNSPEC, s.SOCK_STREAM,
+ s.IPPROTO_TCP)
+ for family, socktype, proto, _, sockaddr in info:
+ try:
+ sock = s.socket(family, socktype, proto)
+ if self.no_delay:
+ sock.setsockopt(s.IPPROTO_TCP, s.TCP_NODELAY, 1)
+ if self.tls_context:
+ context = self.tls_context
+ sock = context.wrap_socket(sock, server_hostname=host)
+ except Exception as e:
+ error = e
+ if sock is not None:
+ sock.close()
+ sock = None
+ else:
+ break
+
+ if error is not None:
+ raise error
+
try:
sock.settimeout(self.connect_timeout)
- sock.connect(self.server)
+ sock.connect(sockaddr)
sock.settimeout(self.timeout)
- if self.no_delay and sock.family == self.socket_module.AF_INET:
- sock.setsockopt(self.socket_module.IPPROTO_TCP,
- self.socket_module.TCP_NODELAY, 1)
-
except Exception:
sock.close()
raise
@@ -1030,7 +1067,7 @@ class PooledClient(object):
allow_unicode_keys=False,
encoding='ascii',
tls_context=None):
- self.server = server
+ self.server = normalize_server_spec(server)
self.serde = serde or LegacyWrappingSerde(serializer, deserializer)
self.connect_timeout = connect_timeout
self.timeout = timeout
diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py
index 9c61cb6..9e273ee 100644
--- a/pymemcache/client/hash.py
+++ b/pymemcache/client/hash.py
@@ -4,7 +4,12 @@ import time
import logging
import six
-from pymemcache.client.base import Client, PooledClient, check_key_helper
+from pymemcache.client.base import (
+ Client,
+ PooledClient,
+ check_key_helper,
+ normalize_server_spec,
+)
from pymemcache.client.rendezvous import RendezvousHash
from pymemcache.exceptions import MemcacheError
@@ -103,7 +108,7 @@ class HashClient(object):
})
for server in servers:
- self.add_server(server)
+ self.add_server(normalize_server_spec(server))
self.encoding = encoding
self.tls_context = tls_context
diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py
index 96e6551..7b8ac3b 100644
--- a/pymemcache/test/test_client.py
+++ b/pymemcache/test/test_client.py
@@ -21,12 +21,13 @@ import functools
import json
import os
import mock
+import re
import socket
import unittest
import pytest
-from pymemcache.client.base import PooledClient, Client
+from pymemcache.client.base import PooledClient, Client, normalize_server_spec
from pymemcache.exceptions import (
MemcacheClientError,
MemcacheServerError,
@@ -52,7 +53,10 @@ class MockSocket(object):
@property
def family(self):
- return socket.AF_INET
+ # TODO: Use ipaddress module when dropping support for Python < 3.3
+ ipv6_re = re.compile(r'^[0-9a-f:]+$')
+ is_ipv6 = any(ipv6_re.match(c[0]) for c in self.connections)
+ return socket.AF_INET6 if is_ipv6 else socket.AF_INET
def sendall(self, value):
self.send_bufs.append(value)
@@ -103,7 +107,7 @@ class MockSocketModule(object):
self.close_failure = close_failure
self.sockets = []
- def socket(self, family, type):
+ def socket(self, family, type, proto=0, fileno=None):
socket = MockSocket(
[],
connect_failure=self.connect_failure,
@@ -1075,12 +1079,40 @@ class TestClient(ClientTestMixin, unittest.TestCase):
@pytest.mark.unit()
class TestClientSocketConnect(unittest.TestCase):
- def test_socket_connect(self):
- server = ("example.com", 11211)
+ def test_socket_connect_ipv4(self):
+ server = ('127.0.0.1', 11211)
client = Client(server, socket_module=MockSocketModule())
client._connect()
+ print(client.sock.connections)
assert client.sock.connections == [server]
+ assert client.sock.family == socket.AF_INET
+
+ timeout = 2
+ connect_timeout = 3
+ client = Client(
+ server, connect_timeout=connect_timeout, timeout=timeout,
+ socket_module=MockSocketModule())
+ client._connect()
+ assert client.sock.timeouts == [connect_timeout, timeout]
+
+ client = Client(server, socket_module=MockSocketModule())
+ client._connect()
+ assert client.sock.socket_options == []
+
+ client = Client(
+ server, socket_module=MockSocketModule(), no_delay=True)
+ client._connect()
+ assert client.sock.socket_options == [(socket.IPPROTO_TCP,
+ socket.TCP_NODELAY, 1)]
+
+ def test_socket_connect_ipv6(self):
+ server = ('::1', 11211)
+
+ client = Client(server, socket_module=MockSocketModule())
+ client._connect()
+ assert client.sock.connections == [server + (0, 0)]
+ assert client.sock.family == socket.AF_INET6
timeout = 2
connect_timeout = 3
@@ -1330,3 +1362,28 @@ class TestRetryOnEINTR(unittest.TestCase):
b'ue1\r\nEND\r\n',
])
assert client[b'key1'] == b'value1'
+
+
+@pytest.mark.unit()
+class TestNormalizeServerSpec(unittest.TestCase):
+ def test_normalize_server_spec(self):
+ f = normalize_server_spec
+ assert f(None) is None
+ assert f(('127.0.0.1', 12345)) == ('127.0.0.1', 12345)
+ assert f(['127.0.0.1', 12345]) == ('127.0.0.1', 12345)
+ assert f('unix:/run/memcached/socket') == '/run/memcached/socket'
+ assert f('/run/memcached/socket') == '/run/memcached/socket'
+ assert f('localhost') == ('localhost', 11211)
+ assert f('localhost:12345') == ('localhost', 12345)
+ assert f('[::1]') == ('::1', 11211)
+ assert f('[::1]:12345') == ('::1', 12345)
+ assert f('127.0.0.1') == ('127.0.0.1', 11211)
+ assert f('127.0.0.1:12345') == ('127.0.0.1', 12345)
+
+ with pytest.raises(ValueError) as excinfo:
+ f({'host': 12345})
+ assert str(excinfo.value) == "Unknown server provided: {'host': 12345}"
+
+ with pytest.raises(ValueError) as excinfo:
+ f(12345)
+ assert str(excinfo.value) == "Unknown server provided: 12345"
diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py
index 4ffad5c..5dd4ec4 100644
--- a/pymemcache/test/test_client_hash.py
+++ b/pymemcache/test/test_client_hash.py
@@ -372,12 +372,21 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
assert isinstance(c, MyClient)
def test_mixed_inet_and_unix_sockets(self):
- servers = [
+ expected = {
'/tmp/pymemcache.{pid}'.format(pid=os.getpid()),
('127.0.0.1', 11211),
- ]
- client = HashClient(servers)
- assert set(servers) == {c.server for c in client.clients.values()}
+ ('::1', 11211),
+ }
+ client = HashClient([
+ '/tmp/pymemcache.{pid}'.format(pid=os.getpid()),
+ '127.0.0.1',
+ '127.0.0.1:11211',
+ '[::1]',
+ '[::1]:11211',
+ ('127.0.0.1', 11211),
+ ('::1', 11211),
+ ])
+ assert expected == {c.server for c in client.clients.values()}
def test_legacy_add_remove_server_signature(self):
server = ('127.0.0.1', 11211)