summaryrefslogtreecommitdiff
path: root/t/unit/test_sasl.py
blob: ac81a66c7be45e91c46ed5de9eabd23d073ae75a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import contextlib
import socket
import sys
from io import BytesIO
from unittest.mock import Mock, call, patch

import pytest

from amqp import sasl
from amqp.serialization import _write_table


class test_SASL:
    def test_sasl_notimplemented(self):
        mech = sasl.SASL()
        with pytest.raises(NotImplementedError):
            mech.mechanism
        with pytest.raises(NotImplementedError):
            mech.start(None)

    def test_plain(self):
        username, password = 'foo', 'bar'
        mech = sasl.PLAIN(username, password)
        response = mech.start(None)
        assert isinstance(response, bytes)
        assert response.split(b'\0') == \
            [b'', username.encode('utf-8'), password.encode('utf-8')]

    def test_plain_no_password(self):
        username, password = 'foo', None
        mech = sasl.PLAIN(username, password)
        response = mech.start(None)
        assert response == NotImplemented

    def test_amqplain(self):
        username, password = 'foo', 'bar'
        mech = sasl.AMQPLAIN(username, password)
        response = mech.start(None)
        assert isinstance(response, bytes)
        login_response = BytesIO()
        _write_table({b'LOGIN': username, b'PASSWORD': password},
                     login_response.write, [])
        expected_response = login_response.getvalue()[4:]
        assert response == expected_response

    def test_amqplain_no_password(self):
        username, password = 'foo', None
        mech = sasl.AMQPLAIN(username, password)
        response = mech.start(None)
        assert response == NotImplemented

    def test_gssapi_missing(self):
        gssapi = sys.modules.pop('gssapi', None)
        GSSAPI = sasl._get_gssapi_mechanism()
        with pytest.raises(NotImplementedError):
            GSSAPI()
        if gssapi is not None:
            sys.modules['gssapi'] = gssapi

    @contextlib.contextmanager
    def fake_gssapi(self):
        orig_gssapi = sys.modules.pop('gssapi', None)
        orig_gssapi_raw = sys.modules.pop('gssapi.raw', None)
        orig_gssapi_raw_misc = sys.modules.pop('gssapi.raw.misc', None)
        gssapi = sys.modules['gssapi'] = Mock()
        sys.modules['gssapi.raw'] = gssapi.raw
        sys.modules['gssapi.raw.misc'] = gssapi.raw.misc

        class GSSError(Exception):
            pass

        gssapi.raw.misc.GSSError = GSSError
        try:
            yield gssapi
        finally:
            if orig_gssapi is None:
                del sys.modules['gssapi']
            else:
                sys.modules['gssapi'] = orig_gssapi
            if orig_gssapi_raw is None:
                del sys.modules['gssapi.raw']
            else:
                sys.modules['gssapi.raw'] = orig_gssapi_raw
            if orig_gssapi_raw_misc is None:
                del sys.modules['gssapi.raw.misc']
            else:
                sys.modules['gssapi.raw.misc'] = orig_gssapi_raw_misc

    def test_gssapi_rdns(self):
        with self.fake_gssapi() as gssapi, \
                patch('socket.gethostbyaddr') as gethostbyaddr:
            connection = Mock()
            connection.transport.sock.getpeername.return_value = ('192.0.2.0',
                                                                  5672)
            connection.transport.sock.family = socket.AF_INET
            gethostbyaddr.return_value = ('broker.example.org', (), ())
            GSSAPI = sasl._get_gssapi_mechanism()

            mech = GSSAPI(rdns=True)
            mech.start(connection)

            connection.transport.sock.getpeername.assert_called_with()
            gethostbyaddr.assert_called_with('192.0.2.0')
            gssapi.Name.assert_called_with(b'amqp@broker.example.org',
                                           gssapi.NameType.hostbased_service)

    def test_gssapi_no_rdns(self):
        with self.fake_gssapi() as gssapi:
            connection = Mock()
            connection.transport.host = 'broker.example.org'
            GSSAPI = sasl._get_gssapi_mechanism()

            mech = GSSAPI()
            mech.start(connection)

            gssapi.Name.assert_called_with(b'amqp@broker.example.org',
                                           gssapi.NameType.hostbased_service)

    def test_gssapi_step_without_client_name(self):
        with self.fake_gssapi() as gssapi:
            context = Mock()
            context.step.return_value = b'secrets'
            name = Mock()
            gssapi.SecurityContext.return_value = context
            gssapi.Name.return_value = name
            connection = Mock()
            connection.transport.host = 'broker.example.org'
            GSSAPI = sasl._get_gssapi_mechanism()

            mech = GSSAPI()
            response = mech.start(connection)

            gssapi.SecurityContext.assert_called_with(name=name, creds=None)
            context.step.assert_called_with(None)
            assert response == b'secrets'

    def test_gssapi_step_with_client_name(self):
        with self.fake_gssapi() as gssapi:
            context = Mock()
            context.step.return_value = b'secrets'
            client_name, service_name, credentials = Mock(), Mock(), Mock()
            gssapi.SecurityContext.return_value = context
            gssapi.Credentials.return_value = credentials
            gssapi.Name.side_effect = [client_name, service_name]
            connection = Mock()
            connection.transport.host = 'broker.example.org'
            GSSAPI = sasl._get_gssapi_mechanism()

            mech = GSSAPI(client_name='amqp-client/client.example.org')
            response = mech.start(connection)
            gssapi.Name.assert_has_calls([
                call(b'amqp-client/client.example.org'),
                call(b'amqp@broker.example.org',
                     gssapi.NameType.hostbased_service)])
            gssapi.Credentials.assert_called_with(name=client_name)
            gssapi.SecurityContext.assert_called_with(name=service_name,
                                                      creds=credentials)
            context.step.assert_called_with(None)
            assert response == b'secrets'

    def test_external(self):
        mech = sasl.EXTERNAL()
        response = mech.start(None)
        assert isinstance(response, bytes)
        assert response == b''