diff options
Diffstat (limited to 'src/network/netdev/wireguard.c')
-rw-r--r-- | src/network/netdev/wireguard.c | 153 |
1 files changed, 101 insertions, 52 deletions
diff --git a/src/network/netdev/wireguard.c b/src/network/netdev/wireguard.c index fb91997f7a..167cf65046 100644 --- a/src/network/netdev/wireguard.c +++ b/src/network/netdev/wireguard.c @@ -1,22 +1,24 @@ /* SPDX-License-Identifier: LGPL-2.1+ */ /*** - Copyright © 2016-2017 Jörg Thalheim <joerg@thalheim.io> Copyright © 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. ***/ #include <sys/ioctl.h> #include <net/if.h> +#include "sd-resolve.h" + #include "alloc-util.h" -#include "parse-util.h" #include "fd-util.h" -#include "strv.h" #include "hexdecoct.h" -#include "string-util.h" -#include "wireguard.h" #include "networkd-link.h" -#include "networkd-util.h" #include "networkd-manager.h" +#include "networkd-util.h" +#include "parse-util.h" +#include "resolve-private.h" +#include "string-util.h" +#include "strv.h" +#include "wireguard.h" #include "wireguard-netlink.h" static void resolve_endpoints(NetDev *netdev); @@ -29,10 +31,13 @@ static WireguardPeer *wireguard_peer_new(Wireguard *w, unsigned section) { if (w->last_peer_section == section && w->peers) return w->peers; - peer = new0(WireguardPeer, 1); + peer = new(WireguardPeer, 1); if (!peer) return NULL; - peer->flags = WGPEER_F_REPLACE_ALLOWEDIPS; + + *peer = (WireguardPeer) { + .flags = WGPEER_F_REPLACE_ALLOWEDIPS, + }; LIST_PREPEND(peers, w->peers, peer); w->last_peer_section = section; @@ -42,7 +47,7 @@ static WireguardPeer *wireguard_peer_new(Wireguard *w, unsigned section) { static int set_wireguard_interface(NetDev *netdev) { int r; - unsigned int i, j; + unsigned i, j; WireguardPeer *peer, *peer_start; WireguardIPmask *mask, *mask_start = NULL; _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *message = NULL; @@ -108,7 +113,7 @@ static int set_wireguard_interface(NetDev *netdev) { if (r < 0) break; - r = sd_netlink_message_append_u32(message, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval); + r = sd_netlink_message_append_u16(message, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval); if (r < 0) break; @@ -196,12 +201,19 @@ static int set_wireguard_interface(NetDev *netdev) { static WireguardEndpoint* wireguard_endpoint_free(WireguardEndpoint *e) { if (!e) return NULL; - netdev_unref(e->netdev); e->host = mfree(e->host); e->port = mfree(e->port); return mfree(e); } +static void wireguard_endpoint_destroy_callback(WireguardEndpoint *e) { + assert(e); + assert(e->netdev); + + netdev_unref(e->netdev); + wireguard_endpoint_free(e); +} + DEFINE_TRIVIAL_CLEANUP_FUNC(WireguardEndpoint*, wireguard_endpoint_free); static int on_resolve_retry(sd_event_source *s, usec_t usec, void *userdata) { @@ -212,8 +224,11 @@ static int on_resolve_retry(sd_event_source *s, usec_t usec, void *userdata) { w = WIREGUARD(netdev); assert(w); - w->resolve_retry_event_source = sd_event_source_unref(w->resolve_retry_event_source); + if (!netdev->manager) + /* The netdev is detached. */ + return 0; + assert(!w->unresolved_endpoints); w->unresolved_endpoints = TAKE_PTR(w->failed_endpoints); resolve_endpoints(netdev); @@ -232,28 +247,30 @@ static int exponential_backoff_milliseconds(unsigned n_retries) { static int wireguard_resolve_handler(sd_resolve_query *q, int ret, const struct addrinfo *ai, - void *userdata) { + WireguardEndpoint *e) { + _cleanup_(netdev_unrefp) NetDev *netdev_will_unrefed = NULL; NetDev *netdev; Wireguard *w; - _cleanup_(wireguard_endpoint_freep) WireguardEndpoint *e; int r; - assert(userdata); - e = userdata; - netdev = e->netdev; + assert(e); + assert(e->netdev); - assert(netdev); + netdev = e->netdev; w = WIREGUARD(netdev); assert(w); - w->resolve_query = sd_resolve_query_unref(w->resolve_query); + if (!netdev->manager) + /* The netdev is detached. */ + return 0; if (ret != 0) { log_netdev_error(netdev, "Failed to resolve host '%s:%s': %s", e->host, e->port, gai_strerror(ret)); LIST_PREPEND(endpoints, w->failed_endpoints, e); - e = NULL; + (void) sd_resolve_query_set_destroy_callback(q, NULL); /* Avoid freeing endpoint by destroy callback. */ + netdev_will_unrefed = netdev; /* But netdev needs to be unrefed. */ } else if ((ai->ai_family == AF_INET && ai->ai_addrlen == sizeof(struct sockaddr_in)) || - (ai->ai_family == AF_INET6 && ai->ai_addrlen == sizeof(struct sockaddr_in6))) + (ai->ai_family == AF_INET6 && ai->ai_addrlen == sizeof(struct sockaddr_in6))) memcpy(&e->peer->endpoint, ai->ai_addr, ai->ai_addrlen); else log_netdev_error(netdev, "Neither IPv4 nor IPv6 address found for peer endpoint: %s:%s", e->host, e->port); @@ -265,51 +282,69 @@ static int wireguard_resolve_handler(sd_resolve_query *q, set_wireguard_interface(netdev); if (w->failed_endpoints) { + _cleanup_(sd_event_source_unrefp) sd_event_source *s = NULL; + w->n_retries++; r = sd_event_add_time(netdev->manager->event, - &w->resolve_retry_event_source, + &s, CLOCK_MONOTONIC, now(CLOCK_MONOTONIC) + exponential_backoff_milliseconds(w->n_retries), 0, on_resolve_retry, netdev); - if (r < 0) + if (r < 0) { log_netdev_warning_errno(netdev, r, "Could not arm resolve retry handler: %m"); + return 0; + } + + r = sd_event_source_set_destroy_callback(s, (sd_event_destroy_t) netdev_destroy_callback); + if (r < 0) { + log_netdev_warning_errno(netdev, r, "Failed to set destroy callback to event source: %m"); + return 0; + } + + (void) sd_event_source_set_floating(s, true); + netdev_ref(netdev); } return 0; } static void resolve_endpoints(NetDev *netdev) { - int r = 0; - Wireguard *w; - WireguardEndpoint *endpoint; static const struct addrinfo hints = { .ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM, .ai_protocol = IPPROTO_UDP }; + WireguardEndpoint *endpoint; + Wireguard *w; + int r = 0; assert(netdev); w = WIREGUARD(netdev); assert(w); LIST_FOREACH(endpoints, endpoint, w->unresolved_endpoints) { - r = sd_resolve_getaddrinfo(netdev->manager->resolve, - &w->resolve_query, - endpoint->host, - endpoint->port, - &hints, - wireguard_resolve_handler, - endpoint); + r = resolve_getaddrinfo(netdev->manager->resolve, + NULL, + endpoint->host, + endpoint->port, + &hints, + wireguard_resolve_handler, + wireguard_endpoint_destroy_callback, + endpoint); if (r == -ENOBUFS) break; + if (r < 0) { + log_netdev_error_errno(netdev, r, "Failed to create resolver: %m"); + continue; + } - LIST_REMOVE(endpoints, w->unresolved_endpoints, endpoint); + /* Avoid freeing netdev. It will be unrefed by the destroy callback. */ + netdev_ref(netdev); - if (r < 0) - log_netdev_error_errno(netdev, r, "Failed create resolver: %m"); + LIST_REMOVE(endpoints, w->unresolved_endpoints, endpoint); } } @@ -532,12 +567,15 @@ int config_parse_wireguard_allowed_ips(const char *unit, return 0; } - ipmask = new0(WireguardIPmask, 1); + ipmask = new(WireguardIPmask, 1); if (!ipmask) return log_oom(); - ipmask->family = family; - ipmask->ip.in6 = addr.in6; - ipmask->cidr = prefixlen; + + *ipmask = (WireguardIPmask) { + .family = family, + .ip.in6 = addr.in6, + .cidr = prefixlen, + }; LIST_PREPEND(ipmasks, peer->ipmasks, ipmask); } @@ -573,10 +611,6 @@ int config_parse_wireguard_endpoint(const char *unit, if (!peer) return log_oom(); - endpoint = new0(WireguardEndpoint, 1); - if (!endpoint) - return log_oom(); - if (rvalue[0] == '[') { begin = &rvalue[1]; end = strchr(rvalue, ']'); @@ -610,12 +644,17 @@ int config_parse_wireguard_endpoint(const char *unit, if (!port) return log_oom(); - endpoint->peer = TAKE_PTR(peer); - endpoint->host = TAKE_PTR(host); - endpoint->port = TAKE_PTR(port); - endpoint->netdev = netdev_ref(data); - LIST_PREPEND(endpoints, w->unresolved_endpoints, endpoint); - endpoint = NULL; + endpoint = new(WireguardEndpoint, 1); + if (!endpoint) + return log_oom(); + + *endpoint = (WireguardEndpoint) { + .peer = TAKE_PTR(peer), + .host = TAKE_PTR(host), + .port = TAKE_PTR(port), + .netdev = data, + }; + LIST_PREPEND(endpoints, w->unresolved_endpoints, TAKE_PTR(endpoint)); return 0; } @@ -674,11 +713,11 @@ static void wireguard_done(NetDev *netdev) { Wireguard *w; WireguardPeer *peer; WireguardIPmask *mask; + WireguardEndpoint *e; assert(netdev); w = WIREGUARD(netdev); - assert(!w->unresolved_endpoints); - w->resolve_retry_event_source = sd_event_source_unref(w->resolve_retry_event_source); + assert(w); while ((peer = w->peers)) { LIST_REMOVE(peers, w->peers, peer); @@ -688,6 +727,16 @@ static void wireguard_done(NetDev *netdev) { } free(peer); } + + while ((e = w->unresolved_endpoints)) { + LIST_REMOVE(endpoints, w->unresolved_endpoints, e); + wireguard_endpoint_free(e); + } + + while ((e = w->failed_endpoints)) { + LIST_REMOVE(endpoints, w->failed_endpoints, e); + wireguard_endpoint_free(e); + } } const NetDevVTable wireguard_vtable = { |