summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/go/thrift/http_client.go8
-rw-r--r--lib/go/thrift/http_client_test.go60
2 files changed, 63 insertions, 5 deletions
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index 1924a1ae2..19c63a985 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -197,6 +197,14 @@ func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
+ // Request might not have been fully read by http client.
+ // Reset so we don't send the remains on next call.
+ defer func() {
+ if p.requestBuffer != nil {
+ p.requestBuffer.Reset()
+ }
+ }()
+
req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
if err != nil {
return NewTTransportExceptionFromError(err)
diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go
index a7977a385..eba366815 100644
--- a/lib/go/thrift/http_client_test.go
+++ b/lib/go/thrift/http_client_test.go
@@ -20,6 +20,8 @@
package thrift
import (
+ "bytes"
+ "context"
"net/http"
"testing"
)
@@ -32,14 +34,14 @@ func TestHttpClient(t *testing.T) {
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportTest(t, trans, trans)
t.Run("nilBuffer", func(t *testing.T) {
_ = trans.Close()
if _, err = trans.Write([]byte{1, 2, 3, 4}); err == nil {
- t.Fatalf("writing to a closed transport did not result in an error")
+ t.Fatal("writing to a closed transport did not result in an error")
}
})
}
@@ -52,7 +54,7 @@ func TestHttpClientHeaders(t *testing.T) {
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
}
@@ -72,7 +74,7 @@ func TestHttpCustomClient(t *testing.T) {
})
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
@@ -94,7 +96,7 @@ func TestHttpCustomClientPackageScope(t *testing.T) {
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
@@ -103,6 +105,54 @@ func TestHttpCustomClientPackageScope(t *testing.T) {
}
}
+func TestHTTPClientFlushesRequestBufferOnErrors(t *testing.T) {
+ var (
+ write1 = []byte("write 1")
+ write2 = []byte("write 2")
+ )
+
+ l, addr := HttpClientSetupForTest(t)
+ if l != nil {
+ defer l.Close()
+ }
+ trans, err := NewTHttpPostClient("http://" + addr.String())
+ if err != nil {
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
+ }
+ defer trans.Close()
+
+ _, err = trans.Write(write1)
+ if err != nil {
+ t.Fatalf("Failed to write to transport: %v", err)
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ err = trans.Flush(ctx)
+ if err == nil {
+ t.Fatal("Expected flush error")
+ }
+
+ _, err = trans.Write(write2)
+ if err != nil {
+ t.Fatalf("Failed to write to transport: %v", err)
+ }
+ err = trans.Flush(context.Background())
+ if err != nil {
+ t.Fatalf("Failed to flush: %v", err)
+ }
+
+ data := make([]byte, 1024)
+ n, err := trans.Read(data)
+ if err != nil {
+ t.Fatalf("Failed to read: %v", err)
+ }
+
+ data = data[:n]
+ if !bytes.Equal(data, write2) {
+ t.Fatalf("Received unexpected data: %q, expected: %q", data, write2)
+ }
+}
+
type customHttpTransport struct {
hit bool
}