summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Dutton <gh@alexdutton.co.uk>2016-12-29 10:55:49 +0000
committerAsif Saifuddin Auvi <auvipy@users.noreply.github.com>2016-12-29 16:55:49 +0600
commit787793c66e1027e3cfd5585bfb374e163ec3948c (patch)
tree24bd612772e0699ece93d1e288197d8ee9249998
parentbb0be0e9333b134ed76e091e7fa2223c0617e9fc (diff)
downloadpy-amqp-787793c66e1027e3cfd5585bfb374e163ec3948c.tar.gz
Refactored to make authn modular; added GSSAPI (#110)
* Refactored to make authn modular; added GSSAPI * Getting tests passing again * Add missing __future__ import to new source file * Made it more obvious that login_method/login_response not heeded However, we could create a SASL subclass that provides the functionality for backwards compatibility. * Also try PLAIN if we're given a userid/password. * Squashing flakes * Reimplemented login_method/login_response; flake fixing * An extra test; deflaking * Remove a blank line to sate flake8. Great. * SASL tests! Also wrapped getting GSSAPI mech impl in a function, so we can test it with and without a `gssapi` module present. * Fixing tests I'd misnamed my test class while debugging something, and not reverted it. Now the new tests are being run, another mistake appeared, now fixed. * More tests for the coverage gods * Attempt GSSAPI (but don't fail hard), and use userid This adds fail_soft and client_name arguments to GSSAPI, so that: * amqp.connection can add GSSAPI to its default list of auth mechanisms, but not fail if the gssapi module isn't available, or if GSSAPI credentials aren't available * A user can use userid to specify which GSSAPI client credentials should be used. Authentication mechanisms can also now return NotImplemented if they decline to do SASL negotiation at start() time. * Added another test for GSSAPI client name explicitly specified * Make pydocstyle happy. * Improving test coverage
-rw-r--r--amqp/connection.py54
-rw-r--r--amqp/sasl.py158
-rw-r--r--amqp/transport.py2
-rw-r--r--docs/reference/amqp.sasl.rst11
-rw-r--r--docs/reference/index.rst1
-rw-r--r--t/unit/test_connection.py76
-rw-r--r--t/unit/test_sasl.py149
-rw-r--r--t/unit/test_transport.py2
8 files changed, 421 insertions, 32 deletions
diff --git a/amqp/connection.py b/amqp/connection.py
index ccd5bac..e7312c5 100644
--- a/amqp/connection.py
+++ b/amqp/connection.py
@@ -21,11 +21,10 @@ import socket
import uuid
import warnings
-from io import BytesIO
-
from vine import ensure_promise
from . import __version__
+from . import sasl
from . import spec
from .abstract_channel import AbstractChannel
from .channel import Channel
@@ -36,7 +35,6 @@ from .exceptions import (
)
from .five import array, items, monotonic, range, values
from .method_framing import frame_handler, frame_writer
-from .serialization import _write_table
from .transport import Transport
try:
@@ -100,8 +98,9 @@ class Connection(AbstractChannel):
(defaults to 'localhost', if a port is not specified then
5672 is used)
- If login_response is not specified, one is built up for you from
- userid and password if they are present.
+ Authentication can be controlled by passing one or more
+ `amqp.sasl.SASL` instances as the `authentication` parameter, or
+ by using the userid and password parameters (for AMQPLAIN and PLAIN).
The 'ssl' parameter may be simply True/False, or for Python >= 2.6
a dictionary of options to pass to ssl.wrap_socket() such as
@@ -188,7 +187,8 @@ class Connection(AbstractChannel):
)
def __init__(self, host='localhost:5672', userid='guest', password='guest',
- login_method='AMQPLAIN', login_response=None,
+ login_method=None, login_response=None,
+ authentication=(),
virtual_host='/', locale='en_US', client_properties=None,
ssl=False, connect_timeout=None, channel_max=None,
frame_max=None, heartbeat=0, on_open=None, on_blocked=None,
@@ -199,20 +199,22 @@ class Connection(AbstractChannel):
self._connection_id = uuid.uuid4().hex
channel_max = channel_max or 65535
frame_max = frame_max or 131072
- if (login_response is None) \
- and (userid is not None) \
- and (password is not None):
- login_response = BytesIO()
- _write_table({'LOGIN': userid, 'PASSWORD': password},
- login_response.write, [])
- # Skip the length at the beginning
- login_response = login_response.getvalue()[4:]
+ if authentication:
+ if isinstance(authentication, sasl.SASL):
+ authentication = (authentication,)
+ self.authentication = authentication
+ elif login_method is not None and login_response is not None:
+ self.authentication = (sasl.RAW(login_method, login_response),)
+ elif userid is not None and password is not None:
+ self.authentication = (sasl.GSSAPI(userid, fail_soft=True),
+ sasl.AMQPLAIN(userid, password),
+ sasl.PLAIN(userid, password))
+ else:
+ raise ValueError("Must supply authentication or userid/password")
self.client_properties = dict(
self.library_properties, **client_properties or {}
)
- self.login_method = login_method
- self.login_response = login_response
self.locale = locale
self.host = host
self.virtual_host = virtual_host
@@ -342,7 +344,7 @@ class Connection(AbstractChannel):
self.version_major = version_major
self.version_minor = version_minor
self.server_properties = server_properties
- self.mechanisms = mechanisms.split(' ')
+ self.mechanisms = mechanisms.split(b' ')
self.locales = locales.split(' ')
AMQP_LOGGER.debug(
START_DEBUG_FMT,
@@ -363,10 +365,24 @@ class Connection(AbstractChannel):
# this key present in client_properties, so we remove it.
client_properties.pop('capabilities', None)
+ for authentication in self.authentication:
+ if authentication.mechanism in self.mechanisms:
+ login_response = authentication.start(self)
+ if login_response is not NotImplemented:
+ break
+ else:
+ raise ConnectionError(
+ "Couldn't find appropriate auth mechanism "
+ "(can offer: {0}; available: {1})".format(
+ b", ".join(m.mechanism
+ for m in self.authentication
+ if m.mechanism).decode(),
+ b", ".join(self.mechanisms).decode()))
+
self.send_method(
spec.Connection.StartOk, argsig,
- (client_properties, self.login_method,
- self.login_response, self.locale),
+ (client_properties, authentication.mechanism,
+ login_response, self.locale),
)
def _on_secure(self, challenge):
diff --git a/amqp/sasl.py b/amqp/sasl.py
new file mode 100644
index 0000000..e2a0d8f
--- /dev/null
+++ b/amqp/sasl.py
@@ -0,0 +1,158 @@
+"""SASL mechanisms for AMQP authentication."""
+from __future__ import absolute_import, unicode_literals
+
+from io import BytesIO
+import socket
+
+import warnings
+
+from amqp.serialization import _write_table
+
+
+class SASL(object):
+ """The base class for all amqp SASL authentication mechanisms.
+
+ You should sub-class this if you're implementing your own authentication.
+ """
+
+ @property
+ def mechanism(self):
+ """Return a bytes containing the SASL mechanism name."""
+ raise NotImplementedError
+
+ def start(self, connection):
+ """Return the first response to a SASL challenge as a bytes object."""
+ raise NotImplementedError
+
+
+class PLAIN(SASL):
+ """PLAIN SASL authentication mechanism.
+
+ See https://tools.ietf.org/html/rfc4616 for details
+ """
+
+ mechanism = b'PLAIN'
+
+ def __init__(self, username, password):
+ self.username, self.password = username, password
+
+ def start(self, connection):
+ login_response = BytesIO()
+ login_response.write(b'\0')
+ login_response.write(self.username.encode('utf-8'))
+ login_response.write(b'\0')
+ login_response.write(self.password.encode('utf-8'))
+ return login_response.getvalue()
+
+
+class AMQPLAIN(SASL):
+ """AMQPLAIN SASL authentication mechanism.
+
+ This is a non-standard mechanism used by AMQP servers.
+ """
+
+ mechanism = b'AMQPLAIN'
+
+ def __init__(self, username, password):
+ self.username, self.password = username, password
+
+ def start(self, connection):
+ login_response = BytesIO()
+ _write_table({b'LOGIN': self.username, b'PASSWORD': self.password},
+ login_response.write, [])
+ # Skip the length at the beginning
+ return login_response.getvalue()[4:]
+
+
+def _get_gssapi_mechanism():
+ try:
+ import gssapi
+ except ImportError:
+ class FakeGSSAPI(SASL):
+ """A no-op SASL mechanism for when gssapi isn't available."""
+
+ mechanism = None
+
+ def __init__(self, client_name=None, service=b'amqp',
+ rdns=False, fail_soft=False):
+ if not fail_soft:
+ raise NotImplementedError(
+ "You need to install the `gssapi` module for GSSAPI "
+ "SASL support")
+
+ def start(self): # pragma: no cover
+ return NotImplemented
+ return FakeGSSAPI
+ else:
+ import gssapi.raw.misc
+
+ class GSSAPI(SASL):
+ """GSSAPI SASL authentication mechanism.
+
+ See https://tools.ietf.org/html/rfc4752 for details
+ """
+
+ mechanism = b'GSSAPI'
+
+ def __init__(self, client_name=None, service=b'amqp',
+ rdns=False, fail_soft=False):
+ if client_name and not isinstance(client_name, bytes):
+ client_name = client_name.encode('ascii')
+ self.client_name = client_name
+ self.fail_soft = fail_soft
+ self.service = service
+ self.rdns = rdns
+
+ def get_hostname(self, connection):
+ sock = connection.transport.sock
+ if self.rdns and sock.family in (socket.AF_INET,
+ socket.AF_INET6):
+ peer = sock.getpeername()
+ hostname, _, _ = socket.gethostbyaddr(peer[0])
+ else:
+ hostname = connection.transport.host
+ if not isinstance(hostname, bytes):
+ hostname = hostname.encode('ascii')
+ return hostname
+
+ def start(self, connection):
+ try:
+ if self.client_name:
+ creds = gssapi.Credentials(
+ name=gssapi.Name(self.client_name))
+ else:
+ creds = None
+ hostname = self.get_hostname(connection)
+ name = gssapi.Name(b'@'.join([self.service, hostname]),
+ gssapi.NameType.hostbased_service)
+ context = gssapi.SecurityContext(name=name, creds=creds)
+ return context.step(None)
+ except gssapi.raw.misc.GSSError:
+ if self.fail_soft:
+ return NotImplemented
+ else:
+ raise
+ return GSSAPI
+
+GSSAPI = _get_gssapi_mechanism()
+
+
+class RAW(SASL):
+ """A generic custom SASL mechanism.
+
+ This mechanism takes a mechanism name and response to send to the server,
+ so can be used for simple custom authentication schemes.
+ """
+
+ mechanism = None
+
+ def __init__(self, mechanism, response):
+ assert isinstance(mechanism, bytes)
+ assert isinstance(response, bytes)
+ self.mechanism, self.response = mechanism, response
+ warnings.warn("Passing login_method and login_response to Connection "
+ "is deprecated. Please implement a SASL subclass "
+ "instead.", DeprecationWarning)
+
+ def start(self, connection):
+ return self.response
diff --git a/amqp/transport.py b/amqp/transport.py
index 5787165..fe0ebe3 100644
--- a/amqp/transport.py
+++ b/amqp/transport.py
@@ -286,7 +286,7 @@ class SSLTransport(_AbstractTransport):
def _setup_transport(self):
"""Wrap the socket in an SSL object."""
- self.sock = self._wrap_socket(self.sock, **self.sslopts or {})
+ self.sock = self._wrap_socket(self.sock, **self.sslopts)
self.sock.do_handshake()
self._quick_recv = self.sock.read
diff --git a/docs/reference/amqp.sasl.rst b/docs/reference/amqp.sasl.rst
new file mode 100644
index 0000000..2c062a9
--- /dev/null
+++ b/docs/reference/amqp.sasl.rst
@@ -0,0 +1,11 @@
+=====================================================
+ amqp.spec
+=====================================================
+
+.. contents::
+ :local:
+.. currentmodule:: amqp.sasl
+
+.. automodule:: amqp.sasl
+ :members:
+ :undoc-members:
diff --git a/docs/reference/index.rst b/docs/reference/index.rst
index 43e1ac7..a214a84 100644
--- a/docs/reference/index.rst
+++ b/docs/reference/index.rst
@@ -19,6 +19,7 @@
amqp.method_framing
amqp.platform
amqp.protocol
+ amqp.sasl
amqp.serialization
amqp.spec
amqp.utils
diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py
index 5c4f2bb..11b8521 100644
--- a/t/unit/test_connection.py
+++ b/t/unit/test_connection.py
@@ -3,6 +3,7 @@ from __future__ import absolute_import, unicode_literals
import pytest
import socket
+import warnings
from case import ContextMock, Mock, call
from amqp import Connection
@@ -10,6 +11,7 @@ from amqp import spec
from amqp.connection import SSLError
from amqp.exceptions import ConnectionError, NotFound, ResourceError
from amqp.five import items
+from amqp.sasl import SASL, AMQPLAIN, PLAIN
from amqp.transport import TCPTransport
@@ -22,6 +24,7 @@ class test_Connection:
self.conn = Connection(
frame_handler=self.frame_handler,
frame_writer=self.frame_writer,
+ authentication=AMQPLAIN('foo', 'bar'),
)
self.conn.Channel = Mock(name='Channel')
self.conn.Transport = Mock(name='Transport')
@@ -29,9 +32,27 @@ class test_Connection:
self.conn.send_method = Mock(name='send_method')
self.conn.frame_writer = Mock(name='frame_writer')
- def test_login_response(self):
- self.conn = Connection(login_response='foo')
- assert self.conn.login_response == 'foo'
+ def test_sasl_authentication(self):
+ authentication = SASL()
+ self.conn = Connection(authentication=authentication)
+ assert self.conn.authentication == (authentication,)
+
+ def test_sasl_authentication_iterable(self):
+ authentication = SASL()
+ self.conn = Connection(authentication=(authentication,))
+ assert self.conn.authentication == (authentication,)
+
+ def test_amqplain(self):
+ self.conn = Connection(userid='foo', password='bar')
+ assert isinstance(self.conn.authentication[1], AMQPLAIN)
+ assert self.conn.authentication[1].username == 'foo'
+ assert self.conn.authentication[1].password == 'bar'
+
+ def test_plain(self):
+ self.conn = Connection(userid='foo', password='bar')
+ assert isinstance(self.conn.authentication[2], PLAIN)
+ assert self.conn.authentication[2].username == 'foo'
+ assert self.conn.authentication[2].password == 'bar'
def test_enter_exit(self):
self.conn.connect = Mock(name='connect')
@@ -68,23 +89,56 @@ class test_Connection:
callback.assert_called_with()
def test_on_start(self):
- self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z', 'en_US en_GB')
+ self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z AMQPLAIN PLAIN',
+ 'en_US en_GB')
assert self.conn.version_major == 3
assert self.conn.version_minor == 4
assert self.conn.server_properties == {'foo': 'bar'}
- assert self.conn.mechanisms == ['x', 'y', 'z']
+ assert self.conn.mechanisms == [b'x', b'y', b'z',
+ b'AMQPLAIN', b'PLAIN']
assert self.conn.locales == ['en_US', 'en_GB']
self.conn.send_method.assert_called_with(
spec.Connection.StartOk, 'FsSs', (
- self.conn.client_properties, self.conn.login_method,
- self.conn.login_response, self.conn.locale,
+ self.conn.client_properties, b'AMQPLAIN',
+ self.conn.authentication[0].start(self.conn), self.conn.locale,
+ ),
+ )
+
+ def test_missing_credentials(self):
+ with pytest.raises(ValueError):
+ self.conn = Connection(userid=None, password=None)
+ with pytest.raises(ValueError):
+ self.conn = Connection(password=None)
+
+ def test_mechanism_mismatch(self):
+ with pytest.raises(ConnectionError):
+ self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z',
+ 'en_US en_GB')
+
+ def test_login_method_response(self):
+ # An old way of doing things.:
+ login_method, login_response = b'foo', b'bar'
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ self.conn = Connection(login_method=login_method,
+ login_response=login_response)
+ self.conn.send_method = Mock(name='send_method')
+ self.conn._on_start(3, 4, {'foo': 'bar'}, login_method,
+ 'en_US en_GB')
+ assert len(w) == 1
+ assert issubclass(w[0].category, DeprecationWarning)
+
+ self.conn.send_method.assert_called_with(
+ spec.Connection.StartOk, 'FsSs', (
+ self.conn.client_properties, login_method,
+ login_response, self.conn.locale,
),
)
def test_on_start__consumer_cancel_notify(self):
self.conn._on_start(
3, 4, {'capabilities': {'consumer_cancel_notify': 1}},
- '', '',
+ b'AMQPLAIN', '',
)
cap = self.conn.client_properties['capabilities']
assert cap['consumer_cancel_notify']
@@ -92,7 +146,7 @@ class test_Connection:
def test_on_start__connection_blocked(self):
self.conn._on_start(
3, 4, {'capabilities': {'connection.blocked': 1}},
- '', '',
+ b'AMQPLAIN', '',
)
cap = self.conn.client_properties['capabilities']
assert cap['connection.blocked']
@@ -100,7 +154,7 @@ class test_Connection:
def test_on_start__authentication_failure_close(self):
self.conn._on_start(
3, 4, {'capabilities': {'authentication_failure_close': 1}},
- '', '',
+ b'AMQPLAIN', '',
)
cap = self.conn.client_properties['capabilities']
assert cap['authentication_failure_close']
@@ -108,7 +162,7 @@ class test_Connection:
def test_on_start__authentication_failure_close__disabled(self):
self.conn._on_start(
3, 4, {'capabilities': {}},
- '', '',
+ b'AMQPLAIN', '',
)
assert 'capabilities' not in self.conn.client_properties
diff --git a/t/unit/test_sasl.py b/t/unit/test_sasl.py
new file mode 100644
index 0000000..fc9c0ba
--- /dev/null
+++ b/t/unit/test_sasl.py
@@ -0,0 +1,149 @@
+from __future__ import absolute_import, unicode_literals
+
+import contextlib
+import socket
+from io import BytesIO
+
+from case import Mock, patch, call
+import pytest
+import sys
+
+from amqp import sasl
+from amqp.serialization import _write_table
+
+
+class test_SASL:
+ def test_sasl_notimplemented(self):
+ mech = sasl.SASL()
+ with pytest.raises(NotImplementedError):
+ mech.mechanism
+ with pytest.raises(NotImplementedError):
+ mech.start(None)
+
+ def test_plain(self):
+ username, password = 'foo', 'bar'
+ mech = sasl.PLAIN(username, password)
+ response = mech.start(None)
+ assert isinstance(response, bytes)
+ assert response.split(b'\0') == \
+ [b'', username.encode('utf-8'), password.encode('utf-8')]
+
+ def test_amqplain(self):
+ username, password = 'foo', 'bar'
+ mech = sasl.AMQPLAIN(username, password)
+ response = mech.start(None)
+ assert isinstance(response, bytes)
+ login_response = BytesIO()
+ _write_table({b'LOGIN': username, b'PASSWORD': password},
+ login_response.write, [])
+ expected_response = login_response.getvalue()[4:]
+ assert response == expected_response
+
+ def test_gssapi_missing(self):
+ gssapi = sys.modules.pop('gssapi', None)
+ GSSAPI = sasl._get_gssapi_mechanism()
+ with pytest.raises(NotImplementedError):
+ GSSAPI()
+ if gssapi is not None:
+ sys.modules['gssapi'] = gssapi
+
+ @contextlib.contextmanager
+ def fake_gssapi(self):
+ orig_gssapi = sys.modules.pop('gssapi', None)
+ orig_gssapi_raw = sys.modules.pop('gssapi.raw', None)
+ orig_gssapi_raw_misc = sys.modules.pop('gssapi.raw.misc', None)
+ gssapi = sys.modules['gssapi'] = Mock()
+ sys.modules['gssapi.raw'] = gssapi.raw
+ sys.modules['gssapi.raw.misc'] = gssapi.raw.misc
+
+ class GSSError(Exception):
+ pass
+
+ gssapi.raw.misc.GSSError = GSSError
+ try:
+ yield gssapi
+ finally:
+ if orig_gssapi is None:
+ del sys.modules['gssapi']
+ else:
+ sys.modules['gssapi'] = orig_gssapi
+ if orig_gssapi_raw is None:
+ del sys.modules['gssapi.raw']
+ else:
+ sys.modules['gssapi.raw'] = orig_gssapi_raw
+ if orig_gssapi_raw_misc is None:
+ del sys.modules['gssapi.raw.misc']
+ else:
+ sys.modules['gssapi.raw.misc'] = orig_gssapi_raw_misc
+
+ @patch('socket.gethostbyaddr')
+ def test_gssapi_rdns(self, gethostbyaddr):
+ with self.fake_gssapi() as gssapi:
+ connection = Mock()
+ connection.transport.sock.getpeername.return_value = ('192.0.2.0',
+ 5672)
+ connection.transport.sock.family = socket.AF_INET
+ gethostbyaddr.return_value = ('broker.example.org', (), ())
+ GSSAPI = sasl._get_gssapi_mechanism()
+
+ mech = GSSAPI(rdns=True)
+ mech.start(connection)
+
+ connection.transport.sock.getpeername.assert_called()
+ gethostbyaddr.assert_called_with('192.0.2.0')
+ gssapi.Name.assert_called_with(b'amqp@broker.example.org',
+ gssapi.NameType.hostbased_service)
+
+ def test_gssapi_no_rdns(self):
+ with self.fake_gssapi() as gssapi:
+ connection = Mock()
+ connection.transport.host = 'broker.example.org'
+ GSSAPI = sasl._get_gssapi_mechanism()
+
+ mech = GSSAPI()
+ mech.start(connection)
+
+ gssapi.Name.assert_called_with(b'amqp@broker.example.org',
+ gssapi.NameType.hostbased_service)
+
+ def test_gssapi_step_without_client_name(self):
+ with self.fake_gssapi() as gssapi:
+ context = Mock()
+ context.step.return_value = b'secrets'
+ name = Mock()
+ gssapi.SecurityContext.return_value = context
+ gssapi.Name.return_value = name
+ connection = Mock()
+ connection.transport.host = 'broker.example.org'
+ GSSAPI = sasl._get_gssapi_mechanism()
+
+ mech = GSSAPI()
+ response = mech.start(connection)
+
+ gssapi.SecurityContext.assert_called_with(name=name, creds=None)
+ context.step.assert_called_with(None)
+ assert response == b'secrets'
+
+ def test_gssapi_step_with_client_name(self):
+ with self.fake_gssapi() as gssapi:
+ context = Mock()
+ context.step.return_value = b'secrets'
+ client_name, service_name, credentials = Mock(), Mock(), Mock()
+ gssapi.SecurityContext.return_value = context
+ gssapi.Credentials.return_value = credentials
+ gssapi.Name.side_effect = [client_name, service_name]
+ connection = Mock()
+ connection.transport.host = 'broker.example.org'
+ GSSAPI = sasl._get_gssapi_mechanism()
+
+ mech = GSSAPI(client_name='amqp-client/client.example.org')
+ response = mech.start(connection)
+ gssapi.Name.assert_has_calls([
+ call(b'amqp-client/client.example.org'),
+ call(b'amqp@broker.example.org',
+ gssapi.NameType.hostbased_service)])
+ gssapi.Credentials.assert_called_with(name=client_name)
+ gssapi.SecurityContext.assert_called_with(name=service_name,
+ creds=credentials)
+ context.step.assert_called_with(None)
+ assert response == b'secrets'
diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py
index 0aa8017..32c4bbd 100644
--- a/t/unit/test_transport.py
+++ b/t/unit/test_transport.py
@@ -323,7 +323,7 @@ class test_SSLTransport:
self.t.sock.do_handshake.assert_called_with()
assert self.t._quick_recv is self.t.sock.read
- @patch('ssl.wrap_socket', create=True)
+ @patch('ssl.wrap_socket')
def test_wrap_socket(self, wrap_socket):
sock = Mock()
self.t._wrap_context = Mock()