diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2023-10-26 16:28:58 +0700 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2023-10-26 16:28:58 +0700 |
commit | baffac980d9962290dc401f2d81c6c980e4d81b8 (patch) | |
tree | 6c7411b7ed45a70c0e279c5d8f6a554623a99124 /crypto | |
parent | Refactoring in Ed448 (diff) | |
download | BouncyCastle.NET-ed25519-baffac980d9962290dc401f2d81c6c980e4d81b8.tar.xz |
DTLS: Fixed retransmission in response to re-receipt of an aggregated ChangeCipherSpec
- see https://github.com/bcgit/bc-java/pull/1491
Diffstat (limited to 'crypto')
-rw-r--r-- | crypto/Readme.html | 1 | ||||
-rw-r--r-- | crypto/src/tls/DtlsRecordLayer.cs | 16 | ||||
-rw-r--r-- | crypto/src/tls/TlsUtilities.cs | 13 | ||||
-rw-r--r-- | crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs | 138 | ||||
-rw-r--r-- | crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs | 134 | ||||
-rw-r--r-- | crypto/test/src/tls/test/FilteredDatagramTransport.cs | 112 | ||||
-rw-r--r-- | crypto/test/src/tls/test/LoggingDatagramTransport.cs | 15 | ||||
-rw-r--r-- | crypto/test/src/tls/test/MinimalHandshakeAggregator.cs | 254 | ||||
-rw-r--r-- | crypto/test/src/tls/test/MockDtlsClient.cs | 6 | ||||
-rw-r--r-- | crypto/test/src/tls/test/ServerHandshakeDropper.cs | 63 | ||||
-rw-r--r-- | crypto/test/src/tls/test/UnreliableDatagramTransport.cs | 23 |
11 files changed, 736 insertions, 39 deletions
diff --git a/crypto/Readme.html b/crypto/Readme.html index 27745b848..91d33d20f 100644 --- a/crypto/Readme.html +++ b/crypto/Readme.html @@ -336,6 +336,7 @@ <li>DTLS: Fixed an exception during server handshake when 1.2 is negotiated and the ClientHello contained no extensions.</li> <li>HC128Engine now strictly requires 128 bits of IV.</li> <li>DTLS: Fixed server support for client_certificate_type extension.</li> + <li>DTLS: Fixed retransmission in response to re-receipt of an aggregated ChangeCipherSpec.</li> </ul> <h5>Additional Features and Functionality</h5> <ul> diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index e3567aa46..fe3b58d41 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -715,10 +715,12 @@ namespace Org.BouncyCastle.Tls { recordEpoch = m_readEpoch; } - else if (recordType == ContentType.handshake && null != m_retransmitEpoch - && epoch == m_retransmitEpoch.Epoch) + else if (null != m_retransmitEpoch && epoch == m_retransmitEpoch.Epoch) { - recordEpoch = m_retransmitEpoch; + if (recordType == ContentType.handshake) + { + recordEpoch = m_retransmitEpoch; + } } if (null == recordEpoch) @@ -994,7 +996,6 @@ namespace Org.BouncyCastle.Tls int recordLength = RecordHeaderLength; if (m_recordQueue.Available >= recordLength) { - short recordType = m_recordQueue.ReadUint8(0); int epoch = m_recordQueue.ReadUint16(3); DtlsEpoch recordEpoch = null; @@ -1002,8 +1003,7 @@ namespace Org.BouncyCastle.Tls { recordEpoch = m_readEpoch; } - else if (recordType == ContentType.handshake && null != m_retransmitEpoch - && epoch == m_retransmitEpoch.Epoch) + else if (null != m_retransmitEpoch && epoch == m_retransmitEpoch.Epoch) { recordEpoch = m_retransmitEpoch; } @@ -1038,7 +1038,6 @@ namespace Org.BouncyCastle.Tls { this.m_inConnection = true; - short recordType = TlsUtilities.ReadUint8(buf, off); int epoch = TlsUtilities.ReadUint16(buf, off + 3); DtlsEpoch recordEpoch = null; @@ -1046,8 +1045,7 @@ namespace Org.BouncyCastle.Tls { recordEpoch = m_readEpoch; } - else if (recordType == ContentType.handshake && null != m_retransmitEpoch - && epoch == m_retransmitEpoch.Epoch) + else if (null != m_retransmitEpoch && epoch == m_retransmitEpoch.Epoch) { recordEpoch = m_retransmitEpoch; } diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs index 2887b0df1..67a49e5ef 100644 --- a/crypto/src/tls/TlsUtilities.cs +++ b/crypto/src/tls/TlsUtilities.cs @@ -770,11 +770,20 @@ namespace Org.BouncyCastle.Tls public static int ReadUint16(byte[] buf, int offset) { - int n = (buf[offset] & 0xff) << 8; - n |= (buf[++offset] & 0xff); + int n = buf[offset] << 8; + n |= buf[++offset]; return n; } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + public static int ReadUint16(ReadOnlySpan<byte> buffer) + { + int n = buffer[0] << 8; + n |= buffer[1]; + return n; + } +#endif + public static int ReadUint24(Stream input) { int i1 = input.ReadByte(); diff --git a/crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs b/crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs new file mode 100644 index 000000000..3c78b7e52 --- /dev/null +++ b/crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs @@ -0,0 +1,138 @@ +using System; +using System.Text; +using System.Threading; + +using NUnit.Framework; + +using Org.BouncyCastle.Tls.Crypto; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Tls.Tests +{ + [TestFixture] + public class DtlsAggregatedHandshakeRetransmissionTest + { + [Test] + public void TestClientServer() + { + DtlsClientProtocol clientProtocol = new DtlsClientProtocol(); + DtlsServerProtocol serverProtocol = new DtlsServerProtocol(); + + MockDatagramAssociation network = new MockDatagramAssociation(1500); + + ServerTask serverTask = new ServerTask(serverProtocol, network.Server); + + Thread serverThread = new Thread(new ThreadStart(serverTask.Run)); + serverThread.Start(); + + DatagramTransport clientTransport = network.Client; + + clientTransport = new ServerHandshakeDropper(clientTransport, true); + + clientTransport = new LoggingDatagramTransport(clientTransport, Console.Out); + + clientTransport = new MinimalHandshakeAggregator(clientTransport, false, true); + + MockDtlsClient client = new MockDtlsClient(null); + + client.SetHandshakeTimeoutMillis(30000); // Test gets stuck, so we need it to time out. + + DtlsTransport dtlsClient = clientProtocol.Connect(client, clientTransport); + + for (int i = 1; i <= 10; ++i) + { + byte[] data = new byte[i]; + Arrays.Fill(data, (byte)i); + dtlsClient.Send(data, 0, data.Length); + } + + byte[] buf = new byte[dtlsClient.GetReceiveLimit()]; + while (dtlsClient.Receive(buf, 0, buf.Length, 100) >= 0) + { + } + + dtlsClient.Close(); + + serverTask.Shutdown(serverThread); + } + + internal class ServerTask + { + private readonly DtlsServerProtocol m_serverProtocol; + private readonly DatagramTransport m_serverTransport; + private volatile bool m_isShutdown = false; + + internal ServerTask(DtlsServerProtocol serverProtocol, DatagramTransport serverTransport) + { + this.m_serverProtocol = serverProtocol; + this.m_serverTransport = serverTransport; + } + + public void Run() + { + try + { + TlsCrypto serverCrypto = new BcTlsCrypto(); + + DtlsRequest request = null; + + // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting + { + DtlsVerifier verifier = new DtlsVerifier(serverCrypto); + + // NOTE: Test value only - would typically be the client IP address + byte[] clientID = Encoding.UTF8.GetBytes("MockDtlsClient"); + + int receiveLimit = m_serverTransport.GetReceiveLimit(); + int dummyOffset = serverCrypto.SecureRandom.Next(16) + 1; + byte[] buf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()]; + + do + { + if (m_isShutdown) + return; + + int length = m_serverTransport.Receive(buf, dummyOffset, receiveLimit, 100); + if (length > 0) + { + request = verifier.VerifyRequest(clientID, buf, dummyOffset, length, m_serverTransport); + } + } + while (request == null); + } + + // NOTE: A real server would handle each DtlsRequest in a new task/thread and continue accepting + { + MockDtlsServer server = new MockDtlsServer(serverCrypto); + DtlsTransport dtlsTransport = m_serverProtocol.Accept(server, m_serverTransport, request); + byte[] buf = new byte[dtlsTransport.GetReceiveLimit()]; + while (!m_isShutdown) + { + int length = dtlsTransport.Receive(buf, 0, buf.Length, 100); + if (length >= 0) + { + dtlsTransport.Send(buf, 0, length); + } + } + dtlsTransport.Close(); + } + } + catch (Exception e) + { + Console.Error.WriteLine(e); + Console.Error.Flush(); + } + } + + internal void Shutdown(Thread serverThread) + { + if (!m_isShutdown) + { + this.m_isShutdown = true; + serverThread.Join(); + } + } + } + } +} diff --git a/crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs b/crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs new file mode 100644 index 000000000..6c897ff04 --- /dev/null +++ b/crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs @@ -0,0 +1,134 @@ +using System; +using System.Text; +using System.Threading; + +using NUnit.Framework; + +using Org.BouncyCastle.Tls.Crypto; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Tls.Tests +{ + [TestFixture] + public class DtlsHandshakeRetransmissionTest + { + [Test] + public void TestClientServer() + { + DtlsClientProtocol clientProtocol = new DtlsClientProtocol(); + DtlsServerProtocol serverProtocol = new DtlsServerProtocol(); + + MockDatagramAssociation network = new MockDatagramAssociation(1500); + + ServerTask serverTask = new ServerTask(serverProtocol, network.Server); + + Thread serverThread = new Thread(new ThreadStart(serverTask.Run)); + serverThread.Start(); + + DatagramTransport clientTransport = network.Client; + + clientTransport = new ServerHandshakeDropper(clientTransport, true); + + clientTransport = new LoggingDatagramTransport(clientTransport, Console.Out); + + MockDtlsClient client = new MockDtlsClient(null); + + DtlsTransport dtlsClient = clientProtocol.Connect(client, clientTransport); + + for (int i = 1; i <= 10; ++i) + { + byte[] data = new byte[i]; + Arrays.Fill(data, (byte)i); + dtlsClient.Send(data, 0, data.Length); + } + + byte[] buf = new byte[dtlsClient.GetReceiveLimit()]; + while (dtlsClient.Receive(buf, 0, buf.Length, 100) >= 0) + { + } + + dtlsClient.Close(); + + serverTask.Shutdown(serverThread); + } + + internal class ServerTask + { + private readonly DtlsServerProtocol m_serverProtocol; + private readonly DatagramTransport m_serverTransport; + private volatile bool m_isShutdown = false; + + internal ServerTask(DtlsServerProtocol serverProtocol, DatagramTransport serverTransport) + { + this.m_serverProtocol = serverProtocol; + this.m_serverTransport = serverTransport; + } + + public void Run() + { + try + { + TlsCrypto serverCrypto = new BcTlsCrypto(); + + DtlsRequest request = null; + + // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting + { + DtlsVerifier verifier = new DtlsVerifier(serverCrypto); + + // NOTE: Test value only - would typically be the client IP address + byte[] clientID = Encoding.UTF8.GetBytes("MockDtlsClient"); + + int receiveLimit = m_serverTransport.GetReceiveLimit(); + int dummyOffset = serverCrypto.SecureRandom.Next(16) + 1; + byte[] buf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()]; + + do + { + if (m_isShutdown) + return; + + int length = m_serverTransport.Receive(buf, dummyOffset, receiveLimit, 100); + if (length > 0) + { + request = verifier.VerifyRequest(clientID, buf, dummyOffset, length, m_serverTransport); + } + } + while (request == null); + } + + // NOTE: A real server would handle each DtlsRequest in a new task/thread and continue accepting + { + MockDtlsServer server = new MockDtlsServer(serverCrypto); + DtlsTransport dtlsTransport = m_serverProtocol.Accept(server, m_serverTransport, request); + byte[] buf = new byte[dtlsTransport.GetReceiveLimit()]; + while (!m_isShutdown) + { + int length = dtlsTransport.Receive(buf, 0, buf.Length, 100); + if (length >= 0) + { + dtlsTransport.Send(buf, 0, length); + } + } + dtlsTransport.Close(); + } + } + catch (Exception e) + { + Console.Error.WriteLine(e); + Console.Error.Flush(); + } + } + + internal void Shutdown(Thread serverThread) + { + if (!m_isShutdown) + { + this.m_isShutdown = true; + serverThread.Join(); + } + } + } + } +} diff --git a/crypto/test/src/tls/test/FilteredDatagramTransport.cs b/crypto/test/src/tls/test/FilteredDatagramTransport.cs new file mode 100644 index 000000000..23c0839d6 --- /dev/null +++ b/crypto/test/src/tls/test/FilteredDatagramTransport.cs @@ -0,0 +1,112 @@ +using System; + +using Org.BouncyCastle.Utilities.Date; + +namespace Org.BouncyCastle.Tls.Tests +{ + public class FilteredDatagramTransport + : DatagramTransport + { + public delegate bool FilterPredicate(byte[] buf, int off, int len); + + public static bool AlwaysAllow(byte[] buf, int off, int len) => true; + + private readonly DatagramTransport m_transport; + + private readonly FilterPredicate m_allowReceiving, m_allowSending; + + public FilteredDatagramTransport(DatagramTransport transport, FilterPredicate allowReceiving, + FilterPredicate allowSending) + { + m_transport = transport; + m_allowReceiving = allowReceiving; + m_allowSending = allowSending; + } + + public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit(); + + public virtual int GetSendLimit() => m_transport.GetSendLimit(); + + public virtual int Receive(byte[] buf, int off, int len, int waitMillis) + { + long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis; + for (;;) + { + int length = m_transport.Receive(buf, off, len, waitMillis); + if (length < 0 || m_allowReceiving(buf, off, len)) + return length; + + Console.WriteLine("PACKET FILTERED ({0} byte packet not received)", length); + + long now = DateTimeUtilities.CurrentUnixMs(); + if (now >= endMillis) + return -1; + + waitMillis = (int)(endMillis - now); + } + } + +//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + public virtual int Receive(Span<byte> buffer, int waitMillis) + { + long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis; + for (;;) + { + int length = m_transport.Receive(buffer, waitMillis); + if (length < 0 || m_allowReceiving(buffer.ToArray(), 0, buffer.Length)) + return length; + + Console.WriteLine("PACKET FILTERED ({0} byte packet not received)", length); + + long now = DateTimeUtilities.CurrentUnixMs(); + if (now >= endMillis) + return -1; + + waitMillis = (int)(endMillis - now); + } + } +#endif + + public virtual void Send(byte[] buf, int off, int len) + { + if (!m_allowSending(buf, off, len)) + { + Console.WriteLine("PACKET FILTERED ({0} byte packet not sent)", len); + } + else + { + m_transport.Send(buf, off, len); + } + } + + //#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + public virtual void Send(ReadOnlySpan<byte> buffer) + { + if (!m_allowSending(buffer.ToArray(), 0, buffer.Length)) + { + Console.WriteLine("PACKET FILTERED ({0} byte packet not sent)", buffer.Length); + } + else + { + m_transport.Send(buffer); + } + } +#endif + + public virtual void Close() => m_transport.Close(); + + //static FilterPredicate ALWAYS_ALLOW = new FilterPredicate() { + // @Override + // public boolean allowPacket(byte[] buf, int off, int len) + // { + // return true; + // } + //}; + + //interface FilterPredicate { + // boolean allowPacket(byte[] buf, int off, int len); + //} + } +} diff --git a/crypto/test/src/tls/test/LoggingDatagramTransport.cs b/crypto/test/src/tls/test/LoggingDatagramTransport.cs index 59113cf73..d03e99551 100644 --- a/crypto/test/src/tls/test/LoggingDatagramTransport.cs +++ b/crypto/test/src/tls/test/LoggingDatagramTransport.cs @@ -22,15 +22,9 @@ namespace Org.BouncyCastle.Tls.Tests this.m_launchTimestamp = DateTimeUtilities.CurrentUnixMs(); } - public virtual int GetReceiveLimit() - { - return m_transport.GetReceiveLimit(); - } + public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit(); - public virtual int GetSendLimit() - { - return m_transport.GetSendLimit(); - } + public virtual int GetSendLimit() => m_transport.GetSendLimit(); public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { @@ -80,10 +74,7 @@ namespace Org.BouncyCastle.Tls.Tests } #endif - public virtual void Close() - { - m_transport.Close(); - } + public virtual void Close() => m_transport.Close(); //#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER #if NET6_0_OR_GREATER diff --git a/crypto/test/src/tls/test/MinimalHandshakeAggregator.cs b/crypto/test/src/tls/test/MinimalHandshakeAggregator.cs new file mode 100644 index 000000000..645e32851 --- /dev/null +++ b/crypto/test/src/tls/test/MinimalHandshakeAggregator.cs @@ -0,0 +1,254 @@ +using System; + +using Org.BouncyCastle.Utilities.Date; + +namespace Org.BouncyCastle.Tls.Tests +{ + /** + * A very minimal and stupid class to aggregate DTLS handshake messages. Only sufficient for unit tests. + */ + public class MinimalHandshakeAggregator + : DatagramTransport + { + private readonly DatagramTransport m_transport; + + private readonly bool m_aggregateReceiving, m_aggregateSending; + + private byte[] m_receiveBuf, m_sendBuf; + + private int m_receiveRecordCount, m_sendRecordCount; + + private byte[] AddToBuf(byte[] baseBuf, byte[] buf, int off, int len) + { + byte[] ret = new byte[baseBuf.Length + len]; + Array.Copy(baseBuf, 0, ret, 0, baseBuf.Length); + Array.Copy(buf, off, ret, baseBuf.Length, len); + return ret; + } + +//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + private byte[] AddToBuf(byte[] baseBuf, ReadOnlySpan<byte> buf) + { + byte[] ret = new byte[baseBuf.Length + buf.Length]; + Array.Copy(baseBuf, 0, ret, 0, baseBuf.Length); + buf.CopyTo(ret[baseBuf.Length..]); + return ret; + } +#endif + + private void AddToReceiveBuf(byte[] buf, int off, int len) + { + m_receiveBuf = AddToBuf(m_receiveBuf, buf, off, len); + m_receiveRecordCount++; + } + +//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + private void AddToReceiveBuf(ReadOnlySpan<byte> buf) + { + m_receiveBuf = AddToBuf(m_receiveBuf, buf); + m_receiveRecordCount++; + } +#endif + + private void ResetReceiveBuf() + { + m_receiveBuf = new byte[0]; + m_receiveRecordCount = 0; + } + + private void AddToSendBuf(byte[] buf, int off, int len) + { + m_sendBuf = AddToBuf(m_sendBuf, buf, off, len); + m_sendRecordCount++; + } + +//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + private void AddToSendBuf(ReadOnlySpan<byte> buf) + { + m_sendBuf = AddToBuf(m_sendBuf, buf); + m_sendRecordCount++; + } +#endif + + private void ResetSendBuf() + { + m_sendBuf = new byte[0]; + m_sendRecordCount = 0; + } + + /** Whether the buffered aggregated data should be flushed after this packet. + * This is done on the end of the first flight - ClientHello and ServerHelloDone - and anything that is + * Epoch 1. + */ + private bool FlushAfterThisPacket(byte[] buf, int off, int len) + { + int epoch = TlsUtilities.ReadUint16(buf, off + 3); + if (epoch > 0) + return true; + + short contentType = TlsUtilities.ReadUint8(buf, off); + if (ContentType.handshake != contentType) + return false; + + short msgType = TlsUtilities.ReadUint8(buf, off + 13); + switch (msgType) + { + case HandshakeType.client_hello: + case HandshakeType.server_hello_done: + return true; + default: + return false; + } + } + +//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + private bool FlushAfterThisPacket(ReadOnlySpan<byte> buffer) + { + int epoch = TlsUtilities.ReadUint16(buffer[3..]); + if (epoch > 0) + return true; + + short contentType = TlsUtilities.ReadUint8(buffer); + if (ContentType.handshake != contentType) + return false; + + short msgType = TlsUtilities.ReadUint8(buffer[13..]); + switch (msgType) + { + case HandshakeType.client_hello: + case HandshakeType.server_hello_done: + return true; + default: + return false; + } + } +#endif + + public MinimalHandshakeAggregator(DatagramTransport transport, bool aggregateReceiving, bool aggregateSending) + { + m_transport = transport; + m_aggregateReceiving = aggregateReceiving; + m_aggregateSending = aggregateSending; + + ResetReceiveBuf(); + ResetSendBuf(); + } + + public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit(); + + public virtual int GetSendLimit() => m_transport.GetSendLimit(); + + public virtual int Receive(byte[] buf, int off, int len, int waitMillis) + { + long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis; + for (;;) + { + int length = m_transport.Receive(buf, off, len, waitMillis); + if (length < 0 || !m_aggregateReceiving) + return length; + + AddToReceiveBuf(buf, off, length); + + if (FlushAfterThisPacket(buf, off, length)) + { + if (m_receiveRecordCount > 1) + { + Console.WriteLine("RECEIVING {0} RECORDS IN {1} BYTE PACKET", m_receiveRecordCount, length); + } + Array.Copy(m_receiveBuf, 0, buf, off, System.Math.Min(len, m_receiveBuf.Length)); + ResetReceiveBuf(); + return length; + } + + long now = DateTimeUtilities.CurrentUnixMs(); + if (now >= endMillis) + return -1; + + waitMillis = (int)(endMillis - now); + } + } + +//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + public virtual int Receive(Span<byte> buffer, int waitMillis) + { + long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis; + for (;;) + { + int length = m_transport.Receive(buffer, waitMillis); + if (length < 0 || !m_aggregateReceiving) + return length; + + AddToReceiveBuf(buffer); + + if (FlushAfterThisPacket(buffer)) + { + if (m_receiveRecordCount > 1) + { + Console.WriteLine("RECEIVING {0} RECORDS IN {1} BYTE PACKET", m_receiveRecordCount, length); + } + int resultLength = System.Math.Min(buffer.Length, m_receiveBuf.Length); + m_receiveBuf.AsSpan(0, resultLength).CopyTo(buffer); + ResetReceiveBuf(); + return resultLength; + } + + long now = DateTimeUtilities.CurrentUnixMs(); + if (now >= endMillis) + return -1; + + waitMillis = (int)(endMillis - now); + } + } +#endif + + public virtual void Send(byte[] buf, int off, int len) + { + if (!m_aggregateSending) + { + m_transport.Send(buf, off, len); + return; + } + AddToSendBuf(buf, off, len); + + if (FlushAfterThisPacket(buf, off, len)) + { + if (m_sendRecordCount > 1) + { + Console.WriteLine("SENDING {0} RECORDS IN {1} BYTE PACKET", m_sendRecordCount, m_sendBuf.Length); + } + m_transport.Send(m_sendBuf, 0, m_sendBuf.Length); + ResetSendBuf(); + } + } + + //#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +#if NET6_0_OR_GREATER + public virtual void Send(ReadOnlySpan<byte> buffer) + { + if (!m_aggregateSending) + { + m_transport.Send(buffer); + return; + } + AddToSendBuf(buffer); + + if (FlushAfterThisPacket(buffer)) + { + if (m_sendRecordCount > 1) + { + Console.WriteLine("SENDING {0} RECORDS IN {1} BYTE PACKET", m_sendRecordCount, m_sendBuf.Length); + } + m_transport.Send(m_sendBuf, 0, m_sendBuf.Length); + ResetSendBuf(); + } + } +#endif + + public virtual void Close() => m_transport.Close(); + } +} diff --git a/crypto/test/src/tls/test/MockDtlsClient.cs b/crypto/test/src/tls/test/MockDtlsClient.cs index 9898acd08..2e7ab9450 100644 --- a/crypto/test/src/tls/test/MockDtlsClient.cs +++ b/crypto/test/src/tls/test/MockDtlsClient.cs @@ -15,6 +15,8 @@ namespace Org.BouncyCastle.Tls.Tests { internal TlsSession m_session; + private int m_handshakeTimeoutMillis = 0; + internal MockDtlsClient(TlsSession session) : base(new BcTlsCrypto()) { @@ -26,6 +28,10 @@ namespace Org.BouncyCastle.Tls.Tests return this.m_session; } + public override int GetHandshakeTimeoutMillis() => m_handshakeTimeoutMillis; + + public void SetHandshakeTimeoutMillis(int millis) => m_handshakeTimeoutMillis = millis; + public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message, Exception cause) { diff --git a/crypto/test/src/tls/test/ServerHandshakeDropper.cs b/crypto/test/src/tls/test/ServerHandshakeDropper.cs new file mode 100644 index 000000000..89c752333 --- /dev/null +++ b/crypto/test/src/tls/test/ServerHandshakeDropper.cs @@ -0,0 +1,63 @@ +using System; + +namespace Org.BouncyCastle.Tls.Tests +{ + /** This is a [Transport] wrapper which causes the first retransmission of the second flight of a server + * handshake to be dropped. */ + public class ServerHandshakeDropper + : FilteredDatagramTransport + { + private static FilterPredicate Choose(bool condition, FilterPredicate left, FilterPredicate right) + { + if (condition) { return left; } else { return right; } + } + + public ServerHandshakeDropper(DatagramTransport transport, bool dropOnReceive) + : base(transport, + Choose(dropOnReceive, new DropFirstServerFinalFlight().AllowPacket, AlwaysAllow), + Choose(dropOnReceive, AlwaysAllow, new DropFirstServerFinalFlight().AllowPacket)) + { + } + + /** This drops the first instance of DTLS packets that either begin with a ChangeCipherSpec, or handshake in + * epoch 1. This is the server's final flight of the handshake. It will test whether the client properly + * retransmits its second flight, and the server properly retransmits the dropped flight. + */ + private class DropFirstServerFinalFlight + { + private bool m_sawChangeCipherSpec = false; + private bool m_sawEpoch1Handshake = false; + + private bool IsChangeCipherSpec(byte[] buf, int off, int len) + { + short contentType = TlsUtilities.ReadUint8(buf, off); + return ContentType.change_cipher_spec == contentType; + } + + private bool IsEpoch1Handshake(byte[] buf, int off, int len) + { + short contentType = TlsUtilities.ReadUint8(buf, off); + if (ContentType.handshake != contentType) + return false; + + int epoch = TlsUtilities.ReadUint16(buf, off + 3); + return 1 == epoch; + } + + public bool AllowPacket(byte[] buf, int off, int len) + { + if (!m_sawChangeCipherSpec && IsChangeCipherSpec(buf, off, len)) + { + m_sawChangeCipherSpec = true; + return false; + } + if (!m_sawEpoch1Handshake && IsEpoch1Handshake(buf, off, len)) + { + m_sawEpoch1Handshake = true; + return false; + } + return true; + } + } + } +} diff --git a/crypto/test/src/tls/test/UnreliableDatagramTransport.cs b/crypto/test/src/tls/test/UnreliableDatagramTransport.cs index 7769db9d1..0aed2d7a1 100644 --- a/crypto/test/src/tls/test/UnreliableDatagramTransport.cs +++ b/crypto/test/src/tls/test/UnreliableDatagramTransport.cs @@ -25,15 +25,9 @@ namespace Org.BouncyCastle.Tls.Tests this.m_percentPacketLossSending = percentPacketLossSending; } - public virtual int GetReceiveLimit() - { - return m_transport.GetReceiveLimit(); - } + public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit(); - public virtual int GetSendLimit() - { - return m_transport.GetSendLimit(); - } + public virtual int GetSendLimit() => m_transport.GetSendLimit(); public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { @@ -48,7 +42,7 @@ namespace Org.BouncyCastle.Tls.Tests if (length < 0 || !LostPacket(m_percentPacketLossReceiving)) return length; - Console.WriteLine("PACKET LOSS (" + length + " byte packet not received)"); + Console.WriteLine("PACKET LOSS ({0} byte packet not received)", length); long now = DateTimeUtilities.CurrentUnixMs(); if (now >= endMillis) @@ -70,7 +64,7 @@ namespace Org.BouncyCastle.Tls.Tests if (length < 0 || !LostPacket(m_percentPacketLossReceiving)) return length; - Console.WriteLine("PACKET LOSS (" + length + " byte packet not received)"); + Console.WriteLine("PACKET LOSS ({0} byte packet not received)", length); long now = DateTimeUtilities.CurrentUnixMs(); if (now >= endMillis) @@ -89,7 +83,7 @@ namespace Org.BouncyCastle.Tls.Tests #else if (LostPacket(m_percentPacketLossSending)) { - Console.WriteLine("PACKET LOSS (" + len + " byte packet not sent)"); + Console.WriteLine("PACKET LOSS ({0} byte packet not sent)", len); } else { @@ -104,7 +98,7 @@ namespace Org.BouncyCastle.Tls.Tests { if (LostPacket(m_percentPacketLossSending)) { - Console.WriteLine("PACKET LOSS (" + buffer.Length + " byte packet not sent)"); + Console.WriteLine("PACKET LOSS ({0} byte packet not sent)", buffer.Length); } else { @@ -113,10 +107,7 @@ namespace Org.BouncyCastle.Tls.Tests } #endif - public virtual void Close() - { - m_transport.Close(); - } + public virtual void Close() => m_transport.Close(); private bool LostPacket(int percentPacketLoss) { |