diff options
Diffstat (limited to 'lib/go/test/tests/client_middleware_exception_test.go')
-rw-r--r-- | lib/go/test/tests/client_middleware_exception_test.go | 189 |
1 files changed, 189 insertions, 0 deletions
diff --git a/lib/go/test/tests/client_middleware_exception_test.go b/lib/go/test/tests/client_middleware_exception_test.go new file mode 100644 index 000000000..5cb42ab8b --- /dev/null +++ b/lib/go/test/tests/client_middleware_exception_test.go @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "errors" + "testing" + + "github.com/apache/thrift/lib/go/test/gopath/src/clientmiddlewareexceptiontest" + "github.com/apache/thrift/lib/go/thrift" +) + +type fakeClientMiddlewareExceptionTestHandler func(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) + +func (f fakeClientMiddlewareExceptionTestHandler) Foo(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { + return f(ctx) +} + +type clientMiddlewareErrorChecker func(err error) error + +var clientMiddlewareExceptionCases = []struct { + label string + handler fakeClientMiddlewareExceptionTestHandler + checker clientMiddlewareErrorChecker +}{ + { + label: "no-error", + handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { + return new(clientmiddlewareexceptiontest.FooResponse), nil + }, + checker: func(err error) error { + if err != nil { + return errors.New("expected err to be nil") + } + return nil + }, + }, + { + label: "exception-1", + handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { + return nil, new(clientmiddlewareexceptiontest.Exception1) + }, + checker: func(err error) error { + if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception1)) { + return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception1") + } + return nil + }, + }, + { + label: "no-error", + handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { + return nil, new(clientmiddlewareexceptiontest.Exception2) + }, + checker: func(err error) error { + if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception2)) { + return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception2") + } + return nil + }, + }, +} + +func TestClientMiddlewareException(t *testing.T) { + for _, c := range clientMiddlewareExceptionCases { + t.Run(c.label, func(t *testing.T) { + serverSocket, err := thrift.NewTServerSocket(":0") + if err != nil { + t.Fatalf("failed to create server socket: %v", err) + } + processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler) + server := thrift.NewTSimpleServer2(processor, serverSocket) + if err := server.Listen(); err != nil { + t.Fatalf("failed to listen server: %v", err) + } + addr := serverSocket.Addr().String() + go server.Serve() + t.Cleanup(func() { + server.Stop() + }) + + var cfg *thrift.TConfiguration + socket := thrift.NewTSocketConf(addr, cfg) + if err := socket.Open(); err != nil { + t.Fatalf("failed to create client connection: %v", err) + } + t.Cleanup(func() { + socket.Close() + }) + inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) + outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) + middleware := func(next thrift.TClient) thrift.TClient { + return thrift.WrappedTClient{ + Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) { + defer func() { + if checkErr := c.checker(err); checkErr != nil { + t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) + } + }() + return next.Call(ctx, method, args, result) + }, + } + } + client := thrift.WrapClient( + thrift.NewTStandardClient(inProtocol, outProtocol), + middleware, + thrift.ExtractIDLExceptionClientMiddleware, + ) + result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background()) + if checkErr := c.checker(err); checkErr != nil { + t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) + } + }) + } +} + +func TestExtractExceptionFromResult(t *testing.T) { + + for _, c := range clientMiddlewareExceptionCases { + t.Run(c.label, func(t *testing.T) { + serverSocket, err := thrift.NewTServerSocket(":0") + if err != nil { + t.Fatalf("failed to create server socket: %v", err) + } + processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler) + server := thrift.NewTSimpleServer2(processor, serverSocket) + if err := server.Listen(); err != nil { + t.Fatalf("failed to listen server: %v", err) + } + addr := serverSocket.Addr().String() + go server.Serve() + t.Cleanup(func() { + server.Stop() + }) + + var cfg *thrift.TConfiguration + socket := thrift.NewTSocketConf(addr, cfg) + if err := socket.Open(); err != nil { + t.Fatalf("failed to create client connection: %v", err) + } + t.Cleanup(func() { + socket.Close() + }) + inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) + outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) + middleware := func(next thrift.TClient) thrift.TClient { + return thrift.WrappedTClient{ + Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) { + defer func() { + if err == nil { + err = thrift.ExtractExceptionFromResult(result) + } + if checkErr := c.checker(err); checkErr != nil { + t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) + } + }() + return next.Call(ctx, method, args, result) + }, + } + } + client := thrift.WrapClient( + thrift.NewTStandardClient(inProtocol, outProtocol), + middleware, + ) + result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background()) + if checkErr := c.checker(err); checkErr != nil { + t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) + } + }) + } +} |