summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2022-02-22 18:48:17 -0800
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2022-02-23 09:17:50 -0800
commit9bee877e663f11f4cbdd3a4f02938c8ab9fe8976 (patch)
tree151cdff0268536208f0eb1cebfdf1f2ec968e869
parent103a11c9c28ac963a3b2591ecac641db3cbaa113 (diff)
downloadthrift-9bee877e663f11f4cbdd3a4f02938c8ab9fe8976.tar.gz
THRIFT-5527: Don't swallow idl exceptions in Process function
Client: go This allows ProcessorMiddlewares to access such exceptions, unless there's a network error writing the response (which takes priority). While I'm here, also make the indentation of Process function more consistent, and make it consistent on returning false and an error when the reading/writing fails.
-rw-r--r--compiler/cpp/src/thrift/generate/t_go_generator.cc144
-rw-r--r--lib/go/test/Makefile.am8
-rw-r--r--lib/go/test/ProcessorMiddlewareTest.thrift32
-rw-r--r--lib/go/test/tests/processor_middleware_test.go108
4 files changed, 249 insertions, 43 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 7897b621a..3b885f103 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -959,8 +959,8 @@ string t_go_generator::go_imports_begin(bool consts) {
// If not writing constants, and there are enums, need extra imports.
if (!consts && get_program()->get_enums().size() > 0) {
system_packages.push_back("database/sql/driver");
- system_packages.push_back("errors");
}
+ system_packages.push_back("errors");
system_packages.push_back("fmt");
system_packages.push_back("time");
// For the thrift import, always do rename import to make sure it's called thrift.
@@ -980,6 +980,7 @@ string t_go_generator::go_imports_end() {
"// (needed to ensure safety because of naive import list construction.)\n"
"var _ = thrift.ZERO\n"
"var _ = fmt.Printf\n"
+ "var _ = errors.New\n"
"var _ = context.Background\n"
"var _ = time.Now\n"
"var _ = bytes.Equal\n\n");
@@ -2964,21 +2965,27 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
<< ") Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err "
"thrift.TException) {" << endl;
indent_up();
+ string write_err;
+ if (!tfunction->is_oneway()) {
+ write_err = tmp("_write_err");
+ f_types_ << indent() << "var " << write_err << " error" << endl;
+ }
f_types_ << indent() << "args := " << argsname << "{}" << endl;
- f_types_ << indent() << "var err2 error" << endl;
- f_types_ << indent() << "if err2 = args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl;
- f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl;
+ f_types_ << indent() << "if err2 := args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl;
if (!tfunction->is_oneway()) {
f_types_ << indent()
- << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
+ << "x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
<< endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+ f_types_ << indent() << "oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
<< "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
- f_types_ << indent() << " oprot.Flush(ctx)" << endl;
+ f_types_ << indent() << "x.Write(ctx, oprot)" << endl;
+ f_types_ << indent() << "oprot.WriteMessageEnd(ctx)" << endl;
+ f_types_ << indent() << "oprot.Flush(ctx)" << endl;
}
- f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
+ f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl;
@@ -3037,9 +3044,6 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
f_types_ << indent() << "result := " << resultname << "{}" << endl;
}
bool need_reference = type_need_reference(tfunction->get_returntype());
- if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
- f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl;
- }
f_types_ << indent() << "if ";
@@ -3053,7 +3057,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
t_struct* arg_struct = tfunction->get_arglist();
const std::vector<t_field*>& fields = arg_struct->get_members();
vector<t_field*>::const_iterator f_iter;
- f_types_ << "err2 = p.handler." << publicize(tfunction->get_name()) << "(";
+ f_types_ << "err2 := p.handler." << publicize(tfunction->get_name()) << "(";
bool first = true;
f_types_ << "ctx";
@@ -3069,7 +3073,9 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
}
f_types_ << "); err2 != nil {" << endl;
- f_types_ << indent() << " tickerCancel()" << endl;
+ indent_up();
+ f_types_ << indent() << "tickerCancel()" << endl;
+ f_types_ << indent() << "err = thrift.WrapTException(err2)" << endl;
t_struct* exceptions = tfunction->get_xceptions();
const vector<t_field*>& x_fields = exceptions->get_members();
@@ -3079,36 +3085,74 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
vector<t_field*>::const_iterator xf_iter;
for (xf_iter = x_fields.begin(); xf_iter != x_fields.end(); ++xf_iter) {
- f_types_ << indent() << " case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
+ f_types_ << indent() << "case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
<< endl;
+ indent_up();
f_types_ << indent() << "result." << publicize((*xf_iter)->get_name()) << " = v" << endl;
+ indent_down();
}
- f_types_ << indent() << " default:" << endl;
+ f_types_ << indent() << "default:" << endl;
+ indent_up();
}
if (!tfunction->is_oneway()) {
// Avoid writing the error to the wire if it's ErrAbandonRequest
- f_types_ << indent() << " if err2 == thrift.ErrAbandonRequest {" << endl;
- f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
- f_types_ << indent() << " }" << endl;
+ f_types_ << indent() << "if errors.Is(err2, thrift.ErrAbandonRequest) {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
- f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
+ string exc(tmp("_exc"));
+ f_types_ << indent() << exc << " := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
"\"Internal error processing " << escape_string(tfunction->get_name())
<< ": \" + err2.Error())" << endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
- << "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
- f_types_ << indent() << " oprot.Flush(ctx)" << endl;
- }
- f_types_ << indent() << " return true, thrift.WrapTException(err2)" << endl;
+ f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+ << "\", thrift.EXCEPTION, seqId); err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
- if (!x_fields.empty()) {
+ f_types_ << indent() << "if err2 := " << exc << ".Write(ctx, oprot); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.Flush(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
+
+ // return success=true as long as writing to the wire was successful.
+ f_types_ << indent() << "return true, err" << endl;
}
+ if (!x_fields.empty()) {
+ indent_down();
+ f_types_ << indent() << "}" << endl; // closes switch
+ }
+
+ indent_down();
f_types_ << indent() << "}"; // closes err2 != nil
if (!tfunction->is_oneway()) {
@@ -3126,29 +3170,47 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
f_types_ << endl;
}
f_types_ << indent() << "tickerCancel()" << endl;
- f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \""
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \""
<< escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {"
<< endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := result." << write_method_name_ << "(ctx, oprot); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {"
- << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.Flush(ctx); " << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
+
+ f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
+
+ // return success=true as long as writing to the wire was successful.
f_types_ << indent() << "return true, err" << endl;
} else {
f_types_ << endl;
f_types_ << indent() << "tickerCancel()" << endl;
- f_types_ << indent() << "return true, nil" << endl;
+ f_types_ << indent() << "return true, err" << endl;
}
indent_down();
f_types_ << indent() << "}" << endl << endl;
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 4b3ecda93..2cca411ac 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -52,7 +52,8 @@ gopath: $(THRIFT) $(THRIFTTEST) \
EqualsTest.thrift \
ConflictArgNamesTest.thrift \
ConstOptionalFieldImport.thrift \
- ConstOptionalField.thrift
+ ConstOptionalField.thrift \
+ ProcessorMiddlewareTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -84,6 +85,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \
$(THRIFT) $(THRIFTARGS) EqualsTest.thrift
$(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift
$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
+ $(THRIFT) $(THRIFTARGS) ProcessorMiddlewareTest.thrift
ln -nfs ../../tests gopath/src/tests
cp -r ./dontexportrwtest gopath/src
touch gopath
@@ -106,7 +108,8 @@ check: gopath
./gopath/src/servicestest/container_test-remote \
./gopath/src/duplicateimportstest \
./gopath/src/equalstest \
- ./gopath/src/conflictargnamestest
+ ./gopath/src/conflictargnamestest \
+ ./gopath/src/processormiddlewaretest
$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
@@ -145,6 +148,7 @@ EXTRA_DIST = \
NamesTest.thrift \
OnewayTest.thrift \
OptionalFieldsTest.thrift \
+ ProcessorMiddlewareTest.thrift \
RefAnnotationFieldsTest.thrift \
RequiredFieldTest.thrift \
ServicesTest.thrift \
diff --git a/lib/go/test/ProcessorMiddlewareTest.thrift b/lib/go/test/ProcessorMiddlewareTest.thrift
new file mode 100644
index 000000000..2d4f5f4b8
--- /dev/null
+++ b/lib/go/test/ProcessorMiddlewareTest.thrift
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+exception Error {
+ 1: optional string foo,
+}
+
+service Service {
+ void ping() throws (
+ 1: Error error,
+ );
+}
diff --git a/lib/go/test/tests/processor_middleware_test.go b/lib/go/test/tests/processor_middleware_test.go
new file mode 100644
index 000000000..1bd911cfe
--- /dev/null
+++ b/lib/go/test/tests/processor_middleware_test.go
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/apache/thrift/lib/go/test/gopath/src/processormiddlewaretest"
+ "github.com/apache/thrift/lib/go/thrift"
+)
+
+const errorMessage = "foo error"
+
+type serviceImpl struct{}
+
+func (serviceImpl) Ping(_ context.Context) (err error) {
+ return &processormiddlewaretest.Error{
+ Foo: thrift.StringPtr(errorMessage),
+ }
+}
+
+func middleware(t *testing.T) thrift.ProcessorMiddleware {
+ return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction {
+ return thrift.WrappedTProcessorFunction{
+ Wrapped: func(ctx context.Context, seqId int32, in, out thrift.TProtocol) (_ bool, err thrift.TException) {
+ defer func() {
+ checkError(t, err)
+ }()
+ return next.Process(ctx, seqId, in, out)
+ },
+ }
+ }
+}
+
+func checkError(tb testing.TB, err error) {
+ tb.Helper()
+
+ var idlErr *processormiddlewaretest.Error
+ if !errors.As(err, &idlErr) {
+ tb.Errorf("expected error to be of type *processormiddlewaretest.Error, actual %T, %#v", err, err)
+ return
+ }
+ if actual := idlErr.GetFoo(); actual != errorMessage {
+ tb.Errorf("expected error message to be %q, actual %q", errorMessage, actual)
+ }
+}
+
+func TestProcessorMiddleware(t *testing.T) {
+ const timeout = time.Second
+
+ processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{})
+ serverTransport, err := thrift.NewTServerSocket("127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Could not find available server port: %v", err)
+ }
+ server := thrift.NewTSimpleServer4(
+ thrift.WrapProcessor(processor, middleware(t)),
+ serverTransport,
+ thrift.NewTHeaderTransportFactoryConf(nil, nil),
+ thrift.NewTHeaderProtocolFactoryConf(nil),
+ )
+ defer server.Stop()
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ server.Serve()
+ }()
+
+ time.Sleep(10 * time.Millisecond)
+
+ cfg := &thrift.TConfiguration{
+ ConnectTimeout: timeout,
+ SocketTimeout: timeout,
+ }
+ transport := thrift.NewTSocketFromAddrConf(serverTransport.Addr(), cfg)
+ if err := transport.Open(); err != nil {
+ t.Fatalf("Could not open client transport: %v", err)
+ }
+ defer transport.Close()
+ protocol := thrift.NewTHeaderProtocolConf(transport, nil)
+
+ client := processormiddlewaretest.NewServiceClient(thrift.NewTStandardClient(protocol, protocol))
+
+ err = client.Ping(context.Background())
+ checkError(t, err)
+}