diff options
author | Alexander Dutton <gh@alexdutton.co.uk> | 2016-12-29 10:55:49 +0000 |
---|---|---|
committer | Asif Saifuddin Auvi <auvipy@users.noreply.github.com> | 2016-12-29 16:55:49 +0600 |
commit | 787793c66e1027e3cfd5585bfb374e163ec3948c (patch) | |
tree | 24bd612772e0699ece93d1e288197d8ee9249998 | |
parent | bb0be0e9333b134ed76e091e7fa2223c0617e9fc (diff) | |
download | py-amqp-787793c66e1027e3cfd5585bfb374e163ec3948c.tar.gz |
Refactored to make authn modular; added GSSAPI (#110)
* Refactored to make authn modular; added GSSAPI
* Getting tests passing again
* Add missing __future__ import to new source file
* Made it more obvious that login_method/login_response not heeded
However, we could create a SASL subclass that provides the functionality
for backwards compatibility.
* Also try PLAIN if we're given a userid/password.
* Squashing flakes
* Reimplemented login_method/login_response; flake fixing
* An extra test; deflaking
* Remove a blank line to sate flake8. Great.
* SASL tests!
Also wrapped getting GSSAPI mech impl in a function, so we can test it
with and without a `gssapi` module present.
* Fixing tests
I'd misnamed my test class while debugging something, and not reverted
it. Now the new tests are being run, another mistake appeared, now
fixed.
* More tests for the coverage gods
* Attempt GSSAPI (but don't fail hard), and use userid
This adds fail_soft and client_name arguments to GSSAPI, so that:
* amqp.connection can add GSSAPI to its default list of auth mechanisms,
but not fail if the gssapi module isn't available, or if GSSAPI
credentials aren't available
* A user can use userid to specify which GSSAPI client credentials
should be used.
Authentication mechanisms can also now return NotImplemented if they
decline to do SASL negotiation at start() time.
* Added another test for GSSAPI client name explicitly specified
* Make pydocstyle happy.
* Improving test coverage
-rw-r--r-- | amqp/connection.py | 54 | ||||
-rw-r--r-- | amqp/sasl.py | 158 | ||||
-rw-r--r-- | amqp/transport.py | 2 | ||||
-rw-r--r-- | docs/reference/amqp.sasl.rst | 11 | ||||
-rw-r--r-- | docs/reference/index.rst | 1 | ||||
-rw-r--r-- | t/unit/test_connection.py | 76 | ||||
-rw-r--r-- | t/unit/test_sasl.py | 149 | ||||
-rw-r--r-- | t/unit/test_transport.py | 2 |
8 files changed, 421 insertions, 32 deletions
diff --git a/amqp/connection.py b/amqp/connection.py index ccd5bac..e7312c5 100644 --- a/amqp/connection.py +++ b/amqp/connection.py @@ -21,11 +21,10 @@ import socket import uuid import warnings -from io import BytesIO - from vine import ensure_promise from . import __version__ +from . import sasl from . import spec from .abstract_channel import AbstractChannel from .channel import Channel @@ -36,7 +35,6 @@ from .exceptions import ( ) from .five import array, items, monotonic, range, values from .method_framing import frame_handler, frame_writer -from .serialization import _write_table from .transport import Transport try: @@ -100,8 +98,9 @@ class Connection(AbstractChannel): (defaults to 'localhost', if a port is not specified then 5672 is used) - If login_response is not specified, one is built up for you from - userid and password if they are present. + Authentication can be controlled by passing one or more + `amqp.sasl.SASL` instances as the `authentication` parameter, or + by using the userid and password parameters (for AMQPLAIN and PLAIN). The 'ssl' parameter may be simply True/False, or for Python >= 2.6 a dictionary of options to pass to ssl.wrap_socket() such as @@ -188,7 +187,8 @@ class Connection(AbstractChannel): ) def __init__(self, host='localhost:5672', userid='guest', password='guest', - login_method='AMQPLAIN', login_response=None, + login_method=None, login_response=None, + authentication=(), virtual_host='/', locale='en_US', client_properties=None, ssl=False, connect_timeout=None, channel_max=None, frame_max=None, heartbeat=0, on_open=None, on_blocked=None, @@ -199,20 +199,22 @@ class Connection(AbstractChannel): self._connection_id = uuid.uuid4().hex channel_max = channel_max or 65535 frame_max = frame_max or 131072 - if (login_response is None) \ - and (userid is not None) \ - and (password is not None): - login_response = BytesIO() - _write_table({'LOGIN': userid, 'PASSWORD': password}, - login_response.write, []) - # Skip the length at the beginning - login_response = login_response.getvalue()[4:] + if authentication: + if isinstance(authentication, sasl.SASL): + authentication = (authentication,) + self.authentication = authentication + elif login_method is not None and login_response is not None: + self.authentication = (sasl.RAW(login_method, login_response),) + elif userid is not None and password is not None: + self.authentication = (sasl.GSSAPI(userid, fail_soft=True), + sasl.AMQPLAIN(userid, password), + sasl.PLAIN(userid, password)) + else: + raise ValueError("Must supply authentication or userid/password") self.client_properties = dict( self.library_properties, **client_properties or {} ) - self.login_method = login_method - self.login_response = login_response self.locale = locale self.host = host self.virtual_host = virtual_host @@ -342,7 +344,7 @@ class Connection(AbstractChannel): self.version_major = version_major self.version_minor = version_minor self.server_properties = server_properties - self.mechanisms = mechanisms.split(' ') + self.mechanisms = mechanisms.split(b' ') self.locales = locales.split(' ') AMQP_LOGGER.debug( START_DEBUG_FMT, @@ -363,10 +365,24 @@ class Connection(AbstractChannel): # this key present in client_properties, so we remove it. client_properties.pop('capabilities', None) + for authentication in self.authentication: + if authentication.mechanism in self.mechanisms: + login_response = authentication.start(self) + if login_response is not NotImplemented: + break + else: + raise ConnectionError( + "Couldn't find appropriate auth mechanism " + "(can offer: {0}; available: {1})".format( + b", ".join(m.mechanism + for m in self.authentication + if m.mechanism).decode(), + b", ".join(self.mechanisms).decode())) + self.send_method( spec.Connection.StartOk, argsig, - (client_properties, self.login_method, - self.login_response, self.locale), + (client_properties, authentication.mechanism, + login_response, self.locale), ) def _on_secure(self, challenge): diff --git a/amqp/sasl.py b/amqp/sasl.py new file mode 100644 index 0000000..e2a0d8f --- /dev/null +++ b/amqp/sasl.py @@ -0,0 +1,158 @@ +"""SASL mechanisms for AMQP authentication.""" +from __future__ import absolute_import, unicode_literals + +from io import BytesIO +import socket + +import warnings + +from amqp.serialization import _write_table + + +class SASL(object): + """The base class for all amqp SASL authentication mechanisms. + + You should sub-class this if you're implementing your own authentication. + """ + + @property + def mechanism(self): + """Return a bytes containing the SASL mechanism name.""" + raise NotImplementedError + + def start(self, connection): + """Return the first response to a SASL challenge as a bytes object.""" + raise NotImplementedError + + +class PLAIN(SASL): + """PLAIN SASL authentication mechanism. + + See https://tools.ietf.org/html/rfc4616 for details + """ + + mechanism = b'PLAIN' + + def __init__(self, username, password): + self.username, self.password = username, password + + def start(self, connection): + login_response = BytesIO() + login_response.write(b'\0') + login_response.write(self.username.encode('utf-8')) + login_response.write(b'\0') + login_response.write(self.password.encode('utf-8')) + return login_response.getvalue() + + +class AMQPLAIN(SASL): + """AMQPLAIN SASL authentication mechanism. + + This is a non-standard mechanism used by AMQP servers. + """ + + mechanism = b'AMQPLAIN' + + def __init__(self, username, password): + self.username, self.password = username, password + + def start(self, connection): + login_response = BytesIO() + _write_table({b'LOGIN': self.username, b'PASSWORD': self.password}, + login_response.write, []) + # Skip the length at the beginning + return login_response.getvalue()[4:] + + +def _get_gssapi_mechanism(): + try: + import gssapi + except ImportError: + class FakeGSSAPI(SASL): + """A no-op SASL mechanism for when gssapi isn't available.""" + + mechanism = None + + def __init__(self, client_name=None, service=b'amqp', + rdns=False, fail_soft=False): + if not fail_soft: + raise NotImplementedError( + "You need to install the `gssapi` module for GSSAPI " + "SASL support") + + def start(self): # pragma: no cover + return NotImplemented + return FakeGSSAPI + else: + import gssapi.raw.misc + + class GSSAPI(SASL): + """GSSAPI SASL authentication mechanism. + + See https://tools.ietf.org/html/rfc4752 for details + """ + + mechanism = b'GSSAPI' + + def __init__(self, client_name=None, service=b'amqp', + rdns=False, fail_soft=False): + if client_name and not isinstance(client_name, bytes): + client_name = client_name.encode('ascii') + self.client_name = client_name + self.fail_soft = fail_soft + self.service = service + self.rdns = rdns + + def get_hostname(self, connection): + sock = connection.transport.sock + if self.rdns and sock.family in (socket.AF_INET, + socket.AF_INET6): + peer = sock.getpeername() + hostname, _, _ = socket.gethostbyaddr(peer[0]) + else: + hostname = connection.transport.host + if not isinstance(hostname, bytes): + hostname = hostname.encode('ascii') + return hostname + + def start(self, connection): + try: + if self.client_name: + creds = gssapi.Credentials( + name=gssapi.Name(self.client_name)) + else: + creds = None + hostname = self.get_hostname(connection) + name = gssapi.Name(b'@'.join([self.service, hostname]), + gssapi.NameType.hostbased_service) + context = gssapi.SecurityContext(name=name, creds=creds) + return context.step(None) + except gssapi.raw.misc.GSSError: + if self.fail_soft: + return NotImplemented + else: + raise + return GSSAPI + +GSSAPI = _get_gssapi_mechanism() + + +class RAW(SASL): + """A generic custom SASL mechanism. + + This mechanism takes a mechanism name and response to send to the server, + so can be used for simple custom authentication schemes. + """ + + mechanism = None + + def __init__(self, mechanism, response): + assert isinstance(mechanism, bytes) + assert isinstance(response, bytes) + self.mechanism, self.response = mechanism, response + warnings.warn("Passing login_method and login_response to Connection " + "is deprecated. Please implement a SASL subclass " + "instead.", DeprecationWarning) + + def start(self, connection): + return self.response diff --git a/amqp/transport.py b/amqp/transport.py index 5787165..fe0ebe3 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -286,7 +286,7 @@ class SSLTransport(_AbstractTransport): def _setup_transport(self): """Wrap the socket in an SSL object.""" - self.sock = self._wrap_socket(self.sock, **self.sslopts or {}) + self.sock = self._wrap_socket(self.sock, **self.sslopts) self.sock.do_handshake() self._quick_recv = self.sock.read diff --git a/docs/reference/amqp.sasl.rst b/docs/reference/amqp.sasl.rst new file mode 100644 index 0000000..2c062a9 --- /dev/null +++ b/docs/reference/amqp.sasl.rst @@ -0,0 +1,11 @@ +===================================================== + amqp.spec +===================================================== + +.. contents:: + :local: +.. currentmodule:: amqp.sasl + +.. automodule:: amqp.sasl + :members: + :undoc-members: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 43e1ac7..a214a84 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -19,6 +19,7 @@ amqp.method_framing amqp.platform amqp.protocol + amqp.sasl amqp.serialization amqp.spec amqp.utils diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 5c4f2bb..11b8521 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, unicode_literals import pytest import socket +import warnings from case import ContextMock, Mock, call from amqp import Connection @@ -10,6 +11,7 @@ from amqp import spec from amqp.connection import SSLError from amqp.exceptions import ConnectionError, NotFound, ResourceError from amqp.five import items +from amqp.sasl import SASL, AMQPLAIN, PLAIN from amqp.transport import TCPTransport @@ -22,6 +24,7 @@ class test_Connection: self.conn = Connection( frame_handler=self.frame_handler, frame_writer=self.frame_writer, + authentication=AMQPLAIN('foo', 'bar'), ) self.conn.Channel = Mock(name='Channel') self.conn.Transport = Mock(name='Transport') @@ -29,9 +32,27 @@ class test_Connection: self.conn.send_method = Mock(name='send_method') self.conn.frame_writer = Mock(name='frame_writer') - def test_login_response(self): - self.conn = Connection(login_response='foo') - assert self.conn.login_response == 'foo' + def test_sasl_authentication(self): + authentication = SASL() + self.conn = Connection(authentication=authentication) + assert self.conn.authentication == (authentication,) + + def test_sasl_authentication_iterable(self): + authentication = SASL() + self.conn = Connection(authentication=(authentication,)) + assert self.conn.authentication == (authentication,) + + def test_amqplain(self): + self.conn = Connection(userid='foo', password='bar') + assert isinstance(self.conn.authentication[1], AMQPLAIN) + assert self.conn.authentication[1].username == 'foo' + assert self.conn.authentication[1].password == 'bar' + + def test_plain(self): + self.conn = Connection(userid='foo', password='bar') + assert isinstance(self.conn.authentication[2], PLAIN) + assert self.conn.authentication[2].username == 'foo' + assert self.conn.authentication[2].password == 'bar' def test_enter_exit(self): self.conn.connect = Mock(name='connect') @@ -68,23 +89,56 @@ class test_Connection: callback.assert_called_with() def test_on_start(self): - self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z', 'en_US en_GB') + self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z AMQPLAIN PLAIN', + 'en_US en_GB') assert self.conn.version_major == 3 assert self.conn.version_minor == 4 assert self.conn.server_properties == {'foo': 'bar'} - assert self.conn.mechanisms == ['x', 'y', 'z'] + assert self.conn.mechanisms == [b'x', b'y', b'z', + b'AMQPLAIN', b'PLAIN'] assert self.conn.locales == ['en_US', 'en_GB'] self.conn.send_method.assert_called_with( spec.Connection.StartOk, 'FsSs', ( - self.conn.client_properties, self.conn.login_method, - self.conn.login_response, self.conn.locale, + self.conn.client_properties, b'AMQPLAIN', + self.conn.authentication[0].start(self.conn), self.conn.locale, + ), + ) + + def test_missing_credentials(self): + with pytest.raises(ValueError): + self.conn = Connection(userid=None, password=None) + with pytest.raises(ValueError): + self.conn = Connection(password=None) + + def test_mechanism_mismatch(self): + with pytest.raises(ConnectionError): + self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z', + 'en_US en_GB') + + def test_login_method_response(self): + # An old way of doing things.: + login_method, login_response = b'foo', b'bar' + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.conn = Connection(login_method=login_method, + login_response=login_response) + self.conn.send_method = Mock(name='send_method') + self.conn._on_start(3, 4, {'foo': 'bar'}, login_method, + 'en_US en_GB') + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + + self.conn.send_method.assert_called_with( + spec.Connection.StartOk, 'FsSs', ( + self.conn.client_properties, login_method, + login_response, self.conn.locale, ), ) def test_on_start__consumer_cancel_notify(self): self.conn._on_start( 3, 4, {'capabilities': {'consumer_cancel_notify': 1}}, - '', '', + b'AMQPLAIN', '', ) cap = self.conn.client_properties['capabilities'] assert cap['consumer_cancel_notify'] @@ -92,7 +146,7 @@ class test_Connection: def test_on_start__connection_blocked(self): self.conn._on_start( 3, 4, {'capabilities': {'connection.blocked': 1}}, - '', '', + b'AMQPLAIN', '', ) cap = self.conn.client_properties['capabilities'] assert cap['connection.blocked'] @@ -100,7 +154,7 @@ class test_Connection: def test_on_start__authentication_failure_close(self): self.conn._on_start( 3, 4, {'capabilities': {'authentication_failure_close': 1}}, - '', '', + b'AMQPLAIN', '', ) cap = self.conn.client_properties['capabilities'] assert cap['authentication_failure_close'] @@ -108,7 +162,7 @@ class test_Connection: def test_on_start__authentication_failure_close__disabled(self): self.conn._on_start( 3, 4, {'capabilities': {}}, - '', '', + b'AMQPLAIN', '', ) assert 'capabilities' not in self.conn.client_properties diff --git a/t/unit/test_sasl.py b/t/unit/test_sasl.py new file mode 100644 index 0000000..fc9c0ba --- /dev/null +++ b/t/unit/test_sasl.py @@ -0,0 +1,149 @@ +from __future__ import absolute_import, unicode_literals + +import contextlib +import socket +from io import BytesIO + +from case import Mock, patch, call +import pytest +import sys + +from amqp import sasl +from amqp.serialization import _write_table + + +class test_SASL: + def test_sasl_notimplemented(self): + mech = sasl.SASL() + with pytest.raises(NotImplementedError): + mech.mechanism + with pytest.raises(NotImplementedError): + mech.start(None) + + def test_plain(self): + username, password = 'foo', 'bar' + mech = sasl.PLAIN(username, password) + response = mech.start(None) + assert isinstance(response, bytes) + assert response.split(b'\0') == \ + [b'', username.encode('utf-8'), password.encode('utf-8')] + + def test_amqplain(self): + username, password = 'foo', 'bar' + mech = sasl.AMQPLAIN(username, password) + response = mech.start(None) + assert isinstance(response, bytes) + login_response = BytesIO() + _write_table({b'LOGIN': username, b'PASSWORD': password}, + login_response.write, []) + expected_response = login_response.getvalue()[4:] + assert response == expected_response + + def test_gssapi_missing(self): + gssapi = sys.modules.pop('gssapi', None) + GSSAPI = sasl._get_gssapi_mechanism() + with pytest.raises(NotImplementedError): + GSSAPI() + if gssapi is not None: + sys.modules['gssapi'] = gssapi + + @contextlib.contextmanager + def fake_gssapi(self): + orig_gssapi = sys.modules.pop('gssapi', None) + orig_gssapi_raw = sys.modules.pop('gssapi.raw', None) + orig_gssapi_raw_misc = sys.modules.pop('gssapi.raw.misc', None) + gssapi = sys.modules['gssapi'] = Mock() + sys.modules['gssapi.raw'] = gssapi.raw + sys.modules['gssapi.raw.misc'] = gssapi.raw.misc + + class GSSError(Exception): + pass + + gssapi.raw.misc.GSSError = GSSError + try: + yield gssapi + finally: + if orig_gssapi is None: + del sys.modules['gssapi'] + else: + sys.modules['gssapi'] = orig_gssapi + if orig_gssapi_raw is None: + del sys.modules['gssapi.raw'] + else: + sys.modules['gssapi.raw'] = orig_gssapi_raw + if orig_gssapi_raw_misc is None: + del sys.modules['gssapi.raw.misc'] + else: + sys.modules['gssapi.raw.misc'] = orig_gssapi_raw_misc + + @patch('socket.gethostbyaddr') + def test_gssapi_rdns(self, gethostbyaddr): + with self.fake_gssapi() as gssapi: + connection = Mock() + connection.transport.sock.getpeername.return_value = ('192.0.2.0', + 5672) + connection.transport.sock.family = socket.AF_INET + gethostbyaddr.return_value = ('broker.example.org', (), ()) + GSSAPI = sasl._get_gssapi_mechanism() + + mech = GSSAPI(rdns=True) + mech.start(connection) + + connection.transport.sock.getpeername.assert_called() + gethostbyaddr.assert_called_with('192.0.2.0') + gssapi.Name.assert_called_with(b'amqp@broker.example.org', + gssapi.NameType.hostbased_service) + + def test_gssapi_no_rdns(self): + with self.fake_gssapi() as gssapi: + connection = Mock() + connection.transport.host = 'broker.example.org' + GSSAPI = sasl._get_gssapi_mechanism() + + mech = GSSAPI() + mech.start(connection) + + gssapi.Name.assert_called_with(b'amqp@broker.example.org', + gssapi.NameType.hostbased_service) + + def test_gssapi_step_without_client_name(self): + with self.fake_gssapi() as gssapi: + context = Mock() + context.step.return_value = b'secrets' + name = Mock() + gssapi.SecurityContext.return_value = context + gssapi.Name.return_value = name + connection = Mock() + connection.transport.host = 'broker.example.org' + GSSAPI = sasl._get_gssapi_mechanism() + + mech = GSSAPI() + response = mech.start(connection) + + gssapi.SecurityContext.assert_called_with(name=name, creds=None) + context.step.assert_called_with(None) + assert response == b'secrets' + + def test_gssapi_step_with_client_name(self): + with self.fake_gssapi() as gssapi: + context = Mock() + context.step.return_value = b'secrets' + client_name, service_name, credentials = Mock(), Mock(), Mock() + gssapi.SecurityContext.return_value = context + gssapi.Credentials.return_value = credentials + gssapi.Name.side_effect = [client_name, service_name] + connection = Mock() + connection.transport.host = 'broker.example.org' + GSSAPI = sasl._get_gssapi_mechanism() + + mech = GSSAPI(client_name='amqp-client/client.example.org') + response = mech.start(connection) + gssapi.Name.assert_has_calls([ + call(b'amqp-client/client.example.org'), + call(b'amqp@broker.example.org', + gssapi.NameType.hostbased_service)]) + gssapi.Credentials.assert_called_with(name=client_name) + gssapi.SecurityContext.assert_called_with(name=service_name, + creds=credentials) + context.step.assert_called_with(None) + assert response == b'secrets' diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py index 0aa8017..32c4bbd 100644 --- a/t/unit/test_transport.py +++ b/t/unit/test_transport.py @@ -323,7 +323,7 @@ class test_SSLTransport: self.t.sock.do_handshake.assert_called_with() assert self.t._quick_recv is self.t.sock.read - @patch('ssl.wrap_socket', create=True) + @patch('ssl.wrap_socket') def test_wrap_socket(self, wrap_socket): sock = Mock() self.t._wrap_context = Mock() |