summaryrefslogtreecommitdiff
path: root/lib/go/test/tests/client_middleware_exception_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/go/test/tests/client_middleware_exception_test.go')
-rw-r--r--lib/go/test/tests/client_middleware_exception_test.go189
1 files changed, 189 insertions, 0 deletions
diff --git a/lib/go/test/tests/client_middleware_exception_test.go b/lib/go/test/tests/client_middleware_exception_test.go
new file mode 100644
index 000000000..5cb42ab8b
--- /dev/null
+++ b/lib/go/test/tests/client_middleware_exception_test.go
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/apache/thrift/lib/go/test/gopath/src/clientmiddlewareexceptiontest"
+ "github.com/apache/thrift/lib/go/thrift"
+)
+
+type fakeClientMiddlewareExceptionTestHandler func(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error)
+
+func (f fakeClientMiddlewareExceptionTestHandler) Foo(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return f(ctx)
+}
+
+type clientMiddlewareErrorChecker func(err error) error
+
+var clientMiddlewareExceptionCases = []struct {
+ label string
+ handler fakeClientMiddlewareExceptionTestHandler
+ checker clientMiddlewareErrorChecker
+}{
+ {
+ label: "no-error",
+ handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return new(clientmiddlewareexceptiontest.FooResponse), nil
+ },
+ checker: func(err error) error {
+ if err != nil {
+ return errors.New("expected err to be nil")
+ }
+ return nil
+ },
+ },
+ {
+ label: "exception-1",
+ handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return nil, new(clientmiddlewareexceptiontest.Exception1)
+ },
+ checker: func(err error) error {
+ if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception1)) {
+ return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception1")
+ }
+ return nil
+ },
+ },
+ {
+ label: "no-error",
+ handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return nil, new(clientmiddlewareexceptiontest.Exception2)
+ },
+ checker: func(err error) error {
+ if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception2)) {
+ return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception2")
+ }
+ return nil
+ },
+ },
+}
+
+func TestClientMiddlewareException(t *testing.T) {
+ for _, c := range clientMiddlewareExceptionCases {
+ t.Run(c.label, func(t *testing.T) {
+ serverSocket, err := thrift.NewTServerSocket(":0")
+ if err != nil {
+ t.Fatalf("failed to create server socket: %v", err)
+ }
+ processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler)
+ server := thrift.NewTSimpleServer2(processor, serverSocket)
+ if err := server.Listen(); err != nil {
+ t.Fatalf("failed to listen server: %v", err)
+ }
+ addr := serverSocket.Addr().String()
+ go server.Serve()
+ t.Cleanup(func() {
+ server.Stop()
+ })
+
+ var cfg *thrift.TConfiguration
+ socket := thrift.NewTSocketConf(addr, cfg)
+ if err := socket.Open(); err != nil {
+ t.Fatalf("failed to create client connection: %v", err)
+ }
+ t.Cleanup(func() {
+ socket.Close()
+ })
+ inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ middleware := func(next thrift.TClient) thrift.TClient {
+ return thrift.WrappedTClient{
+ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) {
+ defer func() {
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ }()
+ return next.Call(ctx, method, args, result)
+ },
+ }
+ }
+ client := thrift.WrapClient(
+ thrift.NewTStandardClient(inProtocol, outProtocol),
+ middleware,
+ thrift.ExtractIDLExceptionClientMiddleware,
+ )
+ result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background())
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ })
+ }
+}
+
+func TestExtractExceptionFromResult(t *testing.T) {
+
+ for _, c := range clientMiddlewareExceptionCases {
+ t.Run(c.label, func(t *testing.T) {
+ serverSocket, err := thrift.NewTServerSocket(":0")
+ if err != nil {
+ t.Fatalf("failed to create server socket: %v", err)
+ }
+ processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler)
+ server := thrift.NewTSimpleServer2(processor, serverSocket)
+ if err := server.Listen(); err != nil {
+ t.Fatalf("failed to listen server: %v", err)
+ }
+ addr := serverSocket.Addr().String()
+ go server.Serve()
+ t.Cleanup(func() {
+ server.Stop()
+ })
+
+ var cfg *thrift.TConfiguration
+ socket := thrift.NewTSocketConf(addr, cfg)
+ if err := socket.Open(); err != nil {
+ t.Fatalf("failed to create client connection: %v", err)
+ }
+ t.Cleanup(func() {
+ socket.Close()
+ })
+ inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ middleware := func(next thrift.TClient) thrift.TClient {
+ return thrift.WrappedTClient{
+ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) {
+ defer func() {
+ if err == nil {
+ err = thrift.ExtractExceptionFromResult(result)
+ }
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ }()
+ return next.Call(ctx, method, args, result)
+ },
+ }
+ }
+ client := thrift.WrapClient(
+ thrift.NewTStandardClient(inProtocol, outProtocol),
+ middleware,
+ )
+ result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background())
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ })
+ }
+}