summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/cpp/src/thrift/generate/t_go_generator.cc144
-rw-r--r--lib/go/test/Makefile.am8
-rw-r--r--lib/go/test/ProcessorMiddlewareTest.thrift32
-rw-r--r--lib/go/test/tests/processor_middleware_test.go108
4 files changed, 249 insertions, 43 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 7897b621a..3b885f103 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -959,8 +959,8 @@ string t_go_generator::go_imports_begin(bool consts) {
// If not writing constants, and there are enums, need extra imports.
if (!consts && get_program()->get_enums().size() > 0) {
system_packages.push_back("database/sql/driver");
- system_packages.push_back("errors");
}
+ system_packages.push_back("errors");
system_packages.push_back("fmt");
system_packages.push_back("time");
// For the thrift import, always do rename import to make sure it's called thrift.
@@ -980,6 +980,7 @@ string t_go_generator::go_imports_end() {
"// (needed to ensure safety because of naive import list construction.)\n"
"var _ = thrift.ZERO\n"
"var _ = fmt.Printf\n"
+ "var _ = errors.New\n"
"var _ = context.Background\n"
"var _ = time.Now\n"
"var _ = bytes.Equal\n\n");
@@ -2964,21 +2965,27 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
<< ") Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err "
"thrift.TException) {" << endl;
indent_up();
+ string write_err;
+ if (!tfunction->is_oneway()) {
+ write_err = tmp("_write_err");
+ f_types_ << indent() << "var " << write_err << " error" << endl;
+ }
f_types_ << indent() << "args := " << argsname << "{}" << endl;
- f_types_ << indent() << "var err2 error" << endl;
- f_types_ << indent() << "if err2 = args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl;
- f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl;
+ f_types_ << indent() << "if err2 := args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl;
if (!tfunction->is_oneway()) {
f_types_ << indent()
- << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
+ << "x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
<< endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+ f_types_ << indent() << "oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
<< "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
- f_types_ << indent() << " oprot.Flush(ctx)" << endl;
+ f_types_ << indent() << "x.Write(ctx, oprot)" << endl;
+ f_types_ << indent() << "oprot.WriteMessageEnd(ctx)" << endl;
+ f_types_ << indent() << "oprot.Flush(ctx)" << endl;
}
- f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
+ f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl;
@@ -3037,9 +3044,6 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
f_types_ << indent() << "result := " << resultname << "{}" << endl;
}
bool need_reference = type_need_reference(tfunction->get_returntype());
- if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
- f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl;
- }
f_types_ << indent() << "if ";
@@ -3053,7 +3057,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
t_struct* arg_struct = tfunction->get_arglist();
const std::vector<t_field*>& fields = arg_struct->get_members();
vector<t_field*>::const_iterator f_iter;
- f_types_ << "err2 = p.handler." << publicize(tfunction->get_name()) << "(";
+ f_types_ << "err2 := p.handler." << publicize(tfunction->get_name()) << "(";
bool first = true;
f_types_ << "ctx";
@@ -3069,7 +3073,9 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
}
f_types_ << "); err2 != nil {" << endl;
- f_types_ << indent() << " tickerCancel()" << endl;
+ indent_up();
+ f_types_ << indent() << "tickerCancel()" << endl;
+ f_types_ << indent() << "err = thrift.WrapTException(err2)" << endl;
t_struct* exceptions = tfunction->get_xceptions();
const vector<t_field*>& x_fields = exceptions->get_members();
@@ -3079,36 +3085,74 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
vector<t_field*>::const_iterator xf_iter;
for (xf_iter = x_fields.begin(); xf_iter != x_fields.end(); ++xf_iter) {
- f_types_ << indent() << " case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
+ f_types_ << indent() << "case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
<< endl;
+ indent_up();
f_types_ << indent() << "result." << publicize((*xf_iter)->get_name()) << " = v" << endl;
+ indent_down();
}
- f_types_ << indent() << " default:" << endl;
+ f_types_ << indent() << "default:" << endl;
+ indent_up();
}
if (!tfunction->is_oneway()) {
// Avoid writing the error to the wire if it's ErrAbandonRequest
- f_types_ << indent() << " if err2 == thrift.ErrAbandonRequest {" << endl;
- f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
- f_types_ << indent() << " }" << endl;
+ f_types_ << indent() << "if errors.Is(err2, thrift.ErrAbandonRequest) {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
- f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
+ string exc(tmp("_exc"));
+ f_types_ << indent() << exc << " := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
"\"Internal error processing " << escape_string(tfunction->get_name())
<< ": \" + err2.Error())" << endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
- << "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
- f_types_ << indent() << " oprot.Flush(ctx)" << endl;
- }
- f_types_ << indent() << " return true, thrift.WrapTException(err2)" << endl;
+ f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+ << "\", thrift.EXCEPTION, seqId); err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
- if (!x_fields.empty()) {
+ f_types_ << indent() << "if err2 := " << exc << ".Write(ctx, oprot); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.Flush(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
+
+ // return success=true as long as writing to the wire was successful.
+ f_types_ << indent() << "return true, err" << endl;
}
+ if (!x_fields.empty()) {
+ indent_down();
+ f_types_ << indent() << "}" << endl; // closes switch
+ }
+
+ indent_down();
f_types_ << indent() << "}"; // closes err2 != nil
if (!tfunction->is_oneway()) {
@@ -3126,29 +3170,47 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
f_types_ << endl;
}
f_types_ << indent() << "tickerCancel()" << endl;
- f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \""
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \""
<< escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {"
<< endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := result." << write_method_name_ << "(ctx, oprot); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {"
- << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.Flush(ctx); " << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
+
+ f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
+
+ // return success=true as long as writing to the wire was successful.
f_types_ << indent() << "return true, err" << endl;
} else {
f_types_ << endl;
f_types_ << indent() << "tickerCancel()" << endl;
- f_types_ << indent() << "return true, nil" << endl;
+ f_types_ << indent() << "return true, err" << endl;
}
indent_down();
f_types_ << indent() << "}" << endl << endl;
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 4b3ecda93..2cca411ac 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -52,7 +52,8 @@ gopath: $(THRIFT) $(THRIFTTEST) \
EqualsTest.thrift \
ConflictArgNamesTest.thrift \
ConstOptionalFieldImport.thrift \
- ConstOptionalField.thrift
+ ConstOptionalField.thrift \
+ ProcessorMiddlewareTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -84,6 +85,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \
$(THRIFT) $(THRIFTARGS) EqualsTest.thrift
$(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift
$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
+ $(THRIFT) $(THRIFTARGS) ProcessorMiddlewareTest.thrift
ln -nfs ../../tests gopath/src/tests
cp -r ./dontexportrwtest gopath/src
touch gopath
@@ -106,7 +108,8 @@ check: gopath
./gopath/src/servicestest/container_test-remote \
./gopath/src/duplicateimportstest \
./gopath/src/equalstest \
- ./gopath/src/conflictargnamestest
+ ./gopath/src/conflictargnamestest \
+ ./gopath/src/processormiddlewaretest
$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
@@ -145,6 +148,7 @@ EXTRA_DIST = \
NamesTest.thrift \
OnewayTest.thrift \
OptionalFieldsTest.thrift \
+ ProcessorMiddlewareTest.thrift \
RefAnnotationFieldsTest.thrift \
RequiredFieldTest.thrift \
ServicesTest.thrift \
diff --git a/lib/go/test/ProcessorMiddlewareTest.thrift b/lib/go/test/ProcessorMiddlewareTest.thrift
new file mode 100644
index 000000000..2d4f5f4b8
--- /dev/null
+++ b/lib/go/test/ProcessorMiddlewareTest.thrift
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+exception Error {
+ 1: optional string foo,
+}
+
+service Service {
+ void ping() throws (
+ 1: Error error,
+ );
+}
diff --git a/lib/go/test/tests/processor_middleware_test.go b/lib/go/test/tests/processor_middleware_test.go
new file mode 100644
index 000000000..1bd911cfe
--- /dev/null
+++ b/lib/go/test/tests/processor_middleware_test.go
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/apache/thrift/lib/go/test/gopath/src/processormiddlewaretest"
+ "github.com/apache/thrift/lib/go/thrift"
+)
+
+const errorMessage = "foo error"
+
+type serviceImpl struct{}
+
+func (serviceImpl) Ping(_ context.Context) (err error) {
+ return &processormiddlewaretest.Error{
+ Foo: thrift.StringPtr(errorMessage),
+ }
+}
+
+func middleware(t *testing.T) thrift.ProcessorMiddleware {
+ return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction {
+ return thrift.WrappedTProcessorFunction{
+ Wrapped: func(ctx context.Context, seqId int32, in, out thrift.TProtocol) (_ bool, err thrift.TException) {
+ defer func() {
+ checkError(t, err)
+ }()
+ return next.Process(ctx, seqId, in, out)
+ },
+ }
+ }
+}
+
+func checkError(tb testing.TB, err error) {
+ tb.Helper()
+
+ var idlErr *processormiddlewaretest.Error
+ if !errors.As(err, &idlErr) {
+ tb.Errorf("expected error to be of type *processormiddlewaretest.Error, actual %T, %#v", err, err)
+ return
+ }
+ if actual := idlErr.GetFoo(); actual != errorMessage {
+ tb.Errorf("expected error message to be %q, actual %q", errorMessage, actual)
+ }
+}
+
+func TestProcessorMiddleware(t *testing.T) {
+ const timeout = time.Second
+
+ processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{})
+ serverTransport, err := thrift.NewTServerSocket("127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Could not find available server port: %v", err)
+ }
+ server := thrift.NewTSimpleServer4(
+ thrift.WrapProcessor(processor, middleware(t)),
+ serverTransport,
+ thrift.NewTHeaderTransportFactoryConf(nil, nil),
+ thrift.NewTHeaderProtocolFactoryConf(nil),
+ )
+ defer server.Stop()
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ server.Serve()
+ }()
+
+ time.Sleep(10 * time.Millisecond)
+
+ cfg := &thrift.TConfiguration{
+ ConnectTimeout: timeout,
+ SocketTimeout: timeout,
+ }
+ transport := thrift.NewTSocketFromAddrConf(serverTransport.Addr(), cfg)
+ if err := transport.Open(); err != nil {
+ t.Fatalf("Could not open client transport: %v", err)
+ }
+ defer transport.Close()
+ protocol := thrift.NewTHeaderProtocolConf(transport, nil)
+
+ client := processormiddlewaretest.NewServiceClient(thrift.NewTStandardClient(protocol, protocol))
+
+ err = client.Ping(context.Background())
+ checkError(t, err)
+}