diff options
author | Allen George <allen.george@gmail.com> | 2017-01-30 07:15:00 -0500 |
---|---|---|
committer | James E. King, III <jking@apache.org> | 2017-04-27 08:46:02 -0400 |
commit | 0e22c362b967bd3765ee3da349faa789904a0707 (patch) | |
tree | cf7271e15659c1181abb6ed8c57b599d79d026f3 | |
parent | 9db23b7be330f47037b4e3e5e374eda5e38b0dfd (diff) | |
download | thrift-0e22c362b967bd3765ee3da349faa789904a0707.tar.gz |
THRIFT-4176: Implement threaded server for Rust
Client: rs
* Create a TIoChannel construct
* Separate TTransport into TReadTransport and TWriteTransport
* Restructure types to avoid shared ownership
* Remove user-visible boxing and ref-counting
* Replace TSimpleServer with a thread-pool based TServer
This closes #1255
29 files changed, 3209 insertions, 1874 deletions
diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 000000000..2962d47aa --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,7 @@ +max_width = 100 +fn_args_layout = "Block" +array_layout = "Block" +where_style = "Rfc" +generics_indent = "Block" +fn_call_style = "Block" +reorder_imported_names = true diff --git a/compiler/cpp/src/thrift/generate/t_rs_generator.cc b/compiler/cpp/src/thrift/generate/t_rs_generator.cc index c34ed173f..30f46f227 100644 --- a/compiler/cpp/src/thrift/generate/t_rs_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_rs_generator.cc @@ -31,10 +31,9 @@ using std::string; using std::vector; using std::set; -static const string endl = "\n"; // avoid ostream << std::endl flushes - -static const string SERVICE_RESULT_VARIABLE = "result_value"; -static const string RESULT_STRUCT_SUFFIX = "Result"; +static const string endl("\n"); // avoid ostream << std::endl flushes +static const string SERVICE_RESULT_VARIABLE("result_value"); +static const string RESULT_STRUCT_SUFFIX("Result"); static const string RUST_RESERVED_WORDS[] = { "abstract", "alignof", "as", "become", "box", "break", "const", "continue", @@ -55,6 +54,9 @@ const set<string> RUST_RESERVED_WORDS_SET( RUST_RESERVED_WORDS + sizeof(RUST_RESERVED_WORDS)/sizeof(RUST_RESERVED_WORDS[0]) ); +static const string SYNC_CLIENT_GENERIC_BOUND_VARS("<IP, OP>"); +static const string SYNC_CLIENT_GENERIC_BOUNDS("where IP: TInputProtocol, OP: TOutputProtocol"); + // FIXME: extract common TMessageIdentifier function // FIXME: have to_rust_type deal with Option @@ -364,9 +366,10 @@ private: ); // Return a string containing all the unpacked service call args given a service call function - // `t_function`. Prepends the args with `&mut self` and includes the arg types in the returned string, - // for example: `fn foo(&mut self, field_0: String)`. - string rust_sync_service_call_declaration(t_function* tfunc); + // `t_function`. Prepends the args with either `&mut self` or `&self` and includes the arg types + // in the returned string, for example: + // `fn foo(&mut self, field_0: String)`. + string rust_sync_service_call_declaration(t_function* tfunc, bool self_is_mutable); // Return a string containing all the unpacked service call args given a service call function // `t_function`. Only includes the arg names, each of which is prefixed with the optional prefix @@ -512,6 +515,9 @@ void t_rs_generator::render_attributes_and_includes() { // constructors take *all* struct parameters, which can trigger the "too many arguments" warning // some auto-gen'd types can be deeply nested. clippy recommends factoring them out which is hard to autogen f_gen_ << "#![cfg_attr(feature = \"cargo-clippy\", allow(too_many_arguments, type_complexity))]" << endl; + // prevent rustfmt from running against this file + // lines are too long, code is (thankfully!) not visual-indented, etc. + f_gen_ << "#![cfg_attr(rustfmt, rustfmt_skip)]" << endl; f_gen_ << endl; // add standard includes @@ -2050,7 +2056,7 @@ void t_rs_generator::render_sync_client_trait(t_service *tservice) { for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) { t_function* tfunc = (*func_iter); string func_name = service_call_client_function_name(tfunc); - string func_args = rust_sync_service_call_declaration(tfunc); + string func_args = rust_sync_service_call_declaration(tfunc, true); string func_return = to_rust_type(tfunc->get_returntype()); render_rustdoc((t_doc*) tfunc); f_gen_ << indent() << "fn " << func_name << func_args << " -> thrift::Result<" << func_return << ">;" << endl; @@ -2069,8 +2075,14 @@ void t_rs_generator::render_sync_client_marker_trait(t_service *tservice) { void t_rs_generator::render_sync_client_marker_trait_impls(t_service *tservice, const string &impl_struct_name) { f_gen_ << indent() - << "impl " << rust_namespace(tservice) << rust_sync_client_marker_trait_name(tservice) - << " for " << impl_struct_name + << "impl " + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << rust_namespace(tservice) << rust_sync_client_marker_trait_name(tservice) + << " for " + << impl_struct_name << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS << " {}" << endl; @@ -2081,11 +2093,19 @@ void t_rs_generator::render_sync_client_marker_trait_impls(t_service *tservice, } void t_rs_generator::render_sync_client_definition_and_impl(const string& client_impl_name) { + // render the definition for the client struct - f_gen_ << "pub struct " << client_impl_name << " {" << endl; + f_gen_ + << "pub struct " + << client_impl_name + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS + << " {" + << endl; indent_up(); - f_gen_ << indent() << "_i_prot: Box<TInputProtocol>," << endl; - f_gen_ << indent() << "_o_prot: Box<TOutputProtocol>," << endl; + f_gen_ << indent() << "_i_prot: IP," << endl; + f_gen_ << indent() << "_o_prot: OP," << endl; f_gen_ << indent() << "_sequence_number: i32," << endl; indent_down(); f_gen_ << "}" << endl; @@ -2093,7 +2113,16 @@ void t_rs_generator::render_sync_client_definition_and_impl(const string& client // render the struct implementation // this includes the new() function as well as the helper send/recv methods for each service call - f_gen_ << "impl " << client_impl_name << " {" << endl; + f_gen_ + << "impl " + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << client_impl_name + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS + << " {" + << endl; indent_up(); render_sync_client_lifecycle_functions(client_impl_name); indent_down(); @@ -2104,8 +2133,9 @@ void t_rs_generator::render_sync_client_definition_and_impl(const string& client void t_rs_generator::render_sync_client_lifecycle_functions(const string& client_struct) { f_gen_ << indent() - << "pub fn new(input_protocol: Box<TInputProtocol>, output_protocol: Box<TOutputProtocol>) -> " + << "pub fn new(input_protocol: IP, output_protocol: OP) -> " << client_struct + << SYNC_CLIENT_GENERIC_BOUND_VARS << " {" << endl; indent_up(); @@ -2121,11 +2151,20 @@ void t_rs_generator::render_sync_client_lifecycle_functions(const string& client } void t_rs_generator::render_sync_client_tthriftclient_impl(const string &client_impl_name) { - f_gen_ << indent() << "impl TThriftClient for " << client_impl_name << " {" << endl; + f_gen_ + << indent() + << "impl " + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " TThriftClient for " + << client_impl_name + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS + << " {" << endl; indent_up(); - f_gen_ << indent() << "fn i_prot_mut(&mut self) -> &mut TInputProtocol { &mut *self._i_prot }" << endl; - f_gen_ << indent() << "fn o_prot_mut(&mut self) -> &mut TOutputProtocol { &mut *self._o_prot }" << endl; + f_gen_ << indent() << "fn i_prot_mut(&mut self) -> &mut TInputProtocol { &mut self._i_prot }" << endl; + f_gen_ << indent() << "fn o_prot_mut(&mut self) -> &mut TOutputProtocol { &mut self._o_prot }" << endl; f_gen_ << indent() << "fn sequence_number(&self) -> i32 { self._sequence_number }" << endl; f_gen_ << indent() @@ -2172,7 +2211,7 @@ string t_rs_generator::sync_client_marker_traits_for_extension(t_service *tservi void t_rs_generator::render_sync_send_recv_wrapper(t_function* tfunc) { string func_name = service_call_client_function_name(tfunc); - string func_decl_args = rust_sync_service_call_declaration(tfunc); + string func_decl_args = rust_sync_service_call_declaration(tfunc, true); string func_call_args = rust_sync_service_call_invocation(tfunc); string func_return = to_rust_type(tfunc->get_returntype()); @@ -2268,12 +2307,17 @@ void t_rs_generator::render_sync_recv(t_function* tfunc) { f_gen_ << indent() << "}" << endl; } -string t_rs_generator::rust_sync_service_call_declaration(t_function* tfunc) { +string t_rs_generator::rust_sync_service_call_declaration(t_function* tfunc, bool self_is_mutable) { ostringstream func_args; - func_args << "(&mut self"; + + if (self_is_mutable) { + func_args << "(&mut self"; + } else { + func_args << "(&self"; + } if (has_args(tfunc)) { - func_args << ", "; // put comma after "&mut self" + func_args << ", "; // put comma after "self" func_args << struct_to_declaration(tfunc->get_arglist(), T_ARGS); } @@ -2388,7 +2432,7 @@ void t_rs_generator::render_sync_handler_trait(t_service *tservice) { for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) { t_function* tfunc = (*func_iter); string func_name = service_call_handler_function_name(tfunc); - string func_args = rust_sync_service_call_declaration(tfunc); + string func_args = rust_sync_service_call_declaration(tfunc, false); string func_return = to_rust_type(tfunc->get_returntype()); render_rustdoc((t_doc*) tfunc); f_gen_ @@ -2472,7 +2516,7 @@ void t_rs_generator::render_sync_processor_definition_and_impl(t_service *tservi f_gen_ << indent() - << "fn process(&mut self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {" + << "fn process(&self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {" << endl; indent_up(); f_gen_ << indent() << "let message_ident = i_prot.read_message_begin()?;" << endl; @@ -2511,7 +2555,7 @@ void t_rs_generator::render_sync_process_delegation_functions(t_service *tservic f_gen_ << indent() << "fn " << function_name - << "(&mut self, " + << "(&self, " << "incoming_sequence_number: i32, " << "i_prot: &mut TInputProtocol, " << "o_prot: &mut TOutputProtocol) " @@ -2524,7 +2568,7 @@ void t_rs_generator::render_sync_process_delegation_functions(t_service *tservic << actual_processor << "::" << function_name << "(" - << "&mut self.handler, " + << "&self.handler, " << "incoming_sequence_number, " << "i_prot, " << "o_prot" @@ -2576,7 +2620,7 @@ void t_rs_generator::render_sync_process_function(t_function *tfunc, const strin << indent() << "pub fn process_" << rust_snake_case(tfunc->get_name()) << "<H: " << handler_type << ">" - << "(handler: &mut H, " + << "(handler: &H, " << sequence_number_param << ": i32, " << "i_prot: &mut TInputProtocol, " << output_protocol_param << ": &mut TOutputProtocol) " diff --git a/lib/rs/Cargo.toml b/lib/rs/Cargo.toml index 07c5e6754..be34785af 100644 --- a/lib/rs/Cargo.toml +++ b/lib/rs/Cargo.toml @@ -11,8 +11,9 @@ exclude = ["Makefile*", "test/**"] keywords = ["thrift"] [dependencies] +byteorder = "0.5.3" integer-encoding = "1.0.3" log = "~0.3.6" -byteorder = "0.5.3" +threadpool = "1.0" try_from = "0.2.0" diff --git a/lib/rs/src/autogen.rs b/lib/rs/src/autogen.rs index 289c7be9a..54d4080e8 100644 --- a/lib/rs/src/autogen.rs +++ b/lib/rs/src/autogen.rs @@ -22,7 +22,7 @@ //! to implement required functionality. Users should never have to use code //! in this module directly. -use ::protocol::{TInputProtocol, TOutputProtocol}; +use protocol::{TInputProtocol, TOutputProtocol}; /// Specifies the minimum functionality an auto-generated client should provide /// to communicate with a Thrift server. diff --git a/lib/rs/src/errors.rs b/lib/rs/src/errors.rs index a6049d5a0..e36cb3b60 100644 --- a/lib/rs/src/errors.rs +++ b/lib/rs/src/errors.rs @@ -21,7 +21,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::{error, fmt, io, string}; use try_from::TryFrom; -use ::protocol::{TFieldIdentifier, TInputProtocol, TOutputProtocol, TStructIdentifier, TType}; +use protocol::{TFieldIdentifier, TInputProtocol, TOutputProtocol, TStructIdentifier, TType}; // FIXME: should all my error structs impl error::Error as well? // FIXME: should all fields in TransportError, ProtocolError and ApplicationError be optional? @@ -198,8 +198,8 @@ impl Error { /// Create an `ApplicationError` from its wire representation. /// /// Application code **should never** call this method directly. - pub fn read_application_error_from_in_protocol(i: &mut TInputProtocol) - -> ::Result<ApplicationError> { + pub fn read_application_error_from_in_protocol(i: &mut TInputProtocol,) + -> ::Result<ApplicationError> { let mut message = "general remote error".to_owned(); let mut kind = ApplicationErrorKind::Unknown; @@ -212,7 +212,9 @@ impl Error { break; } - let id = field_ident.id.expect("sender should always specify id for non-STOP field"); + let id = field_ident + .id + .expect("sender should always specify id for non-STOP field"); match id { 1 => { @@ -222,8 +224,9 @@ impl Error { } 2 => { let remote_type_as_int = i.read_i32()?; - let remote_kind: ApplicationErrorKind = TryFrom::try_from(remote_type_as_int) - .unwrap_or(ApplicationErrorKind::Unknown); + let remote_kind: ApplicationErrorKind = + TryFrom::try_from(remote_type_as_int) + .unwrap_or(ApplicationErrorKind::Unknown); i.read_field_end()?; kind = remote_kind; } @@ -235,20 +238,23 @@ impl Error { i.read_struct_end()?; - Ok(ApplicationError { - kind: kind, - message: message, - }) + Ok( + ApplicationError { + kind: kind, + message: message, + }, + ) } /// Convert an `ApplicationError` into its wire representation and write /// it to the remote. /// /// Application code **should never** call this method directly. - pub fn write_application_error_to_out_protocol(e: &ApplicationError, - o: &mut TOutputProtocol) - -> ::Result<()> { - o.write_struct_begin(&TStructIdentifier { name: "TApplicationException".to_owned() })?; + pub fn write_application_error_to_out_protocol( + e: &ApplicationError, + o: &mut TOutputProtocol, + ) -> ::Result<()> { + o.write_struct_begin(&TStructIdentifier { name: "TApplicationException".to_owned() },)?; let message_field = TFieldIdentifier::new("message", TType::String, 1); let type_field = TFieldIdentifier::new("type", TType::I32, 2); @@ -303,19 +309,23 @@ impl Display for Error { impl From<String> for Error { fn from(s: String) -> Self { - Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: s, - }) + Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: s, + }, + ) } } impl<'a> From<&'a str> for Error { fn from(s: &'a str) -> Self { - Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: String::from(s), - }) + Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: String::from(s), + }, + ) } } @@ -418,10 +428,14 @@ impl TryFrom<i32> for TransportErrorKind { 5 => Ok(TransportErrorKind::NegativeSize), 6 => Ok(TransportErrorKind::SizeLimit), _ => { - Err(Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot convert {} to TransportErrorKind", from), - })) + Err( + Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot convert {} to TransportErrorKind", from), + }, + ), + ) } } } @@ -433,34 +447,44 @@ impl From<io::Error> for Error { io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionRefused | io::ErrorKind::NotConnected => { - Error::Transport(TransportError { - kind: TransportErrorKind::NotOpen, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::NotOpen, + message: err.description().to_owned(), + }, + ) } io::ErrorKind::AlreadyExists => { - Error::Transport(TransportError { - kind: TransportErrorKind::AlreadyOpen, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::AlreadyOpen, + message: err.description().to_owned(), + }, + ) } io::ErrorKind::TimedOut => { - Error::Transport(TransportError { - kind: TransportErrorKind::TimedOut, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::TimedOut, + message: err.description().to_owned(), + }, + ) } io::ErrorKind::UnexpectedEof => { - Error::Transport(TransportError { - kind: TransportErrorKind::EndOfFile, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::EndOfFile, + message: err.description().to_owned(), + }, + ) } _ => { - Error::Transport(TransportError { - kind: TransportErrorKind::Unknown, - message: err.description().to_owned(), // FIXME: use io error's debug string - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::Unknown, + message: err.description().to_owned(), // FIXME: use io error's debug string + }, + ) } } } @@ -468,10 +492,12 @@ impl From<io::Error> for Error { impl From<string::FromUtf8Error> for Error { fn from(err: string::FromUtf8Error) -> Self { - Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::InvalidData, - message: err.description().to_owned(), // FIXME: use fmt::Error's debug string - }) + Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: err.description().to_owned(), // FIXME: use fmt::Error's debug string + }, + ) } } @@ -558,10 +584,14 @@ impl TryFrom<i32> for ProtocolErrorKind { 5 => Ok(ProtocolErrorKind::NotImplemented), 6 => Ok(ProtocolErrorKind::DepthLimit), _ => { - Err(Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot convert {} to ProtocolErrorKind", from), - })) + Err( + Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot convert {} to ProtocolErrorKind", from), + }, + ), + ) } } } @@ -668,10 +698,14 @@ impl TryFrom<i32> for ApplicationErrorKind { 9 => Ok(ApplicationErrorKind::InvalidProtocol), 10 => Ok(ApplicationErrorKind::UnsupportedClientType), _ => { - Err(Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: format!("cannot convert {} to ApplicationErrorKind", from), - })) + Err( + Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: format!("cannot convert {} to ApplicationErrorKind", from), + }, + ), + ) } } } diff --git a/lib/rs/src/lib.rs b/lib/rs/src/lib.rs index ad1872146..7ebb10cc4 100644 --- a/lib/rs/src/lib.rs +++ b/lib/rs/src/lib.rs @@ -26,11 +26,12 @@ //! 4. server //! 5. autogen //! -//! The modules are layered as shown in the diagram below. The `generated` +//! The modules are layered as shown in the diagram below. The `autogen'd` //! layer is generated by the Thrift compiler's Rust plugin. It uses the //! types and functions defined in this crate to serialize and deserialize //! messages and implement RPC. Users interact with these types and services -//! by writing their own code on top. +//! by writing their own code that uses the auto-generated clients and +//! servers. //! //! ```text //! +-----------+ @@ -49,6 +50,7 @@ extern crate byteorder; extern crate integer_encoding; +extern crate threadpool; extern crate try_from; #[macro_use] diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs index 54613a532..e03ec9437 100644 --- a/lib/rs/src/protocol/binary.rs +++ b/lib/rs/src/protocol/binary.rs @@ -16,14 +16,11 @@ // under the License. use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt}; -use std::cell::RefCell; use std::convert::From; -use std::io::{Read, Write}; -use std::rc::Rc; use try_from::TryFrom; -use ::{ProtocolError, ProtocolErrorKind}; -use ::transport::TTransport; +use {ProtocolError, ProtocolErrorKind}; +use transport::{TReadTransport, TWriteTransport}; use super::{TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType}; use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType}; @@ -41,32 +38,35 @@ const BINARY_PROTOCOL_VERSION_1: u32 = 0x80010000; /// Create and use a `TBinaryInputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let mut i_prot = TBinaryInputProtocol::new(transport, true); +/// let mut protocol = TBinaryInputProtocol::new(channel, true); /// -/// let recvd_bool = i_prot.read_bool().unwrap(); -/// let recvd_string = i_prot.read_string().unwrap(); +/// let recvd_bool = protocol.read_bool().unwrap(); +/// let recvd_string = protocol.read_string().unwrap(); /// ``` -pub struct TBinaryInputProtocol<'a> { +#[derive(Debug)] +pub struct TBinaryInputProtocol<T> +where + T: TReadTransport, +{ strict: bool, - transport: Rc<RefCell<Box<TTransport + 'a>>>, + transport: T, } -impl<'a> TBinaryInputProtocol<'a> { +impl<'a, T> TBinaryInputProtocol<T> +where + T: TReadTransport, +{ /// Create a `TBinaryInputProtocol` that reads bytes from `transport`. /// /// Set `strict` to `true` if all incoming messages contain the protocol /// version number in the protocol header. - pub fn new(transport: Rc<RefCell<Box<TTransport + 'a>>>, - strict: bool) -> TBinaryInputProtocol<'a> { + pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> { TBinaryInputProtocol { strict: strict, transport: transport, @@ -74,11 +74,14 @@ impl<'a> TBinaryInputProtocol<'a> { } } -impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { +impl<T> TInputProtocol for TBinaryInputProtocol<T> +where + T: TReadTransport, +{ #[cfg_attr(feature = "cargo-clippy", allow(collapsible_if))] fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { let mut first_bytes = vec![0; 4]; - self.transport.borrow_mut().read_exact(&mut first_bytes[..])?; + self.transport.read_exact(&mut first_bytes[..])?; // the thrift version header is intentionally negative // so the first check we'll do is see if the sign bit is set @@ -87,10 +90,14 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { // apparently we got a protocol-version header - check // it, and if it matches, read the rest of the fields if first_bytes[0..2] != [0x80, 0x01] { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::BadVersion, - message: format!("received bad version: {:?}", &first_bytes[0..2]), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::BadVersion, + message: format!("received bad version: {:?}", &first_bytes[0..2]), + }, + ), + ) } else { let message_type: TMessageType = TryFrom::try_from(first_bytes[3])?; let name = self.read_string()?; @@ -103,17 +110,21 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { if self.strict { // we're in strict mode however, and that always // requires the protocol-version header to be written first - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::BadVersion, - message: format!("received bad version: {:?}", &first_bytes[0..2]), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::BadVersion, + message: format!("received bad version: {:?}", &first_bytes[0..2]), + }, + ), + ) } else { // in the non-strict version the first message field // is the message name. strings (byte arrays) are length-prefixed, // so we've just read the length in the first 4 bytes let name_size = BigEndian::read_i32(&first_bytes) as usize; let mut name_buf: Vec<u8> = Vec::with_capacity(name_size); - self.transport.borrow_mut().read_exact(&mut name_buf)?; + self.transport.read_exact(&mut name_buf)?; let name = String::from_utf8(name_buf)?; // read the rest of the fields @@ -143,7 +154,7 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { TType::Stop => Ok(0), _ => self.read_i16(), }?; - Ok(TFieldIdentifier::new::<Option<String>, String, i16>(None, field_type, id)) + Ok(TFieldIdentifier::new::<Option<String>, String, i16>(None, field_type, id),) } fn read_field_end(&mut self) -> ::Result<()> { @@ -151,9 +162,12 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { } fn read_bytes(&mut self) -> ::Result<Vec<u8>> { - let num_bytes = self.transport.borrow_mut().read_i32::<BigEndian>()? as usize; + let num_bytes = self.transport.read_i32::<BigEndian>()? as usize; let mut buf = vec![0u8; num_bytes]; - self.transport.borrow_mut().read_exact(&mut buf).map(|_| buf).map_err(From::from) + self.transport + .read_exact(&mut buf) + .map(|_| buf) + .map_err(From::from) } fn read_bool(&mut self) -> ::Result<bool> { @@ -165,23 +179,31 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { } fn read_i8(&mut self) -> ::Result<i8> { - self.transport.borrow_mut().read_i8().map_err(From::from) + self.transport.read_i8().map_err(From::from) } fn read_i16(&mut self) -> ::Result<i16> { - self.transport.borrow_mut().read_i16::<BigEndian>().map_err(From::from) + self.transport + .read_i16::<BigEndian>() + .map_err(From::from) } fn read_i32(&mut self) -> ::Result<i32> { - self.transport.borrow_mut().read_i32::<BigEndian>().map_err(From::from) + self.transport + .read_i32::<BigEndian>() + .map_err(From::from) } fn read_i64(&mut self) -> ::Result<i64> { - self.transport.borrow_mut().read_i64::<BigEndian>().map_err(From::from) + self.transport + .read_i64::<BigEndian>() + .map_err(From::from) } fn read_double(&mut self) -> ::Result<f64> { - self.transport.borrow_mut().read_f64::<BigEndian>().map_err(From::from) + self.transport + .read_f64::<BigEndian>() + .map_err(From::from) } fn read_string(&mut self) -> ::Result<String> { @@ -224,7 +246,7 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { // fn read_byte(&mut self) -> ::Result<u8> { - self.transport.borrow_mut().read_u8().map_err(From::from) + self.transport.read_u8().map_err(From::from) } } @@ -240,8 +262,8 @@ impl TBinaryInputProtocolFactory { } impl TInputProtocolFactory for TBinaryInputProtocolFactory { - fn create<'a>(&mut self, transport: Rc<RefCell<Box<TTransport + 'a>>>) -> Box<TInputProtocol + 'a> { - Box::new(TBinaryInputProtocol::new(transport, true)) as Box<TInputProtocol + 'a> + fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send> { + Box::new(TBinaryInputProtocol::new(transport, true)) } } @@ -256,32 +278,35 @@ impl TInputProtocolFactory for TBinaryInputProtocolFactory { /// Create and use a `TBinaryOutputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryOutputProtocol, TOutputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let mut o_prot = TBinaryOutputProtocol::new(transport, true); +/// let mut protocol = TBinaryOutputProtocol::new(channel, true); /// -/// o_prot.write_bool(true).unwrap(); -/// o_prot.write_string("test_string").unwrap(); +/// protocol.write_bool(true).unwrap(); +/// protocol.write_string("test_string").unwrap(); /// ``` -pub struct TBinaryOutputProtocol<'a> { +#[derive(Debug)] +pub struct TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ strict: bool, - transport: Rc<RefCell<Box<TTransport + 'a>>>, + pub transport: T, // FIXME: do not make public; only public for testing! } -impl<'a> TBinaryOutputProtocol<'a> { +impl<T> TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ /// Create a `TBinaryOutputProtocol` that writes bytes to `transport`. /// /// Set `strict` to `true` if all outgoing messages should contain the /// protocol version number in the protocol header. - pub fn new(transport: Rc<RefCell<Box<TTransport + 'a>>>, - strict: bool) -> TBinaryOutputProtocol<'a> { + pub fn new(transport: T, strict: bool) -> TBinaryOutputProtocol<T> { TBinaryOutputProtocol { strict: strict, transport: transport, @@ -289,16 +314,22 @@ impl<'a> TBinaryOutputProtocol<'a> { } fn write_transport(&mut self, buf: &[u8]) -> ::Result<()> { - self.transport.borrow_mut().write(buf).map(|_| ()).map_err(From::from) + self.transport + .write(buf) + .map(|_| ()) + .map_err(From::from) } } -impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { +impl<T> TOutputProtocol for TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { if self.strict { let message_type: u8 = identifier.message_type.into(); let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32); - self.transport.borrow_mut().write_u32::<BigEndian>(header)?; + self.transport.write_u32::<BigEndian>(header)?; self.write_string(&identifier.name)?; self.write_i32(identifier.sequence_number) } else { @@ -322,11 +353,17 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { if identifier.id.is_none() && identifier.field_type != TType::Stop { - return Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot write identifier {:?} without sequence number", - &identifier), - })); + return Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!( + "cannot write identifier {:?} without sequence number", + &identifier + ), + }, + ), + ); } self.write_byte(field_type_to_u8(identifier.field_type))?; @@ -359,23 +396,31 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { } fn write_i8(&mut self, i: i8) -> ::Result<()> { - self.transport.borrow_mut().write_i8(i).map_err(From::from) + self.transport.write_i8(i).map_err(From::from) } fn write_i16(&mut self, i: i16) -> ::Result<()> { - self.transport.borrow_mut().write_i16::<BigEndian>(i).map_err(From::from) + self.transport + .write_i16::<BigEndian>(i) + .map_err(From::from) } fn write_i32(&mut self, i: i32) -> ::Result<()> { - self.transport.borrow_mut().write_i32::<BigEndian>(i).map_err(From::from) + self.transport + .write_i32::<BigEndian>(i) + .map_err(From::from) } fn write_i64(&mut self, i: i64) -> ::Result<()> { - self.transport.borrow_mut().write_i64::<BigEndian>(i).map_err(From::from) + self.transport + .write_i64::<BigEndian>(i) + .map_err(From::from) } fn write_double(&mut self, d: f64) -> ::Result<()> { - self.transport.borrow_mut().write_f64::<BigEndian>(d).map_err(From::from) + self.transport + .write_f64::<BigEndian>(d) + .map_err(From::from) } fn write_string(&mut self, s: &str) -> ::Result<()> { @@ -401,10 +446,12 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { } fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { - let key_type = identifier.key_type + let key_type = identifier + .key_type .expect("map identifier to write should contain key type"); self.write_byte(field_type_to_u8(key_type))?; - let val_type = identifier.value_type + let val_type = identifier + .value_type .expect("map identifier to write should contain value type"); self.write_byte(field_type_to_u8(val_type))?; self.write_i32(identifier.size) @@ -415,14 +462,14 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { } fn flush(&mut self) -> ::Result<()> { - self.transport.borrow_mut().flush().map_err(From::from) + self.transport.flush().map_err(From::from) } // utility // fn write_byte(&mut self, b: u8) -> ::Result<()> { - self.transport.borrow_mut().write_u8(b).map_err(From::from) + self.transport.write_u8(b).map_err(From::from) } } @@ -438,8 +485,8 @@ impl TBinaryOutputProtocolFactory { } impl TOutputProtocolFactory for TBinaryOutputProtocolFactory { - fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol> { - Box::new(TBinaryOutputProtocol::new(transport, true)) as Box<TOutputProtocol> + fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send> { + Box::new(TBinaryOutputProtocol::new(transport, true)) } } @@ -481,10 +528,14 @@ fn field_type_from_u8(b: u8) -> ::Result<TType> { 0x10 => Ok(TType::Utf8), 0x11 => Ok(TType::Utf16), unkn => { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::InvalidData, - message: format!("cannot convert {} to TType", unkn), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} to TType", unkn), + }, + ), + ) } } } @@ -492,56 +543,79 @@ fn field_type_from_u8(b: u8) -> ::Result<TType> { #[cfg(test)] mod tests { - use std::rc::Rc; - use std::cell::RefCell; - - use ::protocol::{TFieldIdentifier, TMessageIdentifier, TMessageType, TInputProtocol, - TListIdentifier, TMapIdentifier, TOutputProtocol, TSetIdentifier, - TStructIdentifier, TType}; - use ::transport::{TPassThruTransport, TTransport}; - use ::transport::mem::TBufferTransport; + use protocol::{TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TMessageType, TOutputProtocol, TSetIdentifier, + TStructIdentifier, TType}; + use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf}; use super::*; #[test] fn must_write_message_call_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); let ident = TMessageIdentifier::new("test", TMessageType::Call, 1); assert!(o_prot.write_message_begin(&ident).is_ok()); - let buf = trans.borrow().write_buffer_to_vec(); - - let expected: [u8; 16] = [0x80, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x74, 0x65, - 0x73, 0x74, 0x00, 0x00, 0x00, 0x01]; - - assert_eq!(&expected, buf.as_slice()); + let expected: [u8; 16] = [ + 0x80, + 0x01, + 0x00, + 0x01, + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x00, + 0x00, + 0x00, + 0x01, + ]; + + assert_eq_written_bytes!(o_prot, expected); } - #[test] fn must_write_message_reply_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10); assert!(o_prot.write_message_begin(&ident).is_ok()); - let buf = trans.borrow().write_buffer_to_vec(); - - let expected: [u8; 16] = [0x80, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x74, 0x65, - 0x73, 0x74, 0x00, 0x00, 0x00, 0x0A]; - - assert_eq!(&expected, buf.as_slice()); + let expected: [u8; 16] = [ + 0x80, + 0x01, + 0x00, + 0x02, + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x00, + 0x00, + 0x00, + 0x0A, + ]; + + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_strict_message_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1); assert!(o_prot.write_message_begin(&sent_ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_message_begin()); assert_eq!(&received_ident, &sent_ident); @@ -564,24 +638,26 @@ mod tests { #[test] fn must_write_field_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22)) - .is_ok()); + assert!( + o_prot + .write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22)) + .is_ok() + ); let expected: [u8; 3] = [0x0B, 0x00, 0x16]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_field_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20); assert!(o_prot.write_field_begin(&sent_field_ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let expected_ident = TFieldIdentifier { name: None, @@ -594,22 +670,21 @@ mod tests { #[test] fn must_write_stop_field() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert!(o_prot.write_field_stop().is_ok()); let expected: [u8; 1] = [0x00]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_field_stop() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); assert!(o_prot.write_field_stop().is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let expected_ident = TFieldIdentifier { name: None, @@ -628,23 +703,26 @@ mod tests { #[test] fn must_write_list_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_list_begin(&TListIdentifier::new(TType::Bool, 5)).is_ok()); + assert!( + o_prot + .write_list_begin(&TListIdentifier::new(TType::Bool, 5)) + .is_ok() + ); let expected: [u8; 5] = [0x02, 0x00, 0x00, 0x00, 0x05]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_list_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TListIdentifier::new(TType::List, 900); assert!(o_prot.write_list_begin(&ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_list_begin()); assert_eq!(&received_ident, &ident); @@ -657,23 +735,26 @@ mod tests { #[test] fn must_write_set_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_set_begin(&TSetIdentifier::new(TType::I16, 7)).is_ok()); + assert!( + o_prot + .write_set_begin(&TSetIdentifier::new(TType::I16, 7)) + .is_ok() + ); let expected: [u8; 5] = [0x06, 0x00, 0x00, 0x00, 0x07]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_set_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TSetIdentifier::new(TType::I64, 2000); assert!(o_prot.write_set_begin(&ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident_result = i_prot.read_set_begin(); assert!(received_ident_result.is_ok()); @@ -687,24 +768,26 @@ mod tests { #[test] fn must_write_map_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32)) - .is_ok()); + assert!( + o_prot + .write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32)) + .is_ok() + ); let expected: [u8; 6] = [0x0A, 0x0C, 0x00, 0x00, 0x00, 0x20]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_map_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TMapIdentifier::new(TType::Map, TType::Set, 100); assert!(o_prot.write_map_begin(&ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_map_begin()); assert_eq!(&received_ident, &ident); @@ -717,31 +800,29 @@ mod tests { #[test] fn must_write_bool_true() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert!(o_prot.write_bool(true).is_ok()); let expected: [u8; 1] = [0x01]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_write_bool_false() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert!(o_prot.write_bool(false).is_ok()); let expected: [u8; 1] = [0x00]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_read_bool_true() { - let (trans, mut i_prot, _) = test_objects(); + let (mut i_prot, _) = test_objects(); - trans.borrow_mut().set_readable_bytes(&[0x01]); + set_readable_bytes!(i_prot, &[0x01]); let read_bool = assert_success!(i_prot.read_bool()); assert_eq!(read_bool, true); @@ -749,9 +830,9 @@ mod tests { #[test] fn must_read_bool_false() { - let (trans, mut i_prot, _) = test_objects(); + let (mut i_prot, _) = test_objects(); - trans.borrow_mut().set_readable_bytes(&[0x00]); + set_readable_bytes!(i_prot, &[0x00]); let read_bool = assert_success!(i_prot.read_bool()); assert_eq!(read_bool, false); @@ -759,9 +840,9 @@ mod tests { #[test] fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() { - let (trans, mut i_prot, _) = test_objects(); + let (mut i_prot, _) = test_objects(); - trans.borrow_mut().set_readable_bytes(&[0xAC]); + set_readable_bytes!(i_prot, &[0xAC]); let read_bool = assert_success!(i_prot.read_bool()); assert_eq!(read_bool, true); @@ -769,52 +850,77 @@ mod tests { #[test] fn must_write_bytes() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF]; assert!(o_prot.write_bytes(&bytes).is_ok()); - let buf = trans.borrow().write_buffer_to_vec(); + let buf = o_prot.transport.write_bytes(); assert_eq!(&buf[0..4], [0x00, 0x00, 0x00, 0x0A]); // length assert_eq!(&buf[4..], bytes); // actual bytes } #[test] fn must_round_trip_bytes() { - let (trans, mut i_prot, mut o_prot) = test_objects(); - - let bytes: [u8; 25] = [0x20, 0xFD, 0x18, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF, 0x34, - 0xDC, 0x98, 0xA4, 0x6D, 0xF3, 0x99, 0xB4, 0xB7, 0xD4, 0x9C, 0xA5, - 0xB3, 0xC9, 0x88]; + let (mut i_prot, mut o_prot) = test_objects(); + + let bytes: [u8; 25] = [ + 0x20, + 0xFD, + 0x18, + 0x84, + 0x99, + 0x12, + 0xAB, + 0xBB, + 0x45, + 0xDF, + 0x34, + 0xDC, + 0x98, + 0xA4, + 0x6D, + 0xF3, + 0x99, + 0xB4, + 0xB7, + 0xD4, + 0x9C, + 0xA5, + 0xB3, + 0xC9, + 0x88, + ]; assert!(o_prot.write_bytes(&bytes).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_bytes = assert_success!(i_prot.read_bytes()); assert_eq!(&received_bytes, &bytes); } - fn test_objects<'a> - () - -> (Rc<RefCell<Box<TBufferTransport>>>, TBinaryInputProtocol<'a>, TBinaryOutputProtocol<'a>) + fn test_objects() + -> (TBinaryInputProtocol<ReadHalf<TBufferChannel>>, + TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) { + let mem = TBufferChannel::with_capacity(40, 40); - let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(40, 40)))); + let (r_mem, w_mem) = mem.split().unwrap(); - let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() }); - let inner = Rc::new(RefCell::new(inner)); + let i_prot = TBinaryInputProtocol::new(r_mem, true); + let o_prot = TBinaryOutputProtocol::new(w_mem, true); - let i_prot = TBinaryInputProtocol::new(inner.clone(), true); - let o_prot = TBinaryOutputProtocol::new(inner.clone(), true); - - (mem, i_prot, o_prot) + (i_prot, o_prot) } - fn assert_no_write<F: FnMut(&mut TBinaryOutputProtocol) -> ::Result<()>>(mut write_fn: F) { - let (trans, _, mut o_prot) = test_objects(); + fn assert_no_write<F>(mut write_fn: F) + where + F: FnMut(&mut TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>, + { + let (_, mut o_prot) = test_objects(); assert!(write_fn(&mut o_prot).is_ok()); - assert_eq!(trans.borrow().write_buffer_as_ref().len(), 0); + assert_eq!(o_prot.transport.write_bytes().len(), 0); } } diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs index 353514d30..dfe11f852 100644 --- a/lib/rs/src/protocol/compact.rs +++ b/lib/rs/src/protocol/compact.rs @@ -17,15 +17,12 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use integer_encoding::{VarIntReader, VarIntWriter}; -use std::cell::RefCell; use std::convert::From; -use std::rc::Rc; -use std::io::{Read, Write}; use try_from::TryFrom; -use ::transport::TTransport; -use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType, - TInputProtocol, TInputProtocolFactory}; +use transport::{TReadTransport, TWriteTransport}; +use super::{TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, + TMapIdentifier, TMessageIdentifier, TMessageType}; use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType}; const COMPACT_PROTOCOL_ID: u8 = 0x82; @@ -39,21 +36,22 @@ const COMPACT_VERSION_MASK: u8 = 0x1F; /// Create and use a `TCompactInputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TCompactInputProtocol, TInputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let mut i_prot = TCompactInputProtocol::new(transport); +/// let mut protocol = TCompactInputProtocol::new(channel); /// -/// let recvd_bool = i_prot.read_bool().unwrap(); -/// let recvd_string = i_prot.read_string().unwrap(); +/// let recvd_bool = protocol.read_bool().unwrap(); +/// let recvd_string = protocol.read_string().unwrap(); /// ``` -pub struct TCompactInputProtocol<'a> { +#[derive(Debug)] +pub struct TCompactInputProtocol<T> +where + T: TReadTransport, +{ // Identifier of the last field deserialized for a struct. last_read_field_id: i16, // Stack of the last read field ids (a new entry is added each time a nested struct is read). @@ -63,12 +61,15 @@ pub struct TCompactInputProtocol<'a> { // and reading the field only occurs after the field id is read. pending_read_bool_value: Option<bool>, // Underlying transport used for byte-level operations. - transport: Rc<RefCell<Box<TTransport + 'a>>>, + transport: T, } -impl<'a> TCompactInputProtocol<'a> { +impl<T> TCompactInputProtocol<T> +where + T: TReadTransport, +{ /// Create a `TCompactInputProtocol` that reads bytes from `transport`. - pub fn new(transport: Rc<RefCell<Box<TTransport + 'a>>>) -> TCompactInputProtocol<'a> { + pub fn new(transport: T) -> TCompactInputProtocol<T> { TCompactInputProtocol { last_read_field_id: 0, read_field_id_stack: Vec::new(), @@ -87,21 +88,28 @@ impl<'a> TCompactInputProtocol<'a> { // high bits set high if count and type encoded separately element_count = possible_element_count as i32; } else { - element_count = self.transport.borrow_mut().read_varint::<u32>()? as i32; + element_count = self.transport.read_varint::<u32>()? as i32; } Ok((element_type, element_count)) } } -impl<'a> TInputProtocol for TCompactInputProtocol<'a> { +impl<T> TInputProtocol for TCompactInputProtocol<T> +where + T: TReadTransport, +{ fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { let compact_id = self.read_byte()?; if compact_id != COMPACT_PROTOCOL_ID { - Err(::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::BadVersion, - message: format!("invalid compact protocol header {:?}", compact_id), - })) + Err( + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::BadVersion, + message: format!("invalid compact protocol header {:?}", compact_id), + }, + ), + ) } else { Ok(()) }?; @@ -109,11 +117,17 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { let type_and_byte = self.read_byte()?; let received_version = type_and_byte & COMPACT_VERSION_MASK; if received_version != COMPACT_VERSION { - Err(::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::BadVersion, - message: format!("cannot process compact protocol version {:?}", - received_version), - })) + Err( + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::BadVersion, + message: format!( + "cannot process compact protocol version {:?}", + received_version + ), + }, + ), + ) } else { Ok(()) }?; @@ -125,7 +139,7 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { self.last_read_field_id = 0; - Ok(TMessageIdentifier::new(service_call_name, message_type, sequence_number)) + Ok(TMessageIdentifier::new(service_call_name, message_type, sequence_number),) } fn read_message_end(&mut self) -> ::Result<()> { @@ -165,9 +179,13 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { match field_type { TType::Stop => { - Ok(TFieldIdentifier::new::<Option<String>, String, Option<i16>>(None, - TType::Stop, - None)) + Ok( + TFieldIdentifier::new::<Option<String>, String, Option<i16>>( + None, + TType::Stop, + None, + ), + ) } _ => { if field_delta != 0 { @@ -176,11 +194,13 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { self.last_read_field_id = self.read_i16()?; }; - Ok(TFieldIdentifier { - name: None, - field_type: field_type, - id: Some(self.last_read_field_id), - }) + Ok( + TFieldIdentifier { + name: None, + field_type: field_type, + id: Some(self.last_read_field_id), + }, + ) } } } @@ -198,10 +218,14 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { 0x01 => Ok(true), 0x02 => Ok(false), unkn => { - Err(::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::InvalidData, - message: format!("cannot convert {} into bool", unkn), - })) + Err( + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into bool", unkn), + }, + ), + ) } } } @@ -209,9 +233,12 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { } fn read_bytes(&mut self) -> ::Result<Vec<u8>> { - let len = self.transport.borrow_mut().read_varint::<u32>()?; + let len = self.transport.read_varint::<u32>()?; let mut buf = vec![0u8; len as usize]; - self.transport.borrow_mut().read_exact(&mut buf).map_err(From::from).map(|_| buf) + self.transport + .read_exact(&mut buf) + .map_err(From::from) + .map(|_| buf) } fn read_i8(&mut self) -> ::Result<i8> { @@ -219,19 +246,21 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { } fn read_i16(&mut self) -> ::Result<i16> { - self.transport.borrow_mut().read_varint::<i16>().map_err(From::from) + self.transport.read_varint::<i16>().map_err(From::from) } fn read_i32(&mut self) -> ::Result<i32> { - self.transport.borrow_mut().read_varint::<i32>().map_err(From::from) + self.transport.read_varint::<i32>().map_err(From::from) } fn read_i64(&mut self) -> ::Result<i64> { - self.transport.borrow_mut().read_varint::<i64>().map_err(From::from) + self.transport.read_varint::<i64>().map_err(From::from) } fn read_double(&mut self) -> ::Result<f64> { - self.transport.borrow_mut().read_f64::<BigEndian>().map_err(From::from) + self.transport + .read_f64::<BigEndian>() + .map_err(From::from) } fn read_string(&mut self) -> ::Result<String> { @@ -258,7 +287,7 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { } fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> { - let element_count = self.transport.borrow_mut().read_varint::<u32>()? as i32; + let element_count = self.transport.read_varint::<u32>()? as i32; if element_count == 0 { Ok(TMapIdentifier::new(None, None, 0)) } else { @@ -278,7 +307,10 @@ impl<'a> TInputProtocol for TCompactInputProtocol<'a> { fn read_byte(&mut self) -> ::Result<u8> { let mut buf = [0u8; 1]; - self.transport.borrow_mut().read_exact(&mut buf).map_err(From::from).map(|_| buf[0]) + self.transport + .read_exact(&mut buf) + .map_err(From::from) + .map(|_| buf[0]) } } @@ -294,8 +326,8 @@ impl TCompactInputProtocolFactory { } impl TInputProtocolFactory for TCompactInputProtocolFactory { - fn create<'a>(&mut self, transport: Rc<RefCell<Box<TTransport + 'a>>>) -> Box<TInputProtocol + 'a> { - Box::new(TCompactInputProtocol::new(transport)) as Box<TInputProtocol + 'a> + fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send> { + Box::new(TCompactInputProtocol::new(transport)) } } @@ -306,35 +338,39 @@ impl TInputProtocolFactory for TCompactInputProtocolFactory { /// Create and use a `TCompactOutputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let mut o_prot = TCompactOutputProtocol::new(transport); +/// let mut protocol = TCompactOutputProtocol::new(channel); /// -/// o_prot.write_bool(true).unwrap(); -/// o_prot.write_string("test_string").unwrap(); +/// protocol.write_bool(true).unwrap(); +/// protocol.write_string("test_string").unwrap(); /// ``` -pub struct TCompactOutputProtocol<'a> { +#[derive(Debug)] +pub struct TCompactOutputProtocol<T> +where + T: TWriteTransport, +{ // Identifier of the last field serialized for a struct. last_write_field_id: i16, - // Stack of the last written field ids (a new entry is added each time a nested struct is written). + // Stack of the last written field ids (new entry added each time a nested struct is written). write_field_id_stack: Vec<i16>, // Field identifier of the boolean field to be written. // Saved because boolean fields and their value are encoded in a single byte pending_write_bool_field_identifier: Option<TFieldIdentifier>, // Underlying transport used for byte-level operations. - transport: Rc<RefCell<Box<TTransport + 'a>>>, + transport: T, } -impl<'a> TCompactOutputProtocol<'a> { +impl<T> TCompactOutputProtocol<T> +where + T: TWriteTransport, +{ /// Create a `TCompactOutputProtocol` that writes bytes to `transport`. - pub fn new(transport: Rc<RefCell<Box<TTransport + 'a>>>) -> TCompactOutputProtocol<'a> { + pub fn new(transport: T) -> TCompactOutputProtocol<T> { TCompactOutputProtocol { last_write_field_id: 0, write_field_id_stack: Vec::new(), @@ -365,7 +401,6 @@ impl<'a> TCompactOutputProtocol<'a> { let header = 0xF0 | elem_identifier; self.write_byte(header)?; self.transport - .borrow_mut() .write_varint(element_count as u32) .map_err(From::from) .map(|_| ()) @@ -379,7 +414,10 @@ impl<'a> TCompactOutputProtocol<'a> { } } -impl<'a> TOutputProtocol for TCompactOutputProtocol<'a> { +impl<T> TOutputProtocol for TCompactOutputProtocol<T> +where + T: TWriteTransport, +{ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { self.write_byte(COMPACT_PROTOCOL_ID)?; self.write_byte((u8::from(identifier.message_type) << 5) | COMPACT_VERSION)?; @@ -401,8 +439,9 @@ impl<'a> TOutputProtocol for TCompactOutputProtocol<'a> { fn write_struct_end(&mut self) -> ::Result<()> { self.assert_no_pending_bool_write(); - self.last_write_field_id = - self.write_field_id_stack.pop().expect("should have previous field ids"); + self.last_write_field_id = self.write_field_id_stack + .pop() + .expect("should have previous field ids"); Ok(()) } @@ -410,16 +449,20 @@ impl<'a> TOutputProtocol for TCompactOutputProtocol<'a> { match identifier.field_type { TType::Bool => { if self.pending_write_bool_field_identifier.is_some() { - panic!("should not have a pending bool while writing another bool with id: \ + panic!( + "should not have a pending bool while writing another bool with id: \ {:?}", - identifier) + identifier + ) } self.pending_write_bool_field_identifier = Some(identifier.clone()); Ok(()) } _ => { let field_type = type_to_u8(identifier.field_type); - let field_id = identifier.id.expect("non-stop field should have field id"); + let field_id = identifier + .id + .expect("non-stop field should have field id"); self.write_field_header(field_type, field_id) } } @@ -453,8 +496,8 @@ impl<'a> TOutputProtocol for TCompactOutputProtocol<'a> { } fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> { - self.transport.borrow_mut().write_varint(b.len() as u32)?; - self.transport.borrow_mut().write_all(b).map_err(From::from) + self.transport.write_varint(b.len() as u32)?; + self.transport.write_all(b).map_err(From::from) } fn write_i8(&mut self, i: i8) -> ::Result<()> { @@ -462,19 +505,30 @@ impl<'a> TOutputProtocol for TCompactOutputProtocol<'a> { } fn write_i16(&mut self, i: i16) -> ::Result<()> { - self.transport.borrow_mut().write_varint(i).map_err(From::from).map(|_| ()) + self.transport + .write_varint(i) + .map_err(From::from) + .map(|_| ()) } fn write_i32(&mut self, i: i32) -> ::Result<()> { - self.transport.borrow_mut().write_varint(i).map_err(From::from).map(|_| ()) + self.transport + .write_varint(i) + .map_err(From::from) + .map(|_| ()) } fn write_i64(&mut self, i: i64) -> ::Result<()> { - self.transport.borrow_mut().write_varint(i).map_err(From::from).map(|_| ()) + self.transport + .write_varint(i) + .map_err(From::from) + .map(|_| ()) } fn write_double(&mut self, d: f64) -> ::Result<()> { - self.transport.borrow_mut().write_f64::<BigEndian>(d).map_err(From::from) + self.transport + .write_f64::<BigEndian>(d) + .map_err(From::from) } fn write_string(&mut self, s: &str) -> ::Result<()> { @@ -501,13 +555,15 @@ impl<'a> TOutputProtocol for TCompactOutputProtocol<'a> { if identifier.size == 0 { self.write_byte(0) } else { - self.transport.borrow_mut().write_varint(identifier.size as u32)?; + self.transport.write_varint(identifier.size as u32)?; - let key_type = identifier.key_type + let key_type = identifier + .key_type .expect("map identifier to write should contain key type"); let key_type_byte = collection_type_to_u8(key_type) << 4; - let val_type = identifier.value_type + let val_type = identifier + .value_type .expect("map identifier to write should contain value type"); let val_type_byte = collection_type_to_u8(val_type); @@ -521,14 +577,17 @@ impl<'a> TOutputProtocol for TCompactOutputProtocol<'a> { } fn flush(&mut self) -> ::Result<()> { - self.transport.borrow_mut().flush().map_err(From::from) + self.transport.flush().map_err(From::from) } // utility // fn write_byte(&mut self, b: u8) -> ::Result<()> { - self.transport.borrow_mut().write(&[b]).map_err(From::from).map(|_| ()) + self.transport + .write(&[b]) + .map_err(From::from) + .map(|_| ()) } } @@ -544,8 +603,8 @@ impl TCompactOutputProtocolFactory { } impl TOutputProtocolFactory for TCompactOutputProtocolFactory { - fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol> { - Box::new(TCompactOutputProtocol::new(transport)) as Box<TOutputProtocol> + fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send> { + Box::new(TCompactOutputProtocol::new(transport)) } } @@ -594,10 +653,14 @@ fn u8_to_type(b: u8) -> ::Result<TType> { 0x0B => Ok(TType::Map), 0x0C => Ok(TType::Struct), unkn => { - Err(::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::InvalidData, - message: format!("cannot convert {} into TType", unkn), - })) + Err( + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into TType", unkn), + }, + ), + ) } } } @@ -605,54 +668,65 @@ fn u8_to_type(b: u8) -> ::Result<TType> { #[cfg(test)] mod tests { - use std::rc::Rc; - use std::cell::RefCell; - - use ::protocol::{TFieldIdentifier, TMessageIdentifier, TMessageType, TInputProtocol, - TListIdentifier, TMapIdentifier, TOutputProtocol, TSetIdentifier, - TStructIdentifier, TType}; - use ::transport::{TPassThruTransport, TTransport}; - use ::transport::mem::TBufferTransport; + use protocol::{TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TMessageType, TOutputProtocol, TSetIdentifier, + TStructIdentifier, TType}; + use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf}; use super::*; #[test] fn must_write_message_begin_0() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new("foo", TMessageType::Call, 431))); - let expected: [u8; 8] = - [0x82 /* protocol ID */, 0x21 /* message type | protocol version */, 0xDE, - 0x06 /* zig-zag varint sequence number */, 0x03 /* message-name length */, - 0x66, 0x6F, 0x6F /* "foo" */]; + let expected: [u8; 8] = [ + 0x82, /* protocol ID */ + 0x21, /* message type | protocol version */ + 0xDE, + 0x06, /* zig-zag varint sequence number */ + 0x03, /* message-name length */ + 0x66, + 0x6F, + 0x6F /* "foo" */, + ]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_write_message_begin_1() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new("bar", TMessageType::Reply, 991828))); + assert_success!( + o_prot.write_message_begin(&TMessageIdentifier::new("bar", TMessageType::Reply, 991828)) + ); - let expected: [u8; 9] = - [0x82 /* protocol ID */, 0x41 /* message type | protocol version */, 0xA8, - 0x89, 0x79 /* zig-zag varint sequence number */, - 0x03 /* message-name length */, 0x62, 0x61, 0x72 /* "bar" */]; + let expected: [u8; 9] = [ + 0x82, /* protocol ID */ + 0x41, /* message type | protocol version */ + 0xA8, + 0x89, + 0x79, /* zig-zag varint sequence number */ + 0x03, /* message-name length */ + 0x62, + 0x61, + 0x72 /* "bar" */, + ]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_message_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1283948); assert_success!(o_prot.write_message_begin(&ident)); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_message_begin()); assert_eq!(&res, &ident); @@ -668,7 +742,7 @@ mod tests { #[test] fn must_write_struct_with_delta_fields() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -692,20 +766,20 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); - - let expected: [u8; 5] = [0x03 /* field type */, 0x00 /* first field id */, - 0x44 /* field delta (4) | field type */, - 0x59 /* field delta (5) | field type */, - 0x00 /* field stop */]; + let expected: [u8; 5] = [ + 0x03, /* field type */ + 0x00, /* first field id */ + 0x44, /* field delta (4) | field type */ + 0x59, /* field delta (5) | field type */ + 0x00 /* field stop */, + ]; - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_struct_with_delta_fields() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -732,40 +806,57 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read the struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } #[test] fn must_write_struct_with_non_zero_initial_field_and_delta_fields() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -789,20 +880,19 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); + let expected: [u8; 4] = [ + 0x15, /* field delta (1) | field type */ + 0x1A, /* field delta (1) | field type */ + 0x48, /* field delta (4) | field type */ + 0x00 /* field stop */, + ]; - let expected: [u8; 4] = [0x15 /* field delta (1) | field type */, - 0x1A /* field delta (1) | field type */, - 0x48 /* field delta (4) | field type */, - 0x00 /* field stop */]; - - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_struct_with_non_zero_initial_field_and_delta_fields() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -829,40 +919,57 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read the struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } #[test] fn must_write_struct_with_long_fields() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -885,21 +992,23 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); - - let expected: [u8; 8] = - [0x05 /* field type */, 0x00 /* first field id */, - 0x06 /* field type */, 0x20 /* zig-zag varint field id */, - 0x0A /* field type */, 0xC6, 0x01 /* zig-zag varint field id */, - 0x00 /* field stop */]; + let expected: [u8; 8] = [ + 0x05, /* field type */ + 0x00, /* first field id */ + 0x06, /* field type */ + 0x20, /* zig-zag varint field id */ + 0x0A, /* field type */ + 0xC6, + 0x01, /* zig-zag varint field id */ + 0x00 /* field stop */, + ]; - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_struct_with_long_fields() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -925,40 +1034,57 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read the struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } #[test] fn must_write_struct_with_mix_of_long_and_delta_fields() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -989,22 +1115,25 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); - - let expected: [u8; 10] = - [0x16 /* field delta (1) | field type */, - 0x85 /* field delta (8) | field type */, 0x0A /* field type */, 0xD0, - 0x0F /* zig-zag varint field id */, 0x0A /* field type */, 0xA2, - 0x1F /* zig-zag varint field id */, - 0x3A /* field delta (3) | field type */, 0x00 /* field stop */]; + let expected: [u8; 10] = [ + 0x16, /* field delta (1) | field type */ + 0x85, /* field delta (8) | field type */ + 0x0A, /* field type */ + 0xD0, + 0x0F, /* zig-zag varint field id */ + 0x0A, /* field type */ + 0xA2, + 0x1F, /* zig-zag varint field id */ + 0x3A, /* field delta (3) | field type */ + 0x00 /* field stop */, + ]; - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_struct_with_mix_of_long_and_delta_fields() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // no bytes should be written however let struct_ident = TStructIdentifier::new("foo"); @@ -1041,43 +1170,70 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read the struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { name: None, ..field_ident_4 }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_5 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_5, - TFieldIdentifier { name: None, ..field_ident_5 }); + assert_eq!( + read_ident_5, + TFieldIdentifier { + name: None, + ..field_ident_5 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_6 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_6, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } @@ -1087,7 +1243,7 @@ mod tests { // last field of the containing struct is a delta // first field of the the contained struct is a delta - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1123,17 +1279,17 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); + let expected: [u8; 7] = [ + 0x16, /* field delta (1) | field type */ + 0x85, /* field delta (8) | field type */ + 0x73, /* field delta (7) | field type */ + 0x07, /* field type */ + 0x30, /* zig-zag varint field id */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; - let expected: [u8; 7] = - [0x16 /* field delta (1) | field type */, - 0x85 /* field delta (8) | field type */, - 0x73 /* field delta (7) | field type */, 0x07 /* field type */, - 0x30 /* zig-zag varint field id */, 0x00 /* field stop - contained */, - 0x00 /* field stop - containing */]; - - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] @@ -1141,7 +1297,7 @@ mod tests { // last field of the containing struct is a delta // first field of the the contained struct is a delta - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1181,52 +1337,76 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read containing struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); // read contained struct back assert_success!(i_prot.read_struct_begin()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { name: None, ..field_ident_4 }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); assert_success!(i_prot.read_field_end()); // end contained struct let read_ident_6 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_6, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); // end containing struct let read_ident_7 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_7, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } @@ -1235,7 +1415,7 @@ mod tests { // last field of the containing struct is a delta // first field of the the contained struct is a full write - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1271,17 +1451,17 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); + let expected: [u8; 7] = [ + 0x16, /* field delta (1) | field type */ + 0x85, /* field delta (8) | field type */ + 0x07, /* field type */ + 0x30, /* zig-zag varint field id */ + 0x33, /* field delta (3) | field type */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; - let expected: [u8; 7] = - [0x16 /* field delta (1) | field type */, - 0x85 /* field delta (8) | field type */, 0x07 /* field type */, - 0x30 /* zig-zag varint field id */, - 0x33 /* field delta (3) | field type */, 0x00 /* field stop - contained */, - 0x00 /* field stop - containing */]; - - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] @@ -1289,7 +1469,7 @@ mod tests { // last field of the containing struct is a delta // first field of the the contained struct is a full write - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1329,52 +1509,76 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read containing struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); // read contained struct back assert_success!(i_prot.read_struct_begin()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { name: None, ..field_ident_4 }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); assert_success!(i_prot.read_field_end()); // end contained struct let read_ident_6 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_6, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); // end containing struct let read_ident_7 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_7, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } @@ -1383,7 +1587,7 @@ mod tests { // last field of the containing struct is a full write // first field of the the contained struct is a delta write - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1419,21 +1623,22 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); - - let expected: [u8; 7] = - [0x16 /* field delta (1) | field type */, 0x08 /* field type */, - 0x2A /* zig-zag varint field id */, 0x77 /* field delta(7) | field type */, - 0x33 /* field delta (3) | field type */, 0x00 /* field stop - contained */, - 0x00 /* field stop - containing */]; + let expected: [u8; 7] = [ + 0x16, /* field delta (1) | field type */ + 0x08, /* field type */ + 0x2A, /* zig-zag varint field id */ + 0x77, /* field delta(7) | field type */ + 0x33, /* field delta (3) | field type */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_nested_structs_2() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1473,52 +1678,76 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read containing struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); // read contained struct back assert_success!(i_prot.read_struct_begin()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { name: None, ..field_ident_4 }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); assert_success!(i_prot.read_field_end()); // end contained struct let read_ident_6 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_6, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); // end containing struct let read_ident_7 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_7, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } @@ -1527,7 +1756,7 @@ mod tests { // last field of the containing struct is a full write // first field of the the contained struct is a full write - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1563,17 +1792,18 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); - - let expected: [u8; 8] = - [0x16 /* field delta (1) | field type */, 0x08 /* field type */, - 0x2A /* zig-zag varint field id */, 0x07 /* field type */, - 0x2A /* zig-zag varint field id */, - 0x63 /* field delta (6) | field type */, 0x00 /* field stop - contained */, - 0x00 /* field stop - containing */]; + let expected: [u8; 8] = [ + 0x16, /* field delta (1) | field type */ + 0x08, /* field type */ + 0x2A, /* zig-zag varint field id */ + 0x07, /* field type */ + 0x2A, /* zig-zag varint field id */ + 0x63, /* field delta (6) | field type */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] @@ -1581,7 +1811,7 @@ mod tests { // last field of the containing struct is a full write // first field of the the contained struct is a full write - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // start containing struct assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1621,58 +1851,82 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read containing struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); assert_success!(i_prot.read_field_end()); // read contained struct back assert_success!(i_prot.read_struct_begin()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { name: None, ..field_ident_4 }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); assert_success!(i_prot.read_field_end()); // end contained struct let read_ident_6 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_6, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); // end containing struct let read_ident_7 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_7, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } #[test] fn must_write_bool_field() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); // no bytes should be written however assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); @@ -1703,20 +1957,22 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - // get bytes written - let buf = trans.borrow_mut().write_buffer_to_vec(); + let expected: [u8; 7] = [ + 0x11, /* field delta (1) | true */ + 0x82, /* field delta (8) | false */ + 0x01, /* true */ + 0x34, /* field id */ + 0x02, /* false */ + 0x5A, /* field id */ + 0x00 /* stop field */, + ]; - let expected: [u8; 7] = [0x11 /* field delta (1) | true */, - 0x82 /* field delta (8) | false */, 0x01 /* true */, - 0x34 /* field id */, 0x02 /* false */, - 0x5A /* field id */, 0x00 /* stop field */]; - - assert_eq!(&buf, &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_bool_field() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); // no bytes should be written however let struct_ident = TStructIdentifier::new("foo"); @@ -1752,46 +2008,68 @@ mod tests { assert_success!(o_prot.write_field_stop()); assert_success!(o_prot.write_struct_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // read the struct back assert_success!(i_prot.read_struct_begin()); let read_ident_1 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_1, - TFieldIdentifier { name: None, ..field_ident_1 }); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); let read_value_1 = assert_success!(i_prot.read_bool()); assert_eq!(read_value_1, true); assert_success!(i_prot.read_field_end()); let read_ident_2 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_2, - TFieldIdentifier { name: None, ..field_ident_2 }); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); let read_value_2 = assert_success!(i_prot.read_bool()); assert_eq!(read_value_2, false); assert_success!(i_prot.read_field_end()); let read_ident_3 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_3, - TFieldIdentifier { name: None, ..field_ident_3 }); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); let read_value_3 = assert_success!(i_prot.read_bool()); assert_eq!(read_value_3, true); assert_success!(i_prot.read_field_end()); let read_ident_4 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_4, - TFieldIdentifier { name: None, ..field_ident_4 }); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); let read_value_4 = assert_success!(i_prot.read_bool()); assert_eq!(read_value_4, false); assert_success!(i_prot.read_field_end()); let read_ident_5 = assert_success!(i_prot.read_field_begin()); - assert_eq!(read_ident_5, - TFieldIdentifier { - name: None, - field_type: TType::Stop, - id: None, - }); + assert_eq!( + read_ident_5, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); assert_success!(i_prot.read_struct_end()); } @@ -1799,7 +2077,7 @@ mod tests { #[test] #[should_panic] fn must_fail_if_write_field_end_without_writing_bool_value() { - let (_, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1))); o_prot.write_field_end().unwrap(); @@ -1808,7 +2086,7 @@ mod tests { #[test] #[should_panic] fn must_fail_if_write_stop_field_without_writing_bool_value() { - let (_, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1))); o_prot.write_field_stop().unwrap(); @@ -1817,7 +2095,7 @@ mod tests { #[test] #[should_panic] fn must_fail_if_write_struct_end_without_writing_bool_value() { - let (_, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1))); o_prot.write_struct_end().unwrap(); @@ -1826,7 +2104,7 @@ mod tests { #[test] #[should_panic] fn must_fail_if_write_struct_end_without_any_fields() { - let (_, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); o_prot.write_struct_end().unwrap(); } @@ -1837,24 +2115,24 @@ mod tests { #[test] fn must_write_small_sized_list_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_list_begin(&TListIdentifier::new(TType::I64, 4))); let expected: [u8; 1] = [0x46 /* size | elem_type */]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_small_sized_list_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TListIdentifier::new(TType::I08, 10); assert_success!(o_prot.write_list_begin(&ident)); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_list_begin()); assert_eq!(&res, &ident); @@ -1862,26 +2140,29 @@ mod tests { #[test] fn must_write_large_sized_list_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); let res = o_prot.write_list_begin(&TListIdentifier::new(TType::List, 9999)); assert!(res.is_ok()); - let expected: [u8; 3] = [0xF9 /* 0xF0 | elem_type */, 0x8F, - 0x4E /* size as varint */]; + let expected: [u8; 3] = [ + 0xF9, /* 0xF0 | elem_type */ + 0x8F, + 0x4E /* size as varint */, + ]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_large_sized_list_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TListIdentifier::new(TType::Set, 47381); assert_success!(o_prot.write_list_begin(&ident)); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_list_begin()); assert_eq!(&res, &ident); @@ -1894,24 +2175,24 @@ mod tests { #[test] fn must_write_small_sized_set_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_set_begin(&TSetIdentifier::new(TType::Struct, 2))); let expected: [u8; 1] = [0x2C /* size | elem_type */]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_small_sized_set_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TSetIdentifier::new(TType::I16, 7); assert_success!(o_prot.write_set_begin(&ident)); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_set_begin()); assert_eq!(&res, &ident); @@ -1919,25 +2200,29 @@ mod tests { #[test] fn must_write_large_sized_set_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_set_begin(&TSetIdentifier::new(TType::Double, 23891))); - let expected: [u8; 4] = [0xF7 /* 0xF0 | elem_type */, 0xD3, 0xBA, - 0x01 /* size as varint */]; + let expected: [u8; 4] = [ + 0xF7, /* 0xF0 | elem_type */ + 0xD3, + 0xBA, + 0x01 /* size as varint */, + ]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_large_sized_set_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TSetIdentifier::new(TType::Map, 3928429); assert_success!(o_prot.write_set_begin(&ident)); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_set_begin()); assert_eq!(&res, &ident); @@ -1950,53 +2235,58 @@ mod tests { #[test] fn must_write_zero_sized_map_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::String, TType::I32, 0))); let expected: [u8; 1] = [0x00]; // since size is zero we don't write anything - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_read_zero_sized_map_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Double, TType::I32, 0))); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_map_begin()); - assert_eq!(&res, - &TMapIdentifier { - key_type: None, - value_type: None, - size: 0, - }); + assert_eq!( + &res, + &TMapIdentifier { + key_type: None, + value_type: None, + size: 0, + } + ); } #[test] fn must_write_map_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Double, TType::String, 238))); - let expected: [u8; 3] = [0xEE, 0x01 /* size as varint */, - 0x78 /* key type | val type */]; + let expected: [u8; 3] = [ + 0xEE, + 0x01, /* size as varint */ + 0x78 /* key type | val type */, + ]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_map_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TMapIdentifier::new(TType::Map, TType::List, 1928349); assert_success!(o_prot.write_map_begin(&ident)); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_map_begin()); assert_eq!(&res, &ident); @@ -2009,23 +2299,26 @@ mod tests { #[test] fn must_write_map_with_bool_key_and_value() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Bool, TType::Bool, 1))); assert_success!(o_prot.write_bool(true)); assert_success!(o_prot.write_bool(false)); assert_success!(o_prot.write_map_end()); - let expected: [u8; 4] = [0x01 /* size as varint */, - 0x11 /* key type | val type */, 0x01 /* key: true */, - 0x02 /* val: false */]; + let expected: [u8; 4] = [ + 0x01, /* size as varint */ + 0x11, /* key type | val type */ + 0x01, /* key: true */ + 0x02 /* val: false */, + ]; - assert_eq!(trans.borrow().write_buffer_as_ref(), &expected); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_map_with_bool_value() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let map_ident = TMapIdentifier::new(TType::Bool, TType::Bool, 2); assert_success!(o_prot.write_map_begin(&map_ident)); @@ -2035,7 +2328,7 @@ mod tests { assert_success!(o_prot.write_bool(true)); assert_success!(o_prot.write_map_end()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); // map header let rcvd_ident = assert_success!(i_prot.read_map_begin()); @@ -2058,28 +2351,30 @@ mod tests { #[test] fn must_read_map_end() { - let (_, mut i_prot, _) = test_objects(); + let (mut i_prot, _) = test_objects(); assert!(i_prot.read_map_end().is_ok()); // will blow up if we try to read from empty buffer } - fn test_objects<'a> - () - -> (Rc<RefCell<Box<TBufferTransport>>>, TCompactInputProtocol<'a>, TCompactOutputProtocol<'a>) + fn test_objects() + -> (TCompactInputProtocol<ReadHalf<TBufferChannel>>, + TCompactOutputProtocol<WriteHalf<TBufferChannel>>) { - let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(80, 80)))); + let mem = TBufferChannel::with_capacity(80, 80); - let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() }); - let inner = Rc::new(RefCell::new(inner)); + let (r_mem, w_mem) = mem.split().unwrap(); - let i_prot = TCompactInputProtocol::new(inner.clone()); - let o_prot = TCompactOutputProtocol::new(inner.clone()); + let i_prot = TCompactInputProtocol::new(r_mem); + let o_prot = TCompactOutputProtocol::new(w_mem); - (mem, i_prot, o_prot) + (i_prot, o_prot) } - fn assert_no_write<F: FnMut(&mut TCompactOutputProtocol) -> ::Result<()>>(mut write_fn: F) { - let (trans, _, mut o_prot) = test_objects(); + fn assert_no_write<F>(mut write_fn: F) + where + F: FnMut(&mut TCompactOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>, + { + let (_, mut o_prot) = test_objects(); assert!(write_fn(&mut o_prot).is_ok()); - assert_eq!(trans.borrow().write_buffer_as_ref().len(), 0); + assert_eq!(o_prot.transport.write_bytes().len(), 0); } } diff --git a/lib/rs/src/protocol/mod.rs b/lib/rs/src/protocol/mod.rs index b230d6363..4f139147c 100644 --- a/lib/rs/src/protocol/mod.rs +++ b/lib/rs/src/protocol/mod.rs @@ -19,59 +19,77 @@ //! //! # Examples //! -//! Create and use a `TOutputProtocol`. +//! Create and use a `TInputProtocol`. //! //! ```no_run -//! use std::cell::RefCell; -//! use std::rc::Rc; -//! use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; -//! use thrift::transport::{TTcpTransport, TTransport}; +//! use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; +//! use thrift::transport::TTcpChannel; //! //! // create the I/O channel -//! let mut transport = TTcpTransport::new(); -//! transport.open("127.0.0.1:9090").unwrap(); -//! let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +//! let mut channel = TTcpChannel::new(); +//! channel.open("127.0.0.1:9090").unwrap(); //! -//! // create the protocol to encode types into bytes -//! let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true); +//! // create the protocol to decode bytes into types +//! let mut protocol = TBinaryInputProtocol::new(channel, true); //! -//! // write types -//! o_prot.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); -//! o_prot.write_string("foo").unwrap(); -//! o_prot.write_field_end().unwrap(); +//! // read types from the wire +//! let field_identifier = protocol.read_field_begin().unwrap(); +//! let field_contents = protocol.read_string().unwrap(); +//! let field_end = protocol.read_field_end().unwrap(); //! ``` //! -//! Create and use a `TInputProtocol`. +//! Create and use a `TOutputProtocol`. //! //! ```no_run -//! use std::cell::RefCell; -//! use std::rc::Rc; -//! use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; -//! use thrift::transport::{TTcpTransport, TTransport}; +//! use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; +//! use thrift::transport::TTcpChannel; //! //! // create the I/O channel -//! let mut transport = TTcpTransport::new(); -//! transport.open("127.0.0.1:9090").unwrap(); -//! let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +//! let mut channel = TTcpChannel::new(); +//! channel.open("127.0.0.1:9090").unwrap(); //! -//! // create the protocol to decode bytes into types -//! let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true); +//! // create the protocol to encode types into bytes +//! let mut protocol = TBinaryOutputProtocol::new(channel, true); //! -//! // read types from the wire -//! let field_identifier = i_prot.read_field_begin().unwrap(); -//! let field_contents = i_prot.read_string().unwrap(); -//! let field_end = i_prot.read_field_end().unwrap(); +//! // write types +//! protocol.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); +//! protocol.write_string("foo").unwrap(); +//! protocol.write_field_end().unwrap(); //! ``` -use std::cell::RefCell; use std::fmt; use std::fmt::{Display, Formatter}; use std::convert::From; -use std::rc::Rc; use try_from::TryFrom; -use ::{ProtocolError, ProtocolErrorKind}; -use ::transport::TTransport; +use {ProtocolError, ProtocolErrorKind}; +use transport::{TReadTransport, TWriteTransport}; + +#[cfg(test)] +macro_rules! assert_eq_written_bytes { + ($o_prot:ident, $expected_bytes:ident) => { + { + assert_eq!($o_prot.transport.write_bytes(), &$expected_bytes); + } + }; +} + +// FIXME: should take both read and write +#[cfg(test)] +macro_rules! copy_write_buffer_to_read_buffer { + ($o_prot:ident) => { + { + $o_prot.transport.copy_write_buffer_to_read_buffer(); + } + }; +} + +#[cfg(test)] +macro_rules! set_readable_bytes { + ($i_prot:ident, $bytes:expr) => { + $i_prot.transport.set_readable_bytes($bytes); + } +} mod binary; mod compact; @@ -107,20 +125,17 @@ const MAXIMUM_SKIP_DEPTH: i8 = 64; /// Create and use a `TInputProtocol` /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true); +/// let mut protocol = TBinaryInputProtocol::new(channel, true); /// -/// let field_identifier = i_prot.read_field_begin().unwrap(); -/// let field_contents = i_prot.read_string().unwrap(); -/// let field_end = i_prot.read_field_end().unwrap(); +/// let field_identifier = protocol.read_field_begin().unwrap(); +/// let field_contents = protocol.read_string().unwrap(); +/// let field_end = protocol.read_field_end().unwrap(); /// ``` pub trait TInputProtocol { /// Read the beginning of a Thrift message. @@ -171,10 +186,14 @@ pub trait TInputProtocol { /// Skip a field with type `field_type` recursively up to `depth` levels. fn skip_till_depth(&mut self, field_type: TType, depth: i8) -> ::Result<()> { if depth == 0 { - return Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::DepthLimit, - message: format!("cannot parse past {:?}", field_type), - })); + return Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::DepthLimit, + message: format!("cannot parse past {:?}", field_type), + }, + ), + ); } match field_type { @@ -213,9 +232,11 @@ pub trait TInputProtocol { TType::Map => { let map_ident = self.read_map_begin()?; for _ in 0..map_ident.size { - let key_type = map_ident.key_type + let key_type = map_ident + .key_type .expect("non-zero sized map should contain key type"); - let val_type = map_ident.value_type + let val_type = map_ident + .value_type .expect("non-zero sized map should contain value type"); self.skip_till_depth(key_type, depth - 1)?; self.skip_till_depth(val_type, depth - 1)?; @@ -223,10 +244,14 @@ pub trait TInputProtocol { self.read_map_end() } u => { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot skip field type {:?}", &u), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot skip field type {:?}", &u), + }, + ), + ) } } } @@ -259,20 +284,17 @@ pub trait TInputProtocol { /// Create and use a `TOutputProtocol` /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true); +/// let mut protocol = TBinaryOutputProtocol::new(channel, true); /// -/// o_prot.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); -/// o_prot.write_string("foo").unwrap(); -/// o_prot.write_field_end().unwrap(); +/// protocol.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); +/// protocol.write_string("foo").unwrap(); +/// protocol.write_field_end().unwrap(); /// ``` pub trait TOutputProtocol { /// Write the beginning of a Thrift message. @@ -330,6 +352,192 @@ pub trait TOutputProtocol { fn write_byte(&mut self, b: u8) -> ::Result<()>; // FIXME: REMOVE } +impl<P> TInputProtocol for Box<P> +where + P: TInputProtocol + ?Sized, +{ + fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { + (**self).read_message_begin() + } + + fn read_message_end(&mut self) -> ::Result<()> { + (**self).read_message_end() + } + + fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> { + (**self).read_struct_begin() + } + + fn read_struct_end(&mut self) -> ::Result<()> { + (**self).read_struct_end() + } + + fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> { + (**self).read_field_begin() + } + + fn read_field_end(&mut self) -> ::Result<()> { + (**self).read_field_end() + } + + fn read_bool(&mut self) -> ::Result<bool> { + (**self).read_bool() + } + + fn read_bytes(&mut self) -> ::Result<Vec<u8>> { + (**self).read_bytes() + } + + fn read_i8(&mut self) -> ::Result<i8> { + (**self).read_i8() + } + + fn read_i16(&mut self) -> ::Result<i16> { + (**self).read_i16() + } + + fn read_i32(&mut self) -> ::Result<i32> { + (**self).read_i32() + } + + fn read_i64(&mut self) -> ::Result<i64> { + (**self).read_i64() + } + + fn read_double(&mut self) -> ::Result<f64> { + (**self).read_double() + } + + fn read_string(&mut self) -> ::Result<String> { + (**self).read_string() + } + + fn read_list_begin(&mut self) -> ::Result<TListIdentifier> { + (**self).read_list_begin() + } + + fn read_list_end(&mut self) -> ::Result<()> { + (**self).read_list_end() + } + + fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> { + (**self).read_set_begin() + } + + fn read_set_end(&mut self) -> ::Result<()> { + (**self).read_set_end() + } + + fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> { + (**self).read_map_begin() + } + + fn read_map_end(&mut self) -> ::Result<()> { + (**self).read_map_end() + } + + fn read_byte(&mut self) -> ::Result<u8> { + (**self).read_byte() + } +} + +impl<P> TOutputProtocol for Box<P> +where + P: TOutputProtocol + ?Sized, +{ + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { + (**self).write_message_begin(identifier) + } + + fn write_message_end(&mut self) -> ::Result<()> { + (**self).write_message_end() + } + + fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()> { + (**self).write_struct_begin(identifier) + } + + fn write_struct_end(&mut self) -> ::Result<()> { + (**self).write_struct_end() + } + + fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { + (**self).write_field_begin(identifier) + } + + fn write_field_end(&mut self) -> ::Result<()> { + (**self).write_field_end() + } + + fn write_field_stop(&mut self) -> ::Result<()> { + (**self).write_field_stop() + } + + fn write_bool(&mut self, b: bool) -> ::Result<()> { + (**self).write_bool(b) + } + + fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> { + (**self).write_bytes(b) + } + + fn write_i8(&mut self, i: i8) -> ::Result<()> { + (**self).write_i8(i) + } + + fn write_i16(&mut self, i: i16) -> ::Result<()> { + (**self).write_i16(i) + } + + fn write_i32(&mut self, i: i32) -> ::Result<()> { + (**self).write_i32(i) + } + + fn write_i64(&mut self, i: i64) -> ::Result<()> { + (**self).write_i64(i) + } + + fn write_double(&mut self, d: f64) -> ::Result<()> { + (**self).write_double(d) + } + + fn write_string(&mut self, s: &str) -> ::Result<()> { + (**self).write_string(s) + } + + fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> { + (**self).write_list_begin(identifier) + } + + fn write_list_end(&mut self) -> ::Result<()> { + (**self).write_list_end() + } + + fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> { + (**self).write_set_begin(identifier) + } + + fn write_set_end(&mut self) -> ::Result<()> { + (**self).write_set_end() + } + + fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { + (**self).write_map_begin(identifier) + } + + fn write_map_end(&mut self) -> ::Result<()> { + (**self).write_map_end() + } + + fn flush(&mut self) -> ::Result<()> { + (**self).flush() + } + + fn write_byte(&mut self, b: u8) -> ::Result<()> { + (**self).write_byte(b) + } +} + /// Helper type used by servers to create `TInputProtocol` instances for /// accepted client connections. /// @@ -338,21 +546,27 @@ pub trait TOutputProtocol { /// Create a `TInputProtocolFactory` and use it to create a `TInputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryInputProtocolFactory, TInputProtocolFactory}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut i_proto_factory = TBinaryInputProtocolFactory::new(); -/// let i_prot = i_proto_factory.create(transport); +/// let factory = TBinaryInputProtocolFactory::new(); +/// let protocol = factory.create(Box::new(channel)); /// ``` pub trait TInputProtocolFactory { - /// Create a `TInputProtocol` that reads bytes from `transport`. - fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TInputProtocol>; + // Create a `TInputProtocol` that reads bytes from `transport`. + fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send>; +} + +impl<T> TInputProtocolFactory for Box<T> +where + T: TInputProtocolFactory + ?Sized, +{ + fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send> { + (**self).create(transport) + } } /// Helper type used by servers to create `TOutputProtocol` instances for @@ -363,21 +577,27 @@ pub trait TInputProtocolFactory { /// Create a `TOutputProtocolFactory` and use it to create a `TOutputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryOutputProtocolFactory, TOutputProtocolFactory}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut o_proto_factory = TBinaryOutputProtocolFactory::new(); -/// let o_prot = o_proto_factory.create(transport); +/// let factory = TBinaryOutputProtocolFactory::new(); +/// let protocol = factory.create(Box::new(channel)); /// ``` pub trait TOutputProtocolFactory { /// Create a `TOutputProtocol` that writes bytes to `transport`. - fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol>; + fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send>; +} + +impl<T> TOutputProtocolFactory for Box<T> +where + T: TOutputProtocolFactory + ?Sized, +{ + fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send> { + (**self).create(transport) + } } /// Thrift message identifier. @@ -394,10 +614,11 @@ pub struct TMessageIdentifier { impl TMessageIdentifier { /// Create a `TMessageIdentifier` for a Thrift service-call named `name` /// with message type `message_type` and sequence number `sequence_number`. - pub fn new<S: Into<String>>(name: S, - message_type: TMessageType, - sequence_number: i32) - -> TMessageIdentifier { + pub fn new<S: Into<String>>( + name: S, + message_type: TMessageType, + sequence_number: i32, + ) -> TMessageIdentifier { TMessageIdentifier { name: name.into(), message_type: message_type, @@ -443,9 +664,10 @@ impl TFieldIdentifier { /// /// `id` should be `None` if `field_type` is `TType::Stop`. pub fn new<N, S, I>(name: N, field_type: TType, id: I) -> TFieldIdentifier - where N: Into<Option<S>>, - S: Into<String>, - I: Into<Option<i16>> + where + N: Into<Option<S>>, + S: Into<String>, + I: Into<Option<i16>>, { TFieldIdentifier { name: name.into().map(|n| n.into()), @@ -510,8 +732,9 @@ impl TMapIdentifier { /// Create a `TMapIdentifier` for a map with `size` entries of type /// `key_type -> value_type`. pub fn new<K, V>(key_type: K, value_type: V, size: i32) -> TMapIdentifier - where K: Into<Option<TType>>, - V: Into<Option<TType>> + where + K: Into<Option<TType>>, + V: Into<Option<TType>>, { TMapIdentifier { key_type: key_type.into(), @@ -565,10 +788,14 @@ impl TryFrom<u8> for TMessageType { 0x03 => Ok(TMessageType::Exception), 0x04 => Ok(TMessageType::OneWay), unkn => { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::InvalidData, - message: format!("cannot convert {} to TMessageType", unkn), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} to TMessageType", unkn), + }, + ), + ) } } } @@ -642,10 +869,14 @@ pub fn verify_expected_sequence_number(expected: i32, actual: i32) -> ::Result<( if expected == actual { Ok(()) } else { - Err(::Error::Application(::ApplicationError { - kind: ::ApplicationErrorKind::BadSequenceId, - message: format!("expected {} got {}", expected, actual), - })) + Err( + ::Error::Application( + ::ApplicationError { + kind: ::ApplicationErrorKind::BadSequenceId, + message: format!("expected {} got {}", expected, actual), + }, + ), + ) } } @@ -657,10 +888,14 @@ pub fn verify_expected_service_call(expected: &str, actual: &str) -> ::Result<() if expected == actual { Ok(()) } else { - Err(::Error::Application(::ApplicationError { - kind: ::ApplicationErrorKind::WrongMethodName, - message: format!("expected {} got {}", expected, actual), - })) + Err( + ::Error::Application( + ::ApplicationError { + kind: ::ApplicationErrorKind::WrongMethodName, + message: format!("expected {} got {}", expected, actual), + }, + ), + ) } } @@ -672,10 +907,14 @@ pub fn verify_expected_message_type(expected: TMessageType, actual: TMessageType if expected == actual { Ok(()) } else { - Err(::Error::Application(::ApplicationError { - kind: ::ApplicationErrorKind::InvalidMessageType, - message: format!("expected {} got {}", expected, actual), - })) + Err( + ::Error::Application( + ::ApplicationError { + kind: ::ApplicationErrorKind::InvalidMessageType, + message: format!("expected {} got {}", expected, actual), + }, + ), + ) } } @@ -686,10 +925,14 @@ pub fn verify_required_field_exists<T>(field_name: &str, field: &Option<T>) -> : match *field { Some(_) => Ok(()), None => { - Err(::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::Unknown, - message: format!("missing required field {}", field_name), - })) + Err( + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::Unknown, + message: format!("missing required field {}", field_name), + }, + ), + ) } } } @@ -700,10 +943,67 @@ pub fn verify_required_field_exists<T>(field_name: &str, field: &Option<T>) -> : /// /// Return `TFieldIdentifier.id` if an id exists, `Err` otherwise. pub fn field_id(field_ident: &TFieldIdentifier) -> ::Result<i16> { - field_ident.id.ok_or_else(|| { - ::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::Unknown, - message: format!("missing field in in {:?}", field_ident), - }) - }) + field_ident + .id + .ok_or_else( + || { + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::Unknown, + message: format!("missing field in in {:?}", field_ident), + }, + ) + }, + ) +} + +#[cfg(test)] +mod tests { + + use std::io::Cursor; + + use super::*; + use transport::{TReadTransport, TWriteTransport}; + + #[test] + fn must_create_usable_input_protocol_from_concrete_input_protocol() { + let r: Box<TReadTransport> = Box::new(Cursor::new([0, 1, 2])); + let mut t = TCompactInputProtocol::new(r); + takes_input_protocol(&mut t) + } + + #[test] + fn must_create_usable_input_protocol_from_boxed_input() { + let r: Box<TReadTransport> = Box::new(Cursor::new([0, 1, 2])); + let mut t: Box<TInputProtocol> = Box::new(TCompactInputProtocol::new(r)); + takes_input_protocol(&mut t) + } + + #[test] + fn must_create_usable_output_protocol_from_concrete_output_protocol() { + let w: Box<TWriteTransport> = Box::new(vec![0u8; 10]); + let mut t = TCompactOutputProtocol::new(w); + takes_output_protocol(&mut t) + } + + #[test] + fn must_create_usable_output_protocol_from_boxed_output() { + let w: Box<TWriteTransport> = Box::new(vec![0u8; 10]); + let mut t: Box<TOutputProtocol> = Box::new(TCompactOutputProtocol::new(w)); + takes_output_protocol(&mut t) + } + + fn takes_input_protocol<R>(t: &mut R) + where + R: TInputProtocol, + { + t.read_byte().unwrap(); + } + + fn takes_output_protocol<W>(t: &mut W) + where + W: TOutputProtocol, + { + t.flush().unwrap(); + } } diff --git a/lib/rs/src/protocol/multiplexed.rs b/lib/rs/src/protocol/multiplexed.rs index a30aca80a..db08027f2 100644 --- a/lib/rs/src/protocol/multiplexed.rs +++ b/lib/rs/src/protocol/multiplexed.rs @@ -37,33 +37,37 @@ use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifie /// Create and use a `TMultiplexedOutputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TMessageIdentifier, TMessageType, TOutputProtocol}; /// use thrift::protocol::{TBinaryOutputProtocol, TMultiplexedOutputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let o_prot = TBinaryOutputProtocol::new(transport, true); -/// let mut o_prot = TMultiplexedOutputProtocol::new("service_name", Box::new(o_prot)); +/// let protocol = TBinaryOutputProtocol::new(channel, true); +/// let mut protocol = TMultiplexedOutputProtocol::new("service_name", protocol); /// /// let ident = TMessageIdentifier::new("svc_call", TMessageType::Call, 1); -/// o_prot.write_message_begin(&ident).unwrap(); +/// protocol.write_message_begin(&ident).unwrap(); /// ``` -pub struct TMultiplexedOutputProtocol<'a> { +#[derive(Debug)] +pub struct TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ service_name: String, - inner: Box<TOutputProtocol + 'a>, + inner: P, } -impl<'a> TMultiplexedOutputProtocol<'a> { +impl<P> TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ /// Create a `TMultiplexedOutputProtocol` that identifies outgoing messages /// as originating from a service named `service_name` and sends them over /// the `wrapped` `TOutputProtocol`. Outgoing messages are encoded and sent /// by `wrapped`, not by this instance. - pub fn new(service_name: &str, wrapped: Box<TOutputProtocol + 'a>) -> TMultiplexedOutputProtocol<'a> { + pub fn new(service_name: &str, wrapped: P) -> TMultiplexedOutputProtocol<P> { TMultiplexedOutputProtocol { service_name: service_name.to_owned(), inner: wrapped, @@ -72,7 +76,10 @@ impl<'a> TMultiplexedOutputProtocol<'a> { } // FIXME: avoid passthrough methods -impl<'a> TOutputProtocol for TMultiplexedOutputProtocol<'a> { +impl<P> TOutputProtocol for TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { match identifier.message_type { // FIXME: is there a better way to override identifier here? TMessageType::Call | TMessageType::OneWay => { @@ -181,39 +188,50 @@ impl<'a> TOutputProtocol for TMultiplexedOutputProtocol<'a> { #[cfg(test)] mod tests { - use std::cell::RefCell; - use std::rc::Rc; - - use ::protocol::{TBinaryOutputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; - use ::transport::{TPassThruTransport, TTransport}; - use ::transport::mem::TBufferTransport; + use protocol::{TBinaryOutputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; + use transport::{TBufferChannel, TIoChannel, WriteHalf}; use super::*; #[test] fn must_write_message_begin_with_prefixed_service_name() { - let (trans, mut o_prot) = test_objects(); + let mut o_prot = test_objects(); let ident = TMessageIdentifier::new("bar", TMessageType::Call, 2); assert_success!(o_prot.write_message_begin(&ident)); - let expected: [u8; 19] = - [0x80, 0x01 /* protocol identifier */, 0x00, 0x01 /* message type */, 0x00, - 0x00, 0x00, 0x07, 0x66, 0x6F, 0x6F /* "foo" */, 0x3A /* ":" */, 0x62, 0x61, - 0x72 /* "bar" */, 0x00, 0x00, 0x00, 0x02 /* sequence number */]; - - assert_eq!(&trans.borrow().write_buffer_to_vec(), &expected); - } - - fn test_objects<'a>() -> (Rc<RefCell<Box<TBufferTransport>>>, TMultiplexedOutputProtocol<'a>) { - let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(40, 40)))); - - let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() }); - let inner = Rc::new(RefCell::new(inner)); - - let o_prot = TBinaryOutputProtocol::new(inner.clone(), true); - let o_prot = TMultiplexedOutputProtocol::new("foo", Box::new(o_prot)); - - (mem, o_prot) + let expected: [u8; 19] = [ + 0x80, + 0x01, /* protocol identifier */ + 0x00, + 0x01, /* message type */ + 0x00, + 0x00, + 0x00, + 0x07, + 0x66, + 0x6F, + 0x6F, /* "foo" */ + 0x3A, /* ":" */ + 0x62, + 0x61, + 0x72, /* "bar" */ + 0x00, + 0x00, + 0x00, + 0x02 /* sequence number */, + ]; + + assert_eq!(o_prot.inner.transport.write_bytes(), expected); + } + + fn test_objects + () + -> TMultiplexedOutputProtocol<TBinaryOutputProtocol<WriteHalf<TBufferChannel>>> + { + let c = TBufferChannel::with_capacity(40, 40); + let (_, w_chan) = c.split().unwrap(); + let prot = TBinaryOutputProtocol::new(w_chan, true); + TMultiplexedOutputProtocol::new("foo", prot) } } diff --git a/lib/rs/src/protocol/stored.rs b/lib/rs/src/protocol/stored.rs index 6826c00a8..b3f305f03 100644 --- a/lib/rs/src/protocol/stored.rs +++ b/lib/rs/src/protocol/stored.rs @@ -17,8 +17,8 @@ use std::convert::Into; -use ::ProtocolErrorKind; -use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TInputProtocol, +use ProtocolErrorKind; +use super::{TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier, TSetIdentifier, TStructIdentifier}; /// `TInputProtocol` required to use a `TMultiplexedProcessor`. @@ -40,35 +40,34 @@ use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifie /// Create and use a `TStoredInputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift; /// use thrift::protocol::{TInputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; /// use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TStoredInputProtocol}; /// use thrift::server::TProcessor; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::{TIoChannel, TTcpChannel}; /// /// // sample processor /// struct ActualProcessor; /// impl TProcessor for ActualProcessor { /// fn process( -/// &mut self, +/// &self, /// _: &mut TInputProtocol, /// _: &mut TOutputProtocol /// ) -> thrift::Result<()> { /// unimplemented!() /// } /// } -/// let mut processor = ActualProcessor {}; +/// let processor = ActualProcessor {}; /// /// // construct the shared transport -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let (i_chan, o_chan) = channel.split().unwrap(); /// /// // construct the actual input and output protocols -/// let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true); -/// let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true); +/// let mut i_prot = TBinaryInputProtocol::new(i_chan, true); +/// let mut o_prot = TBinaryOutputProtocol::new(o_chan, true); /// /// // message identifier received from remote and modified to remove the service name /// let new_msg_ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1); @@ -77,6 +76,7 @@ use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifie /// let mut proxy_i_prot = TStoredInputProtocol::new(&mut i_prot, new_msg_ident); /// let res = processor.process(&mut proxy_i_prot, &mut o_prot); /// ``` +// FIXME: implement Debug pub struct TStoredInputProtocol<'a> { inner: &'a mut TInputProtocol, message_ident: Option<TMessageIdentifier>, @@ -88,9 +88,10 @@ impl<'a> TStoredInputProtocol<'a> { /// `TInputProtocol`. `message_ident` is the modified message identifier - /// with service name stripped - that will be passed to /// `wrapped.read_message_begin(...)`. - pub fn new(wrapped: &mut TInputProtocol, - message_ident: TMessageIdentifier) - -> TStoredInputProtocol { + pub fn new( + wrapped: &mut TInputProtocol, + message_ident: TMessageIdentifier, + ) -> TStoredInputProtocol { TStoredInputProtocol { inner: wrapped, message_ident: message_ident.into(), @@ -100,10 +101,16 @@ impl<'a> TStoredInputProtocol<'a> { impl<'a> TInputProtocol for TStoredInputProtocol<'a> { fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { - self.message_ident.take().ok_or_else(|| { - ::errors::new_protocol_error(ProtocolErrorKind::Unknown, - "message identifier already read") - }) + self.message_ident + .take() + .ok_or_else( + || { + ::errors::new_protocol_error( + ProtocolErrorKind::Unknown, + "message identifier already read", + ) + }, + ) } fn read_message_end(&mut self) -> ::Result<()> { diff --git a/lib/rs/src/server/mod.rs b/lib/rs/src/server/mod.rs index ceac18a62..21c392c45 100644 --- a/lib/rs/src/server/mod.rs +++ b/lib/rs/src/server/mod.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -//! Types required to implement a Thrift server. +//! Types used to implement a Thrift server. -use ::protocol::{TInputProtocol, TOutputProtocol}; +use protocol::{TInputProtocol, TOutputProtocol}; -mod simple; mod multiplexed; +mod threaded; -pub use self::simple::TSimpleServer; pub use self::multiplexed::TMultiplexedProcessor; +pub use self::threaded::TServer; /// Handles incoming Thrift messages and dispatches them to the user-defined /// handler functions. @@ -56,14 +56,14 @@ pub use self::multiplexed::TMultiplexedProcessor; /// /// // `TProcessor` implementation for `SimpleService` /// impl TProcessor for SimpleServiceSyncProcessor { -/// fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { +/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { /// unimplemented!(); /// } /// } /// /// // service functions for SimpleService /// trait SimpleServiceSyncHandler { -/// fn service_call(&mut self) -> thrift::Result<()>; +/// fn service_call(&self) -> thrift::Result<()>; /// } /// /// // @@ -73,7 +73,7 @@ pub use self::multiplexed::TMultiplexedProcessor; /// // define a handler that will be invoked when `service_call` is received /// struct SimpleServiceHandlerImpl; /// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { -/// fn service_call(&mut self) -> thrift::Result<()> { +/// fn service_call(&self) -> thrift::Result<()> { /// unimplemented!(); /// } /// } @@ -82,7 +82,7 @@ pub use self::multiplexed::TMultiplexedProcessor; /// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); /// /// // at this point you can pass the processor to the server -/// // let server = TSimpleServer::new(..., processor); +/// // let server = TServer::new(..., processor); /// ``` pub trait TProcessor { /// Process a Thrift service call. @@ -91,5 +91,5 @@ pub trait TProcessor { /// the response to `o`. /// /// Returns `()` if the handler was executed; `Err` otherwise. - fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> ::Result<()>; + fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> ::Result<()>; } diff --git a/lib/rs/src/server/multiplexed.rs b/lib/rs/src/server/multiplexed.rs index d2314a12a..b1243a86f 100644 --- a/lib/rs/src/server/multiplexed.rs +++ b/lib/rs/src/server/multiplexed.rs @@ -17,9 +17,10 @@ use std::collections::HashMap; use std::convert::Into; +use std::sync::{Arc, Mutex}; -use ::{new_application_error, ApplicationErrorKind}; -use ::protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol}; +use {ApplicationErrorKind, new_application_error}; +use protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol}; use super::TProcessor; @@ -33,8 +34,9 @@ use super::TProcessor; /// /// A `TMultiplexedProcessor` can only handle messages sent by a /// `TMultiplexedOutputProtocol`. +// FIXME: implement Debug pub struct TMultiplexedProcessor { - processors: HashMap<String, Box<TProcessor>>, + processors: Mutex<HashMap<String, Arc<Box<TProcessor>>>>, } impl TMultiplexedProcessor { @@ -46,46 +48,62 @@ impl TMultiplexedProcessor { /// Return `false` if a mapping previously existed (the previous mapping is /// *not* overwritten). #[cfg_attr(feature = "cargo-clippy", allow(map_entry))] - pub fn register_processor<S: Into<String>>(&mut self, - service_name: S, - processor: Box<TProcessor>) - -> bool { + pub fn register_processor<S: Into<String>>( + &mut self, + service_name: S, + processor: Box<TProcessor>, + ) -> bool { + let mut processors = self.processors.lock().unwrap(); + let name = service_name.into(); - if self.processors.contains_key(&name) { + if processors.contains_key(&name) { false } else { - self.processors.insert(name, processor); + processors.insert(name, Arc::new(processor)); true } } } impl TProcessor for TMultiplexedProcessor { - fn process(&mut self, - i_prot: &mut TInputProtocol, - o_prot: &mut TOutputProtocol) - -> ::Result<()> { + fn process(&self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> ::Result<()> { let msg_ident = i_prot.read_message_begin()?; - let sep_index = msg_ident.name + let sep_index = msg_ident + .name .find(':') - .ok_or_else(|| { - new_application_error(ApplicationErrorKind::Unknown, - "no service separator found in incoming message") - })?; + .ok_or_else( + || { + new_application_error( + ApplicationErrorKind::Unknown, + "no service separator found in incoming message", + ) + }, + )?; let (svc_name, svc_call) = msg_ident.name.split_at(sep_index); - match self.processors.get_mut(svc_name) { - Some(ref mut processor) => { - let new_msg_ident = TMessageIdentifier::new(svc_call, - msg_ident.message_type, - msg_ident.sequence_number); + let processor: Option<Arc<Box<TProcessor>>> = { + let processors = self.processors.lock().unwrap(); + processors.get(svc_name).cloned() + }; + + match processor { + Some(arc) => { + let new_msg_ident = TMessageIdentifier::new( + svc_call, + msg_ident.message_type, + msg_ident.sequence_number, + ); let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident); - processor.process(&mut proxy_i_prot, o_prot) + (*arc).process(&mut proxy_i_prot, o_prot) } None => { - Err(new_application_error(ApplicationErrorKind::Unknown, - format!("no processor found for service {}", svc_name))) + Err( + new_application_error( + ApplicationErrorKind::Unknown, + format!("no processor found for service {}", svc_name), + ), + ) } } } diff --git a/lib/rs/src/server/simple.rs b/lib/rs/src/server/simple.rs deleted file mode 100644 index 89ed9778e..000000000 --- a/lib/rs/src/server/simple.rs +++ /dev/null @@ -1,189 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::cell::RefCell; -use std::net::{TcpListener, TcpStream}; -use std::rc::Rc; - -use ::{ApplicationError, ApplicationErrorKind}; -use ::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; -use ::transport::{TTcpTransport, TTransport, TTransportFactory}; - -use super::TProcessor; - -/// Single-threaded blocking Thrift socket server. -/// -/// A `TSimpleServer` listens on a given address and services accepted -/// connections *synchronously* and *sequentially* - i.e. in a blocking manner, -/// one at a time - on the main thread. Each accepted connection has an input -/// half and an output half, each of which uses a `TTransport` and `TProtocol` -/// to translate messages to and from byes. Any combination of `TProtocol` and -/// `TTransport` may be used. -/// -/// # Examples -/// -/// Creating and running a `TSimpleServer` using Thrift-compiler-generated -/// service code. -/// -/// ```no_run -/// use thrift; -/// use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; -/// use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory}; -/// use thrift::protocol::{TInputProtocol, TOutputProtocol}; -/// use thrift::transport::{TBufferedTransportFactory, TTransportFactory}; -/// use thrift::server::{TProcessor, TSimpleServer}; -/// -/// // -/// // auto-generated -/// // -/// -/// // processor for `SimpleService` -/// struct SimpleServiceSyncProcessor; -/// impl SimpleServiceSyncProcessor { -/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor { -/// unimplemented!(); -/// } -/// } -/// -/// // `TProcessor` implementation for `SimpleService` -/// impl TProcessor for SimpleServiceSyncProcessor { -/// fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { -/// unimplemented!(); -/// } -/// } -/// -/// // service functions for SimpleService -/// trait SimpleServiceSyncHandler { -/// fn service_call(&mut self) -> thrift::Result<()>; -/// } -/// -/// // -/// // user-code follows -/// // -/// -/// // define a handler that will be invoked when `service_call` is received -/// struct SimpleServiceHandlerImpl; -/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { -/// fn service_call(&mut self) -> thrift::Result<()> { -/// unimplemented!(); -/// } -/// } -/// -/// // instantiate the processor -/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); -/// -/// // instantiate the server -/// let i_tr_fact: Box<TTransportFactory> = Box::new(TBufferedTransportFactory::new()); -/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new()); -/// let o_tr_fact: Box<TTransportFactory> = Box::new(TBufferedTransportFactory::new()); -/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new()); -/// -/// let mut server = TSimpleServer::new( -/// i_tr_fact, -/// i_pr_fact, -/// o_tr_fact, -/// o_pr_fact, -/// processor -/// ); -/// -/// // start listening for incoming connections -/// match server.listen("127.0.0.1:8080") { -/// Ok(_) => println!("listen completed"), -/// Err(e) => println!("listen failed with error {:?}", e), -/// } -/// ``` -pub struct TSimpleServer<PR: TProcessor> { - i_trans_factory: Box<TTransportFactory>, - i_proto_factory: Box<TInputProtocolFactory>, - o_trans_factory: Box<TTransportFactory>, - o_proto_factory: Box<TOutputProtocolFactory>, - processor: PR, -} - -impl<PR: TProcessor> TSimpleServer<PR> { - /// Create a `TSimpleServer`. - /// - /// Each accepted connection has an input and output half, each of which - /// requires a `TTransport` and `TProtocol`. `TSimpleServer` uses - /// `input_transport_factory` and `input_protocol_factory` to create - /// implementations for the input, and `output_transport_factory` and - /// `output_protocol_factory` to create implementations for the output. - pub fn new(input_transport_factory: Box<TTransportFactory>, - input_protocol_factory: Box<TInputProtocolFactory>, - output_transport_factory: Box<TTransportFactory>, - output_protocol_factory: Box<TOutputProtocolFactory>, - processor: PR) - -> TSimpleServer<PR> { - TSimpleServer { - i_trans_factory: input_transport_factory, - i_proto_factory: input_protocol_factory, - o_trans_factory: output_transport_factory, - o_proto_factory: output_protocol_factory, - processor: processor, - } - } - - /// Listen for incoming connections on `listen_address`. - /// - /// `listen_address` should be in the form `host:port`, - /// for example: `127.0.0.1:8080`. - /// - /// Return `()` if successful. - /// - /// Return `Err` when the server cannot bind to `listen_address` or there - /// is an unrecoverable error. - pub fn listen(&mut self, listen_address: &str) -> ::Result<()> { - let listener = TcpListener::bind(listen_address)?; - for stream in listener.incoming() { - match stream { - Ok(s) => self.handle_incoming_connection(s), - Err(e) => warn!("failed to accept remote connection with error {:?}", e), - } - } - - Err(::Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: "aborted listen loop".into(), - })) - } - - fn handle_incoming_connection(&mut self, stream: TcpStream) { - // create the shared tcp stream - let stream = TTcpTransport::with_stream(stream); - let stream: Box<TTransport> = Box::new(stream); - let stream = Rc::new(RefCell::new(stream)); - - // input protocol and transport - let i_tran = self.i_trans_factory.create(stream.clone()); - let i_tran = Rc::new(RefCell::new(i_tran)); - let mut i_prot = self.i_proto_factory.create(i_tran); - - // output protocol and transport - let o_tran = self.o_trans_factory.create(stream.clone()); - let o_tran = Rc::new(RefCell::new(o_tran)); - let mut o_prot = self.o_proto_factory.create(o_tran); - - // process loop - loop { - let r = self.processor.process(&mut *i_prot, &mut *o_prot); - if let Err(e) = r { - warn!("processor failed with error: {:?}", e); - break; // FIXME: close here - } - } - } -} diff --git a/lib/rs/src/server/threaded.rs b/lib/rs/src/server/threaded.rs new file mode 100644 index 000000000..a486c5aad --- /dev/null +++ b/lib/rs/src/server/threaded.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::net::{TcpListener, TcpStream}; +use std::sync::Arc; +use threadpool::ThreadPool; + +use {ApplicationError, ApplicationErrorKind}; +use protocol::{TInputProtocol, TInputProtocolFactory, TOutputProtocol, TOutputProtocolFactory}; +use transport::{TIoChannel, TReadTransportFactory, TTcpChannel, TWriteTransportFactory}; + +use super::TProcessor; + +/// Fixed-size thread-pool blocking Thrift server. +/// +/// A `TServer` listens on a given address and submits accepted connections +/// to an **unbounded** queue. Connections from this queue are serviced by +/// the first available worker thread from a **fixed-size** thread pool. Each +/// accepted connection is handled by that worker thread, and communication +/// over this thread occurs sequentially and synchronously (i.e. calls block). +/// Accepted connections have an input half and an output half, each of which +/// uses a `TTransport` and `TInputProtocol`/`TOutputProtocol` to translate +/// messages to and from byes. Any combination of `TInputProtocol`, `TOutputProtocol` +/// and `TTransport` may be used. +/// +/// # Examples +/// +/// Creating and running a `TServer` using Thrift-compiler-generated +/// service code. +/// +/// ```no_run +/// use thrift; +/// use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; +/// use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory}; +/// use thrift::protocol::{TInputProtocol, TOutputProtocol}; +/// use thrift::transport::{TBufferedReadTransportFactory, TBufferedWriteTransportFactory, TReadTransportFactory, TWriteTransportFactory}; +/// use thrift::server::{TProcessor, TServer}; +/// +/// // +/// // auto-generated +/// // +/// +/// // processor for `SimpleService` +/// struct SimpleServiceSyncProcessor; +/// impl SimpleServiceSyncProcessor { +/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor { +/// unimplemented!(); +/// } +/// } +/// +/// // `TProcessor` implementation for `SimpleService` +/// impl TProcessor for SimpleServiceSyncProcessor { +/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // service functions for SimpleService +/// trait SimpleServiceSyncHandler { +/// fn service_call(&self) -> thrift::Result<()>; +/// } +/// +/// // +/// // user-code follows +/// // +/// +/// // define a handler that will be invoked when `service_call` is received +/// struct SimpleServiceHandlerImpl; +/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { +/// fn service_call(&self) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // instantiate the processor +/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); +/// +/// // instantiate the server +/// let i_tr_fact: Box<TReadTransportFactory> = Box::new(TBufferedReadTransportFactory::new()); +/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new()); +/// let o_tr_fact: Box<TWriteTransportFactory> = Box::new(TBufferedWriteTransportFactory::new()); +/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new()); +/// +/// let mut server = TServer::new( +/// i_tr_fact, +/// i_pr_fact, +/// o_tr_fact, +/// o_pr_fact, +/// processor, +/// 10 +/// ); +/// +/// // start listening for incoming connections +/// match server.listen("127.0.0.1:8080") { +/// Ok(_) => println!("listen completed"), +/// Err(e) => println!("listen failed with error {:?}", e), +/// } +/// ``` +#[derive(Debug)] +pub struct TServer<PRC, RTF, IPF, WTF, OPF> +where + PRC: TProcessor + Send + Sync + 'static, + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ + r_trans_factory: RTF, + i_proto_factory: IPF, + w_trans_factory: WTF, + o_proto_factory: OPF, + processor: Arc<PRC>, + worker_pool: ThreadPool, +} + +impl<PRC, RTF, IPF, WTF, OPF> TServer<PRC, RTF, IPF, WTF, OPF> + where PRC: TProcessor + Send + Sync + 'static, + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static { + /// Create a `TServer`. + /// + /// Each accepted connection has an input and output half, each of which + /// requires a `TTransport` and `TProtocol`. `TServer` uses + /// `read_transport_factory` and `input_protocol_factory` to create + /// implementations for the input, and `write_transport_factory` and + /// `output_protocol_factory` to create implementations for the output. + pub fn new( + read_transport_factory: RTF, + input_protocol_factory: IPF, + write_transport_factory: WTF, + output_protocol_factory: OPF, + processor: PRC, + num_workers: usize, + ) -> TServer<PRC, RTF, IPF, WTF, OPF> { + TServer { + r_trans_factory: read_transport_factory, + i_proto_factory: input_protocol_factory, + w_trans_factory: write_transport_factory, + o_proto_factory: output_protocol_factory, + processor: Arc::new(processor), + worker_pool: ThreadPool::new_with_name( + "Thrift service processor".to_owned(), + num_workers, + ), + } + } + + /// Listen for incoming connections on `listen_address`. + /// + /// `listen_address` should be in the form `host:port`, + /// for example: `127.0.0.1:8080`. + /// + /// Return `()` if successful. + /// + /// Return `Err` when the server cannot bind to `listen_address` or there + /// is an unrecoverable error. + pub fn listen(&mut self, listen_address: &str) -> ::Result<()> { + let listener = TcpListener::bind(listen_address)?; + for stream in listener.incoming() { + match stream { + Ok(s) => { + let (i_prot, o_prot) = self.new_protocols_for_connection(s)?; + let processor = self.processor.clone(); + self.worker_pool + .execute(move || handle_incoming_connection(processor, i_prot, o_prot),); + } + Err(e) => { + warn!("failed to accept remote connection with error {:?}", e); + } + } + } + + Err( + ::Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: "aborted listen loop".into(), + }, + ), + ) + } + + + fn new_protocols_for_connection( + &mut self, + stream: TcpStream, + ) -> ::Result<(Box<TInputProtocol + Send>, Box<TOutputProtocol + Send>)> { + // create the shared tcp stream + let channel = TTcpChannel::with_stream(stream); + + // split it into two - one to be owned by the + // input tran/proto and the other by the output + let (r_chan, w_chan) = channel.split()?; + + // input protocol and transport + let r_tran = self.r_trans_factory.create(Box::new(r_chan)); + let i_prot = self.i_proto_factory.create(r_tran); + + // output protocol and transport + let w_tran = self.w_trans_factory.create(Box::new(w_chan)); + let o_prot = self.o_proto_factory.create(w_tran); + + Ok((i_prot, o_prot)) + } +} + +fn handle_incoming_connection<PRC>( + processor: Arc<PRC>, + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) where + PRC: TProcessor, +{ + let mut i_prot = i_prot; + let mut o_prot = o_prot; + loop { + let r = processor.process(&mut *i_prot, &mut *o_prot); + if let Err(e) = r { + warn!("processor completed with error: {:?}", e); + break; + } + } +} diff --git a/lib/rs/src/transport/buffered.rs b/lib/rs/src/transport/buffered.rs index 3f240d82a..b588ec1a7 100644 --- a/lib/rs/src/transport/buffered.rs +++ b/lib/rs/src/transport/buffered.rs @@ -15,104 +15,94 @@ // specific language governing permissions and limitations // under the License. -use std::cell::RefCell; use std::cmp; use std::io; use std::io::{Read, Write}; -use std::rc::Rc; -use super::{TTransport, TTransportFactory}; +use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory}; /// Default capacity of the read buffer in bytes. -const DEFAULT_RBUFFER_CAPACITY: usize = 4096; +const READ_CAPACITY: usize = 4096; /// Default capacity of the write buffer in bytes.. -const DEFAULT_WBUFFER_CAPACITY: usize = 4096; +const WRITE_CAPACITY: usize = 4096; -/// Transport that communicates with endpoints using a byte stream. +/// Transport that reads messages via an internal buffer. /// -/// A `TBufferedTransport` maintains a fixed-size internal write buffer. All -/// writes are made to this buffer and are sent to the wrapped transport only -/// when `TTransport::flush()` is called. On a flush a fixed-length header with a -/// count of the buffered bytes is written, followed by the bytes themselves. -/// -/// A `TBufferedTransport` also maintains a fixed-size internal read buffer. -/// On a call to `TTransport::read(...)` one full message - both fixed-length -/// header and bytes - is read from the wrapped transport and buffered. +/// A `TBufferedReadTransport` maintains a fixed-size internal read buffer. +/// On a call to `TBufferedReadTransport::read(...)` one full message - both +/// fixed-length header and bytes - is read from the wrapped channel and buffered. /// Subsequent read calls are serviced from the internal buffer until it is /// exhausted, at which point the next full message is read from the wrapped -/// transport. +/// channel. /// /// # Examples /// -/// Create and use a `TBufferedTransport`. +/// Create and use a `TBufferedReadTransport`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; -/// use std::io::{Read, Write}; -/// use thrift::transport::{TBufferedTransport, TTcpTransport, TTransport}; +/// use std::io::Read; +/// use thrift::transport::{TBufferedReadTransport, TTcpChannel}; /// -/// let mut t = TTcpTransport::new(); -/// t.open("localhost:9090").unwrap(); +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); /// -/// let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>)); -/// let mut t = TBufferedTransport::new(t); +/// let mut t = TBufferedReadTransport::new(c); /// -/// // read /// t.read(&mut vec![0u8; 1]).unwrap(); -/// -/// // write -/// t.write(&[0x00]).unwrap(); -/// t.flush().unwrap(); /// ``` -pub struct TBufferedTransport { - rbuf: Box<[u8]>, - rpos: usize, - rcap: usize, - wbuf: Vec<u8>, - inner: Rc<RefCell<Box<TTransport>>>, +#[derive(Debug)] +pub struct TBufferedReadTransport<C> +where + C: Read, +{ + buf: Box<[u8]>, + pos: usize, + cap: usize, + chan: C, } -impl TBufferedTransport { +impl<C> TBufferedReadTransport<C> +where + C: Read, +{ /// Create a `TBufferedTransport` with default-sized internal read and - /// write buffers that wraps an `inner` `TTransport`. - pub fn new(inner: Rc<RefCell<Box<TTransport>>>) -> TBufferedTransport { - TBufferedTransport::with_capacity(DEFAULT_RBUFFER_CAPACITY, DEFAULT_WBUFFER_CAPACITY, inner) + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TBufferedReadTransport<C> { + TBufferedReadTransport::with_capacity(READ_CAPACITY, channel) } /// Create a `TBufferedTransport` with an internal read buffer of size - /// `read_buffer_capacity` and an internal write buffer of size - /// `write_buffer_capacity` that wraps an `inner` `TTransport`. - pub fn with_capacity(read_buffer_capacity: usize, - write_buffer_capacity: usize, - inner: Rc<RefCell<Box<TTransport>>>) - -> TBufferedTransport { - TBufferedTransport { - rbuf: vec![0; read_buffer_capacity].into_boxed_slice(), - rpos: 0, - rcap: 0, - wbuf: Vec::with_capacity(write_buffer_capacity), - inner: inner, + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(read_capacity: usize, channel: C) -> TBufferedReadTransport<C> { + TBufferedReadTransport { + buf: vec![0; read_capacity].into_boxed_slice(), + pos: 0, + cap: 0, + chan: channel, } } fn get_bytes(&mut self) -> io::Result<&[u8]> { - if self.rcap - self.rpos == 0 { - self.rpos = 0; - self.rcap = self.inner.borrow_mut().read(&mut self.rbuf)?; + if self.cap - self.pos == 0 { + self.pos = 0; + self.cap = self.chan.read(&mut self.buf)?; } - Ok(&self.rbuf[self.rpos..self.rcap]) + Ok(&self.buf[self.pos..self.cap]) } fn consume(&mut self, consumed: usize) { // TODO: was a bug here += <-- test somehow - self.rpos = cmp::min(self.rcap, self.rpos + consumed); + self.pos = cmp::min(self.cap, self.pos + consumed); } } -impl Read for TBufferedTransport { +impl<C> Read for TBufferedReadTransport<C> +where + C: Read, +{ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { let mut bytes_read = 0; @@ -137,65 +127,127 @@ impl Read for TBufferedTransport { } } -impl Write for TBufferedTransport { +/// Factory for creating instances of `TBufferedReadTransport`. +#[derive(Default)] +pub struct TBufferedReadTransportFactory; + +impl TBufferedReadTransportFactory { + pub fn new() -> TBufferedReadTransportFactory { + TBufferedReadTransportFactory {} + } +} + +impl TReadTransportFactory for TBufferedReadTransportFactory { + /// Create a `TBufferedReadTransport`. + fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> { + Box::new(TBufferedReadTransport::new(channel)) + } +} + +/// Transport that writes messages via an internal buffer. +/// +/// A `TBufferedWriteTransport` maintains a fixed-size internal write buffer. +/// All writes are made to this buffer and are sent to the wrapped channel only +/// when `TBufferedWriteTransport::flush()` is called. On a flush a fixed-length +/// header with a count of the buffered bytes is written, followed by the bytes +/// themselves. +/// +/// # Examples +/// +/// Create and use a `TBufferedWriteTransport`. +/// +/// ```no_run +/// use std::io::Write; +/// use thrift::transport::{TBufferedWriteTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TBufferedWriteTransport::new(c); +/// +/// t.write(&[0x00]).unwrap(); +/// t.flush().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TBufferedWriteTransport<C> +where + C: Write, +{ + buf: Vec<u8>, + channel: C, +} + +impl<C> TBufferedWriteTransport<C> +where + C: Write, +{ + /// Create a `TBufferedTransport` with default-sized internal read and + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TBufferedWriteTransport<C> { + TBufferedWriteTransport::with_capacity(WRITE_CAPACITY, channel) + } + + /// Create a `TBufferedTransport` with an internal read buffer of size + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(write_capacity: usize, channel: C) -> TBufferedWriteTransport<C> { + TBufferedWriteTransport { + buf: Vec::with_capacity(write_capacity), + channel: channel, + } + } +} + +impl<C> Write for TBufferedWriteTransport<C> +where + C: Write, +{ fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - let avail_bytes = cmp::min(buf.len(), self.wbuf.capacity() - self.wbuf.len()); - self.wbuf.extend_from_slice(&buf[..avail_bytes]); - assert!(self.wbuf.len() <= self.wbuf.capacity(), - "copy overflowed buffer"); + let avail_bytes = cmp::min(buf.len(), self.buf.capacity() - self.buf.len()); + self.buf.extend_from_slice(&buf[..avail_bytes]); + assert!( + self.buf.len() <= self.buf.capacity(), + "copy overflowed buffer" + ); Ok(avail_bytes) } fn flush(&mut self) -> io::Result<()> { - self.inner.borrow_mut().write_all(&self.wbuf)?; - self.inner.borrow_mut().flush()?; - self.wbuf.clear(); + self.channel.write_all(&self.buf)?; + self.channel.flush()?; + self.buf.clear(); Ok(()) } } -/// Factory for creating instances of `TBufferedTransport` +/// Factory for creating instances of `TBufferedWriteTransport`. #[derive(Default)] -pub struct TBufferedTransportFactory; +pub struct TBufferedWriteTransportFactory; -impl TBufferedTransportFactory { - /// Create a `TBufferedTransportFactory`. - pub fn new() -> TBufferedTransportFactory { - TBufferedTransportFactory {} +impl TBufferedWriteTransportFactory { + pub fn new() -> TBufferedWriteTransportFactory { + TBufferedWriteTransportFactory {} } } -impl TTransportFactory for TBufferedTransportFactory { - fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport> { - Box::new(TBufferedTransport::new(inner)) as Box<TTransport> +impl TWriteTransportFactory for TBufferedWriteTransportFactory { + /// Create a `TBufferedWriteTransport`. + fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> { + Box::new(TBufferedWriteTransport::new(channel)) } } #[cfg(test)] mod tests { - use std::cell::RefCell; use std::io::{Read, Write}; - use std::rc::Rc; use super::*; - use ::transport::{TPassThruTransport, TTransport}; - use ::transport::mem::TBufferTransport; - - macro_rules! new_transports { - ($wbc:expr, $rbc:expr) => ( - { - let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity($wbc, $rbc)))); - let thru: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() }); - let thru = Rc::new(RefCell::new(thru)); - (mem, thru) - } - ); - } + use transport::TBufferChannel; #[test] fn must_return_zero_if_read_buffer_is_empty() { - let (_, thru) = new_transports!(10, 0); - let mut t = TBufferedTransport::with_capacity(10, 0, thru); + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(10, mem); let mut b = vec![0; 10]; let read_result = t.read(&mut b); @@ -205,8 +257,8 @@ mod tests { #[test] fn must_return_zero_if_caller_reads_into_zero_capacity_buffer() { - let (_, thru) = new_transports!(10, 0); - let mut t = TBufferedTransport::with_capacity(10, 0, thru); + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(10, mem); let read_result = t.read(&mut []); @@ -215,10 +267,10 @@ mod tests { #[test] fn must_return_zero_if_nothing_more_can_be_read() { - let (mem, thru) = new_transports!(4, 0); - let mut t = TBufferedTransport::with_capacity(4, 0, thru); + let mem = TBufferChannel::with_capacity(4, 0); + let mut t = TBufferedReadTransport::with_capacity(4, mem); - mem.borrow_mut().set_readable_bytes(&[0, 1, 2, 3]); + t.chan.set_readable_bytes(&[0, 1, 2, 3]); // read buffer is exactly the same size as bytes available let mut buf = vec![0u8; 4]; @@ -239,10 +291,10 @@ mod tests { #[test] fn must_fill_user_buffer_with_only_as_many_bytes_as_available() { - let (mem, thru) = new_transports!(4, 0); - let mut t = TBufferedTransport::with_capacity(4, 0, thru); + let mem = TBufferChannel::with_capacity(4, 0); + let mut t = TBufferedReadTransport::with_capacity(4, mem); - mem.borrow_mut().set_readable_bytes(&[0, 1, 2, 3]); + t.chan.set_readable_bytes(&[0, 1, 2, 3]); // read buffer is much larger than the bytes available let mut buf = vec![0u8; 8]; @@ -268,15 +320,16 @@ mod tests { // we have a much smaller buffer than the // underlying transport has bytes available - let (mem, thru) = new_transports!(10, 0); - let mut t = TBufferedTransport::with_capacity(2, 0, thru); + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(2, mem); // fill the underlying transport's byte buffer let mut readable_bytes = [0u8; 10]; for i in 0..10 { readable_bytes[i] = i as u8; } - mem.borrow_mut().set_readable_bytes(&readable_bytes); + + t.chan.set_readable_bytes(&readable_bytes); // we ask to read into a buffer that's much larger // than the one the buffered transport has; as a result @@ -312,8 +365,8 @@ mod tests { #[test] fn must_return_zero_if_nothing_can_be_written() { - let (_, thru) = new_transports!(0, 0); - let mut t = TBufferedTransport::with_capacity(0, 0, thru); + let mem = TBufferChannel::with_capacity(0, 0); + let mut t = TBufferedWriteTransport::with_capacity(0, mem); let b = vec![0; 10]; let r = t.write(&b); @@ -323,19 +376,20 @@ mod tests { #[test] fn must_return_zero_if_caller_calls_write_with_empty_buffer() { - let (mem, thru) = new_transports!(0, 10); - let mut t = TBufferedTransport::with_capacity(0, 10, thru); + let mem = TBufferChannel::with_capacity(0, 10); + let mut t = TBufferedWriteTransport::with_capacity(10, mem); let r = t.write(&[]); + let expected: [u8; 0] = []; assert_eq!(r.unwrap(), 0); - assert_eq!(mem.borrow_mut().write_buffer_as_ref(), &[]); + assert_eq_transport_written_bytes!(t, expected); } #[test] fn must_return_zero_if_write_buffer_full() { - let (_, thru) = new_transports!(0, 0); - let mut t = TBufferedTransport::with_capacity(0, 4, thru); + let mem = TBufferChannel::with_capacity(0, 0); + let mut t = TBufferedWriteTransport::with_capacity(4, mem); let b = [0x00, 0x01, 0x02, 0x03]; @@ -350,26 +404,22 @@ mod tests { #[test] fn must_only_write_to_inner_transport_on_flush() { - let (mem, thru) = new_transports!(10, 10); - let mut t = TBufferedTransport::new(thru); + let mem = TBufferChannel::with_capacity(10, 10); + let mut t = TBufferedWriteTransport::new(mem); let b: [u8; 5] = [0, 1, 2, 3, 4]; assert_eq!(t.write(&b).unwrap(), 5); - assert_eq!(mem.borrow_mut().write_buffer_as_ref().len(), 0); + assert_eq_transport_num_written_bytes!(t, 0); assert!(t.flush().is_ok()); - { - let inner = mem.borrow_mut(); - let underlying_buffer = inner.write_buffer_as_ref(); - assert_eq!(b, underlying_buffer); - } + assert_eq_transport_written_bytes!(t, b); } #[test] fn must_write_successfully_after_flush() { - let (mem, thru) = new_transports!(0, 5); - let mut t = TBufferedTransport::with_capacity(0, 5, thru); + let mem = TBufferChannel::with_capacity(0, 5); + let mut t = TBufferedWriteTransport::with_capacity(5, mem); // write and flush let b: [u8; 5] = [0, 1, 2, 3, 4]; @@ -377,24 +427,16 @@ mod tests { assert!(t.flush().is_ok()); // check the flushed bytes - { - let inner = mem.borrow_mut(); - let underlying_buffer = inner.write_buffer_as_ref(); - assert_eq!(b, underlying_buffer); - } + assert_eq_transport_written_bytes!(t, b); // reset our underlying transport - mem.borrow_mut().empty_write_buffer(); + t.channel.empty_write_buffer(); // write and flush again assert_eq!(t.write(&b).unwrap(), 5); assert!(t.flush().is_ok()); // check the flushed bytes - { - let inner = mem.borrow_mut(); - let underlying_buffer = inner.write_buffer_as_ref(); - assert_eq!(b, underlying_buffer); - } + assert_eq_transport_written_bytes!(t, b); } } diff --git a/lib/rs/src/transport/framed.rs b/lib/rs/src/transport/framed.rs index 75c12f435..d78d2f7a1 100644 --- a/lib/rs/src/transport/framed.rs +++ b/lib/rs/src/transport/framed.rs @@ -16,165 +16,242 @@ // under the License. use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use std::cell::RefCell; use std::cmp; use std::io; use std::io::{ErrorKind, Read, Write}; -use std::rc::Rc; -use super::{TTransport, TTransportFactory}; +use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory}; /// Default capacity of the read buffer in bytes. -const WRITE_BUFFER_CAPACITY: usize = 4096; +const READ_CAPACITY: usize = 4096; -/// Default capacity of the write buffer in bytes.. -const DEFAULT_WBUFFER_CAPACITY: usize = 4096; +/// Default capacity of the write buffer in bytes. +const WRITE_CAPACITY: usize = 4096; -/// Transport that communicates with endpoints using framed messages. +/// Transport that reads framed messages. /// -/// A `TFramedTransport` maintains a fixed-size internal write buffer. All -/// writes are made to this buffer and are sent to the wrapped transport only -/// when `TTransport::flush()` is called. On a flush a fixed-length header with a -/// count of the buffered bytes is written, followed by the bytes themselves. -/// -/// A `TFramedTransport` also maintains a fixed-size internal read buffer. -/// On a call to `TTransport::read(...)` one full message - both fixed-length -/// header and bytes - is read from the wrapped transport and buffered. -/// Subsequent read calls are serviced from the internal buffer until it is -/// exhausted, at which point the next full message is read from the wrapped -/// transport. +/// A `TFramedReadTransport` maintains a fixed-size internal read buffer. +/// On a call to `TFramedReadTransport::read(...)` one full message - both +/// fixed-length header and bytes - is read from the wrapped channel and +/// buffered. Subsequent read calls are serviced from the internal buffer +/// until it is exhausted, at which point the next full message is read +/// from the wrapped channel. /// /// # Examples /// -/// Create and use a `TFramedTransport`. +/// Create and use a `TFramedReadTransport`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; -/// use std::io::{Read, Write}; -/// use thrift::transport::{TFramedTransport, TTcpTransport, TTransport}; +/// use std::io::Read; +/// use thrift::transport::{TFramedReadTransport, TTcpChannel}; /// -/// let mut t = TTcpTransport::new(); -/// t.open("localhost:9090").unwrap(); +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); /// -/// let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>)); -/// let mut t = TFramedTransport::new(t); +/// let mut t = TFramedReadTransport::new(c); /// -/// // read /// t.read(&mut vec![0u8; 1]).unwrap(); -/// -/// // write -/// t.write(&[0x00]).unwrap(); -/// t.flush().unwrap(); /// ``` -pub struct TFramedTransport { - rbuf: Box<[u8]>, - rpos: usize, - rcap: usize, - wbuf: Box<[u8]>, - wpos: usize, - inner: Rc<RefCell<Box<TTransport>>>, +#[derive(Debug)] +pub struct TFramedReadTransport<C> +where + C: Read, +{ + buf: Box<[u8]>, + pos: usize, + cap: usize, + chan: C, } -impl TFramedTransport { +impl<C> TFramedReadTransport<C> +where + C: Read, +{ /// Create a `TFramedTransport` with default-sized internal read and - /// write buffers that wraps an `inner` `TTransport`. - pub fn new(inner: Rc<RefCell<Box<TTransport>>>) -> TFramedTransport { - TFramedTransport::with_capacity(WRITE_BUFFER_CAPACITY, DEFAULT_WBUFFER_CAPACITY, inner) + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TFramedReadTransport<C> { + TFramedReadTransport::with_capacity(READ_CAPACITY, channel) } /// Create a `TFramedTransport` with an internal read buffer of size - /// `read_buffer_capacity` and an internal write buffer of size - /// `write_buffer_capacity` that wraps an `inner` `TTransport`. - pub fn with_capacity(read_buffer_capacity: usize, - write_buffer_capacity: usize, - inner: Rc<RefCell<Box<TTransport>>>) - -> TFramedTransport { - TFramedTransport { - rbuf: vec![0; read_buffer_capacity].into_boxed_slice(), - rpos: 0, - rcap: 0, - wbuf: vec![0; write_buffer_capacity].into_boxed_slice(), - wpos: 0, - inner: inner, + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(read_capacity: usize, channel: C) -> TFramedReadTransport<C> { + TFramedReadTransport { + buf: vec![0; read_capacity].into_boxed_slice(), + pos: 0, + cap: 0, + chan: channel, } } } -impl Read for TFramedTransport { +impl<C> Read for TFramedReadTransport<C> +where + C: Read, +{ fn read(&mut self, b: &mut [u8]) -> io::Result<usize> { - if self.rcap - self.rpos == 0 { - let message_size = self.inner.borrow_mut().read_i32::<BigEndian>()? as usize; - if message_size > self.rbuf.len() { - return Err(io::Error::new(ErrorKind::Other, - format!("bytes to be read ({}) exceeds buffer \ + if self.cap - self.pos == 0 { + let message_size = self.chan.read_i32::<BigEndian>()? as usize; + if message_size > self.buf.len() { + return Err( + io::Error::new( + ErrorKind::Other, + format!( + "bytes to be read ({}) exceeds buffer \ capacity ({})", - message_size, - self.rbuf.len()))); + message_size, + self.buf.len() + ), + ), + ); } - self.inner.borrow_mut().read_exact(&mut self.rbuf[..message_size])?; - self.rpos = 0; - self.rcap = message_size as usize; + self.chan.read_exact(&mut self.buf[..message_size])?; + self.pos = 0; + self.cap = message_size as usize; } - let nread = cmp::min(b.len(), self.rcap - self.rpos); - b[..nread].clone_from_slice(&self.rbuf[self.rpos..self.rpos + nread]); - self.rpos += nread; + let nread = cmp::min(b.len(), self.cap - self.pos); + b[..nread].clone_from_slice(&self.buf[self.pos..self.pos + nread]); + self.pos += nread; Ok(nread) } } -impl Write for TFramedTransport { +/// Factory for creating instances of `TFramedReadTransport`. +#[derive(Default)] +pub struct TFramedReadTransportFactory; + +impl TFramedReadTransportFactory { + pub fn new() -> TFramedReadTransportFactory { + TFramedReadTransportFactory {} + } +} + +impl TReadTransportFactory for TFramedReadTransportFactory { + /// Create a `TFramedReadTransport`. + fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> { + Box::new(TFramedReadTransport::new(channel)) + } +} + +/// Transport that writes framed messages. +/// +/// A `TFramedWriteTransport` maintains a fixed-size internal write buffer. All +/// writes are made to this buffer and are sent to the wrapped channel only +/// when `TFramedWriteTransport::flush()` is called. On a flush a fixed-length +/// header with a count of the buffered bytes is written, followed by the bytes +/// themselves. +/// +/// # Examples +/// +/// Create and use a `TFramedWriteTransport`. +/// +/// ```no_run +/// use std::io::Write; +/// use thrift::transport::{TFramedWriteTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TFramedWriteTransport::new(c); +/// +/// t.write(&[0x00]).unwrap(); +/// t.flush().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TFramedWriteTransport<C> +where + C: Write, +{ + buf: Box<[u8]>, + pos: usize, + channel: C, +} + +impl<C> TFramedWriteTransport<C> +where + C: Write, +{ + /// Create a `TFramedTransport` with default-sized internal read and + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TFramedWriteTransport<C> { + TFramedWriteTransport::with_capacity(WRITE_CAPACITY, channel) + } + + /// Create a `TFramedTransport` with an internal read buffer of size + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(write_capacity: usize, channel: C) -> TFramedWriteTransport<C> { + TFramedWriteTransport { + buf: vec![0; write_capacity].into_boxed_slice(), + pos: 0, + channel: channel, + } + } +} + +impl<C> Write for TFramedWriteTransport<C> +where + C: Write, +{ fn write(&mut self, b: &[u8]) -> io::Result<usize> { - if b.len() > (self.wbuf.len() - self.wpos) { - return Err(io::Error::new(ErrorKind::Other, - format!("bytes to be written ({}) exceeds buffer \ + if b.len() > (self.buf.len() - self.pos) { + return Err( + io::Error::new( + ErrorKind::Other, + format!( + "bytes to be written ({}) exceeds buffer \ capacity ({})", - b.len(), - self.wbuf.len() - self.wpos))); + b.len(), + self.buf.len() - self.pos + ), + ), + ); } let nwrite = b.len(); // always less than available write buffer capacity - self.wbuf[self.wpos..(self.wpos + nwrite)].clone_from_slice(b); - self.wpos += nwrite; + self.buf[self.pos..(self.pos + nwrite)].clone_from_slice(b); + self.pos += nwrite; Ok(nwrite) } fn flush(&mut self) -> io::Result<()> { - let message_size = self.wpos; + let message_size = self.pos; if let 0 = message_size { return Ok(()); } else { - self.inner.borrow_mut().write_i32::<BigEndian>(message_size as i32)?; + self.channel + .write_i32::<BigEndian>(message_size as i32)?; } let mut byte_index = 0; - while byte_index < self.wpos { - let nwrite = self.inner.borrow_mut().write(&self.wbuf[byte_index..self.wpos])?; - byte_index = cmp::min(byte_index + nwrite, self.wpos); + while byte_index < self.pos { + let nwrite = self.channel.write(&self.buf[byte_index..self.pos])?; + byte_index = cmp::min(byte_index + nwrite, self.pos); } - self.wpos = 0; - self.inner.borrow_mut().flush() + self.pos = 0; + self.channel.flush() } } -/// Factory for creating instances of `TFramedTransport`. +/// Factory for creating instances of `TFramedWriteTransport`. #[derive(Default)] -pub struct TFramedTransportFactory; +pub struct TFramedWriteTransportFactory; -impl TFramedTransportFactory { - // Create a `TFramedTransportFactory`. - pub fn new() -> TFramedTransportFactory { - TFramedTransportFactory {} +impl TFramedWriteTransportFactory { + pub fn new() -> TFramedWriteTransportFactory { + TFramedWriteTransportFactory {} } } -impl TTransportFactory for TFramedTransportFactory { - fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport> { - Box::new(TFramedTransport::new(inner)) as Box<TTransport> +impl TWriteTransportFactory for TFramedWriteTransportFactory { + /// Create a `TFramedWriteTransport`. + fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> { + Box::new(TFramedWriteTransport::new(channel)) } } @@ -183,5 +260,5 @@ mod tests { // use std::io::{Read, Write}; // // use super::*; - // use ::transport::mem::TBufferTransport; + // use ::transport::mem::TBufferChannel; } diff --git a/lib/rs/src/transport/mem.rs b/lib/rs/src/transport/mem.rs index 97ec50345..86ac6bb25 100644 --- a/lib/rs/src/transport/mem.rs +++ b/lib/rs/src/transport/mem.rs @@ -17,9 +17,11 @@ use std::cmp; use std::io; +use std::sync::{Arc, Mutex}; -/// Simple transport that contains both a fixed-length internal read buffer and -/// a fixed-length internal write buffer. +use super::{ReadHalf, TIoChannel, WriteHalf}; + +/// In-memory read and write channel with fixed-size read and write buffers. /// /// On a `write` bytes are written to the internal write buffer. Writes are no /// longer accepted once this buffer is full. Callers must `empty_write_buffer()` @@ -29,37 +31,61 @@ use std::io; /// `set_readable_bytes(...)`. Callers can then read until the buffer is /// depleted. No further reads are accepted until the internal read buffer is /// replenished again. -pub struct TBufferTransport { - rbuf: Box<[u8]>, - rpos: usize, - ridx: usize, - rcap: usize, - wbuf: Box<[u8]>, - wpos: usize, - wcap: usize, +#[derive(Debug)] +pub struct TBufferChannel { + read: Arc<Mutex<ReadData>>, + write: Arc<Mutex<WriteData>>, +} + +#[derive(Debug)] +struct ReadData { + buf: Box<[u8]>, + pos: usize, + idx: usize, + cap: usize, } -impl TBufferTransport { - /// Constructs a new, empty `TBufferTransport` with the given +#[derive(Debug)] +struct WriteData { + buf: Box<[u8]>, + pos: usize, + cap: usize, +} + +impl TBufferChannel { + /// Constructs a new, empty `TBufferChannel` with the given /// read buffer capacity and write buffer capacity. - pub fn with_capacity(read_buffer_capacity: usize, - write_buffer_capacity: usize) - -> TBufferTransport { - TBufferTransport { - rbuf: vec![0; read_buffer_capacity].into_boxed_slice(), - ridx: 0, - rpos: 0, - rcap: read_buffer_capacity, - wbuf: vec![0; write_buffer_capacity].into_boxed_slice(), - wpos: 0, - wcap: write_buffer_capacity, + pub fn with_capacity(read_capacity: usize, write_capacity: usize) -> TBufferChannel { + TBufferChannel { + read: Arc::new( + Mutex::new( + ReadData { + buf: vec![0; read_capacity].into_boxed_slice(), + idx: 0, + pos: 0, + cap: read_capacity, + }, + ), + ), + write: Arc::new( + Mutex::new( + WriteData { + buf: vec![0; write_capacity].into_boxed_slice(), + pos: 0, + cap: write_capacity, + }, + ), + ), } } - /// Return a slice containing the bytes held by the internal read buffer. - /// Returns an empty slice if no readable bytes are present. - pub fn read_buffer(&self) -> &[u8] { - &self.rbuf[..self.ridx] + /// Return a copy of the bytes held by the internal read buffer. + /// Returns an empty vector if no readable bytes are present. + pub fn read_bytes(&self) -> Vec<u8> { + let rdata = self.read.as_ref().lock().unwrap(); + let mut buf = vec![0u8; rdata.idx]; + buf.copy_from_slice(&rdata.buf[..rdata.idx]); + buf } // FIXME: do I really need this API call? @@ -68,8 +94,9 @@ impl TBufferTransport { /// /// Subsequent calls to `read` will return nothing. pub fn empty_read_buffer(&mut self) { - self.rpos = 0; - self.ridx = 0; + let mut rdata = self.read.as_ref().lock().unwrap(); + rdata.pos = 0; + rdata.idx = 0; } /// Copy bytes from the source buffer `buf` into the internal read buffer, @@ -77,37 +104,36 @@ impl TBufferTransport { /// which is `min(buf.len(), internal_read_buf.len())`. pub fn set_readable_bytes(&mut self, buf: &[u8]) -> usize { self.empty_read_buffer(); - let max_bytes = cmp::min(self.rcap, buf.len()); - self.rbuf[..max_bytes].clone_from_slice(&buf[..max_bytes]); - self.ridx = max_bytes; + let mut rdata = self.read.as_ref().lock().unwrap(); + let max_bytes = cmp::min(rdata.cap, buf.len()); + rdata.buf[..max_bytes].clone_from_slice(&buf[..max_bytes]); + rdata.idx = max_bytes; max_bytes } - /// Return a slice containing the bytes held by the internal write buffer. - /// Returns an empty slice if no bytes were written. - pub fn write_buffer_as_ref(&self) -> &[u8] { - &self.wbuf[..self.wpos] - } - - /// Return a vector with a copy of the bytes held by the internal write buffer. + /// Return a copy of the bytes held by the internal write buffer. /// Returns an empty vector if no bytes were written. - pub fn write_buffer_to_vec(&self) -> Vec<u8> { - let mut buf = vec![0u8; self.wpos]; - buf.copy_from_slice(&self.wbuf[..self.wpos]); + pub fn write_bytes(&self) -> Vec<u8> { + let wdata = self.write.as_ref().lock().unwrap(); + let mut buf = vec![0u8; wdata.pos]; + buf.copy_from_slice(&wdata.buf[..wdata.pos]); buf } /// Resets the internal write buffer, making it seem like no bytes were - /// written. Calling `write_buffer` after this returns an empty slice. + /// written. Calling `write_buffer` after this returns an empty vector. pub fn empty_write_buffer(&mut self) { - self.wpos = 0; + let mut wdata = self.write.as_ref().lock().unwrap(); + wdata.pos = 0; } /// Overwrites the contents of the read buffer with the contents of the /// write buffer. The write buffer is emptied after this operation. pub fn copy_write_buffer_to_read_buffer(&mut self) { + // FIXME: redo this entire method let buf = { - let b = self.write_buffer_as_ref(); + let wdata = self.write.as_ref().lock().unwrap(); + let b = &wdata.buf[..wdata.pos]; let mut b_ret = vec![0; b.len()]; b_ret.copy_from_slice(b); b_ret @@ -120,20 +146,45 @@ impl TBufferTransport { } } -impl io::Read for TBufferTransport { +impl TIoChannel for TBufferChannel { + fn split(self) -> ::Result<(ReadHalf<Self>, WriteHalf<Self>)> + where + Self: Sized, + { + Ok( + (ReadHalf { + handle: TBufferChannel { + read: self.read.clone(), + write: self.write.clone(), + }, + }, + WriteHalf { + handle: TBufferChannel { + read: self.read.clone(), + write: self.write.clone(), + }, + }), + ) + } +} + +impl io::Read for TBufferChannel { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - let nread = cmp::min(buf.len(), self.ridx - self.rpos); - buf[..nread].clone_from_slice(&self.rbuf[self.rpos..self.rpos + nread]); - self.rpos += nread; + let mut rdata = self.read.as_ref().lock().unwrap(); + let nread = cmp::min(buf.len(), rdata.idx - rdata.pos); + buf[..nread].clone_from_slice(&rdata.buf[rdata.pos..rdata.pos + nread]); + rdata.pos += nread; Ok(nread) } } -impl io::Write for TBufferTransport { +impl io::Write for TBufferChannel { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - let nwrite = cmp::min(buf.len(), self.wcap - self.wpos); - self.wbuf[self.wpos..self.wpos + nwrite].clone_from_slice(&buf[..nwrite]); - self.wpos += nwrite; + let mut wdata = self.write.as_ref().lock().unwrap(); + let nwrite = cmp::min(buf.len(), wdata.cap - wdata.pos); + let (start, end) = (wdata.pos, wdata.pos + nwrite); + wdata.buf[start..end].clone_from_slice(&buf[..nwrite]); + wdata.pos += nwrite; Ok(nwrite) } @@ -146,68 +197,68 @@ impl io::Write for TBufferTransport { mod tests { use std::io::{Read, Write}; - use super::TBufferTransport; + use super::TBufferChannel; #[test] fn must_empty_write_buffer() { - let mut t = TBufferTransport::with_capacity(0, 1); + let mut t = TBufferChannel::with_capacity(0, 1); let bytes_to_write: [u8; 1] = [0x01]; let result = t.write(&bytes_to_write); assert_eq!(result.unwrap(), 1); - assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write); + assert_eq!(&t.write_bytes(), &bytes_to_write); t.empty_write_buffer(); - assert_eq!(t.write_buffer_as_ref().len(), 0); + assert_eq!(t.write_bytes().len(), 0); } #[test] fn must_accept_writes_after_buffer_emptied() { - let mut t = TBufferTransport::with_capacity(0, 2); + let mut t = TBufferChannel::with_capacity(0, 2); let bytes_to_write: [u8; 2] = [0x01, 0x02]; // first write (all bytes written) let result = t.write(&bytes_to_write); assert_eq!(result.unwrap(), 2); - assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write); + assert_eq!(&t.write_bytes(), &bytes_to_write); // try write again (nothing should be written) let result = t.write(&bytes_to_write); assert_eq!(result.unwrap(), 0); - assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write); // still the same as before + assert_eq!(&t.write_bytes(), &bytes_to_write); // still the same as before // now reset the buffer t.empty_write_buffer(); - assert_eq!(t.write_buffer_as_ref().len(), 0); + assert_eq!(t.write_bytes().len(), 0); // now try write again - the write should succeed let result = t.write(&bytes_to_write); assert_eq!(result.unwrap(), 2); - assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write); + assert_eq!(&t.write_bytes(), &bytes_to_write); } #[test] fn must_accept_multiple_writes_until_buffer_is_full() { - let mut t = TBufferTransport::with_capacity(0, 10); + let mut t = TBufferChannel::with_capacity(0, 10); // first write (all bytes written) let bytes_to_write_0: [u8; 2] = [0x01, 0x41]; let write_0_result = t.write(&bytes_to_write_0); assert_eq!(write_0_result.unwrap(), 2); - assert_eq!(t.write_buffer_as_ref(), &bytes_to_write_0); + assert_eq!(t.write_bytes(), &bytes_to_write_0); // second write (all bytes written, starting at index 2) let bytes_to_write_1: [u8; 7] = [0x24, 0x41, 0x32, 0x33, 0x11, 0x98, 0xAF]; let write_1_result = t.write(&bytes_to_write_1); assert_eq!(write_1_result.unwrap(), 7); - assert_eq!(&t.write_buffer_as_ref()[2..], &bytes_to_write_1); + assert_eq!(&t.write_bytes()[2..], &bytes_to_write_1); // third write (only 1 byte written - that's all we have space for) let bytes_to_write_2: [u8; 3] = [0xBF, 0xDA, 0x98]; let write_2_result = t.write(&bytes_to_write_2); assert_eq!(write_2_result.unwrap(), 1); - assert_eq!(&t.write_buffer_as_ref()[9..], &bytes_to_write_2[0..1]); // how does this syntax work?! + assert_eq!(&t.write_bytes()[9..], &bytes_to_write_2[0..1]); // how does this syntax work?! // fourth write (no writes are accepted) let bytes_to_write_3: [u8; 3] = [0xBF, 0xAA, 0xFD]; @@ -219,50 +270,50 @@ mod tests { expected.extend_from_slice(&bytes_to_write_0); expected.extend_from_slice(&bytes_to_write_1); expected.extend_from_slice(&bytes_to_write_2[0..1]); - assert_eq!(t.write_buffer_as_ref(), &expected[..]); + assert_eq!(t.write_bytes(), &expected[..]); } #[test] fn must_empty_read_buffer() { - let mut t = TBufferTransport::with_capacity(1, 0); + let mut t = TBufferChannel::with_capacity(1, 0); let bytes_to_read: [u8; 1] = [0x01]; let result = t.set_readable_bytes(&bytes_to_read); assert_eq!(result, 1); - assert_eq!(&t.read_buffer(), &bytes_to_read); + assert_eq!(t.read_bytes(), &bytes_to_read); t.empty_read_buffer(); - assert_eq!(t.read_buffer().len(), 0); + assert_eq!(t.read_bytes().len(), 0); } #[test] fn must_allow_readable_bytes_to_be_set_after_read_buffer_emptied() { - let mut t = TBufferTransport::with_capacity(1, 0); + let mut t = TBufferChannel::with_capacity(1, 0); let bytes_to_read_0: [u8; 1] = [0x01]; let result = t.set_readable_bytes(&bytes_to_read_0); assert_eq!(result, 1); - assert_eq!(&t.read_buffer(), &bytes_to_read_0); + assert_eq!(t.read_bytes(), &bytes_to_read_0); t.empty_read_buffer(); - assert_eq!(t.read_buffer().len(), 0); + assert_eq!(t.read_bytes().len(), 0); let bytes_to_read_1: [u8; 1] = [0x02]; let result = t.set_readable_bytes(&bytes_to_read_1); assert_eq!(result, 1); - assert_eq!(&t.read_buffer(), &bytes_to_read_1); + assert_eq!(t.read_bytes(), &bytes_to_read_1); } #[test] fn must_accept_multiple_reads_until_all_bytes_read() { - let mut t = TBufferTransport::with_capacity(10, 0); + let mut t = TBufferChannel::with_capacity(10, 0); let readable_bytes: [u8; 10] = [0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0x00, 0x1A, 0x2B, 0x3C, 0x4D]; // check that we're able to set the bytes to be read let result = t.set_readable_bytes(&readable_bytes); assert_eq!(result, 10); - assert_eq!(&t.read_buffer(), &readable_bytes); + assert_eq!(t.read_bytes(), &readable_bytes); // first read let mut read_buf_0 = vec![0; 5]; @@ -300,21 +351,21 @@ mod tests { #[test] fn must_allow_reads_to_succeed_after_read_buffer_replenished() { - let mut t = TBufferTransport::with_capacity(3, 0); + let mut t = TBufferChannel::with_capacity(3, 0); let readable_bytes_0: [u8; 3] = [0x02, 0xAB, 0x33]; // check that we're able to set the bytes to be read let result = t.set_readable_bytes(&readable_bytes_0); assert_eq!(result, 3); - assert_eq!(&t.read_buffer(), &readable_bytes_0); + assert_eq!(t.read_bytes(), &readable_bytes_0); let mut read_buf = vec![0; 4]; // drain the read buffer let read_result = t.read(&mut read_buf); assert_eq!(read_result.unwrap(), 3); - assert_eq!(t.read_buffer(), &read_buf[0..3]); + assert_eq!(t.read_bytes(), &read_buf[0..3]); // check that a subsequent read fails let read_result = t.read(&mut read_buf); @@ -332,11 +383,11 @@ mod tests { // check that we're able to set the bytes to be read let result = t.set_readable_bytes(&readable_bytes_1); assert_eq!(result, 2); - assert_eq!(&t.read_buffer(), &readable_bytes_1); + assert_eq!(t.read_bytes(), &readable_bytes_1); // read again let read_result = t.read(&mut read_buf); assert_eq!(read_result.unwrap(), 2); - assert_eq!(t.read_buffer(), &read_buf[0..2]); + assert_eq!(t.read_bytes(), &read_buf[0..2]); } } diff --git a/lib/rs/src/transport/mod.rs b/lib/rs/src/transport/mod.rs index 1c39f5087..939278643 100644 --- a/lib/rs/src/transport/mod.rs +++ b/lib/rs/src/transport/mod.rs @@ -15,37 +15,266 @@ // specific language governing permissions and limitations // under the License. -//! Types required to send and receive bytes over an I/O channel. +//! Types used to send and receive bytes over an I/O channel. //! -//! The core type is the `TTransport` trait, through which a `TProtocol` can -//! send and receive primitives over the wire. While `TProtocol` instances deal -//! with primitive types, `TTransport` instances understand only bytes. +//! The core types are the `TReadTransport`, `TWriteTransport` and the +//! `TIoChannel` traits, through which `TInputProtocol` or +//! `TOutputProtocol` can receive and send primitives over the wire. While +//! `TInputProtocol` and `TOutputProtocol` instances deal with language primitives +//! the types in this module understand only bytes. -use std::cell::RefCell; use std::io; -use std::rc::Rc; +use std::io::{Read, Write}; +use std::ops::{Deref, DerefMut}; + +#[cfg(test)] +macro_rules! assert_eq_transport_num_written_bytes { + ($transport:ident, $num_written_bytes:expr) => { + { + assert_eq!($transport.channel.write_bytes().len(), $num_written_bytes); + } + }; +} + + +#[cfg(test)] +macro_rules! assert_eq_transport_written_bytes { + ($transport:ident, $expected_bytes:ident) => { + { + assert_eq!($transport.channel.write_bytes(), &$expected_bytes); + } + }; +} mod buffered; mod framed; -mod passthru; mod socket; +mod mem; + +pub use self::buffered::{TBufferedReadTransport, TBufferedReadTransportFactory, + TBufferedWriteTransport, TBufferedWriteTransportFactory}; +pub use self::framed::{TFramedReadTransport, TFramedReadTransportFactory, TFramedWriteTransport, + TFramedWriteTransportFactory}; +pub use self::mem::TBufferChannel; +pub use self::socket::TTcpChannel; + +/// Identifies a transport used by a `TInputProtocol` to receive bytes. +pub trait TReadTransport: Read {} + +/// Helper type used by a server to create `TReadTransport` instances for +/// accepted client connections. +pub trait TReadTransportFactory { + /// Create a `TTransport` that wraps a channel over which bytes are to be read. + fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send>; +} + +/// Identifies a transport used by `TOutputProtocol` to send bytes. +pub trait TWriteTransport: Write {} + +/// Helper type used by a server to create `TWriteTransport` instances for +/// accepted client connections. +pub trait TWriteTransportFactory { + /// Create a `TTransport` that wraps a channel over which bytes are to be sent. + fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send>; +} + +impl<T> TReadTransport for T +where + T: Read, +{ +} + +impl<T> TWriteTransport for T +where + T: Write, +{ +} + +// FIXME: implement the Debug trait for boxed transports + +impl<T> TReadTransportFactory for Box<T> +where + T: TReadTransportFactory + ?Sized, +{ + fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> { + (**self).create(channel) + } +} + +impl<T> TWriteTransportFactory for Box<T> +where + T: TWriteTransportFactory + ?Sized, +{ + fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> { + (**self).create(channel) + } +} + +/// Identifies a splittable bidirectional I/O channel used to send and receive bytes. +pub trait TIoChannel: Read + Write { + /// Split the channel into a readable half and a writable half, where the + /// readable half implements `io::Read` and the writable half implements + /// `io::Write`. Returns `None` if the channel was not initialized, or if it + /// cannot be split safely. + /// + /// Returned halves may share the underlying OS channel or buffer resources. + /// Implementations **should ensure** that these two halves can be safely + /// used independently by concurrent threads. + fn split(self) -> ::Result<(::transport::ReadHalf<Self>, ::transport::WriteHalf<Self>)> + where + Self: Sized; +} + +/// The readable half of an object returned from `TIoChannel::split`. +#[derive(Debug)] +pub struct ReadHalf<C> +where + C: Read, +{ + handle: C, +} + +/// The writable half of an object returned from `TIoChannel::split`. +#[derive(Debug)] +pub struct WriteHalf<C> +where + C: Write, +{ + handle: C, +} + +impl<C> Read for ReadHalf<C> +where + C: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.handle.read(buf) + } +} + +impl<C> Write for WriteHalf<C> +where + C: Write, +{ + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + self.handle.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.handle.flush() + } +} + +impl<C> Deref for ReadHalf<C> +where + C: Read, +{ + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.handle + } +} + +impl<C> DerefMut for ReadHalf<C> +where + C: Read, +{ + fn deref_mut(&mut self) -> &mut C { + &mut self.handle + } +} + +impl<C> Deref for WriteHalf<C> +where + C: Write, +{ + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.handle + } +} + +impl<C> DerefMut for WriteHalf<C> +where + C: Write, +{ + fn deref_mut(&mut self) -> &mut C { + &mut self.handle + } +} + +#[cfg(test)] +mod tests { + + use std::io::Cursor; + + use super::*; + + #[test] + fn must_create_usable_read_channel_from_concrete_read_type() { + let r = Cursor::new([0, 1, 2]); + let _ = TBufferedReadTransport::new(r); + } + + #[test] + fn must_create_usable_read_channel_from_boxed_read() { + let r: Box<Read> = Box::new(Cursor::new([0, 1, 2])); + let _ = TBufferedReadTransport::new(r); + } + + #[test] + fn must_create_usable_write_channel_from_concrete_write_type() { + let w = vec![0u8; 10]; + let _ = TBufferedWriteTransport::new(w); + } + + #[test] + fn must_create_usable_write_channel_from_boxed_write() { + let w: Box<Write> = Box::new(vec![0u8; 10]); + let _ = TBufferedWriteTransport::new(w); + } + + #[test] + fn must_create_usable_read_transport_from_concrete_read_transport() { + let r = Cursor::new([0, 1, 2]); + let mut t = TBufferedReadTransport::new(r); + takes_read_transport(&mut t) + } + + #[test] + fn must_create_usable_read_transport_from_boxed_read() { + let r = Cursor::new([0, 1, 2]); + let mut t: Box<TReadTransport> = Box::new(TBufferedReadTransport::new(r)); + takes_read_transport(&mut t) + } -pub mod mem; + #[test] + fn must_create_usable_write_transport_from_concrete_write_transport() { + let w = vec![0u8; 10]; + let mut t = TBufferedWriteTransport::new(w); + takes_write_transport(&mut t) + } -pub use self::mem::TBufferTransport; -pub use self::buffered::{TBufferedTransport, TBufferedTransportFactory}; -pub use self::framed::{TFramedTransport, TFramedTransportFactory}; -pub use self::passthru::TPassThruTransport; -pub use self::socket::TTcpTransport; + #[test] + fn must_create_usable_write_transport_from_boxed_write() { + let w = vec![0u8; 10]; + let mut t: Box<TWriteTransport> = Box::new(TBufferedWriteTransport::new(w)); + takes_write_transport(&mut t) + } -/// Identifies an I/O channel that can be used to send and receive bytes. -pub trait TTransport: io::Read + io::Write {} -impl<I: io::Read + io::Write> TTransport for I {} + fn takes_read_transport<R>(t: &mut R) + where + R: TReadTransport, + { + t.bytes(); + } -/// Helper type used by servers to create `TTransport` instances for accepted -/// client connections. -pub trait TTransportFactory { - /// Create a `TTransport` that wraps an `inner` transport, thus creating - /// a transport stack. - fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport>; + fn takes_write_transport<W>(t: &mut W) + where + W: TWriteTransport, + { + t.flush().unwrap(); + } } diff --git a/lib/rs/src/transport/passthru.rs b/lib/rs/src/transport/passthru.rs deleted file mode 100644 index 60dc3a63f..000000000 --- a/lib/rs/src/transport/passthru.rs +++ /dev/null @@ -1,73 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::cell::RefCell; -use std::rc::Rc; -use std::io; -use std::io::{Read, Write}; - -use super::TTransport; - -/// Proxy that wraps an inner `TTransport` and delegates all calls to it. -/// -/// Unlike other `TTransport` wrappers, `TPassThruTransport` is generic with -/// regards to the wrapped transport. This allows callers to use methods -/// specific to the type being wrapped instead of being constrained to methods -/// on the `TTransport` trait. -/// -/// # Examples -/// -/// Create and use a `TPassThruTransport`. -/// -/// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; -/// use thrift::transport::{TPassThruTransport, TTcpTransport}; -/// -/// let t = TTcpTransport::new(); -/// let t = TPassThruTransport::new(Rc::new(RefCell::new(Box::new(t)))); -/// -/// // since the type parameter is maintained, we are able -/// // to use functions specific to `TTcpTransport` -/// t.inner.borrow_mut().open("localhost:9090").unwrap(); -/// ``` -pub struct TPassThruTransport<I: TTransport> { - pub inner: Rc<RefCell<Box<I>>>, -} - -impl<I: TTransport> TPassThruTransport<I> { - /// Create a `TPassThruTransport` that wraps an `inner` TTransport. - pub fn new(inner: Rc<RefCell<Box<I>>>) -> TPassThruTransport<I> { - TPassThruTransport { inner: inner } - } -} - -impl<I: TTransport> Read for TPassThruTransport<I> { - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - self.inner.borrow_mut().read(buf) - } -} - -impl<I: TTransport> Write for TPassThruTransport<I> { - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - self.inner.borrow_mut().write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.inner.borrow_mut().flush() - } -} diff --git a/lib/rs/src/transport/socket.rs b/lib/rs/src/transport/socket.rs index 9f2b8ba31..16b59ef21 100644 --- a/lib/rs/src/transport/socket.rs +++ b/lib/rs/src/transport/socket.rs @@ -21,69 +21,74 @@ use std::io::{ErrorKind, Read, Write}; use std::net::{Shutdown, TcpStream}; use std::ops::Drop; -use ::{TransportError, TransportErrorKind}; +use {TransportErrorKind, new_transport_error}; +use super::{ReadHalf, TIoChannel, WriteHalf}; -/// Communicate with a Thrift service over a TCP socket. +/// Bidirectional TCP/IP channel. /// /// # Examples /// -/// Create a `TTcpTransport`. +/// Create a `TTcpChannel`. /// /// ```no_run /// use std::io::{Read, Write}; -/// use thrift::transport::TTcpTransport; +/// use thrift::transport::TTcpChannel; /// -/// let mut t = TTcpTransport::new(); -/// t.open("localhost:9090").unwrap(); +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); /// /// let mut buf = vec![0u8; 4]; -/// t.read(&mut buf).unwrap(); -/// t.write(&vec![0, 1, 2]).unwrap(); +/// c.read(&mut buf).unwrap(); +/// c.write(&vec![0, 1, 2]).unwrap(); /// ``` /// -/// Create a `TTcpTransport` by wrapping an existing `TcpStream`. +/// Create a `TTcpChannel` by wrapping an existing `TcpStream`. /// /// ```no_run /// use std::io::{Read, Write}; /// use std::net::TcpStream; -/// use thrift::transport::TTcpTransport; +/// use thrift::transport::TTcpChannel; /// /// let stream = TcpStream::connect("127.0.0.1:9189").unwrap(); -/// let mut t = TTcpTransport::with_stream(stream); /// -/// // no need to call t.open() since we've already connected above +/// // no need to call c.open() since we've already connected above +/// let mut c = TTcpChannel::with_stream(stream); /// /// let mut buf = vec![0u8; 4]; -/// t.read(&mut buf).unwrap(); -/// t.write(&vec![0, 1, 2]).unwrap(); +/// c.read(&mut buf).unwrap(); +/// c.write(&vec![0, 1, 2]).unwrap(); /// ``` -#[derive(Default)] -pub struct TTcpTransport { +#[derive(Debug, Default)] +pub struct TTcpChannel { stream: Option<TcpStream>, } -impl TTcpTransport { - /// Create an uninitialized `TTcpTransport`. +impl TTcpChannel { + /// Create an uninitialized `TTcpChannel`. /// - /// The returned instance must be opened using `TTcpTransport::open(...)` + /// The returned instance must be opened using `TTcpChannel::open(...)` /// before it can be used. - pub fn new() -> TTcpTransport { - TTcpTransport { stream: None } + pub fn new() -> TTcpChannel { + TTcpChannel { stream: None } } - /// Create a `TTcpTransport` that wraps an existing `TcpStream`. + /// Create a `TTcpChannel` that wraps an existing `TcpStream`. /// /// The passed-in stream is assumed to have been opened before being wrapped - /// by the created `TTcpTransport` instance. - pub fn with_stream(stream: TcpStream) -> TTcpTransport { - TTcpTransport { stream: Some(stream) } + /// by the created `TTcpChannel` instance. + pub fn with_stream(stream: TcpStream) -> TTcpChannel { + TTcpChannel { stream: Some(stream) } } /// Connect to `remote_address`, which should have the form `host:port`. pub fn open(&mut self, remote_address: &str) -> ::Result<()> { if self.stream.is_some() { - Err(::Error::Transport(TransportError::new(TransportErrorKind::AlreadyOpen, - "transport previously opened"))) + Err( + new_transport_error( + TransportErrorKind::AlreadyOpen, + "tcp connection previously opened", + ), + ) } else { match TcpStream::connect(&remote_address) { Ok(s) => { @@ -95,33 +100,62 @@ impl TTcpTransport { } } - /// Shutdown this transport. + /// Shut down this channel. /// /// Both send and receive halves are closed, and this instance can no /// longer be used to communicate with another endpoint. pub fn close(&mut self) -> ::Result<()> { - self.if_set(|s| s.shutdown(Shutdown::Both)).map_err(From::from) + self.if_set(|s| s.shutdown(Shutdown::Both)) + .map_err(From::from) } fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T> - where F: FnMut(&mut TcpStream) -> io::Result<T> + where + F: FnMut(&mut TcpStream) -> io::Result<T>, { if let Some(ref mut s) = self.stream { stream_operation(s) } else { - Err(io::Error::new(ErrorKind::NotConnected, "tcp endpoint not connected")) + Err(io::Error::new(ErrorKind::NotConnected, "tcp endpoint not connected"),) } } } -impl Read for TTcpTransport { +impl TIoChannel for TTcpChannel { + fn split(self) -> ::Result<(ReadHalf<Self>, WriteHalf<Self>)> + where + Self: Sized, + { + let mut s = self; + + s.stream + .as_mut() + .and_then(|s| s.try_clone().ok()) + .map( + |cloned| { + (ReadHalf { handle: TTcpChannel { stream: s.stream.take() } }, + WriteHalf { handle: TTcpChannel { stream: Some(cloned) } }) + }, + ) + .ok_or_else( + || { + new_transport_error( + TransportErrorKind::Unknown, + "cannot clone underlying tcp stream", + ) + }, + ) + } +} + +impl Read for TTcpChannel { fn read(&mut self, b: &mut [u8]) -> io::Result<usize> { self.if_set(|s| s.read(b)) } } -impl Write for TTcpTransport { +impl Write for TTcpChannel { fn write(&mut self, b: &[u8]) -> io::Result<usize> { self.if_set(|s| s.write(b)) } @@ -131,11 +165,11 @@ impl Write for TTcpTransport { } } -// Do I have to implement the Drop trait? TcpStream closes the socket on drop. -impl Drop for TTcpTransport { +// FIXME: Do I have to implement the Drop trait? TcpStream closes the socket on drop. +impl Drop for TTcpChannel { fn drop(&mut self) { if let Err(e) = self.close() { - warn!("error while closing socket transport: {:?}", e) + warn!("error while closing socket: {:?}", e) } } } diff --git a/lib/rs/test/src/bin/kitchen_sink_client.rs b/lib/rs/test/src/bin/kitchen_sink_client.rs index 27171beff..9738298cb 100644 --- a/lib/rs/test/src/bin/kitchen_sink_client.rs +++ b/lib/rs/test/src/bin/kitchen_sink_client.rs @@ -21,13 +21,11 @@ extern crate clap; extern crate kitchen_sink; extern crate thrift; -use std::cell::RefCell; -use std::rc::Rc; - use kitchen_sink::base_two::{TNapkinServiceSyncClient, TRamenServiceSyncClient}; use kitchen_sink::midlayer::{MealServiceSyncClient, TMealServiceSyncClient}; use kitchen_sink::ultimate::{FullMealServiceSyncClient, TFullMealServiceSyncClient}; -use thrift::transport::{TFramedTransport, TTcpTransport, TTransport}; +use thrift::transport::{ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel, + TTcpChannel, WriteHalf}; use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol, TCompactOutputProtocol, TInputProtocol, TOutputProtocol}; @@ -50,24 +48,25 @@ fn run() -> thrift::Result<()> { (@arg port: --port +takes_value "Port on which the Thrift test server is listening") (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")") (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\")") - ).get_matches(); + ) + .get_matches(); let host = matches.value_of("host").unwrap_or("127.0.0.1"); let port = value_t!(matches, "port", u16).unwrap_or(9090); let protocol = matches.value_of("protocol").unwrap_or("compact"); let service = matches.value_of("service").unwrap_or("part"); - let t = open_tcp_transport(host, port)?; - let t = Rc::new(RefCell::new(Box::new(TFramedTransport::new(t)) as Box<TTransport>)); + let (i_chan, o_chan) = tcp_channel(host, port)?; + let (i_tran, o_tran) = (TFramedReadTransport::new(i_chan), TFramedWriteTransport::new(o_chan)); let (i_prot, o_prot): (Box<TInputProtocol>, Box<TOutputProtocol>) = match protocol { "binary" => { - (Box::new(TBinaryInputProtocol::new(t.clone(), true)), - Box::new(TBinaryOutputProtocol::new(t.clone(), true))) + (Box::new(TBinaryInputProtocol::new(i_tran, true)), + Box::new(TBinaryOutputProtocol::new(o_tran, true))) } "compact" => { - (Box::new(TCompactInputProtocol::new(t.clone())), - Box::new(TCompactOutputProtocol::new(t.clone()))) + (Box::new(TCompactInputProtocol::new(i_tran)), + Box::new(TCompactOutputProtocol::new(o_tran))) } unmatched => return Err(format!("unsupported protocol {}", unmatched).into()), }; @@ -75,28 +74,31 @@ fn run() -> thrift::Result<()> { run_client(service, i_prot, o_prot) } -fn run_client(service: &str, - i_prot: Box<TInputProtocol>, - o_prot: Box<TOutputProtocol>) - -> thrift::Result<()> { +fn run_client( + service: &str, + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) -> thrift::Result<()> { match service { "full" => run_full_meal_service(i_prot, o_prot), "part" => run_meal_service(i_prot, o_prot), - _ => Err(thrift::Error::from(format!("unknown service type {}", service))), + _ => Err(thrift::Error::from(format!("unknown service type {}", service)),), } } -fn open_tcp_transport(host: &str, port: u16) -> thrift::Result<Rc<RefCell<Box<TTransport>>>> { - let mut t = TTcpTransport::new(); - match t.open(&format!("{}:{}", host, port)) { - Ok(()) => Ok(Rc::new(RefCell::new(Box::new(t) as Box<TTransport>))), - Err(e) => Err(e), - } +fn tcp_channel( + host: &str, + port: u16, +) -> thrift::Result<(ReadHalf<TTcpChannel>, WriteHalf<TTcpChannel>)> { + let mut c = TTcpChannel::new(); + c.open(&format!("{}:{}", host, port))?; + c.split() } -fn run_meal_service(i_prot: Box<TInputProtocol>, - o_prot: Box<TOutputProtocol>) - -> thrift::Result<()> { +fn run_meal_service( + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) -> thrift::Result<()> { let mut client = MealServiceSyncClient::new(i_prot, o_prot); // client.full_meal(); // <-- IMPORTANT: if you uncomment this, compilation *should* fail @@ -110,9 +112,10 @@ fn run_meal_service(i_prot: Box<TInputProtocol>, Ok(()) } -fn run_full_meal_service(i_prot: Box<TInputProtocol>, - o_prot: Box<TOutputProtocol>) - -> thrift::Result<()> { +fn run_full_meal_service( + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) -> thrift::Result<()> { let mut client = FullMealServiceSyncClient::new(i_prot, o_prot); execute_call("full", "ramen", || client.ramen(100))?; @@ -124,17 +127,20 @@ fn run_full_meal_service(i_prot: Box<TInputProtocol>, } fn execute_call<F, R>(service_type: &str, call_name: &str, mut f: F) -> thrift::Result<()> - where F: FnMut() -> thrift::Result<R> +where + F: FnMut() -> thrift::Result<R>, { let res = f(); match res { Ok(_) => println!("{}: completed {} call", service_type, call_name), Err(ref e) => { - println!("{}: failed {} call with error {:?}", - service_type, - call_name, - e) + println!( + "{}: failed {} call with error {:?}", + service_type, + call_name, + e + ) } } diff --git a/lib/rs/test/src/bin/kitchen_sink_server.rs b/lib/rs/test/src/bin/kitchen_sink_server.rs index 4ce4fa377..19112cdbb 100644 --- a/lib/rs/test/src/bin/kitchen_sink_server.rs +++ b/lib/rs/test/src/bin/kitchen_sink_server.rs @@ -22,7 +22,7 @@ extern crate kitchen_sink; extern crate thrift; use kitchen_sink::base_one::Noodle; -use kitchen_sink::base_two::{Napkin, Ramen, NapkinServiceSyncHandler, RamenServiceSyncHandler}; +use kitchen_sink::base_two::{Napkin, NapkinServiceSyncHandler, Ramen, RamenServiceSyncHandler}; use kitchen_sink::midlayer::{Dessert, Meal, MealServiceSyncHandler, MealServiceSyncProcessor}; use kitchen_sink::ultimate::{Drink, FullMeal, FullMealAndDrinks, FullMealAndDrinksServiceSyncProcessor, FullMealServiceSyncHandler}; @@ -30,8 +30,9 @@ use kitchen_sink::ultimate::FullMealAndDrinksServiceSyncHandler; use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory, TCompactInputProtocolFactory, TCompactOutputProtocolFactory, TInputProtocolFactory, TOutputProtocolFactory}; -use thrift::transport::{TFramedTransportFactory, TTransportFactory}; -use thrift::server::TSimpleServer; +use thrift::transport::{TFramedReadTransportFactory, TFramedWriteTransportFactory, + TReadTransportFactory, TWriteTransportFactory}; +use thrift::server::TServer; fn main() { match run() { @@ -52,7 +53,8 @@ fn run() -> thrift::Result<()> { (@arg port: --port +takes_value "port on which the test server listens") (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")") (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\")") - ).get_matches(); + ) + .get_matches(); let port = value_t!(matches, "port", u16).unwrap_or(9090); let protocol = matches.value_of("protocol").unwrap_or("compact"); @@ -61,9 +63,8 @@ fn run() -> thrift::Result<()> { println!("binding to {}", listen_address); - let (i_transport_factory, o_transport_factory): (Box<TTransportFactory>, - Box<TTransportFactory>) = - (Box::new(TFramedTransportFactory {}), Box::new(TFramedTransportFactory {})); + let r_transport_factory = TFramedReadTransportFactory::new(); + let w_transport_factory = TFramedWriteTransportFactory::new(); let (i_protocol_factory, o_protocol_factory): (Box<TInputProtocolFactory>, Box<TOutputProtocolFactory>) = @@ -93,51 +94,75 @@ fn run() -> thrift::Result<()> { // Since what I'm doing is uncommon I'm just going to duplicate the code match &*service { "part" => { - run_meal_server(&listen_address, - i_transport_factory, - i_protocol_factory, - o_transport_factory, - o_protocol_factory) + run_meal_server( + &listen_address, + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + ) } "full" => { - run_full_meal_server(&listen_address, - i_transport_factory, - i_protocol_factory, - o_transport_factory, - o_protocol_factory) + run_full_meal_server( + &listen_address, + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + ) } unknown => Err(format!("unsupported service type {}", unknown).into()), } } -fn run_meal_server(listen_address: &str, - i_transport_factory: Box<TTransportFactory>, - i_protocol_factory: Box<TInputProtocolFactory>, - o_transport_factory: Box<TTransportFactory>, - o_protocol_factory: Box<TOutputProtocolFactory>) - -> thrift::Result<()> { +fn run_meal_server<RTF, IPF, WTF, OPF>( + listen_address: &str, + r_transport_factory: RTF, + i_protocol_factory: IPF, + w_transport_factory: WTF, + o_protocol_factory: OPF, +) -> thrift::Result<()> +where + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ let processor = MealServiceSyncProcessor::new(PartHandler {}); - let mut server = TSimpleServer::new(i_transport_factory, - i_protocol_factory, - o_transport_factory, - o_protocol_factory, - processor); + let mut server = TServer::new( + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + processor, + 1, + ); server.listen(listen_address) } -fn run_full_meal_server(listen_address: &str, - i_transport_factory: Box<TTransportFactory>, - i_protocol_factory: Box<TInputProtocolFactory>, - o_transport_factory: Box<TTransportFactory>, - o_protocol_factory: Box<TOutputProtocolFactory>) - -> thrift::Result<()> { +fn run_full_meal_server<RTF, IPF, WTF, OPF>( + listen_address: &str, + r_transport_factory: RTF, + i_protocol_factory: IPF, + w_transport_factory: WTF, + o_protocol_factory: OPF, +) -> thrift::Result<()> +where + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ let processor = FullMealAndDrinksServiceSyncProcessor::new(FullHandler {}); - let mut server = TSimpleServer::new(i_transport_factory, - i_protocol_factory, - o_transport_factory, - o_protocol_factory, - processor); + let mut server = TServer::new( + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + processor, + 1, + ); server.listen(listen_address) } @@ -145,21 +170,21 @@ fn run_full_meal_server(listen_address: &str, struct PartHandler; impl MealServiceSyncHandler for PartHandler { - fn handle_meal(&mut self) -> thrift::Result<Meal> { + fn handle_meal(&self) -> thrift::Result<Meal> { println!("part: handling meal call"); Ok(meal()) } } impl RamenServiceSyncHandler for PartHandler { - fn handle_ramen(&mut self, _: i32) -> thrift::Result<Ramen> { + fn handle_ramen(&self, _: i32) -> thrift::Result<Ramen> { println!("part: handling ramen call"); Ok(ramen()) } } impl NapkinServiceSyncHandler for PartHandler { - fn handle_napkin(&mut self) -> thrift::Result<Napkin> { + fn handle_napkin(&self) -> thrift::Result<Napkin> { println!("part: handling napkin call"); Ok(napkin()) } @@ -171,34 +196,34 @@ impl NapkinServiceSyncHandler for PartHandler { struct FullHandler; impl FullMealAndDrinksServiceSyncHandler for FullHandler { - fn handle_full_meal_and_drinks(&mut self) -> thrift::Result<FullMealAndDrinks> { + fn handle_full_meal_and_drinks(&self) -> thrift::Result<FullMealAndDrinks> { Ok(FullMealAndDrinks::new(full_meal(), Drink::WHISKEY)) } } impl FullMealServiceSyncHandler for FullHandler { - fn handle_full_meal(&mut self) -> thrift::Result<FullMeal> { + fn handle_full_meal(&self) -> thrift::Result<FullMeal> { println!("full: handling full meal call"); Ok(full_meal()) } } impl MealServiceSyncHandler for FullHandler { - fn handle_meal(&mut self) -> thrift::Result<Meal> { + fn handle_meal(&self) -> thrift::Result<Meal> { println!("full: handling meal call"); Ok(meal()) } } impl RamenServiceSyncHandler for FullHandler { - fn handle_ramen(&mut self, _: i32) -> thrift::Result<Ramen> { + fn handle_ramen(&self, _: i32) -> thrift::Result<Ramen> { println!("full: handling ramen call"); Ok(ramen()) } } impl NapkinServiceSyncHandler for FullHandler { - fn handle_napkin(&mut self) -> thrift::Result<Napkin> { + fn handle_napkin(&self) -> thrift::Result<Napkin> { println!("full: handling napkin call"); Ok(napkin()) } diff --git a/lib/rs/test/src/lib.rs b/lib/rs/test/src/lib.rs index 8a7ccd0ae..53f487340 100644 --- a/lib/rs/test/src/lib.rs +++ b/lib/rs/test/src/lib.rs @@ -48,6 +48,9 @@ mod tests { #[test] fn must_be_able_to_use_defaults() { - let _ = midlayer::Meal { noodle: Some(base_one::Noodle::default()), ..Default::default() }; + let _ = midlayer::Meal { + noodle: Some(base_one::Noodle::default()), + ..Default::default() + }; } } diff --git a/test/rs/src/bin/test_client.rs b/test/rs/src/bin/test_client.rs index a2ea83204..aad78a058 100644 --- a/test/rs/src/bin/test_client.rs +++ b/test/rs/src/bin/test_client.rs @@ -22,14 +22,14 @@ extern crate thrift; extern crate thrift_test; // huh. I have to do this to use my lib use ordered_float::OrderedFloat; -use std::cell::RefCell; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::Debug; -use std::rc::Rc; use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol, TCompactOutputProtocol, TInputProtocol, TOutputProtocol}; -use thrift::transport::{TBufferedTransport, TFramedTransport, TTcpTransport, TTransport}; +use thrift::transport::{ReadHalf, TBufferedReadTransport, TBufferedWriteTransport, + TFramedReadTransport, TFramedWriteTransport, TIoChannel, TReadTransport, + TTcpChannel, TWriteTransport, WriteHalf}; use thrift_test::*; fn main() { @@ -58,7 +58,8 @@ fn run() -> thrift::Result<()> { (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")") (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")") (@arg testloops: -n --testloops +takes_value "Number of times to run tests") - ).get_matches(); + ) + .get_matches(); let host = matches.value_of("host").unwrap_or("127.0.0.1"); let port = value_t!(matches, "port", u16).unwrap_or(9090); @@ -66,32 +67,39 @@ fn run() -> thrift::Result<()> { let transport = matches.value_of("transport").unwrap_or("buffered"); let protocol = matches.value_of("protocol").unwrap_or("binary"); - let t = open_tcp_transport(host, port)?; + let (i_chan, o_chan) = tcp_channel(host, port)?; - let t: Box<TTransport> = match transport { - "buffered" => Box::new(TBufferedTransport::new(t)), - "framed" => Box::new(TFramedTransport::new(t)), + let (i_tran, o_tran) = match transport { + "buffered" => { + (Box::new(TBufferedReadTransport::new(i_chan)) as Box<TReadTransport>, + Box::new(TBufferedWriteTransport::new(o_chan)) as Box<TWriteTransport>) + } + "framed" => { + (Box::new(TFramedReadTransport::new(i_chan)) as Box<TReadTransport>, + Box::new(TFramedWriteTransport::new(o_chan)) as Box<TWriteTransport>) + } unmatched => return Err(format!("unsupported transport {}", unmatched).into()), }; - let t = Rc::new(RefCell::new(t)); let (i_prot, o_prot): (Box<TInputProtocol>, Box<TOutputProtocol>) = match protocol { "binary" => { - (Box::new(TBinaryInputProtocol::new(t.clone(), true)), - Box::new(TBinaryOutputProtocol::new(t.clone(), true))) + (Box::new(TBinaryInputProtocol::new(i_tran, true)), + Box::new(TBinaryOutputProtocol::new(o_tran, true))) } "compact" => { - (Box::new(TCompactInputProtocol::new(t.clone())), - Box::new(TCompactOutputProtocol::new(t.clone()))) + (Box::new(TCompactInputProtocol::new(i_tran)), + Box::new(TCompactOutputProtocol::new(o_tran))) } unmatched => return Err(format!("unsupported protocol {}", unmatched).into()), }; - println!("connecting to {}:{} with {}+{} stack", - host, - port, - protocol, - transport); + println!( + "connecting to {}:{} with {}+{} stack", + host, + port, + protocol, + transport + ); let mut client = ThriftTestSyncClient::new(i_prot, o_prot); @@ -102,16 +110,19 @@ fn run() -> thrift::Result<()> { Ok(()) } -// FIXME: expose "open" through the client interface so I don't have to early open the transport -fn open_tcp_transport(host: &str, port: u16) -> thrift::Result<Rc<RefCell<Box<TTransport>>>> { - let mut t = TTcpTransport::new(); - match t.open(&format!("{}:{}", host, port)) { - Ok(()) => Ok(Rc::new(RefCell::new(Box::new(t) as Box<TTransport>))), - Err(e) => Err(e), - } +// FIXME: expose "open" through the client interface so I don't have to early +// open +fn tcp_channel( + host: &str, + port: u16, +) -> thrift::Result<(ReadHalf<TTcpChannel>, WriteHalf<TTcpChannel>)> { + let mut c = TTcpChannel::new(); + c.open(&format!("{}:{}", host, port))?; + c.split() } -fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Error> { +fn make_thrift_calls(client: &mut ThriftTestSyncClient<Box<TInputProtocol>, Box<TOutputProtocol>>,) + -> Result<(), thrift::Error> { println!("testVoid"); client.test_void()?; @@ -131,12 +142,15 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er verify_expected_result(client.test_i32(1159348374), 1159348374)?; println!("testi64"); - // try!(verify_expected_result(client.test_i64(-8651829879438294565), -8651829879438294565)); + // try!(verify_expected_result(client.test_i64(-8651829879438294565), + // -8651829879438294565)); verify_expected_result(client.test_i64(i64::min_value()), i64::min_value())?; println!("testDouble"); - verify_expected_result(client.test_double(OrderedFloat::from(42.42)), - OrderedFloat::from(42.42))?; + verify_expected_result( + client.test_double(OrderedFloat::from(42.42)), + OrderedFloat::from(42.42), + )?; println!("testTypedef"); { @@ -175,10 +189,14 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er } // Xtruct again, with optional values - // FIXME: apparently the erlang thrift server does not like opt-in-req-out parameters that are undefined. Joy. + // FIXME: apparently the erlang thrift server does not like opt-in-req-out + // parameters that are undefined. Joy. // { - // let x_snd = Xtruct { string_thing: Some("foo".to_owned()), byte_thing: None, i32_thing: None, i64_thing: Some(12938492818) }; - // let x_cmp = Xtruct { string_thing: Some("foo".to_owned()), byte_thing: Some(0), i32_thing: Some(0), i64_thing: Some(12938492818) }; // the C++ server is responding correctly + // let x_snd = Xtruct { string_thing: Some("foo".to_owned()), byte_thing: None, + // i32_thing: None, i64_thing: Some(12938492818) }; + // let x_cmp = Xtruct { string_thing: Some("foo".to_owned()), byte_thing: + // Some(0), i32_thing: Some(0), i64_thing: Some(12938492818) }; // the C++ + // server is responding correctly // try!(verify_expected_result(client.test_struct(x_snd), x_cmp)); // } // @@ -188,22 +206,26 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er { let x_snd = Xtruct2 { byte_thing: Some(32), - struct_thing: Some(Xtruct { - string_thing: Some("foo".to_owned()), - byte_thing: Some(1), - i32_thing: Some(324382098), - i64_thing: Some(12938492818), - }), + struct_thing: Some( + Xtruct { + string_thing: Some("foo".to_owned()), + byte_thing: Some(1), + i32_thing: Some(324382098), + i64_thing: Some(12938492818), + }, + ), i32_thing: Some(293481098), }; let x_cmp = Xtruct2 { byte_thing: Some(32), - struct_thing: Some(Xtruct { - string_thing: Some("foo".to_owned()), - byte_thing: Some(1), - i32_thing: Some(324382098), - i64_thing: Some(12938492818), - }), + struct_thing: Some( + Xtruct { + string_thing: Some("foo".to_owned()), + byte_thing: Some(1), + i32_thing: Some(324382098), + i64_thing: Some(12938492818), + }, + ), i32_thing: Some(293481098), }; verify_expected_result(client.test_nest(x_snd), x_cmp)?; @@ -270,7 +292,8 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er } // nested map - // expect : {-4 => {-4 => -4, -3 => -3, -2 => -2, -1 => -1, }, 4 => {1 => 1, 2 => 2, 3 => 3, 4 => 4, }, } + // expect : {-4 => {-4 => -4, -3 => -3, -2 => -2, -1 => -1, }, 4 => {1 => 1, 2 + // => 2, 3 => 3, 4 => 4, }, } println!("testMapMap"); { let mut m_cmp_nested_0: BTreeMap<i32, i32> = BTreeMap::new(); @@ -302,13 +325,10 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er i64_thing: Some(-19234123981), }; - verify_expected_result(client.test_multi(1, - -123948, - -19234123981, - m_snd, - Numberz::EIGHT, - 81), - s_cmp)?; + verify_expected_result( + client.test_multi(1, -123948, -19234123981, m_snd, Numberz::EIGHT, 81), + s_cmp, + )?; } // Insanity @@ -324,24 +344,30 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er arg_map_usermap.insert(Numberz::EIGHT, 19); let mut arg_vec_xtructs: Vec<Xtruct> = Vec::new(); - arg_vec_xtructs.push(Xtruct { - string_thing: Some("foo".to_owned()), - byte_thing: Some(8), - i32_thing: Some(29), - i64_thing: Some(92384), - }); - arg_vec_xtructs.push(Xtruct { - string_thing: Some("bar".to_owned()), - byte_thing: Some(28), - i32_thing: Some(2), - i64_thing: Some(-1281), - }); - arg_vec_xtructs.push(Xtruct { - string_thing: Some("baz".to_owned()), - byte_thing: Some(0), - i32_thing: Some(3948539), - i64_thing: Some(-12938492), - }); + arg_vec_xtructs.push( + Xtruct { + string_thing: Some("foo".to_owned()), + byte_thing: Some(8), + i32_thing: Some(29), + i64_thing: Some(92384), + }, + ); + arg_vec_xtructs.push( + Xtruct { + string_thing: Some("bar".to_owned()), + byte_thing: Some(28), + i32_thing: Some(2), + i64_thing: Some(-1281), + }, + ); + arg_vec_xtructs.push( + Xtruct { + string_thing: Some("baz".to_owned()), + byte_thing: Some(0), + i32_thing: Some(3948539), + i64_thing: Some(-12938492), + }, + ); let mut s_cmp_nested_1: BTreeMap<Numberz, Insanity> = BTreeMap::new(); let insanity = Insanity { @@ -372,7 +398,7 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er Err(thrift::Error::User(ref e)) => { match e.downcast_ref::<Xception>() { Some(x) => Ok(x), - None => Err(thrift::Error::User("did not get expected Xception struct".into())), + None => Err(thrift::Error::User("did not get expected Xception struct".into()),), } } _ => Err(thrift::Error::User("did not get exception".into())), @@ -414,7 +440,7 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er Err(thrift::Error::User(ref e)) => { match e.downcast_ref::<Xception>() { Some(x) => Ok(x), - None => Err(thrift::Error::User("did not get expected Xception struct".into())), + None => Err(thrift::Error::User("did not get expected Xception struct".into()),), } } _ => Err(thrift::Error::User("did not get exception".into())), @@ -435,7 +461,7 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er Err(thrift::Error::User(ref e)) => { match e.downcast_ref::<Xception2>() { Some(x) => Ok(x), - None => Err(thrift::Error::User("did not get expected Xception struct".into())), + None => Err(thrift::Error::User("did not get expected Xception struct".into()),), } } _ => Err(thrift::Error::User("did not get exception".into())), @@ -443,12 +469,17 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er let x_cmp = Xception2 { error_code: Some(2002), - struct_thing: Some(Xtruct { - string_thing: Some("This is an Xception2".to_owned()), - byte_thing: Some(0), /* since this is an OPT_IN_REQ_OUT field the sender sets a default */ - i32_thing: Some(0), /* since this is an OPT_IN_REQ_OUT field the sender sets a default */ - i64_thing: Some(0), /* since this is an OPT_IN_REQ_OUT field the sender sets a default */ - }), + struct_thing: Some( + Xtruct { + string_thing: Some("This is an Xception2".to_owned()), + // since this is an OPT_IN_REQ_OUT field the sender sets a default + byte_thing: Some(0), + // since this is an OPT_IN_REQ_OUT field the sender sets a default + i32_thing: Some(0), + // since this is an OPT_IN_REQ_OUT field the sender sets a default + i64_thing: Some(0), + }, + ), }; verify_expected_result(Ok(x), &x_cmp)?; @@ -458,17 +489,18 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er { let r = client.test_multi_exception("haha".to_owned(), "RETURNED".to_owned()); let x = match r { - Err(e) => { - Err(thrift::Error::User(format!("received an unexpected exception {:?}", e).into())) - } + Err(e) => Err(thrift::Error::User(format!("received an unexpected exception {:?}", e).into(),),), _ => r, }?; let x_cmp = Xtruct { string_thing: Some("RETURNED".to_owned()), - byte_thing: Some(0), // since this is an OPT_IN_REQ_OUT field the sender sets a default - i32_thing: Some(0), // since this is an OPT_IN_REQ_OUT field the sender sets a default - i64_thing: Some(0), // since this is an OPT_IN_REQ_OUT field the sender sets a default + // since this is an OPT_IN_REQ_OUT field the sender sets a default + byte_thing: Some(0), + // since this is an OPT_IN_REQ_OUT field the sender sets a default + i32_thing: Some(0), + // since this is an OPT_IN_REQ_OUT field the sender sets a default + i64_thing: Some(0), }; verify_expected_result(Ok(x), x_cmp)?; @@ -479,20 +511,22 @@ fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Er client.test_oneway(1)?; } - // final test to verify that the connection is still writable after the one-way call + // final test to verify that the connection is still writable after the one-way + // call client.test_void() } -fn verify_expected_result<T: Debug + PartialEq + Sized>(actual: Result<T, thrift::Error>, - expected: T) - -> Result<(), thrift::Error> { +#[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))] +fn verify_expected_result<T: Debug + PartialEq + Sized>( + actual: Result<T, thrift::Error>, + expected: T, +) -> Result<(), thrift::Error> { match actual { Ok(v) => { if v == expected { Ok(()) } else { - Err(thrift::Error::User(format!("expected {:?} but got {:?}", &expected, &v) - .into())) + Err(thrift::Error::User(format!("expected {:?} but got {:?}", &expected, &v).into()),) } } Err(e) => Err(e), diff --git a/test/rs/src/bin/test_server.rs b/test/rs/src/bin/test_server.rs index 613cd5559..9c738ab01 100644 --- a/test/rs/src/bin/test_server.rs +++ b/test/rs/src/bin/test_server.rs @@ -29,8 +29,10 @@ use std::time::Duration; use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory, TCompactInputProtocolFactory, TCompactOutputProtocolFactory, TInputProtocolFactory, TOutputProtocolFactory}; -use thrift::server::TSimpleServer; -use thrift::transport::{TBufferedTransportFactory, TFramedTransportFactory, TTransportFactory}; +use thrift::server::TServer; +use thrift::transport::{TBufferedReadTransportFactory, TBufferedWriteTransportFactory, + TFramedReadTransportFactory, TFramedWriteTransportFactory, + TReadTransportFactory, TWriteTransportFactory}; use thrift_test::*; fn main() { @@ -49,7 +51,6 @@ fn run() -> thrift::Result<()> { // --domain-socket // --named-pipe // --ssl - // --workers let matches = clap_app!(rust_test_client => (version: "1.0") (author: "Apache Thrift Developers <dev@thrift.apache.org>") @@ -57,29 +58,35 @@ fn run() -> thrift::Result<()> { (@arg port: --port +takes_value "port on which the test server listens") (@arg transport: --transport +takes_value "transport implementation to use (\"buffered\", \"framed\")") (@arg protocol: --protocol +takes_value "protocol implementation to use (\"binary\", \"compact\")") - (@arg server_type: --server_type +takes_value "type of server instantiated (\"simple\", \"thread-pool\", \"threaded\", \"non-blocking\")") - ).get_matches(); + (@arg server_type: --server_type +takes_value "type of server instantiated (\"simple\", \"thread-pool\")") + (@arg workers: -n --workers +takes_value "number of thread-pool workers (\"4\")") + ) + .get_matches(); let port = value_t!(matches, "port", u16).unwrap_or(9090); let transport = matches.value_of("transport").unwrap_or("buffered"); let protocol = matches.value_of("protocol").unwrap_or("binary"); - let server_type = matches.value_of("server_type").unwrap_or("simple"); + let server_type = matches.value_of("server_type").unwrap_or("thread-pool"); + let workers = value_t!(matches, "workers", usize).unwrap_or(4); let listen_address = format!("127.0.0.1:{}", port); println!("binding to {}", listen_address); - let (i_transport_factory, o_transport_factory): (Box<TTransportFactory>, - Box<TTransportFactory>) = match &*transport { - "buffered" => { - (Box::new(TBufferedTransportFactory::new()), Box::new(TBufferedTransportFactory::new())) - } - "framed" => { - (Box::new(TFramedTransportFactory::new()), Box::new(TFramedTransportFactory::new())) - } - unknown => { - return Err(format!("unsupported transport type {}", unknown).into()); - } - }; + let (i_transport_factory, o_transport_factory): (Box<TReadTransportFactory>, + Box<TWriteTransportFactory>) = + match &*transport { + "buffered" => { + (Box::new(TBufferedReadTransportFactory::new()), + Box::new(TBufferedWriteTransportFactory::new())) + } + "framed" => { + (Box::new(TFramedReadTransportFactory::new()), + Box::new(TFramedWriteTransportFactory::new())) + } + unknown => { + return Err(format!("unsupported transport type {}", unknown).into()); + } + }; let (i_protocol_factory, o_protocol_factory): (Box<TInputProtocolFactory>, Box<TOutputProtocolFactory>) = @@ -101,11 +108,24 @@ fn run() -> thrift::Result<()> { let mut server = match &*server_type { "simple" => { - TSimpleServer::new(i_transport_factory, - i_protocol_factory, - o_transport_factory, - o_protocol_factory, - processor) + TServer::new( + i_transport_factory, + i_protocol_factory, + o_transport_factory, + o_protocol_factory, + processor, + 1, + ) + } + "thread-pool" => { + TServer::new( + i_transport_factory, + i_protocol_factory, + o_transport_factory, + o_protocol_factory, + processor, + workers, + ) } unknown => { return Err(format!("unsupported server type {}", unknown).into()); @@ -117,95 +137,93 @@ fn run() -> thrift::Result<()> { struct ThriftTestSyncHandlerImpl; impl ThriftTestSyncHandler for ThriftTestSyncHandlerImpl { - fn handle_test_void(&mut self) -> thrift::Result<()> { + fn handle_test_void(&self) -> thrift::Result<()> { println!("testVoid()"); Ok(()) } - fn handle_test_string(&mut self, thing: String) -> thrift::Result<String> { + fn handle_test_string(&self, thing: String) -> thrift::Result<String> { println!("testString({})", &thing); Ok(thing) } - fn handle_test_bool(&mut self, thing: bool) -> thrift::Result<bool> { + fn handle_test_bool(&self, thing: bool) -> thrift::Result<bool> { println!("testBool({})", thing); Ok(thing) } - fn handle_test_byte(&mut self, thing: i8) -> thrift::Result<i8> { + fn handle_test_byte(&self, thing: i8) -> thrift::Result<i8> { println!("testByte({})", thing); Ok(thing) } - fn handle_test_i32(&mut self, thing: i32) -> thrift::Result<i32> { + fn handle_test_i32(&self, thing: i32) -> thrift::Result<i32> { println!("testi32({})", thing); Ok(thing) } - fn handle_test_i64(&mut self, thing: i64) -> thrift::Result<i64> { + fn handle_test_i64(&self, thing: i64) -> thrift::Result<i64> { println!("testi64({})", thing); Ok(thing) } - fn handle_test_double(&mut self, - thing: OrderedFloat<f64>) - -> thrift::Result<OrderedFloat<f64>> { + fn handle_test_double(&self, thing: OrderedFloat<f64>) -> thrift::Result<OrderedFloat<f64>> { println!("testDouble({})", thing); Ok(thing) } - fn handle_test_binary(&mut self, thing: Vec<u8>) -> thrift::Result<Vec<u8>> { + fn handle_test_binary(&self, thing: Vec<u8>) -> thrift::Result<Vec<u8>> { println!("testBinary({:?})", thing); Ok(thing) } - fn handle_test_struct(&mut self, thing: Xtruct) -> thrift::Result<Xtruct> { + fn handle_test_struct(&self, thing: Xtruct) -> thrift::Result<Xtruct> { println!("testStruct({:?})", thing); Ok(thing) } - fn handle_test_nest(&mut self, thing: Xtruct2) -> thrift::Result<Xtruct2> { + fn handle_test_nest(&self, thing: Xtruct2) -> thrift::Result<Xtruct2> { println!("testNest({:?})", thing); Ok(thing) } - fn handle_test_map(&mut self, thing: BTreeMap<i32, i32>) -> thrift::Result<BTreeMap<i32, i32>> { + fn handle_test_map(&self, thing: BTreeMap<i32, i32>) -> thrift::Result<BTreeMap<i32, i32>> { println!("testMap({:?})", thing); Ok(thing) } - fn handle_test_string_map(&mut self, - thing: BTreeMap<String, String>) - -> thrift::Result<BTreeMap<String, String>> { + fn handle_test_string_map( + &self, + thing: BTreeMap<String, String>, + ) -> thrift::Result<BTreeMap<String, String>> { println!("testStringMap({:?})", thing); Ok(thing) } - fn handle_test_set(&mut self, thing: BTreeSet<i32>) -> thrift::Result<BTreeSet<i32>> { + fn handle_test_set(&self, thing: BTreeSet<i32>) -> thrift::Result<BTreeSet<i32>> { println!("testSet({:?})", thing); Ok(thing) } - fn handle_test_list(&mut self, thing: Vec<i32>) -> thrift::Result<Vec<i32>> { + fn handle_test_list(&self, thing: Vec<i32>) -> thrift::Result<Vec<i32>> { println!("testList({:?})", thing); Ok(thing) } - fn handle_test_enum(&mut self, thing: Numberz) -> thrift::Result<Numberz> { + fn handle_test_enum(&self, thing: Numberz) -> thrift::Result<Numberz> { println!("testEnum({:?})", thing); Ok(thing) } - fn handle_test_typedef(&mut self, thing: UserId) -> thrift::Result<UserId> { + fn handle_test_typedef(&self, thing: UserId) -> thrift::Result<UserId> { println!("testTypedef({})", thing); Ok(thing) } /// @return map<i32,map<i32,i32>> - returns a dictionary with these values: - /// {-4 => {-4 => -4, -3 => -3, -2 => -2, -1 => -1, }, 4 => {1 => 1, 2 => 2, 3 => 3, 4 => 4, }, } - fn handle_test_map_map(&mut self, - hello: i32) - -> thrift::Result<BTreeMap<i32, BTreeMap<i32, i32>>> { + /// {-4 => {-4 => -4, -3 => -3, -2 => -2, -1 => -1, }, 4 => {1 => 1, 2 => + /// 2, 3 => 3, 4 => 4, }, } + fn handle_test_map_map(&self, hello: i32) -> thrift::Result<BTreeMap<i32, BTreeMap<i32, i32>>> { println!("testMapMap({})", hello); let mut inner_map_0: BTreeMap<i32, i32> = BTreeMap::new(); @@ -232,9 +250,10 @@ impl ThriftTestSyncHandler for ThriftTestSyncHandlerImpl { /// 2 => { 6 => <empty Insanity struct>, }, /// } /// return map<UserId, map<Numberz,Insanity>> - a map with the above values - fn handle_test_insanity(&mut self, - argument: Insanity) - -> thrift::Result<BTreeMap<UserId, BTreeMap<Numberz, Insanity>>> { + fn handle_test_insanity( + &self, + argument: Insanity, + ) -> thrift::Result<BTreeMap<UserId, BTreeMap<Numberz, Insanity>>> { println!("testInsanity({:?})", argument); let mut map_0: BTreeMap<Numberz, Insanity> = BTreeMap::new(); map_0.insert(Numberz::TWO, argument.clone()); @@ -254,15 +273,18 @@ impl ThriftTestSyncHandler for ThriftTestSyncHandlerImpl { Ok(ret) } - /// returns an Xtruct with string_thing = "Hello2", byte_thing = arg0, i32_thing = arg1 and i64_thing = arg2 - fn handle_test_multi(&mut self, - arg0: i8, - arg1: i32, - arg2: i64, - _: BTreeMap<i16, String>, - _: Numberz, - _: UserId) - -> thrift::Result<Xtruct> { + /// returns an Xtruct with: + /// string_thing = "Hello2", byte_thing = arg0, i32_thing = arg1 and + /// i64_thing = arg2 + fn handle_test_multi( + &self, + arg0: i8, + arg1: i32, + arg2: i64, + _: BTreeMap<i16, String>, + _: Numberz, + _: UserId, + ) -> thrift::Result<Xtruct> { let x_ret = Xtruct { string_thing: Some("Hello2".to_owned()), byte_thing: Some(arg0), @@ -273,64 +295,77 @@ impl ThriftTestSyncHandler for ThriftTestSyncHandlerImpl { Ok(x_ret) } - /// if arg == "Xception" throw Xception with errorCode = 1001 and message = arg + /// if arg == "Xception" throw Xception with errorCode = 1001 and message = + /// arg /// else if arg == "TException" throw TException /// else do not throw anything - fn handle_test_exception(&mut self, arg: String) -> thrift::Result<()> { + fn handle_test_exception(&self, arg: String) -> thrift::Result<()> { println!("testException({})", arg); match &*arg { "Xception" => { - Err((Xception { - error_code: Some(1001), - message: Some(arg), - }) - .into()) + Err( + (Xception { + error_code: Some(1001), + message: Some(arg), + }) + .into(), + ) } "TException" => Err("this is a random error".into()), _ => Ok(()), } } - /// if arg0 == "Xception" throw Xception with errorCode = 1001 and message = "This is an Xception" - /// else if arg0 == "Xception2" throw Xception2 with errorCode = 2002 and struct_thing.string_thing = "This is an Xception2" - // else do not throw anything and return Xtruct with string_thing = arg1 - fn handle_test_multi_exception(&mut self, - arg0: String, - arg1: String) - -> thrift::Result<Xtruct> { + /// if arg0 == "Xception": + /// throw Xception with errorCode = 1001 and message = "This is an + /// Xception" + /// else if arg0 == "Xception2": + /// throw Xception2 with errorCode = 2002 and struct_thing.string_thing = + /// "This is an Xception2" + // else: + // do not throw anything and return Xtruct with string_thing = arg1 + fn handle_test_multi_exception(&self, arg0: String, arg1: String) -> thrift::Result<Xtruct> { match &*arg0 { "Xception" => { - Err((Xception { - error_code: Some(1001), - message: Some("This is an Xception".to_owned()), - }) - .into()) + Err( + (Xception { + error_code: Some(1001), + message: Some("This is an Xception".to_owned()), + }) + .into(), + ) } "Xception2" => { - Err((Xception2 { - error_code: Some(2002), - struct_thing: Some(Xtruct { - string_thing: Some("This is an Xception2".to_owned()), - byte_thing: None, - i32_thing: None, - i64_thing: None, - }), - }) - .into()) + Err( + (Xception2 { + error_code: Some(2002), + struct_thing: Some( + Xtruct { + string_thing: Some("This is an Xception2".to_owned()), + byte_thing: None, + i32_thing: None, + i64_thing: None, + }, + ), + }) + .into(), + ) } _ => { - Ok(Xtruct { - string_thing: Some(arg1), - byte_thing: None, - i32_thing: None, - i64_thing: None, - }) + Ok( + Xtruct { + string_thing: Some(arg1), + byte_thing: None, + i32_thing: None, + i64_thing: None, + }, + ) } } } - fn handle_test_oneway(&mut self, seconds_to_sleep: i32) -> thrift::Result<()> { + fn handle_test_oneway(&self, seconds_to_sleep: i32) -> thrift::Result<()> { thread::sleep(Duration::from_secs(seconds_to_sleep as u64)); Ok(()) } diff --git a/tutorial/rs/README.md b/tutorial/rs/README.md index 4d0d7c8af..384e9f8bb 100644 --- a/tutorial/rs/README.md +++ b/tutorial/rs/README.md @@ -35,13 +35,12 @@ extern crate thrift; extern crate try_from; // generated Rust module -mod tutorial; +use tutorial; -use std::cell::RefCell; -use std::rc::Rc; -use thrift::protocol::{TInputProtocol, TOutputProtocol}; use thrift::protocol::{TCompactInputProtocol, TCompactOutputProtocol}; -use thrift::transport::{TFramedTransport, TTcpTransport, TTransport}; +use thrift::protocol::{TInputProtocol, TOutputProtocol}; +use thrift::transport::{TFramedReadTransport, TFramedWriteTransport}; +use thrift::transport::{TIoChannel, TTcpChannel}; use tutorial::{CalculatorSyncClient, TCalculatorSyncClient}; use tutorial::{Operation, Work}; @@ -61,28 +60,16 @@ fn run() -> thrift::Result<()> { // println!("connect to server on 127.0.0.1:9090"); - let mut t = TTcpTransport::new(); - let t = match t.open("127.0.0.1:9090") { - Ok(()) => t, - Err(e) => { - return Err( - format!("failed to connect with {:?}", e).into() - ); - } - }; - - let t = Rc::new(RefCell::new( - Box::new(t) as Box<TTransport> - )); - let t = Rc::new(RefCell::new( - Box::new(TFramedTransport::new(t)) as Box<TTransport> - )); + let mut c = TTcpTransport::new(); + c.open("127.0.0.1:9090")?; - let i_prot: Box<TInputProtocol> = Box::new( - TCompactInputProtocol::new(t.clone()) + let (i_chan, o_chan) = c.split()?; + + let i_prot = TCompactInputProtocol::new( + TFramedReadTransport::new(i_chan) ); - let o_prot: Box<TOutputProtocol> = Box::new( - TCompactOutputProtocol::new(t.clone()) + let o_prot = TCompactOutputProtocol::new( + TFramedWriteTransport::new(o_chan) ); let client = CalculatorSyncClient::new(i_prot, o_prot); @@ -177,10 +164,10 @@ A typedef is translated to a `pub type` declaration. ```thrift typedef i64 UserId -typedef map<string, Bonk> MapType +typedef map<string, UserId> MapType ``` ```rust -pub type UserId = 164; +pub type UserId = i64; pub type MapType = BTreeMap<String, Bonk>; ``` @@ -327,4 +314,4 @@ pub struct Foo { ## Known Issues * Struct constants are not supported -* Map, list and set constants require a const holder struct
\ No newline at end of file +* Map, list and set constants require a const holder struct diff --git a/tutorial/rs/src/bin/tutorial_client.rs b/tutorial/rs/src/bin/tutorial_client.rs index 2b0d4f908..24ab4be06 100644 --- a/tutorial/rs/src/bin/tutorial_client.rs +++ b/tutorial/rs/src/bin/tutorial_client.rs @@ -21,15 +21,12 @@ extern crate clap; extern crate thrift; extern crate thrift_tutorial; -use std::cell::RefCell; -use std::rc::Rc; - -use thrift::protocol::{TInputProtocol, TOutputProtocol}; use thrift::protocol::{TCompactInputProtocol, TCompactOutputProtocol}; -use thrift::transport::{TFramedTransport, TTcpTransport, TTransport}; +use thrift::transport::{ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel, + TTcpChannel, WriteHalf}; use thrift_tutorial::shared::TSharedServiceSyncClient; -use thrift_tutorial::tutorial::{CalculatorSyncClient, TCalculatorSyncClient, Operation, Work}; +use thrift_tutorial::tutorial::{CalculatorSyncClient, Operation, TCalculatorSyncClient, Work}; fn main() { match run() { @@ -73,7 +70,8 @@ fn run() -> thrift::Result<()> { let logid = 32; // let's do...a multiply! - let res = client.calculate(logid, Work::new(7, 8, Operation::MULTIPLY, None))?; + let res = client + .calculate(logid, Work::new(7, 8, Operation::MULTIPLY, None))?; println!("multiplied 7 and 8 and got {}", res); // let's get the log for it @@ -102,34 +100,31 @@ fn run() -> thrift::Result<()> { Ok(()) } -fn new_client(host: &str, port: u16) -> thrift::Result<CalculatorSyncClient> { - let mut t = TTcpTransport::new(); +type ClientInputProtocol = TCompactInputProtocol<TFramedReadTransport<ReadHalf<TTcpChannel>>>; +type ClientOutputProtocol = TCompactOutputProtocol<TFramedWriteTransport<WriteHalf<TTcpChannel>>>; + +fn new_client + ( + host: &str, + port: u16, +) -> thrift::Result<CalculatorSyncClient<ClientInputProtocol, ClientOutputProtocol>> { + let mut c = TTcpChannel::new(); // open the underlying TCP stream println!("connecting to tutorial server on {}:{}", host, port); - let t = match t.open(&format!("{}:{}", host, port)) { - Ok(()) => t, - Err(e) => { - return Err(format!("failed to open tcp stream to {}:{} error:{:?}", - host, - port, - e) - .into()); - } - }; - - // refcounted because it's shared by both input and output transports - let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>)); + c.open(&format!("{}:{}", host, port))?; - // wrap a raw socket (slow) with a buffered transport of some kind - let t = Box::new(TFramedTransport::new(t)) as Box<TTransport>; + // clone the TCP channel into two halves, one which + // we'll use for reading, the other for writing + let (i_chan, o_chan) = c.split()?; - // refcounted again because it's shared by both input and output protocols - let t = Rc::new(RefCell::new(t)); + // wrap the raw sockets (slow) with a buffered transport of some kind + let i_tran = TFramedReadTransport::new(i_chan); + let o_tran = TFramedWriteTransport::new(o_chan); // now create the protocol implementations - let i_prot = Box::new(TCompactInputProtocol::new(t.clone())) as Box<TInputProtocol>; - let o_prot = Box::new(TCompactOutputProtocol::new(t.clone())) as Box<TOutputProtocol>; + let i_prot = TCompactInputProtocol::new(i_tran); + let o_prot = TCompactOutputProtocol::new(o_tran); // we're done! Ok(CalculatorSyncClient::new(i_prot, o_prot)) diff --git a/tutorial/rs/src/bin/tutorial_server.rs b/tutorial/rs/src/bin/tutorial_server.rs index 9cc186649..8db8eed26 100644 --- a/tutorial/rs/src/bin/tutorial_server.rs +++ b/tutorial/rs/src/bin/tutorial_server.rs @@ -24,12 +24,12 @@ extern crate thrift_tutorial; use std::collections::HashMap; use std::convert::{From, Into}; use std::default::Default; +use std::sync::Mutex; -use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; use thrift::protocol::{TCompactInputProtocolFactory, TCompactOutputProtocolFactory}; -use thrift::server::TSimpleServer; +use thrift::server::TServer; -use thrift::transport::{TFramedTransportFactory, TTransportFactory}; +use thrift::transport::{TFramedReadTransportFactory, TFramedWriteTransportFactory}; use thrift_tutorial::shared::{SharedServiceSyncHandler, SharedStruct}; use thrift_tutorial::tutorial::{CalculatorSyncHandler, CalculatorSyncProcessor}; use thrift_tutorial::tutorial::{InvalidOperation, Operation, Work}; @@ -58,33 +58,36 @@ fn run() -> thrift::Result<()> { println!("binding to {}", listen_address); - let i_tran_fact: Box<TTransportFactory> = Box::new(TFramedTransportFactory::new()); - let i_prot_fact: Box<TInputProtocolFactory> = Box::new(TCompactInputProtocolFactory::new()); + let i_tran_fact = TFramedReadTransportFactory::new(); + let i_prot_fact = TCompactInputProtocolFactory::new(); - let o_tran_fact: Box<TTransportFactory> = Box::new(TFramedTransportFactory::new()); - let o_prot_fact: Box<TOutputProtocolFactory> = Box::new(TCompactOutputProtocolFactory::new()); + let o_tran_fact = TFramedWriteTransportFactory::new(); + let o_prot_fact = TCompactOutputProtocolFactory::new(); // demux incoming messages let processor = CalculatorSyncProcessor::new(CalculatorServer { ..Default::default() }); // create the server and start listening - let mut server = TSimpleServer::new(i_tran_fact, - i_prot_fact, - o_tran_fact, - o_prot_fact, - processor); + let mut server = TServer::new( + i_tran_fact, + i_prot_fact, + o_tran_fact, + o_prot_fact, + processor, + 10, + ); server.listen(&listen_address) } /// Handles incoming Calculator service calls. struct CalculatorServer { - log: HashMap<i32, SharedStruct>, + log: Mutex<HashMap<i32, SharedStruct>>, } impl Default for CalculatorServer { fn default() -> CalculatorServer { - CalculatorServer { log: HashMap::new() } + CalculatorServer { log: Mutex::new(HashMap::new()) } } } @@ -94,9 +97,9 @@ impl Default for CalculatorServer { // SharedService handler impl SharedServiceSyncHandler for CalculatorServer { - fn handle_get_struct(&mut self, key: i32) -> thrift::Result<SharedStruct> { - self.log - .get(&key) + fn handle_get_struct(&self, key: i32) -> thrift::Result<SharedStruct> { + let log = self.log.lock().unwrap(); + log.get(&key) .cloned() .ok_or_else(|| format!("could not find log for key {}", key).into()) } @@ -104,25 +107,27 @@ impl SharedServiceSyncHandler for CalculatorServer { // Calculator handler impl CalculatorSyncHandler for CalculatorServer { - fn handle_ping(&mut self) -> thrift::Result<()> { + fn handle_ping(&self) -> thrift::Result<()> { println!("pong!"); Ok(()) } - fn handle_add(&mut self, num1: i32, num2: i32) -> thrift::Result<i32> { + fn handle_add(&self, num1: i32, num2: i32) -> thrift::Result<i32> { println!("handling add: n1:{} n2:{}", num1, num2); Ok(num1 + num2) } - fn handle_calculate(&mut self, logid: i32, w: Work) -> thrift::Result<i32> { + fn handle_calculate(&self, logid: i32, w: Work) -> thrift::Result<i32> { println!("handling calculate: l:{}, w:{:?}", logid, w); let res = if let Some(ref op) = w.op { if w.num1.is_none() || w.num2.is_none() { - Err(InvalidOperation { - what_op: Some(*op as i32), - why: Some("no operands specified".to_owned()), - }) + Err( + InvalidOperation { + what_op: Some(*op as i32), + why: Some("no operands specified".to_owned()), + }, + ) } else { // so that I don't have to call unwrap() multiple times below let num1 = w.num1.as_ref().expect("operands checked"); @@ -134,10 +139,12 @@ impl CalculatorSyncHandler for CalculatorServer { Operation::MULTIPLY => Ok(num1 * num2), Operation::DIVIDE => { if *num2 == 0 { - Err(InvalidOperation { - what_op: Some(*op as i32), - why: Some("divide by 0".to_owned()), - }) + Err( + InvalidOperation { + what_op: Some(*op as i32), + why: Some("divide by 0".to_owned()), + }, + ) } else { Ok(num1 / num2) } @@ -145,12 +152,13 @@ impl CalculatorSyncHandler for CalculatorServer { } } } else { - Err(InvalidOperation::new(None, "no operation specified".to_owned())) + Err(InvalidOperation::new(None, "no operation specified".to_owned()),) }; // if the operation was successful log it if let Ok(ref v) = res { - self.log.insert(logid, SharedStruct::new(logid, format!("{}", v))); + let mut log = self.log.lock().unwrap(); + log.insert(logid, SharedStruct::new(logid, format!("{}", v))); } // the try! macro automatically maps errors @@ -161,7 +169,7 @@ impl CalculatorSyncHandler for CalculatorServer { res.map_err(From::from) } - fn handle_zip(&mut self) -> thrift::Result<()> { + fn handle_zip(&self) -> thrift::Result<()> { println!("handling zip"); Ok(()) } |