diff options
25 files changed, 282 insertions, 81 deletions
diff --git a/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs b/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs index b1f841892..b8df515de 100644 --- a/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs +++ b/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs @@ -31,6 +31,7 @@ namespace Thrift.IntegrationTests.Protocols public class ProtocolsOperationsTests { private readonly CompareLogic _compareLogic = new CompareLogic(); + private static readonly TConfiguration Configuration = null; // or new TConfiguration() if needed [DataTestMethod] [DataRow(typeof(TBinaryProtocol), TMessageType.Call)] @@ -494,7 +495,7 @@ namespace Thrift.IntegrationTests.Protocols private static Tuple<Stream, TProtocol> GetProtocolInstance(Type protocolType) { var memoryStream = new MemoryStream(); - var streamClientTransport = new TStreamTransport(memoryStream, memoryStream); + var streamClientTransport = new TStreamTransport(memoryStream, memoryStream,Configuration); var protocol = (TProtocol) Activator.CreateInstance(protocolType, streamClientTransport); return new Tuple<Stream, TProtocol>(memoryStream, protocol); } diff --git a/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs b/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs index 970ce7ece..4054a29f2 100644 --- a/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs +++ b/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs @@ -21,7 +21,6 @@ using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; -using NSubstitute; using Thrift.Protocol; using Thrift.Protocol.Entities; using Thrift.Transport; @@ -36,7 +35,7 @@ namespace Thrift.Tests.Protocols [TestMethod] public void TJSONProtocol_Can_Create_Instance_Test() { - var httpClientTransport = Substitute.For<THttpTransport>(new Uri("http://localhost"), null, null); + var httpClientTransport = new THttpTransport( new Uri("http://localhost"), null, null, null); var result = new TJSONProtocolWrapper(httpClientTransport); @@ -45,7 +44,7 @@ namespace Thrift.Tests.Protocols Assert.IsNotNull(result.WrappedReader); Assert.IsNotNull(result.Transport); Assert.IsTrue(result.WrappedRecursionDepth == 0); - Assert.IsTrue(result.WrappedRecursionLimit == TProtocol.DefaultRecursionDepth); + Assert.IsTrue(result.WrappedRecursionLimit == TConfiguration.DEFAULT_RECURSION_DEPTH); Assert.IsTrue(result.Transport.Equals(httpClientTransport)); Assert.IsTrue(result.WrappedContext.GetType().Name.Equals("JSONBaseContext", StringComparison.OrdinalIgnoreCase)); diff --git a/lib/netstd/Thrift/Protocol/TProtocol.cs b/lib/netstd/Thrift/Protocol/TProtocol.cs index 75edb11d1..dca3f9efc 100644 --- a/lib/netstd/Thrift/Protocol/TProtocol.cs +++ b/lib/netstd/Thrift/Protocol/TProtocol.cs @@ -27,7 +27,6 @@ namespace Thrift.Protocol // ReSharper disable once InconsistentNaming public abstract class TProtocol : IDisposable { - public const int DefaultRecursionDepth = 64; private bool _isDisposed; protected int RecursionDepth; @@ -36,7 +35,7 @@ namespace Thrift.Protocol protected TProtocol(TTransport trans) { Trans = trans; - RecursionLimit = DefaultRecursionDepth; + RecursionLimit = trans.Configuration.RecursionLimit; RecursionDepth = 0; } diff --git a/lib/netstd/Thrift/TConfiguration.cs b/lib/netstd/Thrift/TConfiguration.cs new file mode 100644 index 000000000..c8dde1043 --- /dev/null +++ b/lib/netstd/Thrift/TConfiguration.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift +{ + public class TConfiguration + { + public const int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; + public const int DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries + public const int DEFAULT_RECURSION_DEPTH = 64; + + public int MaxMessageSize { get; set; } = DEFAULT_MAX_MESSAGE_SIZE; + public int MaxFrameSize { get; set; } = DEFAULT_MAX_FRAME_SIZE; + public int RecursionLimit { get; set; } = DEFAULT_RECURSION_DEPTH; + + // TODO(JensG): add connection and i/o timeouts + } +} diff --git a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs index 4f8454c0b..bbd94fa98 100644 --- a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs +++ b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs @@ -28,7 +28,7 @@ using System.Threading.Tasks; namespace Thrift.Transport.Client { // ReSharper disable once InconsistentNaming - public class THttpTransport : TTransport + public class THttpTransport : TEndpointTransport { private readonly X509Certificate[] _certificates; private readonly Uri _uri; @@ -39,13 +39,14 @@ namespace Thrift.Transport.Client private MemoryStream _outputStream = new MemoryStream(); private bool _isDisposed; - public THttpTransport(Uri uri, IDictionary<string, string> customRequestHeaders = null, string userAgent = null) - : this(uri, Enumerable.Empty<X509Certificate>(), customRequestHeaders, userAgent) + public THttpTransport(Uri uri, TConfiguration config, IDictionary<string, string> customRequestHeaders = null, string userAgent = null) + : this(uri, config, Enumerable.Empty<X509Certificate>(), customRequestHeaders, userAgent) { } - public THttpTransport(Uri uri, IEnumerable<X509Certificate> certificates, + public THttpTransport(Uri uri, TConfiguration config, IEnumerable<X509Certificate> certificates, IDictionary<string, string> customRequestHeaders, string userAgent = null) + : base(config) { _uri = uri; _certificates = (certificates ?? Enumerable.Empty<X509Certificate>()).ToArray(); @@ -104,6 +105,8 @@ namespace Thrift.Transport.Client if (_inputStream == null) throw new TTransportException(TTransportException.ExceptionType.NotOpen, "No request has been sent"); + CheckReadBytesAvailable(length); + try { var ret = await _inputStream.ReadAsync(buffer, offset, length, cancellationToken); @@ -112,6 +115,7 @@ namespace Thrift.Transport.Client throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "No more data available"); } + CountConsumedMessageBytes(ret); return ret; } catch (IOException iox) @@ -196,9 +200,11 @@ namespace Thrift.Transport.Client finally { _outputStream = new MemoryStream(); + ResetConsumedMessageSize(); } } + // IDisposable protected override void Dispose(bool disposing) { diff --git a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs index cdbbc0d36..abf8f14c4 100644 --- a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs +++ b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs @@ -24,18 +24,20 @@ using System.Threading.Tasks; namespace Thrift.Transport.Client { // ReSharper disable once InconsistentNaming - public class TMemoryBufferTransport : TTransport + public class TMemoryBufferTransport : TEndpointTransport { private bool IsDisposed; private byte[] Bytes; private int _bytesUsed; - public TMemoryBufferTransport(int initialCapacity = 2048) + public TMemoryBufferTransport(TConfiguration config, int initialCapacity = 2048) + : base(config) { Bytes = new byte[initialCapacity]; } - public TMemoryBufferTransport(byte[] buf) + public TMemoryBufferTransport(byte[] buf, TConfiguration config) + :base(config) { Bytes = (byte[])buf.Clone(); _bytesUsed = Bytes.Length; @@ -112,13 +114,18 @@ namespace Thrift.Transport.Client if ((0 > newPos) || (newPos > _bytesUsed)) throw new ArgumentException(nameof(origin)); Position = newPos; + + ResetConsumedMessageSize(); + CountConsumedMessageBytes(Position); } public override ValueTask<int> ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken) { + CheckReadBytesAvailable(length); var count = Math.Min(Length - Position, length); Buffer.BlockCopy(Bytes, Position, buffer, offset, count); Position += count; + CountConsumedMessageBytes(count); return new ValueTask<int>(count); } @@ -142,6 +149,7 @@ namespace Thrift.Transport.Client { await Task.FromCanceled(cancellationToken); } + ResetConsumedMessageSize(); } public byte[] GetBuffer() @@ -157,7 +165,6 @@ namespace Thrift.Transport.Client return true; } - // IDisposable protected override void Dispose(bool disposing) { diff --git a/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs b/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs index 1ae6074b8..f7f10b71a 100644 --- a/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs +++ b/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs @@ -23,17 +23,18 @@ using System.Threading.Tasks; namespace Thrift.Transport.Client { // ReSharper disable once InconsistentNaming - public class TNamedPipeTransport : TTransport + public class TNamedPipeTransport : TEndpointTransport { private NamedPipeClientStream PipeStream; private readonly int ConnectTimeout; - public TNamedPipeTransport(string pipe, int timeout = Timeout.Infinite) - : this(".", pipe, timeout) + public TNamedPipeTransport(string pipe, TConfiguration config, int timeout = Timeout.Infinite) + : this(".", pipe, config, timeout) { } - public TNamedPipeTransport(string server, string pipe, int timeout = Timeout.Infinite) + public TNamedPipeTransport(string server, string pipe, TConfiguration config, int timeout = Timeout.Infinite) + : base(config) { var serverName = string.IsNullOrWhiteSpace(server) ? server : "."; ConnectTimeout = (timeout > 0) ? timeout : Timeout.Infinite; @@ -51,6 +52,7 @@ namespace Thrift.Transport.Client } await PipeStream.ConnectAsync( ConnectTimeout, cancellationToken); + ResetConsumedMessageSize(); } public override void Close() @@ -69,7 +71,10 @@ namespace Thrift.Transport.Client throw new TTransportException(TTransportException.ExceptionType.NotOpen); } - return await PipeStream.ReadAsync(buffer, offset, length, cancellationToken); + CheckReadBytesAvailable(length); + var numRead = await PipeStream.ReadAsync(buffer, offset, length, cancellationToken); + CountConsumedMessageBytes(numRead); + return numRead; } public override async Task WriteAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken) @@ -98,8 +103,10 @@ namespace Thrift.Transport.Client { await Task.FromCanceled(cancellationToken); } + ResetConsumedMessageSize(); } + protected override void Dispose(bool disposing) { if(disposing) diff --git a/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs b/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs index dd506bc98..d55915434 100644 --- a/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs +++ b/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs @@ -30,13 +30,15 @@ namespace Thrift.Transport.Client private bool _isDisposed; - public TSocketTransport(TcpClient client) + public TSocketTransport(TcpClient client, TConfiguration config) + : base(config) { TcpClient = client ?? throw new ArgumentNullException(nameof(client)); SetInputOutputStream(); } - public TSocketTransport(IPAddress host, int port, int timeout = 0) + public TSocketTransport(IPAddress host, int port, TConfiguration config, int timeout = 0) + : base(config) { Host = host; Port = port; @@ -47,7 +49,8 @@ namespace Thrift.Transport.Client SetInputOutputStream(); } - public TSocketTransport(string host, int port, int timeout = 0) + public TSocketTransport(string host, int port, TConfiguration config, int timeout = 0) + : base(config) { try { diff --git a/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs b/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs index d8574d610..e04b3b333 100644 --- a/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs +++ b/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs @@ -22,15 +22,17 @@ using System.Threading.Tasks; namespace Thrift.Transport.Client { // ReSharper disable once InconsistentNaming - public class TStreamTransport : TTransport + public class TStreamTransport : TEndpointTransport { private bool _isDisposed; - protected TStreamTransport() + protected TStreamTransport(TConfiguration config) + :base(config) { } - public TStreamTransport(Stream inputStream, Stream outputStream) + public TStreamTransport(Stream inputStream, Stream outputStream, TConfiguration config) + : base(config) { InputStream = inputStream; OutputStream = outputStream; @@ -38,7 +40,14 @@ namespace Thrift.Transport.Client protected Stream OutputStream { get; set; } - protected Stream InputStream { get; set; } + private Stream _InputStream = null; + protected Stream InputStream { + get => _InputStream; + set { + _InputStream = value; + ResetConsumedMessageSize(); + } + } public override bool IsOpen => true; @@ -90,8 +99,10 @@ namespace Thrift.Transport.Client public override async Task FlushAsync(CancellationToken cancellationToken) { await OutputStream.FlushAsync(cancellationToken); + ResetConsumedMessageSize(); } + // IDisposable protected override void Dispose(bool disposing) { diff --git a/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs b/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs index a926a38f9..0980526f2 100644 --- a/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs +++ b/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs @@ -42,11 +42,12 @@ namespace Thrift.Transport.Client private SslStream _secureStream; private int _timeout; - public TTlsSocketTransport(TcpClient client, + public TTlsSocketTransport(TcpClient client, TConfiguration config, X509Certificate2 certificate, bool isServer = false, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12) + : base(config) { _client = client; _certificate = certificate; @@ -68,12 +69,12 @@ namespace Thrift.Transport.Client } } - public TTlsSocketTransport(IPAddress host, int port, + public TTlsSocketTransport(IPAddress host, int port, TConfiguration config, string certificatePath, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12) - : this(host, port, 0, + : this(host, port, config, 0, new X509Certificate2(certificatePath), certValidator, localCertificateSelectionCallback, @@ -81,12 +82,12 @@ namespace Thrift.Transport.Client { } - public TTlsSocketTransport(IPAddress host, int port, + public TTlsSocketTransport(IPAddress host, int port, TConfiguration config, X509Certificate2 certificate = null, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12) - : this(host, port, 0, + : this(host, port, config, 0, certificate, certValidator, localCertificateSelectionCallback, @@ -94,11 +95,12 @@ namespace Thrift.Transport.Client { } - public TTlsSocketTransport(IPAddress host, int port, int timeout, + public TTlsSocketTransport(IPAddress host, int port, TConfiguration config, int timeout, X509Certificate2 certificate, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12) + : base(config) { _host = host; _port = port; @@ -111,11 +113,12 @@ namespace Thrift.Transport.Client InitSocket(); } - public TTlsSocketTransport(string host, int port, int timeout, + public TTlsSocketTransport(string host, int port, TConfiguration config, int timeout, X509Certificate2 certificate, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12) + : base(config) { try { diff --git a/lib/netstd/Thrift/Transport/TBufferedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs index e4fdd3a8d..10cec3c3d 100644 --- a/lib/netstd/Thrift/Transport/TBufferedTransport.cs +++ b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs @@ -24,12 +24,11 @@ using System.Threading.Tasks; namespace Thrift.Transport { // ReSharper disable once InconsistentNaming - public class TBufferedTransport : TTransport + public class TBufferedTransport : TLayeredTransport { private readonly int DesiredBufferSize; - private readonly Client.TMemoryBufferTransport ReadBuffer = new Client.TMemoryBufferTransport(1024); - private readonly Client.TMemoryBufferTransport WriteBuffer = new Client.TMemoryBufferTransport(1024); - private readonly TTransport InnerTransport; + private readonly Client.TMemoryBufferTransport ReadBuffer; + private readonly Client.TMemoryBufferTransport WriteBuffer; private bool IsDisposed; public class Factory : TTransportFactory @@ -42,19 +41,20 @@ namespace Thrift.Transport //TODO: should support only specified input transport? public TBufferedTransport(TTransport transport, int bufSize = 1024) + : base(transport) { if (bufSize <= 0) { throw new ArgumentOutOfRangeException(nameof(bufSize), "Buffer size must be a positive number."); } - InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport)); DesiredBufferSize = bufSize; - if (DesiredBufferSize != ReadBuffer.Capacity) - ReadBuffer.Capacity = DesiredBufferSize; - if (DesiredBufferSize != WriteBuffer.Capacity) - WriteBuffer.Capacity = DesiredBufferSize; + WriteBuffer = new Client.TMemoryBufferTransport(InnerTransport.Configuration, bufSize); + ReadBuffer = new Client.TMemoryBufferTransport(InnerTransport.Configuration, bufSize); + + Debug.Assert(DesiredBufferSize == ReadBuffer.Capacity); + Debug.Assert(DesiredBufferSize == WriteBuffer.Capacity); } public TTransport UnderlyingTransport diff --git a/lib/netstd/Thrift/Transport/TFramedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs index de6df7238..c842a16aa 100644 --- a/lib/netstd/Thrift/Transport/TFramedTransport.cs +++ b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs @@ -23,13 +23,12 @@ using System.Threading.Tasks; namespace Thrift.Transport { // ReSharper disable once InconsistentNaming - public class TFramedTransport : TTransport + public class TFramedTransport : TLayeredTransport { private const int HeaderSize = 4; private readonly byte[] HeaderBuf = new byte[HeaderSize]; - private readonly Client.TMemoryBufferTransport ReadBuffer = new Client.TMemoryBufferTransport(); - private readonly Client.TMemoryBufferTransport WriteBuffer = new Client.TMemoryBufferTransport(); - private readonly TTransport InnerTransport; + private readonly Client.TMemoryBufferTransport ReadBuffer; + private readonly Client.TMemoryBufferTransport WriteBuffer; private bool IsDisposed; @@ -42,9 +41,10 @@ namespace Thrift.Transport } public TFramedTransport(TTransport transport) + : base(transport) { - InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport)); - + ReadBuffer = new Client.TMemoryBufferTransport(Configuration); + WriteBuffer = new Client.TMemoryBufferTransport(Configuration); InitWriteBuffer(); } @@ -86,7 +86,11 @@ namespace Thrift.Transport private async ValueTask ReadFrameAsync(CancellationToken cancellationToken) { await InnerTransport.ReadAllAsync(HeaderBuf, 0, HeaderSize, cancellationToken); - var size = DecodeFrameSize(HeaderBuf); + int size = DecodeFrameSize(HeaderBuf); + + if ((0 > size) || (size > Configuration.MaxFrameSize)) // size must be in the range 0 to allowed max + throw new TTransportException(TTransportException.ExceptionType.Unknown, $"Maximum frame size exceeded ({size} bytes)"); + UpdateKnownMessageSize(size + HeaderSize); ReadBuffer.SetLength(size); ReadBuffer.Seek(0, SeekOrigin.Begin); diff --git a/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs new file mode 100644 index 000000000..59d98ff1d --- /dev/null +++ b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Transport +{ + public abstract class TLayeredTransport : TTransport + { + public readonly TTransport InnerTransport; + + public override TConfiguration Configuration { get => InnerTransport.Configuration; } + + public TLayeredTransport(TTransport transport) + { + InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport)); + } + + public override void UpdateKnownMessageSize(long size) + { + InnerTransport.UpdateKnownMessageSize(size); + } + } +} diff --git a/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs b/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs index 2a40db396..7271f504e 100644 --- a/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs +++ b/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs @@ -42,27 +42,31 @@ namespace Thrift.Transport.Server protected TTransportFactory OutputTransportFactory; protected ITAsyncProcessor Processor; + protected TConfiguration Configuration; public THttpServerTransport( ITAsyncProcessor processor, + TConfiguration config, RequestDelegate next = null, ILoggerFactory loggerFactory = null) - : this(processor, new TBinaryProtocol.Factory(), null, next, loggerFactory) + : this(processor, config, new TBinaryProtocol.Factory(), null, next, loggerFactory) { } public THttpServerTransport( ITAsyncProcessor processor, + TConfiguration config, TProtocolFactory protocolFactory, TTransportFactory transFactory = null, RequestDelegate next = null, ILoggerFactory loggerFactory = null) - : this(processor, protocolFactory, protocolFactory, transFactory, transFactory, next, loggerFactory) + : this(processor, config, protocolFactory, protocolFactory, transFactory, transFactory, next, loggerFactory) { } public THttpServerTransport( ITAsyncProcessor processor, + TConfiguration config, TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory, TTransportFactory inputTransFactory = null, @@ -73,6 +77,8 @@ namespace Thrift.Transport.Server // loggerFactory == null is not illegal anymore Processor = processor ?? throw new ArgumentNullException(nameof(processor)); + Configuration = config; // may be null + InputProtocolFactory = inputProtocolFactory ?? throw new ArgumentNullException(nameof(inputProtocolFactory)); OutputProtocolFactory = outputProtocolFactory ?? throw new ArgumentNullException(nameof(outputProtocolFactory)); @@ -91,7 +97,7 @@ namespace Thrift.Transport.Server public async Task ProcessRequestAsync(HttpContext context, CancellationToken cancellationToken) { - var transport = new TStreamTransport(context.Request.Body, context.Response.Body); + var transport = new TStreamTransport(context.Request.Body, context.Response.Body, Configuration); try { diff --git a/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs b/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs index b2f29b4e1..a8b64c495 100644 --- a/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs +++ b/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs @@ -38,7 +38,8 @@ namespace Thrift.Transport.Server private volatile bool _isPending = true; private NamedPipeServerStream _stream = null; - public TNamedPipeServerTransport(string pipeAddress) + public TNamedPipeServerTransport(string pipeAddress, TConfiguration config) + : base(config) { _pipeAddress = pipeAddress; } @@ -224,7 +225,7 @@ namespace Thrift.Transport.Server await _stream.WaitForConnectionAsync(cancellationToken); - var trans = new ServerTransport(_stream); + var trans = new ServerTransport(_stream, Configuration); _stream = null; // pass ownership to ServerTransport //_isPending = false; @@ -243,11 +244,12 @@ namespace Thrift.Transport.Server } } - private class ServerTransport : TTransport + private class ServerTransport : TEndpointTransport { private readonly NamedPipeServerStream PipeStream; - public ServerTransport(NamedPipeServerStream stream) + public ServerTransport(NamedPipeServerStream stream, TConfiguration config) + : base(config) { PipeStream = stream; } @@ -274,7 +276,10 @@ namespace Thrift.Transport.Server throw new TTransportException(TTransportException.ExceptionType.NotOpen); } - return await PipeStream.ReadAsync(buffer, offset, length, cancellationToken); + CheckReadBytesAvailable(length); + var numBytes = await PipeStream.ReadAsync(buffer, offset, length, cancellationToken); + CountConsumedMessageBytes(numBytes); + return numBytes; } public override async Task WriteAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken) @@ -303,6 +308,8 @@ namespace Thrift.Transport.Server { await Task.FromCanceled(cancellationToken); } + + ResetConsumedMessageSize(); } protected override void Dispose(bool disposing) diff --git a/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs b/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs index 86d82e3fc..6656b641a 100644 --- a/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs +++ b/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs @@ -31,14 +31,15 @@ namespace Thrift.Transport.Server private readonly int _clientTimeout; private TcpListener _server; - public TServerSocketTransport(TcpListener listener, int clientTimeout = 0) + public TServerSocketTransport(TcpListener listener, TConfiguration config, int clientTimeout = 0) + : base(config) { _server = listener; _clientTimeout = clientTimeout; } - public TServerSocketTransport(int port, int clientTimeout = 0) - : this(null, clientTimeout) + public TServerSocketTransport(int port, TConfiguration config, int clientTimeout = 0) + : this(null, config, clientTimeout) { try { @@ -93,7 +94,7 @@ namespace Thrift.Transport.Server try { - tSocketTransport = new TSocketTransport(tcpClient) + tSocketTransport = new TSocketTransport(tcpClient,Configuration) { Timeout = _clientTimeout }; diff --git a/lib/netstd/Thrift/Transport/Server/TServerTransport.cs b/lib/netstd/Thrift/Transport/Server/TServerTransport.cs index dd60f6a12..31f578d54 100644 --- a/lib/netstd/Thrift/Transport/Server/TServerTransport.cs +++ b/lib/netstd/Thrift/Transport/Server/TServerTransport.cs @@ -23,6 +23,13 @@ namespace Thrift.Transport // ReSharper disable once InconsistentNaming public abstract class TServerTransport { + public readonly TConfiguration Configuration; + + public TServerTransport(TConfiguration config) + { + Configuration = config ?? new TConfiguration(); + } + public abstract void Listen(); public abstract void Close(); public abstract bool IsClientPending(); diff --git a/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs b/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs index 231b83f5a..9f7456252 100644 --- a/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs +++ b/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs @@ -39,10 +39,12 @@ namespace Thrift.Transport.Server public TTlsServerSocketTransport( TcpListener listener, + TConfiguration config, X509Certificate2 certificate, RemoteCertificateValidationCallback clientCertValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12) + : base(config) { if (!certificate.HasPrivateKey) { @@ -59,11 +61,12 @@ namespace Thrift.Transport.Server public TTlsServerSocketTransport( int port, + TConfiguration config, X509Certificate2 certificate, RemoteCertificateValidationCallback clientCertValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12) - : this(null, certificate, clientCertValidator, localCertificateSelectionCallback, sslProtocols) + : this(null, config, certificate, clientCertValidator, localCertificateSelectionCallback, sslProtocols) { try { @@ -117,8 +120,8 @@ namespace Thrift.Transport.Server client.SendTimeout = client.ReceiveTimeout = _clientTimeout; //wrap the client in an SSL Socket passing in the SSL cert - var tTlsSocket = new TTlsSocketTransport( - client, + var tTlsSocket = new TTlsSocketTransport( + client, Configuration, _serverCertificate, true, _clientCertValidator, _localCertificateSelectionCallback, _sslProtocols); diff --git a/lib/netstd/Thrift/Transport/TEndpointTransport.cs b/lib/netstd/Thrift/Transport/TEndpointTransport.cs new file mode 100644 index 000000000..810f3f4ad --- /dev/null +++ b/lib/netstd/Thrift/Transport/TEndpointTransport.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Thrift.Transport +{ + + abstract public class TEndpointTransport : TTransport + { + protected long MaxMessageSize { get => Configuration.MaxMessageSize; } + protected long RemainingMessageSize { get; private set; } + + private readonly TConfiguration _configuration; + public override TConfiguration Configuration { get => _configuration; } + + public TEndpointTransport( TConfiguration config) + { + _configuration = config ?? new TConfiguration(); + Debug.Assert(Configuration != null); + + ResetConsumedMessageSize(); + } + + /// <summary> + /// Resets RemainingMessageSize to the configured maximum + /// </summary> + protected void ResetConsumedMessageSize(long knownSize = -1) + { + if(knownSize >= 0) + RemainingMessageSize = Math.Min( MaxMessageSize, knownSize); + else + RemainingMessageSize = MaxMessageSize; + } + + /// <summary> + /// Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport). + /// Will throw if we already consumed too many bytes. + /// </summary> + /// <param name="size"></param> + public override void UpdateKnownMessageSize(long size) + { + var consumed = MaxMessageSize - RemainingMessageSize; + ResetConsumedMessageSize(size); + CountConsumedMessageBytes(consumed); + } + + /// <summary> + /// Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data + /// </summary> + /// <param name="numBytes"></param> + protected void CheckReadBytesAvailable(long numBytes) + { + if (RemainingMessageSize < numBytes) + throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached"); + } + + /// <summary> + /// Consumes numBytes from the RemainingMessageSize. + /// </summary> + /// <param name="numBytes"></param> + protected void CountConsumedMessageBytes(long numBytes) + { + if (RemainingMessageSize >= numBytes) + { + RemainingMessageSize -= numBytes; + } + else + { + RemainingMessageSize = 0; + throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached"); + } + } + } +} diff --git a/lib/netstd/Thrift/Transport/TTransport.cs b/lib/netstd/Thrift/Transport/TTransport.cs index 799801202..8f510ddb9 100644 --- a/lib/netstd/Thrift/Transport/TTransport.cs +++ b/lib/netstd/Thrift/Transport/TTransport.cs @@ -30,7 +30,10 @@ namespace Thrift.Transport //TODO: think how to avoid peek byte private readonly byte[] _peekBuffer = new byte[1]; private bool _hasPeekByte; + public abstract bool IsOpen { get; } + public abstract TConfiguration Configuration { get; } + public abstract void UpdateKnownMessageSize(long size); public void Dispose() { diff --git a/test/netstd/Client/Performance/PerformanceTests.cs b/test/netstd/Client/Performance/PerformanceTests.cs index 041d12eae..05c64b240 100644 --- a/test/netstd/Client/Performance/PerformanceTests.cs +++ b/test/netstd/Client/Performance/PerformanceTests.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using System.Text; using ThriftTest; using Thrift.Collections; +using Thrift; using Thrift.Protocol; using System.Threading; using Thrift.Transport.Client; @@ -36,6 +37,7 @@ namespace Client.Tests private TMemoryBufferTransport MemBuffer; private TTransport Transport; private LayeredChoice Layered; + private readonly TConfiguration Configuration = new TConfiguration(); internal static int Execute() { @@ -52,6 +54,11 @@ namespace Client.Tests return 0; } + public PerformanceTests() + { + Configuration.MaxFrameSize = Configuration.MaxMessageSize; // default frame size is too small for this test + } + private async Task ProtocolPeformanceTestAsync() { Console.WriteLine("Setting up for ProtocolPeformanceTestAsync ..."); @@ -76,9 +83,9 @@ namespace Client.Tests { // read happens after write here, so let's take over the written bytes if (forWrite) - MemBuffer = new TMemoryBufferTransport(); + MemBuffer = new TMemoryBufferTransport(Configuration); else - MemBuffer = new TMemoryBufferTransport(MemBuffer.GetBuffer()); + MemBuffer = new TMemoryBufferTransport(MemBuffer.GetBuffer(), Configuration); // layered transports anyone? switch (Layered) diff --git a/test/netstd/Client/TestClient.cs b/test/netstd/Client/TestClient.cs index 13ae31343..0c147dcc7 100644 --- a/test/netstd/Client/TestClient.cs +++ b/test/netstd/Client/TestClient.cs @@ -28,6 +28,7 @@ using System.ServiceModel; using System.Text; using System.Threading; using System.Threading.Tasks; +using Thrift; using Thrift.Collections; using Thrift.Protocol; using Thrift.Transport; @@ -72,6 +73,7 @@ namespace ThriftTest public LayeredChoice layered = LayeredChoice.None; public ProtocolChoice protocol = ProtocolChoice.Binary; public TransportChoice transport = TransportChoice.Socket; + private readonly TConfiguration Configuration = null; // or new TConfiguration() if needed internal void Parse(List<string> args) { @@ -235,12 +237,12 @@ namespace ThriftTest { case TransportChoice.Http: Debug.Assert(url != null); - trans = new THttpTransport(new Uri(url), null); + trans = new THttpTransport(new Uri(url), Configuration); break; case TransportChoice.NamedPipe: Debug.Assert(pipe != null); - trans = new TNamedPipeTransport(pipe); + trans = new TNamedPipeTransport(pipe,Configuration); break; case TransportChoice.TlsSocket: @@ -250,14 +252,15 @@ namespace ThriftTest throw new InvalidOperationException("Certificate doesn't contain private key"); } - trans = new TTlsSocketTransport(host, port, 0, cert, + trans = new TTlsSocketTransport(host, port, Configuration, 0, + cert, (sender, certificate, chain, errors) => true, null, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12); break; case TransportChoice.Socket: default: - trans = new TSocketTransport(host, port); + trans = new TSocketTransport(host, port, Configuration); break; } diff --git a/test/netstd/Server/TestServer.cs b/test/netstd/Server/TestServer.cs index 280f4e983..68461dc9c 100644 --- a/test/netstd/Server/TestServer.cs +++ b/test/netstd/Server/TestServer.cs @@ -148,6 +148,8 @@ namespace ThriftTest public class TestServer { public static int _clientID = -1; + private static readonly TConfiguration Configuration = null; // or new TConfiguration() if needed + public delegate void TestLogDelegate(string msg, params object[] values); public class MyServerEventHandler : TServerEventHandler @@ -552,7 +554,7 @@ namespace ThriftTest { case TransportChoice.NamedPipe: Debug.Assert(param.pipe != null); - trans = new TNamedPipeServerTransport(param.pipe); + trans = new TNamedPipeServerTransport(param.pipe, Configuration); break; @@ -564,14 +566,15 @@ namespace ThriftTest throw new InvalidOperationException("Certificate doesn't contain private key"); } - trans = new TTlsServerSocketTransport( param.port, cert, + trans = new TTlsServerSocketTransport(param.port, Configuration, + cert, (sender, certificate, chain, errors) => true, null, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12); break; case TransportChoice.Socket: default: - trans = new TServerSocketTransport(param.port, 0); + trans = new TServerSocketTransport(param.port, Configuration); break; } diff --git a/tutorial/netstd/Client/Program.cs b/tutorial/netstd/Client/Program.cs index f9509fa2d..857b3e808 100644 --- a/tutorial/netstd/Client/Program.cs +++ b/tutorial/netstd/Client/Program.cs @@ -40,6 +40,7 @@ namespace Client { private static ServiceCollection ServiceCollection = new ServiceCollection(); private static ILogger Logger; + private static readonly TConfiguration Configuration = null; // new TConfiguration() if needed private static void DisplayHelp() { @@ -143,7 +144,7 @@ Sample: private static TTransport GetTransport(string[] args) { - TTransport transport = new TSocketTransport(IPAddress.Loopback, 9090); + TTransport transport = new TSocketTransport(IPAddress.Loopback, 9090, Configuration); // construct endpoint transport var transportArg = args.FirstOrDefault(x => x.StartsWith("-tr"))?.Split(':')?[1]; @@ -152,19 +153,20 @@ Sample: switch (selectedTransport) { case Transport.Tcp: - transport = new TSocketTransport(IPAddress.Loopback, 9090); + transport = new TSocketTransport(IPAddress.Loopback, 9090, Configuration); break; case Transport.NamedPipe: - transport = new TNamedPipeTransport(".test"); + transport = new TNamedPipeTransport(".test", Configuration); break; case Transport.Http: - transport = new THttpTransport(new Uri("http://localhost:9090"), null); + transport = new THttpTransport(new Uri("http://localhost:9090"), Configuration); break; case Transport.TcpTls: - transport = new TTlsSocketTransport(IPAddress.Loopback, 9090, GetCertificate(), CertValidator, LocalCertificateSelectionCallback); + transport = new TTlsSocketTransport(IPAddress.Loopback, 9090, Configuration, + GetCertificate(), CertValidator, LocalCertificateSelectionCallback); break; default: diff --git a/tutorial/netstd/Server/Program.cs b/tutorial/netstd/Server/Program.cs index e1dab01e0..c1e0cb3ec 100644 --- a/tutorial/netstd/Server/Program.cs +++ b/tutorial/netstd/Server/Program.cs @@ -44,6 +44,7 @@ namespace Server { private static ServiceCollection ServiceCollection = new ServiceCollection(); private static ILogger Logger; + private static readonly TConfiguration Configuration = null; // new TConfiguration() if needed public static void Main(string[] args) { @@ -163,13 +164,14 @@ Sample: switch (transport) { case Transport.Tcp: - serverTransport = new TServerSocketTransport(9090); + serverTransport = new TServerSocketTransport(9090, Configuration); break; case Transport.NamedPipe: - serverTransport = new TNamedPipeServerTransport(".test"); + serverTransport = new TNamedPipeServerTransport(".test", Configuration); break; case Transport.TcpTls: - serverTransport = new TTlsServerSocketTransport(9090, GetCertificate(), ClientCertValidator, LocalCertificateSelectionCallback); + serverTransport = new TTlsServerSocketTransport(9090, Configuration, + GetCertificate(), ClientCertValidator, LocalCertificateSelectionCallback); break; } |