diff options
author | Morozov <weugek@gmail.com> | 2020-12-15 10:35:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-15 07:35:57 +0000 |
commit | 4461728f18542eba5d211f9fc412557aab61c491 (patch) | |
tree | 7b37ce33ae85dd4a16e215e5a098bd2603bee421 | |
parent | 70792f2191e5e7345bf08f766638e166d5937f32 (diff) | |
download | thrift-4461728f18542eba5d211f9fc412557aab61c491.tar.gz |
THRIFT-5324: reset http client buffer after flush
THttpClient did not reset its internal buffer when HTTP client returned
an error, leaving the whole or partially read message in the buffer.
Now we reset the buffer in defer.
Client: go
-rw-r--r-- | lib/go/thrift/http_client.go | 8 | ||||
-rw-r--r-- | lib/go/thrift/http_client_test.go | 60 |
2 files changed, 63 insertions, 5 deletions
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go index 1924a1ae2..19c63a985 100644 --- a/lib/go/thrift/http_client.go +++ b/lib/go/thrift/http_client.go @@ -197,6 +197,14 @@ func (p *THttpClient) Flush(ctx context.Context) error { // Close any previous response body to avoid leaking connections. p.closeResponse() + // Request might not have been fully read by http client. + // Reset so we don't send the remains on next call. + defer func() { + if p.requestBuffer != nil { + p.requestBuffer.Reset() + } + }() + req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer) if err != nil { return NewTTransportExceptionFromError(err) diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go index a7977a385..eba366815 100644 --- a/lib/go/thrift/http_client_test.go +++ b/lib/go/thrift/http_client_test.go @@ -20,6 +20,8 @@ package thrift import ( + "bytes" + "context" "net/http" "testing" ) @@ -32,14 +34,14 @@ func TestHttpClient(t *testing.T) { trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportTest(t, trans, trans) t.Run("nilBuffer", func(t *testing.T) { _ = trans.Close() if _, err = trans.Write([]byte{1, 2, 3, 4}); err == nil { - t.Fatalf("writing to a closed transport did not result in an error") + t.Fatal("writing to a closed transport did not result in an error") } }) } @@ -52,7 +54,7 @@ func TestHttpClientHeaders(t *testing.T) { trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportHeaderTest(t, trans, trans) } @@ -72,7 +74,7 @@ func TestHttpCustomClient(t *testing.T) { }) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportHeaderTest(t, trans, trans) @@ -94,7 +96,7 @@ func TestHttpCustomClientPackageScope(t *testing.T) { trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportHeaderTest(t, trans, trans) @@ -103,6 +105,54 @@ func TestHttpCustomClientPackageScope(t *testing.T) { } } +func TestHTTPClientFlushesRequestBufferOnErrors(t *testing.T) { + var ( + write1 = []byte("write 1") + write2 = []byte("write 2") + ) + + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) + } + defer trans.Close() + + _, err = trans.Write(write1) + if err != nil { + t.Fatalf("Failed to write to transport: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = trans.Flush(ctx) + if err == nil { + t.Fatal("Expected flush error") + } + + _, err = trans.Write(write2) + if err != nil { + t.Fatalf("Failed to write to transport: %v", err) + } + err = trans.Flush(context.Background()) + if err != nil { + t.Fatalf("Failed to flush: %v", err) + } + + data := make([]byte, 1024) + n, err := trans.Read(data) + if err != nil { + t.Fatalf("Failed to read: %v", err) + } + + data = data[:n] + if !bytes.Equal(data, write2) { + t.Fatalf("Received unexpected data: %q, expected: %q", data, write2) + } +} + type customHttpTransport struct { hit bool } |