summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMorozov <weugek@gmail.com>2020-12-15 10:35:57 +0300
committerGitHub <noreply@github.com>2020-12-15 07:35:57 +0000
commit4461728f18542eba5d211f9fc412557aab61c491 (patch)
tree7b37ce33ae85dd4a16e215e5a098bd2603bee421
parent70792f2191e5e7345bf08f766638e166d5937f32 (diff)
downloadthrift-4461728f18542eba5d211f9fc412557aab61c491.tar.gz
THRIFT-5324: reset http client buffer after flush
THttpClient did not reset its internal buffer when HTTP client returned an error, leaving the whole or partially read message in the buffer. Now we reset the buffer in defer. Client: go
-rw-r--r--lib/go/thrift/http_client.go8
-rw-r--r--lib/go/thrift/http_client_test.go60
2 files changed, 63 insertions, 5 deletions
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index 1924a1ae2..19c63a985 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -197,6 +197,14 @@ func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
+ // Request might not have been fully read by http client.
+ // Reset so we don't send the remains on next call.
+ defer func() {
+ if p.requestBuffer != nil {
+ p.requestBuffer.Reset()
+ }
+ }()
+
req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
if err != nil {
return NewTTransportExceptionFromError(err)
diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go
index a7977a385..eba366815 100644
--- a/lib/go/thrift/http_client_test.go
+++ b/lib/go/thrift/http_client_test.go
@@ -20,6 +20,8 @@
package thrift
import (
+ "bytes"
+ "context"
"net/http"
"testing"
)
@@ -32,14 +34,14 @@ func TestHttpClient(t *testing.T) {
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportTest(t, trans, trans)
t.Run("nilBuffer", func(t *testing.T) {
_ = trans.Close()
if _, err = trans.Write([]byte{1, 2, 3, 4}); err == nil {
- t.Fatalf("writing to a closed transport did not result in an error")
+ t.Fatal("writing to a closed transport did not result in an error")
}
})
}
@@ -52,7 +54,7 @@ func TestHttpClientHeaders(t *testing.T) {
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
}
@@ -72,7 +74,7 @@ func TestHttpCustomClient(t *testing.T) {
})
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
@@ -94,7 +96,7 @@ func TestHttpCustomClientPackageScope(t *testing.T) {
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
- t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
@@ -103,6 +105,54 @@ func TestHttpCustomClientPackageScope(t *testing.T) {
}
}
+func TestHTTPClientFlushesRequestBufferOnErrors(t *testing.T) {
+ var (
+ write1 = []byte("write 1")
+ write2 = []byte("write 2")
+ )
+
+ l, addr := HttpClientSetupForTest(t)
+ if l != nil {
+ defer l.Close()
+ }
+ trans, err := NewTHttpPostClient("http://" + addr.String())
+ if err != nil {
+ t.Fatalf("Unable to connect to %s: %v", addr.String(), err)
+ }
+ defer trans.Close()
+
+ _, err = trans.Write(write1)
+ if err != nil {
+ t.Fatalf("Failed to write to transport: %v", err)
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ err = trans.Flush(ctx)
+ if err == nil {
+ t.Fatal("Expected flush error")
+ }
+
+ _, err = trans.Write(write2)
+ if err != nil {
+ t.Fatalf("Failed to write to transport: %v", err)
+ }
+ err = trans.Flush(context.Background())
+ if err != nil {
+ t.Fatalf("Failed to flush: %v", err)
+ }
+
+ data := make([]byte, 1024)
+ n, err := trans.Read(data)
+ if err != nil {
+ t.Fatalf("Failed to read: %v", err)
+ }
+
+ data = data[:n]
+ if !bytes.Equal(data, write2) {
+ t.Fatalf("Received unexpected data: %q, expected: %q", data, write2)
+ }
+}
+
type customHttpTransport struct {
hit bool
}