summaryrefslogtreecommitdiff
path: root/tests/unittests/config/test_cc_wireguard.py
blob: 6c91625b5d7210e1a7e6c3cc6d58139097f57393 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# This file is part of cloud-init. See LICENSE file for license information.
import pytest

from cloudinit import subp, util
from cloudinit.config import cc_wireguard
from cloudinit.config.schema import (
    SchemaValidationError,
    get_schema,
    validate_cloudconfig_schema,
)
from tests.unittests.helpers import CiTestCase, mock, skipUnlessJsonSchema

NL = "\n"
# Module path used in mocks
MPATH = "cloudinit.config.cc_wireguard"
MIN_KERNEL_VERSION = (5, 6)


class FakeCloud:
    def __init__(self, distro):
        self.distro = distro


class TestWireGuard(CiTestCase):

    with_logs = True
    allowed_subp = [CiTestCase.SUBP_SHELL_TRUE]

    def setUp(self):
        super(TestWireGuard, self).setUp()
        self.tmp = self.tmp_dir()

    def test_readiness_probe_schema_non_string_values(self):
        """ValueError raised for any values expected as string type."""
        wg_readinessprobes = [1, ["not-a-valid-command"]]
        errors = [
            "Expected a string for readinessprobe at 0. Found 1",
            "Expected a string for readinessprobe at 1."
            " Found ['not-a-valid-command']",
        ]
        with self.assertRaises(ValueError) as context_mgr:
            cc_wireguard.readinessprobe_command_validation(wg_readinessprobes)
        error_msg = str(context_mgr.exception)
        for error in errors:
            self.assertIn(error, error_msg)

    def test_suppl_schema_error_on_missing_keys(self):
        """ValueError raised reporting any missing required keys"""
        cfg = {}
        match = (
            f"Invalid wireguard interface configuration:{NL}"
            "Missing required wg:interfaces keys: config_path, content, name"
        )
        with self.assertRaisesRegex(ValueError, match):
            cc_wireguard.supplemental_schema_validation(cfg)

    def test_suppl_schema_error_on_non_string_values(self):
        """ValueError raised for any values expected as string type."""
        cfg = {"name": 1, "config_path": 2, "content": 3}
        errors = [
            "Expected a string for wg:interfaces:config_path. Found 2",
            "Expected a string for wg:interfaces:content. Found 3",
            "Expected a string for wg:interfaces:name. Found 1",
        ]
        with self.assertRaises(ValueError) as context_mgr:
            cc_wireguard.supplemental_schema_validation(cfg)
        error_msg = str(context_mgr.exception)
        for error in errors:
            self.assertIn(error, error_msg)

    def test_write_config_failed(self):
        """Errors when writing config are raised."""
        wg_int = {"name": "wg0", "config_path": "/no/valid/path"}

        with self.assertRaises(RuntimeError) as context_mgr:
            cc_wireguard.write_config(wg_int)
        self.assertIn(
            "Failure writing Wireguard configuration file /no/valid/path:\n",
            str(context_mgr.exception),
        )

    @mock.patch("%s.subp.subp" % MPATH)
    def test_readiness_probe_invalid_command(self, m_subp):
        """Errors when executing readinessprobes are raised."""
        wg_readinessprobes = ["not-a-valid-command"]

        def fake_subp(cmd, capture=None, shell=None):
            fail_cmds = ["not-a-valid-command"]
            if cmd in fail_cmds and capture and shell:
                raise subp.ProcessExecutionError(
                    "not-a-valid-command: command not found"
                )

        m_subp.side_effect = fake_subp

        with self.assertRaises(RuntimeError) as context_mgr:
            cc_wireguard.readinessprobe(wg_readinessprobes)
        self.assertIn(
            "Failed running readinessprobe command:\n"
            "not-a-valid-command: Unexpected error while"
            " running command.\n"
            "Command: -\nExit code: -\nReason: -\n"
            "Stdout: not-a-valid-command: command not found\nStderr: -",
            str(context_mgr.exception),
        )

    @mock.patch("%s.subp.subp" % MPATH)
    def test_enable_wg_on_error(self, m_subp):
        """Errors when enabling wireguard interfaces are raised."""
        wg_int = {"name": "wg0"}
        distro = mock.MagicMock()  # No errors raised
        distro.manage_service.side_effect = subp.ProcessExecutionError(
            "systemctl start wg-quik@wg0 failed: exit code 1"
        )
        mycloud = FakeCloud(distro)
        with self.assertRaises(RuntimeError) as context_mgr:
            cc_wireguard.enable_wg(wg_int, mycloud)
        self.assertEqual(
            "Failed enabling/starting Wireguard interface(s):\n"
            "Unexpected error while running command.\n"
            "Command: -\nExit code: -\nReason: -\n"
            "Stdout: systemctl start wg-quik@wg0 failed: exit code 1\n"
            "Stderr: -",
            str(context_mgr.exception),
        )

    @mock.patch("%s.subp.which" % MPATH)
    def test_maybe_install_wg_packages_noop_when_wg_tools_present(
        self, m_which
    ):
        """Do nothing if wireguard-tools already exists."""
        m_which.return_value = "/usr/bin/wg"  # already installed
        distro = mock.MagicMock()
        distro.update_package_sources.side_effect = RuntimeError(
            "Some apt error"
        )
        cc_wireguard.maybe_install_wireguard_packages(cloud=FakeCloud(distro))

    @mock.patch("%s.subp.which" % MPATH)
    def test_maybe_install_wf_tools_raises_update_errors(self, m_which):
        """maybe_install_wireguard_packages logs and raises
        apt update errors."""
        m_which.return_value = None
        distro = mock.MagicMock()
        distro.update_package_sources.side_effect = RuntimeError(
            "Some apt error"
        )
        with self.assertRaises(RuntimeError) as context_manager:
            cc_wireguard.maybe_install_wireguard_packages(
                cloud=FakeCloud(distro)
            )
        self.assertEqual("Some apt error", str(context_manager.exception))
        self.assertIn("Package update failed\nTraceback", self.logs.getvalue())

    @mock.patch("%s.subp.which" % MPATH)
    def test_maybe_install_wg_raises_install_errors(self, m_which):
        """maybe_install_wireguard_packages logs and raises package
        install errors."""
        m_which.return_value = None
        distro = mock.MagicMock()
        distro.update_package_sources.return_value = None
        distro.install_packages.side_effect = RuntimeError(
            "Some install error"
        )
        with self.assertRaises(RuntimeError) as context_manager:
            cc_wireguard.maybe_install_wireguard_packages(
                cloud=FakeCloud(distro)
            )
        self.assertEqual("Some install error", str(context_manager.exception))
        self.assertIn(
            "Failed to install wireguard-tools\n", self.logs.getvalue()
        )

    @mock.patch("%s.subp.subp" % MPATH)
    def test_load_wg_module_failed(self, m_subp):
        """load_wireguard_kernel_module logs and raises
        kernel modules loading error."""
        m_subp.side_effect = subp.ProcessExecutionError(
            "Some kernel module load error"
        )
        with self.assertRaises(subp.ProcessExecutionError) as context_manager:
            cc_wireguard.load_wireguard_kernel_module()
        self.assertEqual(
            "Unexpected error while running command.\n"
            "Command: -\nExit code: -\nReason: -\n"
            "Stdout: Some kernel module load error\n"
            "Stderr: -",
            str(context_manager.exception),
        )
        self.assertIn(
            "WARNING: Could not load wireguard module:\n", self.logs.getvalue()
        )

    @mock.patch("%s.subp.which" % MPATH)
    def test_maybe_install_wg_packages_happy_path(self, m_which):
        """maybe_install_wireguard_packages installs wireguard-tools."""
        packages = ["wireguard-tools"]

        if util.kernel_version() < MIN_KERNEL_VERSION:
            packages.append("wireguard")

        m_which.return_value = None
        distro = mock.MagicMock()  # No errors raised
        cc_wireguard.maybe_install_wireguard_packages(cloud=FakeCloud(distro))
        distro.update_package_sources.assert_called_once_with()
        distro.install_packages.assert_called_once_with(packages)

    @mock.patch("%s.maybe_install_wireguard_packages" % MPATH)
    def test_handle_no_config(self, m_maybe_install_wireguard_packages):
        """When no wireguard configuration is provided, nothing happens."""
        cfg = {}
        cc_wireguard.handle(
            "wg", cfg=cfg, cloud=None, log=self.logger, args=None
        )
        self.assertIn(
            "DEBUG: Skipping module named wg, no 'wireguard'"
            " configuration found",
            self.logs.getvalue(),
        )
        self.assertEqual(m_maybe_install_wireguard_packages.call_count, 0)

    def test_readiness_probe_with_non_string_values(self):
        """ValueError raised for any values expected as string type."""
        cfg = [1, 2]
        errors = [
            "Expected a string for readinessprobe at 0. Found 1",
            "Expected a string for readinessprobe at 1. Found 2",
        ]
        with self.assertRaises(ValueError) as context_manager:
            cc_wireguard.readinessprobe_command_validation(cfg)
        error_msg = str(context_manager.exception)
        for error in errors:
            self.assertIn(error, error_msg)


class TestWireguardSchema:
    @pytest.mark.parametrize(
        "config, error_msg",
        [
            # Valid schemas
            (
                {
                    "wireguard": {
                        "interfaces": [
                            {
                                "name": "wg0",
                                "config_path": "/etc/wireguard/wg0.conf",
                                "content": "test",
                            }
                        ]
                    }
                },
                None,
            ),
        ],
    )
    @skipUnlessJsonSchema()
    def test_schema_validation(self, config, error_msg):
        if error_msg is not None:
            with pytest.raises(SchemaValidationError, match=error_msg):
                validate_cloudconfig_schema(config, get_schema(), strict=True)
        else:
            validate_cloudconfig_schema(config, get_schema(), strict=True)


# vi: ts=4 expandtab