import contextlib import socket import sys from io import BytesIO from unittest.mock import Mock, call, patch import pytest from amqp import sasl from amqp.serialization import _write_table class test_SASL: def test_sasl_notimplemented(self): mech = sasl.SASL() with pytest.raises(NotImplementedError): mech.mechanism with pytest.raises(NotImplementedError): mech.start(None) def test_plain(self): username, password = 'foo', 'bar' mech = sasl.PLAIN(username, password) response = mech.start(None) assert isinstance(response, bytes) assert response.split(b'\0') == \ [b'', username.encode('utf-8'), password.encode('utf-8')] def test_plain_no_password(self): username, password = 'foo', None mech = sasl.PLAIN(username, password) response = mech.start(None) assert response == NotImplemented def test_amqplain(self): username, password = 'foo', 'bar' mech = sasl.AMQPLAIN(username, password) response = mech.start(None) assert isinstance(response, bytes) login_response = BytesIO() _write_table({b'LOGIN': username, b'PASSWORD': password}, login_response.write, []) expected_response = login_response.getvalue()[4:] assert response == expected_response def test_amqplain_no_password(self): username, password = 'foo', None mech = sasl.AMQPLAIN(username, password) response = mech.start(None) assert response == NotImplemented def test_gssapi_missing(self): gssapi = sys.modules.pop('gssapi', None) GSSAPI = sasl._get_gssapi_mechanism() with pytest.raises(NotImplementedError): GSSAPI() if gssapi is not None: sys.modules['gssapi'] = gssapi @contextlib.contextmanager def fake_gssapi(self): orig_gssapi = sys.modules.pop('gssapi', None) orig_gssapi_raw = sys.modules.pop('gssapi.raw', None) orig_gssapi_raw_misc = sys.modules.pop('gssapi.raw.misc', None) gssapi = sys.modules['gssapi'] = Mock() sys.modules['gssapi.raw'] = gssapi.raw sys.modules['gssapi.raw.misc'] = gssapi.raw.misc class GSSError(Exception): pass gssapi.raw.misc.GSSError = GSSError try: yield gssapi finally: if orig_gssapi is None: del sys.modules['gssapi'] else: sys.modules['gssapi'] = orig_gssapi if orig_gssapi_raw is None: del sys.modules['gssapi.raw'] else: sys.modules['gssapi.raw'] = orig_gssapi_raw if orig_gssapi_raw_misc is None: del sys.modules['gssapi.raw.misc'] else: sys.modules['gssapi.raw.misc'] = orig_gssapi_raw_misc def test_gssapi_rdns(self): with self.fake_gssapi() as gssapi, \ patch('socket.gethostbyaddr') as gethostbyaddr: connection = Mock() connection.transport.sock.getpeername.return_value = ('192.0.2.0', 5672) connection.transport.sock.family = socket.AF_INET gethostbyaddr.return_value = ('broker.example.org', (), ()) GSSAPI = sasl._get_gssapi_mechanism() mech = GSSAPI(rdns=True) mech.start(connection) connection.transport.sock.getpeername.assert_called_with() gethostbyaddr.assert_called_with('192.0.2.0') gssapi.Name.assert_called_with(b'amqp@broker.example.org', gssapi.NameType.hostbased_service) def test_gssapi_no_rdns(self): with self.fake_gssapi() as gssapi: connection = Mock() connection.transport.host = 'broker.example.org' GSSAPI = sasl._get_gssapi_mechanism() mech = GSSAPI() mech.start(connection) gssapi.Name.assert_called_with(b'amqp@broker.example.org', gssapi.NameType.hostbased_service) def test_gssapi_step_without_client_name(self): with self.fake_gssapi() as gssapi: context = Mock() context.step.return_value = b'secrets' name = Mock() gssapi.SecurityContext.return_value = context gssapi.Name.return_value = name connection = Mock() connection.transport.host = 'broker.example.org' GSSAPI = sasl._get_gssapi_mechanism() mech = GSSAPI() response = mech.start(connection) gssapi.SecurityContext.assert_called_with(name=name, creds=None) context.step.assert_called_with(None) assert response == b'secrets' def test_gssapi_step_with_client_name(self): with self.fake_gssapi() as gssapi: context = Mock() context.step.return_value = b'secrets' client_name, service_name, credentials = Mock(), Mock(), Mock() gssapi.SecurityContext.return_value = context gssapi.Credentials.return_value = credentials gssapi.Name.side_effect = [client_name, service_name] connection = Mock() connection.transport.host = 'broker.example.org' GSSAPI = sasl._get_gssapi_mechanism() mech = GSSAPI(client_name='amqp-client/client.example.org') response = mech.start(connection) gssapi.Name.assert_has_calls([ call(b'amqp-client/client.example.org'), call(b'amqp@broker.example.org', gssapi.NameType.hostbased_service)]) gssapi.Credentials.assert_called_with(name=client_name) gssapi.SecurityContext.assert_called_with(name=service_name, creds=credentials) context.step.assert_called_with(None) assert response == b'secrets' def test_external(self): mech = sasl.EXTERNAL() response = mech.start(None) assert isinstance(response, bytes) assert response == b''