diff options
-rw-r--r-- | compiler/cpp/src/thrift/generate/t_go_generator.cc | 144 | ||||
-rw-r--r-- | lib/go/test/Makefile.am | 8 | ||||
-rw-r--r-- | lib/go/test/ProcessorMiddlewareTest.thrift | 32 | ||||
-rw-r--r-- | lib/go/test/tests/processor_middleware_test.go | 108 |
4 files changed, 249 insertions, 43 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 7897b621a..3b885f103 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -959,8 +959,8 @@ string t_go_generator::go_imports_begin(bool consts) { // If not writing constants, and there are enums, need extra imports. if (!consts && get_program()->get_enums().size() > 0) { system_packages.push_back("database/sql/driver"); - system_packages.push_back("errors"); } + system_packages.push_back("errors"); system_packages.push_back("fmt"); system_packages.push_back("time"); // For the thrift import, always do rename import to make sure it's called thrift. @@ -980,6 +980,7 @@ string t_go_generator::go_imports_end() { "// (needed to ensure safety because of naive import list construction.)\n" "var _ = thrift.ZERO\n" "var _ = fmt.Printf\n" + "var _ = errors.New\n" "var _ = context.Background\n" "var _ = time.Now\n" "var _ = bytes.Equal\n\n"); @@ -2964,21 +2965,27 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* << ") Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err " "thrift.TException) {" << endl; indent_up(); + string write_err; + if (!tfunction->is_oneway()) { + write_err = tmp("_write_err"); + f_types_ << indent() << "var " << write_err << " error" << endl; + } f_types_ << indent() << "args := " << argsname << "{}" << endl; - f_types_ << indent() << "var err2 error" << endl; - f_types_ << indent() << "if err2 = args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl; - f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl; + f_types_ << indent() << "if err2 := args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl; if (!tfunction->is_oneway()) { f_types_ << indent() - << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())" + << "x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())" << endl; - f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) + f_types_ << indent() << "oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.EXCEPTION, seqId)" << endl; - f_types_ << indent() << " x.Write(ctx, oprot)" << endl; - f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; - f_types_ << indent() << " oprot.Flush(ctx)" << endl; + f_types_ << indent() << "x.Write(ctx, oprot)" << endl; + f_types_ << indent() << "oprot.WriteMessageEnd(ctx)" << endl; + f_types_ << indent() << "oprot.Flush(ctx)" << endl; } - f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl; + f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl; @@ -3037,9 +3044,6 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "result := " << resultname << "{}" << endl; } bool need_reference = type_need_reference(tfunction->get_returntype()); - if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { - f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl; - } f_types_ << indent() << "if "; @@ -3053,7 +3057,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* t_struct* arg_struct = tfunction->get_arglist(); const std::vector<t_field*>& fields = arg_struct->get_members(); vector<t_field*>::const_iterator f_iter; - f_types_ << "err2 = p.handler." << publicize(tfunction->get_name()) << "("; + f_types_ << "err2 := p.handler." << publicize(tfunction->get_name()) << "("; bool first = true; f_types_ << "ctx"; @@ -3069,7 +3073,9 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* } f_types_ << "); err2 != nil {" << endl; - f_types_ << indent() << " tickerCancel()" << endl; + indent_up(); + f_types_ << indent() << "tickerCancel()" << endl; + f_types_ << indent() << "err = thrift.WrapTException(err2)" << endl; t_struct* exceptions = tfunction->get_xceptions(); const vector<t_field*>& x_fields = exceptions->get_members(); @@ -3079,36 +3085,74 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* vector<t_field*>::const_iterator xf_iter; for (xf_iter = x_fields.begin(); xf_iter != x_fields.end(); ++xf_iter) { - f_types_ << indent() << " case " << type_to_go_type(((*xf_iter)->get_type())) << ":" + f_types_ << indent() << "case " << type_to_go_type(((*xf_iter)->get_type())) << ":" << endl; + indent_up(); f_types_ << indent() << "result." << publicize((*xf_iter)->get_name()) << " = v" << endl; + indent_down(); } - f_types_ << indent() << " default:" << endl; + f_types_ << indent() << "default:" << endl; + indent_up(); } if (!tfunction->is_oneway()) { // Avoid writing the error to the wire if it's ErrAbandonRequest - f_types_ << indent() << " if err2 == thrift.ErrAbandonRequest {" << endl; - f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl; - f_types_ << indent() << " }" << endl; + f_types_ << indent() << "if errors.Is(err2, thrift.ErrAbandonRequest) {" << endl; + indent_up(); + f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; - f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, " + string exc(tmp("_exc")); + f_types_ << indent() << exc << " := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, " "\"Internal error processing " << escape_string(tfunction->get_name()) << ": \" + err2.Error())" << endl; - f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) - << "\", thrift.EXCEPTION, seqId)" << endl; - f_types_ << indent() << " x.Write(ctx, oprot)" << endl; - f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; - f_types_ << indent() << " oprot.Flush(ctx)" << endl; - } - f_types_ << indent() << " return true, thrift.WrapTException(err2)" << endl; + f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) + << "\", thrift.EXCEPTION, seqId); err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; - if (!x_fields.empty()) { + f_types_ << indent() << "if err2 := " << exc << ".Write(ctx, oprot); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; + + f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; + + f_types_ << indent() << "if err2 := oprot.Flush(ctx); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; + + f_types_ << indent() << "if " << write_err << " != nil {" << endl; + indent_up(); + f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl; + indent_down(); f_types_ << indent() << "}" << endl; + + // return success=true as long as writing to the wire was successful. + f_types_ << indent() << "return true, err" << endl; } + if (!x_fields.empty()) { + indent_down(); + f_types_ << indent() << "}" << endl; // closes switch + } + + indent_down(); f_types_ << indent() << "}"; // closes err2 != nil if (!tfunction->is_oneway()) { @@ -3126,29 +3170,47 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << endl; } f_types_ << indent() << "tickerCancel()" << endl; - f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \"" + + f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {" << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + + f_types_ << indent() << "if err2 := result." << write_method_name_ << "(ctx, oprot); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {" - << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + + f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); " + << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl; - f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; + + f_types_ << indent() << "if err2 := oprot.Flush(ctx); " << write_err << " == nil && err2 != nil {" << endl; + indent_up(); + f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl; + indent_down(); f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err != nil {" << endl; - f_types_ << indent() << " return" << endl; + + f_types_ << indent() << "if " << write_err << " != nil {" << endl; + indent_up(); + f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl; + indent_down(); f_types_ << indent() << "}" << endl; + + // return success=true as long as writing to the wire was successful. f_types_ << indent() << "return true, err" << endl; } else { f_types_ << endl; f_types_ << indent() << "tickerCancel()" << endl; - f_types_ << indent() << "return true, nil" << endl; + f_types_ << indent() << "return true, err" << endl; } indent_down(); f_types_ << indent() << "}" << endl << endl; diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am index 4b3ecda93..2cca411ac 100644 --- a/lib/go/test/Makefile.am +++ b/lib/go/test/Makefile.am @@ -52,7 +52,8 @@ gopath: $(THRIFT) $(THRIFTTEST) \ EqualsTest.thrift \ ConflictArgNamesTest.thrift \ ConstOptionalFieldImport.thrift \ - ConstOptionalField.thrift + ConstOptionalField.thrift \ + ProcessorMiddlewareTest.thrift mkdir -p gopath/src grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift $(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift @@ -84,6 +85,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \ $(THRIFT) $(THRIFTARGS) EqualsTest.thrift $(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift $(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift + $(THRIFT) $(THRIFTARGS) ProcessorMiddlewareTest.thrift ln -nfs ../../tests gopath/src/tests cp -r ./dontexportrwtest gopath/src touch gopath @@ -106,7 +108,8 @@ check: gopath ./gopath/src/servicestest/container_test-remote \ ./gopath/src/duplicateimportstest \ ./gopath/src/equalstest \ - ./gopath/src/conflictargnamestest + ./gopath/src/conflictargnamestest \ + ./gopath/src/processormiddlewaretest $(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift $(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest @@ -145,6 +148,7 @@ EXTRA_DIST = \ NamesTest.thrift \ OnewayTest.thrift \ OptionalFieldsTest.thrift \ + ProcessorMiddlewareTest.thrift \ RefAnnotationFieldsTest.thrift \ RequiredFieldTest.thrift \ ServicesTest.thrift \ diff --git a/lib/go/test/ProcessorMiddlewareTest.thrift b/lib/go/test/ProcessorMiddlewareTest.thrift new file mode 100644 index 000000000..2d4f5f4b8 --- /dev/null +++ b/lib/go/test/ProcessorMiddlewareTest.thrift @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * Contains some contributions under the Thrift Software License. + * Please see doc/old-thrift-license.txt in the Thrift distribution for + * details. + */ + +exception Error { + 1: optional string foo, +} + +service Service { + void ping() throws ( + 1: Error error, + ); +} diff --git a/lib/go/test/tests/processor_middleware_test.go b/lib/go/test/tests/processor_middleware_test.go new file mode 100644 index 000000000..1bd911cfe --- /dev/null +++ b/lib/go/test/tests/processor_middleware_test.go @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/apache/thrift/lib/go/test/gopath/src/processormiddlewaretest" + "github.com/apache/thrift/lib/go/thrift" +) + +const errorMessage = "foo error" + +type serviceImpl struct{} + +func (serviceImpl) Ping(_ context.Context) (err error) { + return &processormiddlewaretest.Error{ + Foo: thrift.StringPtr(errorMessage), + } +} + +func middleware(t *testing.T) thrift.ProcessorMiddleware { + return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction { + return thrift.WrappedTProcessorFunction{ + Wrapped: func(ctx context.Context, seqId int32, in, out thrift.TProtocol) (_ bool, err thrift.TException) { + defer func() { + checkError(t, err) + }() + return next.Process(ctx, seqId, in, out) + }, + } + } +} + +func checkError(tb testing.TB, err error) { + tb.Helper() + + var idlErr *processormiddlewaretest.Error + if !errors.As(err, &idlErr) { + tb.Errorf("expected error to be of type *processormiddlewaretest.Error, actual %T, %#v", err, err) + return + } + if actual := idlErr.GetFoo(); actual != errorMessage { + tb.Errorf("expected error message to be %q, actual %q", errorMessage, actual) + } +} + +func TestProcessorMiddleware(t *testing.T) { + const timeout = time.Second + + processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{}) + serverTransport, err := thrift.NewTServerSocket("127.0.0.1:0") + if err != nil { + t.Fatalf("Could not find available server port: %v", err) + } + server := thrift.NewTSimpleServer4( + thrift.WrapProcessor(processor, middleware(t)), + serverTransport, + thrift.NewTHeaderTransportFactoryConf(nil, nil), + thrift.NewTHeaderProtocolFactoryConf(nil), + ) + defer server.Stop() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + server.Serve() + }() + + time.Sleep(10 * time.Millisecond) + + cfg := &thrift.TConfiguration{ + ConnectTimeout: timeout, + SocketTimeout: timeout, + } + transport := thrift.NewTSocketFromAddrConf(serverTransport.Addr(), cfg) + if err := transport.Open(); err != nil { + t.Fatalf("Could not open client transport: %v", err) + } + defer transport.Close() + protocol := thrift.NewTHeaderProtocolConf(transport, nil) + + client := processormiddlewaretest.NewServiceClient(thrift.NewTStandardClient(protocol, protocol)) + + err = client.Ping(context.Background()) + checkError(t, err) +} |