From f0336417cae1c32f4ea59a9f9851a15f269340f7 Mon Sep 17 00:00:00 2001 From: tokcum <47994370+tokcum@users.noreply.github.com> Date: Wed, 30 Mar 2022 11:39:08 +0200 Subject: THRIFT-5283: add support for Unix Domain Sockets in lib/rs (#2545) Client: rs --- lib/rs/src/server/threaded.rs | 61 ++++++++++++++--- lib/rs/src/transport/socket.rs | 15 +++++ lib/rs/test/Cargo.toml | 1 + lib/rs/test/src/bin/kitchen_sink_client.rs | 65 ++++++++++++++---- lib/rs/test/src/bin/kitchen_sink_server.rs | 49 +++++++++++--- test/rs/src/bin/test_client.rs | 104 +++++++++++++++++++++++------ test/rs/src/bin/test_server.rs | 23 +++++-- test/rs/src/lib.rs | 4 -- test/tests.json | 3 +- 9 files changed, 256 insertions(+), 69 deletions(-) diff --git a/lib/rs/src/server/threaded.rs b/lib/rs/src/server/threaded.rs index 897235c2b..ad55b4459 100644 --- a/lib/rs/src/server/threaded.rs +++ b/lib/rs/src/server/threaded.rs @@ -17,10 +17,15 @@ use log::warn; -use std::net::{TcpListener, TcpStream, ToSocketAddrs}; +use std::net::{TcpListener, ToSocketAddrs}; use std::sync::Arc; use threadpool::ThreadPool; +#[cfg(unix)] +use std::os::unix::net::UnixListener; +#[cfg(unix)] +use std::path::Path; + use crate::protocol::{ TInputProtocol, TInputProtocolFactory, TOutputProtocol, TOutputProtocolFactory, }; @@ -178,10 +183,8 @@ where for stream in listener.incoming() { match stream { Ok(s) => { - let (i_prot, o_prot) = self.new_protocols_for_connection(s)?; - let processor = self.processor.clone(); - self.worker_pool - .execute(move || handle_incoming_connection(processor, i_prot, o_prot)); + let channel = TTcpChannel::with_stream(s); + self.handle_stream(channel)?; } Err(e) => { warn!("failed to accept remote connection with error {:?}", e); @@ -195,19 +198,55 @@ where })) } - fn new_protocols_for_connection( + /// Listen for incoming connections on `listen_path`. + /// + /// `listen_path` should implement `AsRef` trait. + /// + /// Return `()` if successful. + /// + /// Return `Err` when the server cannot bind to `listen_path` or there + /// is an unrecoverable error. + #[cfg(unix)] + pub fn listen_uds>(&mut self, listen_path: P) -> crate::Result<()> { + let listener = UnixListener::bind(listen_path)?; + for stream in listener.incoming() { + match stream { + Ok(s) => { + self.handle_stream(s)?; + } + Err(e) => { + warn!( + "failed to accept connection via unix domain socket with error {:?}", + e + ); + } + } + } + + Err(crate::Error::Application(ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: "aborted listen loop".into(), + })) + } + + fn handle_stream(&mut self, stream: S) -> crate::Result<()> { + let (i_prot, o_prot) = self.new_protocols_for_connection(stream)?; + let processor = self.processor.clone(); + self.worker_pool + .execute(move || handle_incoming_connection(processor, i_prot, o_prot)); + Ok(()) + } + + fn new_protocols_for_connection( &mut self, - stream: TcpStream, + stream: S, ) -> crate::Result<( Box, Box, )> { - // create the shared tcp stream - let channel = TTcpChannel::with_stream(stream); - // split it into two - one to be owned by the // input tran/proto and the other by the output - let (r_chan, w_chan) = channel.split()?; + let (r_chan, w_chan) = stream.split()?; // input protocol and transport let r_tran = self.r_trans_factory.create(Box::new(r_chan)); diff --git a/lib/rs/src/transport/socket.rs b/lib/rs/src/transport/socket.rs index 275bcd459..48d6dda13 100644 --- a/lib/rs/src/transport/socket.rs +++ b/lib/rs/src/transport/socket.rs @@ -20,6 +20,9 @@ use std::io; use std::io::{ErrorKind, Read, Write}; use std::net::{Shutdown, TcpStream, ToSocketAddrs}; +#[cfg(unix)] +use std::os::unix::net::UnixStream; + use super::{ReadHalf, TIoChannel, WriteHalf}; use crate::{new_transport_error, TransportErrorKind}; @@ -166,3 +169,15 @@ impl Write for TTcpChannel { self.if_set(|s| s.flush()) } } + +#[cfg(unix)] +impl TIoChannel for UnixStream { + fn split(self) -> crate::Result<(ReadHalf, WriteHalf)> + where + Self: Sized, + { + let socket_rx = self.try_clone().unwrap(); + + Ok((ReadHalf::new(self), WriteHalf::new(socket_rx))) + } +} diff --git a/lib/rs/test/Cargo.toml b/lib/rs/test/Cargo.toml index 0ba96fdfd..47b8cbf7a 100644 --- a/lib/rs/test/Cargo.toml +++ b/lib/rs/test/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] clap = "~2.33" bitflags = "=1.2" +log = "0.4" [dependencies.thrift] path = "../" diff --git a/lib/rs/test/src/bin/kitchen_sink_client.rs b/lib/rs/test/src/bin/kitchen_sink_client.rs index 74197de7f..b98afb814 100644 --- a/lib/rs/test/src/bin/kitchen_sink_client.rs +++ b/lib/rs/test/src/bin/kitchen_sink_client.rs @@ -16,8 +16,16 @@ // under the License. use clap::{clap_app, value_t}; +use log::*; use std::convert::Into; +use std::net::TcpStream; +use std::net::ToSocketAddrs; + +#[cfg(unix)] +use std::os::unix::net::UnixStream; +#[cfg(unix)] +use std::path::Path; use kitchen_sink::base_two::{TNapkinServiceSyncClient, TRamenServiceSyncClient}; use kitchen_sink::midlayer::{MealServiceSyncClient, TMealServiceSyncClient}; @@ -30,9 +38,9 @@ use thrift::protocol::{ TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol, TCompactOutputProtocol, TInputProtocol, TOutputProtocol, }; -use thrift::transport::{ - ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel, TTcpChannel, WriteHalf, -}; +use thrift::transport::{TFramedReadTransport, TFramedWriteTransport, TIoChannel, TTcpChannel}; + +type IoProtocol = (Box, Box); fn main() { match run() { @@ -51,6 +59,7 @@ fn run() -> thrift::Result<()> { (about: "Thrift Rust kitchen sink client") (@arg host: --host +takes_value "Host on which the Thrift test server is located") (@arg port: --port +takes_value "Port on which the Thrift test server is listening") + (@arg domain_socket: --("domain-socket") + takes_value "Unix Domain Socket on which the Thrift test server is listening") (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")") (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")") ) @@ -58,10 +67,47 @@ fn run() -> thrift::Result<()> { let host = matches.value_of("host").unwrap_or("127.0.0.1"); let port = value_t!(matches, "port", u16).unwrap_or(9090); + let domain_socket = matches.value_of("domain_socket"); let protocol = matches.value_of("protocol").unwrap_or("compact"); let service = matches.value_of("service").unwrap_or("part"); - let (i_chan, o_chan) = tcp_channel(host, port)?; + let (i_prot, o_prot) = match domain_socket { + None => { + let listen_address = format!("{}:{}", host, port); + info!("Client binds to {} with {}", listen_address, protocol); + bind(listen_address, protocol)? + } + Some(domain_socket) => { + info!("Client binds to {} (UDS) with {}", domain_socket, protocol); + bind_uds(domain_socket, protocol)? + } + }; + + run_client(service, i_prot, o_prot) +} + +fn bind(listen_address: A, protocol: &str) -> Result { + let stream = TcpStream::connect(listen_address)?; + let channel = TTcpChannel::with_stream(stream); + + let (i_prot, o_prot) = build(channel, protocol)?; + Ok((i_prot, o_prot)) +} + +#[cfg(unix)] +fn bind_uds>(domain_socket: P, protocol: &str) -> Result { + let stream = UnixStream::connect(domain_socket)?; + + let (i_prot, o_prot) = build(stream, protocol)?; + Ok((i_prot, o_prot)) +} + +fn build( + channel: C, + protocol: &str, +) -> thrift::Result<(Box, Box)> { + let (i_chan, o_chan) = channel.split()?; + let (i_tran, o_tran) = ( TFramedReadTransport::new(i_chan), TFramedWriteTransport::new(o_chan), @@ -79,7 +125,7 @@ fn run() -> thrift::Result<()> { unmatched => return Err(format!("unsupported protocol {}", unmatched).into()), }; - run_client(service, i_prot, o_prot) + Ok((i_prot, o_prot)) } fn run_client( @@ -98,15 +144,6 @@ fn run_client( } } -fn tcp_channel( - host: &str, - port: u16, -) -> thrift::Result<(ReadHalf, WriteHalf)> { - let mut c = TTcpChannel::new(); - c.open(&format!("{}:{}", host, port))?; - c.split() -} - fn exec_meal_client( i_prot: Box, o_prot: Box, diff --git a/lib/rs/test/src/bin/kitchen_sink_server.rs b/lib/rs/test/src/bin/kitchen_sink_server.rs index 8b910b3bf..ea571c686 100644 --- a/lib/rs/test/src/bin/kitchen_sink_server.rs +++ b/lib/rs/test/src/bin/kitchen_sink_server.rs @@ -16,6 +16,7 @@ // under the License. use clap::{clap_app, value_t}; +use log::*; use thrift; use thrift::protocol::{ @@ -28,6 +29,7 @@ use thrift::transport::{ TWriteTransportFactory, }; +use crate::Socket::{ListenAddress, UnixDomainSocket}; use kitchen_sink::base_one::Noodle; use kitchen_sink::base_two::{ BrothType, Napkin, NapkinServiceSyncHandler, Ramen, RamenServiceSyncHandler, @@ -42,6 +44,11 @@ use kitchen_sink::ultimate::{ FullMealServiceSyncHandler, }; +enum Socket { + ListenAddress(String), + UnixDomainSocket(String), +} + fn main() { match run() { Ok(()) => println!("kitchen sink server completed successfully"), @@ -57,18 +64,29 @@ fn run() -> thrift::Result<()> { (version: "0.1.0") (author: "Apache Thrift Developers ") (about: "Thrift Rust kitchen sink test server") - (@arg port: --port +takes_value "port on which the test server listens") + (@arg port: --port +takes_value "Port on which the Thrift test server listens") + (@arg domain_socket: --("domain-socket") + takes_value "Unix Domain Socket on which the Thrift test server listens") (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")") (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")") ) .get_matches(); let port = value_t!(matches, "port", u16).unwrap_or(9090); + let domain_socket = matches.value_of("domain_socket"); let protocol = matches.value_of("protocol").unwrap_or("compact"); let service = matches.value_of("service").unwrap_or("part"); let listen_address = format!("127.0.0.1:{}", port); - println!("binding to {}", listen_address); + let socket = match domain_socket { + None => { + info!("Server is binding to {}", listen_address); + Socket::ListenAddress(listen_address) + } + Some(domain_socket) => { + info!("Server is binding to {} (UDS)", domain_socket); + Socket::UnixDomainSocket(domain_socket.to_string()) + } + }; let r_transport_factory = TFramedReadTransportFactory::new(); let w_transport_factory = TFramedWriteTransportFactory::new(); @@ -102,21 +120,21 @@ fn run() -> thrift::Result<()> { // Since what I'm doing is uncommon I'm just going to duplicate the code match &*service { "part" => run_meal_server( - &listen_address, + socket, r_transport_factory, i_protocol_factory, w_transport_factory, o_protocol_factory, ), "full" => run_full_meal_server( - &listen_address, + socket, r_transport_factory, i_protocol_factory, w_transport_factory, o_protocol_factory, ), "recursive" => run_recursive_server( - &listen_address, + socket, r_transport_factory, i_protocol_factory, w_transport_factory, @@ -127,7 +145,7 @@ fn run() -> thrift::Result<()> { } fn run_meal_server( - listen_address: &str, + socket: Socket, r_transport_factory: RTF, i_protocol_factory: IPF, w_transport_factory: WTF, @@ -149,11 +167,14 @@ where 1, ); - server.listen(listen_address) + match socket { + ListenAddress(listen_address) => server.listen(listen_address), + UnixDomainSocket(s) => server.listen_uds(s), + } } fn run_full_meal_server( - listen_address: &str, + socket: Socket, r_transport_factory: RTF, i_protocol_factory: IPF, w_transport_factory: WTF, @@ -175,7 +196,10 @@ where 1, ); - server.listen(listen_address) + match socket { + ListenAddress(listen_address) => server.listen(listen_address), + UnixDomainSocket(s) => server.listen_uds(s), + } } struct PartHandler; @@ -267,7 +291,7 @@ fn napkin() -> Napkin { } fn run_recursive_server( - listen_address: &str, + socket: Socket, r_transport_factory: RTF, i_protocol_factory: IPF, w_transport_factory: WTF, @@ -289,7 +313,10 @@ where 1, ); - server.listen(listen_address) + match socket { + ListenAddress(listen_address) => server.listen(listen_address), + UnixDomainSocket(s) => server.listen_uds(s), + } } struct RecursiveTestServerHandler; diff --git a/test/rs/src/bin/test_client.rs b/test/rs/src/bin/test_client.rs index 8623915d4..8274aaeb2 100644 --- a/test/rs/src/bin/test_client.rs +++ b/test/rs/src/bin/test_client.rs @@ -21,7 +21,12 @@ use log::*; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::Debug; -use std::net::TcpStream; +use std::net::{TcpStream, ToSocketAddrs}; + +#[cfg(unix)] +use std::os::unix::net::UnixStream; +#[cfg(unix)] +use std::path::Path; use thrift; use thrift::protocol::{ @@ -35,6 +40,11 @@ use thrift::transport::{ use thrift::OrderedFloat; use thrift_test::*; +type ThriftClientPair = ( + ThriftTestSyncClient, Box>, + Option, Box>>, +); + fn main() { env_logger::init(); @@ -51,7 +61,6 @@ fn main() { fn run() -> thrift::Result<()> { // unsupported options: - // --domain-socket // --pipe // --anon-pipes // --ssl @@ -62,56 +71,107 @@ fn run() -> thrift::Result<()> { (about: "Rust Thrift test client") (@arg host: --host +takes_value "Host on which the Thrift test server is located") (@arg port: --port +takes_value "Port on which the Thrift test server is listening") - (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")") + (@arg domain_socket: --("domain-socket") +takes_value "Unix Domain Socket on which the Thrift test server is listening") (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\", \"multi\", \"multic\")") + (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")") (@arg testloops: -n --testloops +takes_value "Number of times to run tests") ) .get_matches(); let host = matches.value_of("host").unwrap_or("127.0.0.1"); let port = value_t!(matches, "port", u16).unwrap_or(9090); - let testloops = value_t!(matches, "testloops", u8).unwrap_or(1); - let transport = matches.value_of("transport").unwrap_or("buffered"); + let domain_socket = matches.value_of("domain_socket"); let protocol = matches.value_of("protocol").unwrap_or("binary"); + let transport = matches.value_of("transport").unwrap_or("buffered"); + let testloops = value_t!(matches, "testloops", u8).unwrap_or(1); + + let (mut thrift_test_client, mut second_service_client) = match domain_socket { + None => { + let listen_address = format!("{}:{}", host, port); + info!( + "Client binds to {} with {}+{} stack", + listen_address, protocol, transport + ); + bind(listen_address.as_str(), protocol, transport)? + } + Some(domain_socket) => { + info!( + "Client binds to {} (UDS) with {}+{} stack", + domain_socket, protocol, transport + ); + bind_uds(domain_socket, protocol, transport)? + } + }; + + for _ in 0..testloops { + make_thrift_calls(&mut thrift_test_client, &mut second_service_client)? + } + Ok(()) +} + +fn bind( + listen_address: A, + protocol: &str, + transport: &str, +) -> Result { // create a TCPStream that will be shared by all Thrift clients // service calls from multiple Thrift clients will be interleaved over the same connection // this isn't a problem for us because we're single-threaded and all calls block to completion - let shared_stream = TcpStream::connect(format!("{}:{}", host, port))?; + let shared_stream = TcpStream::connect(listen_address)?; - let mut second_service_client = if protocol.starts_with("multi") { + let second_service_client = if protocol.starts_with("multi") { let shared_stream_clone = shared_stream.try_clone()?; - let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?; + let channel = TTcpChannel::with_stream(shared_stream_clone); + let (i_prot, o_prot) = build(channel, transport, protocol, "SecondService")?; Some(SecondServiceSyncClient::new(i_prot, o_prot)) } else { None }; - let mut thrift_test_client = { - let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?; + let thrift_test_client = { + let channel = TTcpChannel::with_stream(shared_stream); + let (i_prot, o_prot) = build(channel, transport, protocol, "ThriftTest")?; ThriftTestSyncClient::new(i_prot, o_prot) }; - info!( - "connecting to {}:{} with {}+{} stack", - host, port, protocol, transport - ); + Ok((thrift_test_client, second_service_client)) +} - for _ in 0..testloops { - make_thrift_calls(&mut thrift_test_client, &mut second_service_client)? - } +#[cfg(unix)] +fn bind_uds>( + domain_socket: P, + protocol: &str, + transport: &str, +) -> Result { + // create a UnixStream that will be shared by all Thrift clients + // service calls from multiple Thrift clients will be interleaved over the same connection + // this isn't a problem for us because we're single-threaded and all calls block to completion + let shared_stream = UnixStream::connect(domain_socket)?; - Ok(()) + let second_service_client = if protocol.starts_with("multi") { + let shared_stream_clone = shared_stream.try_clone()?; + let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?; + Some(SecondServiceSyncClient::new(i_prot, o_prot)) + } else { + None + }; + + let thrift_test_client = { + let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?; + ThriftTestSyncClient::new(i_prot, o_prot) + }; + + Ok((thrift_test_client, second_service_client)) } -fn build( - stream: TcpStream, +fn build( + channel: C, transport: &str, protocol: &str, service_name: &str, ) -> thrift::Result<(Box, Box)> { - let c = TTcpChannel::with_stream(stream); - let (i_chan, o_chan) = c.split()?; + let (i_chan, o_chan) = channel.split()?; let (i_tran, o_tran): (Box, Box) = match transport { "buffered" => ( diff --git a/test/rs/src/bin/test_server.rs b/test/rs/src/bin/test_server.rs index 6a05e79e5..7e6d08f1c 100644 --- a/test/rs/src/bin/test_server.rs +++ b/test/rs/src/bin/test_server.rs @@ -52,7 +52,6 @@ fn main() { fn run() -> thrift::Result<()> { // unsupported options: - // --domain-socket // --pipe // --ssl let matches = clap_app!(rust_test_client => @@ -60,21 +59,26 @@ fn run() -> thrift::Result<()> { (author: "Apache Thrift Developers ") (about: "Rust Thrift test server") (@arg port: --port +takes_value "port on which the test server listens") + (@arg domain_socket: --("domain-socket") +takes_value "Unix Domain Socket on which the test server listens") (@arg transport: --transport +takes_value "transport implementation to use (\"buffered\", \"framed\")") (@arg protocol: --protocol +takes_value "protocol implementation to use (\"binary\", \"compact\")") - (@arg server_type: --server_type +takes_value "type of server instantiated (\"simple\", \"thread-pool\")") + (@arg server_type: --("server-type") +takes_value "type of server instantiated (\"simple\", \"thread-pool\")") (@arg workers: -n --workers +takes_value "number of thread-pool workers (\"4\")") ) - .get_matches(); + .get_matches(); let port = value_t!(matches, "port", u16).unwrap_or(9090); + let domain_socket = matches.value_of("domain_socket"); let transport = matches.value_of("transport").unwrap_or("buffered"); let protocol = matches.value_of("protocol").unwrap_or("binary"); let server_type = matches.value_of("server_type").unwrap_or("thread-pool"); let workers = value_t!(matches, "workers", usize).unwrap_or(4); let listen_address = format!("127.0.0.1:{}", port); - info!("binding to {}", listen_address); + match domain_socket { + None => info!("Server is binding to {}", listen_address), + Some(domain_socket) => info!("Server is binding to {} (UDS)", domain_socket), + } let (i_transport_factory, o_transport_factory): ( Box, @@ -135,7 +139,10 @@ fn run() -> thrift::Result<()> { workers, ); - server.listen(&listen_address) + match domain_socket { + None => server.listen(&listen_address), + Some(domain_socket) => server.listen_uds(domain_socket), + } } else { let mut server = TServer::new( i_transport_factory, @@ -146,9 +153,13 @@ fn run() -> thrift::Result<()> { workers, ); - server.listen(&listen_address) + match domain_socket { + None => server.listen(&listen_address), + Some(domain_socket) => server.listen_uds(domain_socket), + } } } + unknown => Err(format!("unsupported server type {}", unknown).into()), } } diff --git a/test/rs/src/lib.rs b/test/rs/src/lib.rs index 3c7cfc09e..9cfd7a66f 100644 --- a/test/rs/src/lib.rs +++ b/test/rs/src/lib.rs @@ -15,9 +15,5 @@ // specific language governing permissions and limitations // under the License. - - - - mod thrift_test; pub use crate::thrift_test::*; diff --git a/test/tests.json b/test/tests.json index a8dbef7d4..3563dc9ab 100644 --- a/test/tests.json +++ b/test/tests.json @@ -679,7 +679,8 @@ ] }, "sockets": [ - "ip" + "ip", + "domain" ], "transports": [ "buffered", -- cgit v1.2.1