diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2022-02-22 18:48:17 -0800 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2022-02-23 09:17:50 -0800 |
commit | 9bee877e663f11f4cbdd3a4f02938c8ab9fe8976 (patch) | |
tree | 151cdff0268536208f0eb1cebfdf1f2ec968e869 | |
parent | 103a11c9c28ac963a3b2591ecac641db3cbaa113 (diff) | |
download | thrift-9bee877e663f11f4cbdd3a4f02938c8ab9fe8976.tar.gz |
THRIFT-5527: Don't swallow idl exceptions in Process function
Client: go
This allows ProcessorMiddlewares to access such exceptions, unless
there's a network error writing the response (which takes priority).
While I'm here, also make the indentation of Process function more
consistent, and make it consistent on returning false and an error when
the reading/writing fails.
-rw-r--r-- | compiler/cpp/src/thrift/generate/t_go_generator.cc | 144 | ||||
-rw-r--r-- | lib/go/test/Makefile.am | 8 | ||||
-rw-r--r-- | lib/go/test/ProcessorMiddlewareTest.thrift | 32 | ||||
-rw-r--r-- | lib/go/test/tests/processor_middleware_test.go | 108 |
4 files changed, 249 insertions, 43 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 7897b621a..3b885f103 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -959,8 +959,8 @@ string t_go_generator::go_imports_begin(bool consts) { // If not writing constants, and there are enums, need extra imports. if (!consts && get_program()->get_enums().size() > 0) { system_packages.push_back("database/sql/driver"); - system_packages.push_back("errors"); } + system_packages.push_back("errors"); system_packages.push_back("fmt"); system_packages.push_back("time"); // For the thrift import, always do rename import to make sure it's called thrift. @@ -980,6 +980,7 @@ string t_go_generator::go_imports_end() { "// (needed to ensure safety because of naive import list construction.)\n" "var _ = thrift.ZERO\n" "var _ = fmt.Printf\n" + "var _ = errors.New\n" "var _ = context.Background\n" "var _ = time.Now\n" "var _ = bytes.Equal\n\n"); @@ -2964,21 +2965,27 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* << ") Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err " "thrift.TException) {" << endl; indent_up(); + string write_err; + if (!tfunction->is_oneway()) { + write_err = tmp("_write_err"); + f_types_ << indent() << "var " << write_err << " error" << endl; + } f_types_ << indent() << "args := " << argsname << "{}" << endl; - f_types_ << indent() << "var err2 error" << endl; - f_types_ << indent() << "if err2 = args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl; - f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl; + f_types_ << indent() << "if err2 := args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl; if (!tfunction->is_oneway()) { f_types_ << indent() - << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())" + << "x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())" << endl; - f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) + f_types_ << indent() << "oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.EXCEPTION, seqId)" << endl; - f_types_ << indent() << " x.Write(ctx, oprot)" << endl; - f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; - f_types_ << indent() << " oprot.Flush(ctx)" << endl; + f_types_ << indent() << "x.Write(ctx, oprot)" << endl; + f_types_ << indent() << "oprot.WriteMessageEnd(ctx)" << endl; + f_types_ << indent() << "oprot.Flush(ctx)" << endl; } - f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl; + f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl; @@ -3037,9 +3044,6 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "result := " << resultname << "{}" << endl; } bool need_reference = type_need_reference(tfunction->get_returntype()); - if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { - f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl; - } f_types_ << indent() << "if "; @@ -3053,7 +3057,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* t_struct* arg_struct = tfunction->get_arglist(); const std::vector<t_field*>& fields = arg_struct->get_members(); vector<t_field*>::const_iterator f_iter; - f_types_ << "err2 = p.handler." << publicize(tfunction->get_name()) << "("; + f_types_ << "err2 := p.handler." << publicize(tfunction->get_name()) << "("; bool first = true; f_types_ << "ctx"; @@ -3069,7 +3073,9 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* } f_types_ << "); err2 != nil {" << endl; - f_types_ << indent() << " tickerCancel()" << endl; + indent_up(); + f_types_ << indent() << "tickerCancel()" << endl; + f_types_ << indent() << "err = thrift.WrapTException(err2)" << endl; t_struct* exceptions = tfunction->get_xceptions(); const vector<t_field*>& x_fields = exceptions->get_members(); @@ -3079,36 +3085,74 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* vector<t_field*>::const_iterator xf_iter; for (xf_iter = x_fields.begin(); xf_iter != x_fields.end(); ++xf_iter) { - f_types_ << indent() << " case " << type_to_go_type(((*xf_iter)->get_type())) << ":" + f_types_ << indent() << "case " << type_to_go_type(((*xf_iter)->get_type())) << ":" << endl; + indent_up(); f_types_ << indent() << "result." << publicize((*xf_iter)->get_name()) << " = v" << endl; + indent_down(); } - f_types_ << indent() << " default:" << endl; + f_types_ << indent() << "default:" << endl; + indent_up(); } if (!tfunction->is_oneway()) { // Avoid writing the error to the wire if it's ErrAbandonRequest - f_types_ << indent() << " if err2 == thrift.ErrAbandonRequest {" << endl; - f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl; - f_types_ << indent() << " }" << endl; + f_types_ << indent() << "if errors.Is(err2, thrift.ErrAbandonRequest) {" << endl; + indent_up(); + f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; - f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, " + string exc(tmp("_exc")); + f_types_ << indent() << exc << " := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, " "\"Internal error processing " << escape_string(tfunction->get_name()) << ": \" + err2.Error())" << endl; - f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) - << "\", thrift.EXCEPTION, seqId)" << endl; - f_types_ << indent() << " x.Write(ctx, oprot)" << endl; - f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; - f_types_ << indent() << " oprot.Flush(ctx)" << endl; - } - f_types_ << indent() << " return true, thrift.WrapTException(err2)" << endl; + f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) + << "\", thrift.EXCEPTION, seqId); err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; - if (!x_fields.empty()) { + f_types_ << indent() << "if err2 := " << exc << ".Write(ctx, oprot); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; + + f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; + + f_types_ << indent() << "if err2 := oprot.Flush(ctx); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; + + f_types_ << indent() << "if " << write_err << " != nil {" << endl; + indent_up(); + f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl; + indent_down(); f_types_ << indent() << "}" << endl; + + // return success=true as long as writing to the wire was successful. + f_types_ << indent() << "return true, err" << endl; } + if (!x_fields.empty()) { + indent_down(); + f_types_ << indent() << "}" << endl; // closes switch + } + + indent_down(); f_types_ << indent() << "}"; // closes err2 != nil if (!tfunction->is_oneway()) { @@ -3126,29 +3170,47 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << endl; } f_types_ << indent() << "tickerCancel()" << endl; - f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \"" + + f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {" << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + + f_types_ << indent() << "if err2 := result." << write_method_name_ << "(ctx, oprot); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {" - << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + + f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + + f_types_ << indent() << "if err2 := oprot.Flush(ctx); " << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err != nil {" << endl; - f_types_ << indent() << " return" << endl; + + f_types_ << indent() << "if " << write_err << " != nil {" << endl; + indent_up(); + f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl; + indent_down(); f_types_ << indent() << "}" << endl; + + // return success=true as long as writing to the wire was successful. f_types_ << indent() << "return true, err" << endl; } else { f_types_ << endl; f_types_ << indent() << "tickerCancel()" << endl; - f_types_ << indent() << "return true, nil" << endl; + f_types_ << indent() << "return true, err" << endl; } indent_down(); f_types_ << indent() << "}" << endl << endl; diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am index 4b3ecda93..2cca411ac 100644 --- a/lib/go/test/Makefile.am +++ b/lib/go/test/Makefile.am @@ -52,7 +52,8 @@ gopath: $(THRIFT) $(THRIFTTEST) \ EqualsTest.thrift \ ConflictArgNamesTest.thrift \ ConstOptionalFieldImport.thrift \ - ConstOptionalField.thrift + ConstOptionalField.thrift \ + ProcessorMiddlewareTest.thrift mkdir -p gopath/src grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift $(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift @@ -84,6 +85,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \ $(THRIFT) $(THRIFTARGS) EqualsTest.thrift $(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift $(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift + $(THRIFT) $(THRIFTARGS) ProcessorMiddlewareTest.thrift ln -nfs ../../tests gopath/src/tests cp -r ./dontexportrwtest gopath/src touch gopath @@ -106,7 +108,8 @@ check: gopath ./gopath/src/servicestest/container_test-remote \ ./gopath/src/duplicateimportstest \ ./gopath/src/equalstest \ - ./gopath/src/conflictargnamestest + ./gopath/src/conflictargnamestest \ + ./gopath/src/processormiddlewaretest $(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift $(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest @@ -145,6 +148,7 @@ EXTRA_DIST = \ NamesTest.thrift \ OnewayTest.thrift \ OptionalFieldsTest.thrift \ + ProcessorMiddlewareTest.thrift \ RefAnnotationFieldsTest.thrift \ RequiredFieldTest.thrift \ ServicesTest.thrift \ diff --git a/lib/go/test/ProcessorMiddlewareTest.thrift b/lib/go/test/ProcessorMiddlewareTest.thrift new file mode 100644 index 000000000..2d4f5f4b8 --- /dev/null +++ b/lib/go/test/ProcessorMiddlewareTest.thrift @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * Contains some contributions under the Thrift Software License. + * Please see doc/old-thrift-license.txt in the Thrift distribution for + * details. + */ + +exception Error { + 1: optional string foo, +} + +service Service { + void ping() throws ( + 1: Error error, + ); +} diff --git a/lib/go/test/tests/processor_middleware_test.go b/lib/go/test/tests/processor_middleware_test.go new file mode 100644 index 000000000..1bd911cfe --- /dev/null +++ b/lib/go/test/tests/processor_middleware_test.go @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/apache/thrift/lib/go/test/gopath/src/processormiddlewaretest" + "github.com/apache/thrift/lib/go/thrift" +) + +const errorMessage = "foo error" + +type serviceImpl struct{} + +func (serviceImpl) Ping(_ context.Context) (err error) { + return &processormiddlewaretest.Error{ + Foo: thrift.StringPtr(errorMessage), + } +} + +func middleware(t *testing.T) thrift.ProcessorMiddleware { + return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction { + return thrift.WrappedTProcessorFunction{ + Wrapped: func(ctx context.Context, seqId int32, in, out thrift.TProtocol) (_ bool, err thrift.TException) { + defer func() { + checkError(t, err) + }() + return next.Process(ctx, seqId, in, out) + }, + } + } +} + +func checkError(tb testing.TB, err error) { + tb.Helper() + + var idlErr *processormiddlewaretest.Error + if !errors.As(err, &idlErr) { + tb.Errorf("expected error to be of type *processormiddlewaretest.Error, actual %T, %#v", err, err) + return + } + if actual := idlErr.GetFoo(); actual != errorMessage { + tb.Errorf("expected error message to be %q, actual %q", errorMessage, actual) + } +} + +func TestProcessorMiddleware(t *testing.T) { + const timeout = time.Second + + processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{}) + serverTransport, err := thrift.NewTServerSocket("127.0.0.1:0") + if err != nil { + t.Fatalf("Could not find available server port: %v", err) + } + server := thrift.NewTSimpleServer4( + thrift.WrapProcessor(processor, middleware(t)), + serverTransport, + thrift.NewTHeaderTransportFactoryConf(nil, nil), + thrift.NewTHeaderProtocolFactoryConf(nil), + ) + defer server.Stop() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + server.Serve() + }() + + time.Sleep(10 * time.Millisecond) + + cfg := &thrift.TConfiguration{ + ConnectTimeout: timeout, + SocketTimeout: timeout, + } + transport := thrift.NewTSocketFromAddrConf(serverTransport.Addr(), cfg) + if err := transport.Open(); err != nil { + t.Fatalf("Could not open client transport: %v", err) + } + defer transport.Close() + protocol := thrift.NewTHeaderProtocolConf(transport, nil) + + client := processormiddlewaretest.NewServiceClient(thrift.NewTStandardClient(protocol, protocol)) + + err = client.Ping(context.Background()) + checkError(t, err) +} |