diff options
Diffstat (limited to 'src/resolve/resolved-dns-stream.c')
-rw-r--r-- | src/resolve/resolved-dns-stream.c | 265 |
1 files changed, 108 insertions, 157 deletions
diff --git a/src/resolve/resolved-dns-stream.c b/src/resolve/resolved-dns-stream.c index 066daef96e..aee339a4c8 100644 --- a/src/resolve/resolved-dns-stream.c +++ b/src/resolve/resolved-dns-stream.c @@ -11,14 +11,15 @@ #define DNS_STREAM_TIMEOUT_USEC (10 * USEC_PER_SEC) #define DNS_STREAMS_MAX 128 -#define WRITE_TLS_DATA 1 - static void dns_stream_stop(DnsStream *s) { assert(s); s->io_event_source = sd_event_source_unref(s->io_event_source); s->timeout_event_source = sd_event_source_unref(s->timeout_event_source); s->fd = safe_close(s->fd); + + /* Disconnect us from the server object if we are now not usable anymore */ + dns_stream_detach(s); } static int dns_stream_update_io(DnsStream *s) { @@ -38,26 +39,33 @@ static int dns_stream_update_io(DnsStream *s) { if (!s->read_packet || s->n_read < sizeof(s->read_size) + s->read_packet->size) f |= EPOLLIN; +#if ENABLE_DNS_OVER_TLS + /* For handshake and clean closing purposes, TLS can override requested events */ + if (s->dnstls_events) + f = s->dnstls_events; +#endif + return sd_event_source_set_io_events(s->io_event_source, f); } static int dns_stream_complete(DnsStream *s, int error) { + _cleanup_(dns_stream_unrefp) _unused_ DnsStream *ref = dns_stream_ref(s); /* Protect stream while we process it */ + assert(s); #if ENABLE_DNS_OVER_TLS - if (s->tls_session && IN_SET(error, ETIMEDOUT, 0)) { + if (s->encrypted) { int r; - r = gnutls_bye(s->tls_session, GNUTLS_SHUT_RDWR); - if (r == GNUTLS_E_AGAIN && !s->tls_bye) { - dns_stream_ref(s); /* keep reference for closing TLS session */ - s->tls_bye = true; - } else + r = dnstls_stream_shutdown(s, error); + if (r != -EAGAIN) dns_stream_stop(s); } else #endif dns_stream_stop(s); + dns_stream_detach(s); + if (s->complete) s->complete(s, error); else /* the default action if no completion function is set is to close the stream */ @@ -191,34 +199,24 @@ static int dns_stream_identify(DnsStream *s) { return 0; } -static ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt, int flags) { - ssize_t r; +ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt, int flags) { + ssize_t m; assert(s); assert(iov); #if ENABLE_DNS_OVER_TLS - if (s->tls_session && !(flags & WRITE_TLS_DATA)) { + if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) { ssize_t ss; size_t i; - r = 0; + m = 0; for (i = 0; i < iovcnt; i++) { - ss = gnutls_record_send(s->tls_session, iov[i].iov_base, iov[i].iov_len); - if (ss < 0) { - switch(ss) { - - case GNUTLS_E_INTERRUPTED: - return -EINTR; - case GNUTLS_E_AGAIN: - return -EAGAIN; - default: - log_debug("Failed to invoke gnutls_record_send: %s", gnutls_strerror(ss)); - return -EIO; - } - } + ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len); + if (ss < 0) + return ss; - r += ss; + m += ss; if (ss != (ssize_t) iov[i].iov_len) continue; } @@ -232,80 +230,47 @@ static ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t i .msg_namelen = s->tfo_salen }; - r = sendmsg(s->fd, &hdr, MSG_FASTOPEN); - if (r < 0) { + m = sendmsg(s->fd, &hdr, MSG_FASTOPEN); + if (m < 0) { if (errno == EOPNOTSUPP) { s->tfo_salen = 0; - r = connect(s->fd, &s->tfo_address.sa, s->tfo_salen); - if (r < 0) + if (connect(s->fd, &s->tfo_address.sa, s->tfo_salen) < 0) return -errno; - r = -EAGAIN; - } else if (errno == EINPROGRESS) - r = -EAGAIN; + return -EAGAIN; + } + if (errno == EINPROGRESS) + return -EAGAIN; + + return -errno; } else s->tfo_salen = 0; /* connection is made */ } else { - r = writev(s->fd, iov, iovcnt); - if (r < 0) - r = -errno; + m = writev(s->fd, iov, iovcnt); + if (m < 0) + return -errno; } - return r; + return m; } static ssize_t dns_stream_read(DnsStream *s, void *buf, size_t count) { ssize_t ss; #if ENABLE_DNS_OVER_TLS - if (s->tls_session) { - ss = gnutls_record_recv(s->tls_session, buf, count); - if (ss < 0) { - switch(ss) { - - case GNUTLS_E_INTERRUPTED: - return -EINTR; - case GNUTLS_E_AGAIN: - return -EAGAIN; - default: - log_debug("Failed to invoke gnutls_record_send: %s", gnutls_strerror(ss)); - return -EIO; - } - } else if (s->on_connection) { - int r; - - r = s->on_connection(s); - s->on_connection = NULL; /* only call once */ - if (r < 0) - return r; - } - } else + if (s->encrypted) + ss = dnstls_stream_read(s, buf, count); + else #endif { ss = read(s->fd, buf, count); if (ss < 0) - ss = -errno; + return -errno; } return ss; } -#if ENABLE_DNS_OVER_TLS -static ssize_t dns_stream_tls_writev(gnutls_transport_ptr_t p, const giovec_t * iov, int iovcnt) { - int r; - - assert(p); - - r = dns_stream_writev((DnsStream*) p, (struct iovec*) iov, iovcnt, WRITE_TLS_DATA); - if (r < 0) { - errno = -r; - return -1; - } - - return r; -} -#endif - static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) { DnsStream *s = userdata; @@ -315,42 +280,24 @@ static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) { } static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) { - DnsStream *s = userdata; + _cleanup_(dns_stream_unrefp) DnsStream *s = dns_stream_ref(userdata); /* Protect stream while we process it */ int r; assert(s); #if ENABLE_DNS_OVER_TLS - if (s->tls_bye) { - assert(s->tls_session); - - r = gnutls_bye(s->tls_session, GNUTLS_SHUT_RDWR); - if (r != GNUTLS_E_AGAIN) { - s->tls_bye = false; - dns_stream_unref(s); - } - - return 0; - } - - if (s->tls_handshake < 0) { - assert(s->tls_session); - - s->tls_handshake = gnutls_handshake(s->tls_session); - if (s->tls_handshake >= 0) { - if (s->on_connection && !(gnutls_session_get_flags(s->tls_session) & GNUTLS_SFLAGS_FALSE_START)) { - r = s->on_connection(s); - s->on_connection = NULL; /* only call once */ - if (r < 0) - return r; - } - } else { - if (gnutls_error_is_fatal(s->tls_handshake)) - return dns_stream_complete(s, ECONNREFUSED); - else - return 0; - } + if (s->encrypted) { + r = dnstls_stream_on_io(s, revents); + if (r == DNSTLS_STREAM_CLOSED) + return 0; + if (r == -EAGAIN) + return dns_stream_update_io(s); + if (r < 0) + return dns_stream_complete(s, -r); + r = dns_stream_update_io(s); + if (r < 0) + return r; } #endif @@ -368,10 +315,8 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use struct iovec iov[2]; ssize_t ss; - iov[0].iov_base = &s->write_size; - iov[0].iov_len = sizeof(s->write_size); - iov[1].iov_base = DNS_PACKET_DATA(s->write_packet); - iov[1].iov_len = s->write_packet->size; + iov[0] = IOVEC_MAKE(&s->write_size, sizeof(s->write_size)); + iov[1] = IOVEC_MAKE(DNS_PACKET_DATA(s->write_packet), s->write_packet->size); IOVEC_INCREMENT(iov, 2, s->n_written); @@ -449,8 +394,8 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size), sizeof(s->read_size) + be16toh(s->read_size) - s->n_read); if (ss < 0) { - if (!IN_SET(errno, EINTR, EAGAIN)) - return dns_stream_complete(s, errno); + if (!IN_SET(-ss, EINTR, EAGAIN)) + return dns_stream_complete(s, -ss); } else if (ss == 0) return dns_stream_complete(s, ECONNRESET); else @@ -482,32 +427,22 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use return 0; } -DnsStream *dns_stream_unref(DnsStream *s) { +static DnsStream *dns_stream_free(DnsStream *s) { DnsPacket *p; Iterator i; - if (!s) - return NULL; - - assert(s->n_ref > 0); - s->n_ref--; - - if (s->n_ref > 0) - return NULL; + assert(s); dns_stream_stop(s); - if (s->server && s->server->stream == s) - s->server->stream = NULL; - if (s->manager) { LIST_REMOVE(streams, s->manager->dns_streams, s); s->manager->n_dns_streams--; } #if ENABLE_DNS_OVER_TLS - if (s->tls_session) - gnutls_deinit(s->tls_session); + if (s->encrypted) + dnstls_stream_free(s); #endif ORDERED_SET_FOREACH(p, s->write_queue, i) @@ -522,38 +457,39 @@ DnsStream *dns_stream_unref(DnsStream *s) { return mfree(s); } -DnsStream *dns_stream_ref(DnsStream *s) { - if (!s) - return NULL; - - assert(s->n_ref > 0); - s->n_ref++; +DEFINE_TRIVIAL_REF_UNREF_FUNC(DnsStream, dns_stream, dns_stream_free); - return s; -} +int dns_stream_new( + Manager *m, + DnsStream **ret, + DnsProtocol protocol, + int fd, + const union sockaddr_union *tfo_address) { -int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd, const union sockaddr_union *tfo_address) { _cleanup_(dns_stream_unrefp) DnsStream *s = NULL; int r; assert(m); + assert(ret); assert(fd >= 0); if (m->n_dns_streams > DNS_STREAMS_MAX) return -EBUSY; - s = new0(DnsStream, 1); + s = new(DnsStream, 1); if (!s) return -ENOMEM; + *s = (DnsStream) { + .n_ref = 1, + .fd = -1, + .protocol = protocol, + }; + r = ordered_set_ensure_allocated(&s->write_queue, &dns_packet_hash_ops); if (r < 0) return r; - s->n_ref = 1; - s->fd = -1; - s->protocol = protocol; - r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s); if (r < 0) return r; @@ -572,39 +508,26 @@ int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd, co (void) sd_event_source_set_description(s->timeout_event_source, "dns-stream-timeout"); LIST_PREPEND(streams, m->dns_streams, s); + m->n_dns_streams++; s->manager = m; + s->fd = fd; + if (tfo_address) { s->tfo_address = *tfo_address; s->tfo_salen = tfo_address->sa.sa_family == AF_INET6 ? sizeof(tfo_address->in6) : sizeof(tfo_address->in); } - m->n_dns_streams++; - *ret = TAKE_PTR(s); return 0; } -#if ENABLE_DNS_OVER_TLS -int dns_stream_connect_tls(DnsStream *s, gnutls_session_t tls_session) { - gnutls_transport_set_ptr2(tls_session, (gnutls_transport_ptr_t) (long) s->fd, s); - gnutls_transport_set_vec_push_function(tls_session, &dns_stream_tls_writev); - - s->encrypted = true; - s->tls_session = tls_session; - s->tls_handshake = gnutls_handshake(tls_session); - if (s->tls_handshake < 0 && gnutls_error_is_fatal(s->tls_handshake)) - return -ECONNREFUSED; - - return 0; -} -#endif - int dns_stream_write_packet(DnsStream *s, DnsPacket *p) { int r; assert(s); + assert(p); r = ordered_set_put(s->write_queue, p); if (r < 0) @@ -614,3 +537,31 @@ int dns_stream_write_packet(DnsStream *s, DnsPacket *p) { return dns_stream_update_io(s); } + +DnsPacket *dns_stream_take_read_packet(DnsStream *s) { + assert(s); + + if (!s->read_packet) + return NULL; + + if (s->n_read < sizeof(s->read_size)) + return NULL; + + if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) + return NULL; + + s->n_read = 0; + return TAKE_PTR(s->read_packet); +} + +void dns_stream_detach(DnsStream *s) { + assert(s); + + if (!s->server) + return; + + if (s->server->stream != s) + return; + + dns_server_unref_stream(s->server); +} |