summary refs log tree commit diff
path: root/crypto
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-10-26 16:28:58 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-10-26 16:28:58 +0700
commitbaffac980d9962290dc401f2d81c6c980e4d81b8 (patch)
tree6c7411b7ed45a70c0e279c5d8f6a554623a99124 /crypto
parentRefactoring in Ed448 (diff)
downloadBouncyCastle.NET-ed25519-baffac980d9962290dc401f2d81c6c980e4d81b8.tar.xz
DTLS: Fixed retransmission in response to re-receipt of an aggregated ChangeCipherSpec
- see https://github.com/bcgit/bc-java/pull/1491
Diffstat (limited to 'crypto')
-rw-r--r--crypto/Readme.html1
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs16
-rw-r--r--crypto/src/tls/TlsUtilities.cs13
-rw-r--r--crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs138
-rw-r--r--crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs134
-rw-r--r--crypto/test/src/tls/test/FilteredDatagramTransport.cs112
-rw-r--r--crypto/test/src/tls/test/LoggingDatagramTransport.cs15
-rw-r--r--crypto/test/src/tls/test/MinimalHandshakeAggregator.cs254
-rw-r--r--crypto/test/src/tls/test/MockDtlsClient.cs6
-rw-r--r--crypto/test/src/tls/test/ServerHandshakeDropper.cs63
-rw-r--r--crypto/test/src/tls/test/UnreliableDatagramTransport.cs23
11 files changed, 736 insertions, 39 deletions
diff --git a/crypto/Readme.html b/crypto/Readme.html
index 27745b848..91d33d20f 100644
--- a/crypto/Readme.html
+++ b/crypto/Readme.html
@@ -336,6 +336,7 @@
             <li>DTLS: Fixed an exception during server handshake when 1.2 is negotiated and the ClientHello contained no extensions.</li>
             <li>HC128Engine now strictly requires 128 bits of IV.</li>
             <li>DTLS: Fixed server support for client_certificate_type extension.</li>
+            <li>DTLS: Fixed retransmission in response to re-receipt of an aggregated ChangeCipherSpec.</li>
         </ul>
         <h5>Additional Features and Functionality</h5>
         <ul>
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index e3567aa46..fe3b58d41 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -715,10 +715,12 @@ namespace Org.BouncyCastle.Tls
             {
                 recordEpoch = m_readEpoch;
             }
-            else if (recordType == ContentType.handshake && null != m_retransmitEpoch
-                && epoch == m_retransmitEpoch.Epoch)
+            else if (null != m_retransmitEpoch && epoch == m_retransmitEpoch.Epoch)
             {
-                recordEpoch = m_retransmitEpoch;
+                if (recordType == ContentType.handshake)
+                {
+                    recordEpoch = m_retransmitEpoch;
+                }
             }
 
             if (null == recordEpoch)
@@ -994,7 +996,6 @@ namespace Org.BouncyCastle.Tls
             int recordLength = RecordHeaderLength;
             if (m_recordQueue.Available >= recordLength)
             {
-                short recordType = m_recordQueue.ReadUint8(0);
                 int epoch = m_recordQueue.ReadUint16(3);
 
                 DtlsEpoch recordEpoch = null;
@@ -1002,8 +1003,7 @@ namespace Org.BouncyCastle.Tls
                 {
                     recordEpoch = m_readEpoch;
                 }
-                else if (recordType == ContentType.handshake && null != m_retransmitEpoch
-                    && epoch == m_retransmitEpoch.Epoch)
+                else if (null != m_retransmitEpoch && epoch == m_retransmitEpoch.Epoch)
                 {
                     recordEpoch = m_retransmitEpoch;
                 }
@@ -1038,7 +1038,6 @@ namespace Org.BouncyCastle.Tls
             {
                 this.m_inConnection = true;
 
-                short recordType = TlsUtilities.ReadUint8(buf, off);
                 int epoch = TlsUtilities.ReadUint16(buf, off + 3);
 
                 DtlsEpoch recordEpoch = null;
@@ -1046,8 +1045,7 @@ namespace Org.BouncyCastle.Tls
                 {
                     recordEpoch = m_readEpoch;
                 }
-                else if (recordType == ContentType.handshake && null != m_retransmitEpoch
-                    && epoch == m_retransmitEpoch.Epoch)
+                else if (null != m_retransmitEpoch && epoch == m_retransmitEpoch.Epoch)
                 {
                     recordEpoch = m_retransmitEpoch;
                 }
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index 2887b0df1..67a49e5ef 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -770,11 +770,20 @@ namespace Org.BouncyCastle.Tls
 
         public static int ReadUint16(byte[] buf, int offset)
         {
-            int n = (buf[offset] & 0xff) << 8;
-            n |= (buf[++offset] & 0xff);
+            int n = buf[offset] << 8;
+            n |= buf[++offset];
             return n;
         }
 
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+        public static int ReadUint16(ReadOnlySpan<byte> buffer)
+        {
+            int n = buffer[0] << 8;
+            n |= buffer[1];
+            return n;
+        }
+#endif
+
         public static int ReadUint24(Stream input)
         {
             int i1 = input.ReadByte();
diff --git a/crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs b/crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs
new file mode 100644
index 000000000..3c78b7e52
--- /dev/null
+++ b/crypto/test/src/tls/test/DtlsAggregatedHandshakeRetransmissionTest.cs
@@ -0,0 +1,138 @@
+using System;
+using System.Text;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl.BC;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    [TestFixture]
+    public class DtlsAggregatedHandshakeRetransmissionTest
+    {
+        [Test]
+        public void TestClientServer()
+        {
+            DtlsClientProtocol clientProtocol = new DtlsClientProtocol();
+            DtlsServerProtocol serverProtocol = new DtlsServerProtocol();
+
+            MockDatagramAssociation network = new MockDatagramAssociation(1500);
+
+            ServerTask serverTask = new ServerTask(serverProtocol, network.Server);
+
+            Thread serverThread = new Thread(new ThreadStart(serverTask.Run));
+            serverThread.Start();
+
+            DatagramTransport clientTransport = network.Client;
+
+            clientTransport = new ServerHandshakeDropper(clientTransport, true);
+
+            clientTransport = new LoggingDatagramTransport(clientTransport, Console.Out);
+
+            clientTransport = new MinimalHandshakeAggregator(clientTransport, false, true);
+
+            MockDtlsClient client = new MockDtlsClient(null);
+
+            client.SetHandshakeTimeoutMillis(30000);    // Test gets stuck, so we need it to time out.
+
+            DtlsTransport dtlsClient = clientProtocol.Connect(client, clientTransport);
+
+            for (int i = 1; i <= 10; ++i)
+            {
+                byte[] data = new byte[i];
+                Arrays.Fill(data, (byte)i);
+                dtlsClient.Send(data, 0, data.Length);
+            }
+
+            byte[] buf = new byte[dtlsClient.GetReceiveLimit()];
+            while (dtlsClient.Receive(buf, 0, buf.Length, 100) >= 0)
+            {
+            }
+
+            dtlsClient.Close();
+
+            serverTask.Shutdown(serverThread);
+        }
+
+        internal class ServerTask
+        {
+            private readonly DtlsServerProtocol m_serverProtocol;
+            private readonly DatagramTransport m_serverTransport;
+            private volatile bool m_isShutdown = false;
+
+            internal ServerTask(DtlsServerProtocol serverProtocol, DatagramTransport serverTransport)
+            {
+                this.m_serverProtocol = serverProtocol;
+                this.m_serverTransport = serverTransport;
+            }
+
+            public void Run()
+            {
+                try
+                {
+                    TlsCrypto serverCrypto = new BcTlsCrypto();
+
+                    DtlsRequest request = null;
+
+                    // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting
+                    {
+                        DtlsVerifier verifier = new DtlsVerifier(serverCrypto);
+
+                        // NOTE: Test value only - would typically be the client IP address
+                        byte[] clientID = Encoding.UTF8.GetBytes("MockDtlsClient");
+
+                        int receiveLimit = m_serverTransport.GetReceiveLimit();
+                        int dummyOffset = serverCrypto.SecureRandom.Next(16) + 1;
+                        byte[] buf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()];
+
+                        do
+                        {
+                            if (m_isShutdown)
+                                return;
+
+                            int length = m_serverTransport.Receive(buf, dummyOffset, receiveLimit, 100);
+                            if (length > 0)
+                            {
+                                request = verifier.VerifyRequest(clientID, buf, dummyOffset, length, m_serverTransport);
+                            }
+                        }
+                        while (request == null);
+                    }
+
+                    // NOTE: A real server would handle each DtlsRequest in a new task/thread and continue accepting
+                    {
+                        MockDtlsServer server = new MockDtlsServer(serverCrypto);
+                        DtlsTransport dtlsTransport = m_serverProtocol.Accept(server, m_serverTransport, request);
+                        byte[] buf = new byte[dtlsTransport.GetReceiveLimit()];
+                        while (!m_isShutdown)
+                        {
+                            int length = dtlsTransport.Receive(buf, 0, buf.Length, 100);
+                            if (length >= 0)
+                            {
+                                dtlsTransport.Send(buf, 0, length);
+                            }
+                        }
+                        dtlsTransport.Close();
+                    }
+                }
+                catch (Exception e)
+                {
+                    Console.Error.WriteLine(e);
+                    Console.Error.Flush();
+                }
+            }
+
+            internal void Shutdown(Thread serverThread)
+            {
+                if (!m_isShutdown)
+                {
+                    this.m_isShutdown = true;
+                    serverThread.Join();
+                }
+            }
+        }
+    }
+}
diff --git a/crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs b/crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs
new file mode 100644
index 000000000..6c897ff04
--- /dev/null
+++ b/crypto/test/src/tls/test/DtlsHandshakeRetransmissionTest.cs
@@ -0,0 +1,134 @@
+using System;
+using System.Text;
+using System.Threading;
+
+using NUnit.Framework;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl.BC;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    [TestFixture]
+    public class DtlsHandshakeRetransmissionTest
+    {
+        [Test]
+        public void TestClientServer()
+        {
+            DtlsClientProtocol clientProtocol = new DtlsClientProtocol();
+            DtlsServerProtocol serverProtocol = new DtlsServerProtocol();
+
+            MockDatagramAssociation network = new MockDatagramAssociation(1500);
+
+            ServerTask serverTask = new ServerTask(serverProtocol, network.Server);
+
+            Thread serverThread = new Thread(new ThreadStart(serverTask.Run));
+            serverThread.Start();
+
+            DatagramTransport clientTransport = network.Client;
+
+            clientTransport = new ServerHandshakeDropper(clientTransport, true);
+
+            clientTransport = new LoggingDatagramTransport(clientTransport, Console.Out);
+
+            MockDtlsClient client = new MockDtlsClient(null);
+
+            DtlsTransport dtlsClient = clientProtocol.Connect(client, clientTransport);
+
+            for (int i = 1; i <= 10; ++i)
+            {
+                byte[] data = new byte[i];
+                Arrays.Fill(data, (byte)i);
+                dtlsClient.Send(data, 0, data.Length);
+            }
+
+            byte[] buf = new byte[dtlsClient.GetReceiveLimit()];
+            while (dtlsClient.Receive(buf, 0, buf.Length, 100) >= 0)
+            {
+            }
+
+            dtlsClient.Close();
+
+            serverTask.Shutdown(serverThread);
+        }
+
+        internal class ServerTask
+        {
+            private readonly DtlsServerProtocol m_serverProtocol;
+            private readonly DatagramTransport m_serverTransport;
+            private volatile bool m_isShutdown = false;
+
+            internal ServerTask(DtlsServerProtocol serverProtocol, DatagramTransport serverTransport)
+            {
+                this.m_serverProtocol = serverProtocol;
+                this.m_serverTransport = serverTransport;
+            }
+
+            public void Run()
+            {
+                try
+                {
+                    TlsCrypto serverCrypto = new BcTlsCrypto();
+
+                    DtlsRequest request = null;
+
+                    // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting
+                    {
+                        DtlsVerifier verifier = new DtlsVerifier(serverCrypto);
+
+                        // NOTE: Test value only - would typically be the client IP address
+                        byte[] clientID = Encoding.UTF8.GetBytes("MockDtlsClient");
+
+                        int receiveLimit = m_serverTransport.GetReceiveLimit();
+                        int dummyOffset = serverCrypto.SecureRandom.Next(16) + 1;
+                        byte[] buf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()];
+
+                        do
+                        {
+                            if (m_isShutdown)
+                                return;
+
+                            int length = m_serverTransport.Receive(buf, dummyOffset, receiveLimit, 100);
+                            if (length > 0)
+                            {
+                                request = verifier.VerifyRequest(clientID, buf, dummyOffset, length, m_serverTransport);
+                            }
+                        }
+                        while (request == null);
+                    }
+
+                    // NOTE: A real server would handle each DtlsRequest in a new task/thread and continue accepting
+                    {
+                        MockDtlsServer server = new MockDtlsServer(serverCrypto);
+                        DtlsTransport dtlsTransport = m_serverProtocol.Accept(server, m_serverTransport, request);
+                        byte[] buf = new byte[dtlsTransport.GetReceiveLimit()];
+                        while (!m_isShutdown)
+                        {
+                            int length = dtlsTransport.Receive(buf, 0, buf.Length, 100);
+                            if (length >= 0)
+                            {
+                                dtlsTransport.Send(buf, 0, length);
+                            }
+                        }
+                        dtlsTransport.Close();
+                    }
+                }
+                catch (Exception e)
+                {
+                    Console.Error.WriteLine(e);
+                    Console.Error.Flush();
+                }
+            }
+
+            internal void Shutdown(Thread serverThread)
+            {
+                if (!m_isShutdown)
+                {
+                    this.m_isShutdown = true;
+                    serverThread.Join();
+                }
+            }
+        }
+    }
+}
diff --git a/crypto/test/src/tls/test/FilteredDatagramTransport.cs b/crypto/test/src/tls/test/FilteredDatagramTransport.cs
new file mode 100644
index 000000000..23c0839d6
--- /dev/null
+++ b/crypto/test/src/tls/test/FilteredDatagramTransport.cs
@@ -0,0 +1,112 @@
+using System;
+
+using Org.BouncyCastle.Utilities.Date;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    public class FilteredDatagramTransport
+        : DatagramTransport
+    {
+        public delegate bool FilterPredicate(byte[] buf, int off, int len);
+
+        public static bool AlwaysAllow(byte[] buf, int off, int len) => true;
+
+        private readonly DatagramTransport m_transport;
+
+        private readonly FilterPredicate m_allowReceiving, m_allowSending;
+
+        public FilteredDatagramTransport(DatagramTransport transport, FilterPredicate allowReceiving,
+            FilterPredicate allowSending)
+        {
+            m_transport = transport;
+            m_allowReceiving = allowReceiving;
+            m_allowSending = allowSending;
+        }
+
+        public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit();
+
+        public virtual int GetSendLimit() => m_transport.GetSendLimit();
+
+        public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
+        {
+            long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis;
+            for (;;)
+            {
+                int length = m_transport.Receive(buf, off, len, waitMillis);
+                if (length < 0 || m_allowReceiving(buf, off, len))
+                    return length;
+
+                Console.WriteLine("PACKET FILTERED ({0} byte packet not received)", length);
+
+                long now = DateTimeUtilities.CurrentUnixMs();
+                if (now >= endMillis)
+                    return -1;
+
+                waitMillis = (int)(endMillis - now);
+            }
+        }
+
+//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        public virtual int Receive(Span<byte> buffer, int waitMillis)
+        {
+            long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis;
+            for (;;)
+            {
+                int length = m_transport.Receive(buffer, waitMillis);
+                if (length < 0 || m_allowReceiving(buffer.ToArray(), 0, buffer.Length))
+                    return length;
+
+                Console.WriteLine("PACKET FILTERED ({0} byte packet not received)", length);
+
+                long now = DateTimeUtilities.CurrentUnixMs();
+                if (now >= endMillis)
+                    return -1;
+
+                waitMillis = (int)(endMillis - now);
+            }
+        }
+#endif
+
+        public virtual void Send(byte[] buf, int off, int len)
+        {
+            if (!m_allowSending(buf, off, len))
+            {
+                Console.WriteLine("PACKET FILTERED ({0} byte packet not sent)", len);
+            }
+            else
+            {
+                m_transport.Send(buf, off, len);
+            }
+        }
+
+        //#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        public virtual void Send(ReadOnlySpan<byte> buffer)
+        {
+            if (!m_allowSending(buffer.ToArray(), 0, buffer.Length))
+            {
+                Console.WriteLine("PACKET FILTERED ({0} byte packet not sent)", buffer.Length);
+            }
+            else
+            {
+                m_transport.Send(buffer);
+            }
+        }
+#endif
+
+        public virtual void Close() => m_transport.Close();
+
+        //static FilterPredicate ALWAYS_ALLOW = new FilterPredicate() {
+        //    @Override
+        //    public boolean allowPacket(byte[] buf, int off, int len)
+        //    {
+        //        return true;
+        //    }
+        //};
+
+        //interface FilterPredicate {
+        //    boolean allowPacket(byte[] buf, int off, int len);
+        //}
+    }
+}
diff --git a/crypto/test/src/tls/test/LoggingDatagramTransport.cs b/crypto/test/src/tls/test/LoggingDatagramTransport.cs
index 59113cf73..d03e99551 100644
--- a/crypto/test/src/tls/test/LoggingDatagramTransport.cs
+++ b/crypto/test/src/tls/test/LoggingDatagramTransport.cs
@@ -22,15 +22,9 @@ namespace Org.BouncyCastle.Tls.Tests
             this.m_launchTimestamp = DateTimeUtilities.CurrentUnixMs();
         }
 
-        public virtual int GetReceiveLimit()
-        {
-            return m_transport.GetReceiveLimit();
-        }
+        public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit();
 
-        public virtual int GetSendLimit()
-        {
-            return m_transport.GetSendLimit();
-        }
+        public virtual int GetSendLimit() => m_transport.GetSendLimit();
 
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
@@ -80,10 +74,7 @@ namespace Org.BouncyCastle.Tls.Tests
         }
 #endif
 
-        public virtual void Close()
-        {
-            m_transport.Close();
-        }
+        public virtual void Close() => m_transport.Close();
 
 //#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
 #if NET6_0_OR_GREATER
diff --git a/crypto/test/src/tls/test/MinimalHandshakeAggregator.cs b/crypto/test/src/tls/test/MinimalHandshakeAggregator.cs
new file mode 100644
index 000000000..645e32851
--- /dev/null
+++ b/crypto/test/src/tls/test/MinimalHandshakeAggregator.cs
@@ -0,0 +1,254 @@
+using System;
+
+using Org.BouncyCastle.Utilities.Date;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    /**
+     * A very minimal and stupid class to aggregate DTLS handshake messages.  Only sufficient for unit tests.
+     */
+    public class MinimalHandshakeAggregator
+        : DatagramTransport
+    {
+        private readonly DatagramTransport m_transport;
+
+        private readonly bool m_aggregateReceiving, m_aggregateSending;
+
+        private byte[] m_receiveBuf, m_sendBuf;
+
+        private int m_receiveRecordCount, m_sendRecordCount;
+
+        private byte[] AddToBuf(byte[] baseBuf, byte[] buf, int off, int len)
+        {
+            byte[] ret = new byte[baseBuf.Length + len];
+            Array.Copy(baseBuf, 0, ret, 0, baseBuf.Length);
+            Array.Copy(buf, off, ret, baseBuf.Length, len);
+            return ret;
+        }
+
+//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        private byte[] AddToBuf(byte[] baseBuf, ReadOnlySpan<byte> buf)
+        {
+            byte[] ret = new byte[baseBuf.Length + buf.Length];
+            Array.Copy(baseBuf, 0, ret, 0, baseBuf.Length);
+            buf.CopyTo(ret[baseBuf.Length..]);
+            return ret;
+        }
+#endif
+
+        private void AddToReceiveBuf(byte[] buf, int off, int len)
+        {
+            m_receiveBuf = AddToBuf(m_receiveBuf, buf, off, len);
+            m_receiveRecordCount++;
+        }
+
+//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        private void AddToReceiveBuf(ReadOnlySpan<byte> buf)
+        {
+            m_receiveBuf = AddToBuf(m_receiveBuf, buf);
+            m_receiveRecordCount++;
+        }
+#endif
+
+        private void ResetReceiveBuf()
+        {
+            m_receiveBuf = new byte[0];
+            m_receiveRecordCount = 0;
+        }
+
+        private void AddToSendBuf(byte[] buf, int off, int len)
+        {
+            m_sendBuf = AddToBuf(m_sendBuf, buf, off, len);
+            m_sendRecordCount++;
+        }
+
+//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        private void AddToSendBuf(ReadOnlySpan<byte> buf)
+        {
+            m_sendBuf = AddToBuf(m_sendBuf, buf);
+            m_sendRecordCount++;
+        }
+#endif
+
+        private void ResetSendBuf()
+        {
+            m_sendBuf = new byte[0];
+            m_sendRecordCount = 0;
+        }
+
+        /** Whether the buffered aggregated data should be flushed after this packet.
+         * This is done on the end of the first flight - ClientHello and ServerHelloDone - and anything that is
+         * Epoch 1.
+         */
+        private bool FlushAfterThisPacket(byte[] buf, int off, int len)
+        {
+            int epoch = TlsUtilities.ReadUint16(buf, off + 3);
+            if (epoch > 0)
+                return true;
+
+            short contentType = TlsUtilities.ReadUint8(buf, off);
+            if (ContentType.handshake != contentType)
+                return false;
+
+            short msgType = TlsUtilities.ReadUint8(buf, off + 13);
+            switch (msgType)
+            {
+            case HandshakeType.client_hello:
+            case HandshakeType.server_hello_done:
+                return true;
+            default:
+                return false;
+            }
+        }
+
+//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        private bool FlushAfterThisPacket(ReadOnlySpan<byte> buffer)
+        {
+            int epoch = TlsUtilities.ReadUint16(buffer[3..]);
+            if (epoch > 0)
+                return true;
+
+            short contentType = TlsUtilities.ReadUint8(buffer);
+            if (ContentType.handshake != contentType)
+                return false;
+
+            short msgType = TlsUtilities.ReadUint8(buffer[13..]);
+            switch (msgType)
+            {
+            case HandshakeType.client_hello:
+            case HandshakeType.server_hello_done:
+                return true;
+            default:
+                return false;
+            }
+        }
+#endif
+
+        public MinimalHandshakeAggregator(DatagramTransport transport, bool aggregateReceiving, bool aggregateSending)
+        {
+            m_transport = transport;
+            m_aggregateReceiving = aggregateReceiving;
+            m_aggregateSending = aggregateSending;
+
+            ResetReceiveBuf();
+            ResetSendBuf();
+        }
+
+        public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit();
+
+        public virtual int GetSendLimit() => m_transport.GetSendLimit();
+
+        public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
+        {
+            long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis;
+            for (;;)
+            {
+                int length = m_transport.Receive(buf, off, len, waitMillis);
+                if (length < 0 || !m_aggregateReceiving)
+                    return length;
+
+                AddToReceiveBuf(buf, off, length);
+
+                if (FlushAfterThisPacket(buf, off, length))
+                {
+                    if (m_receiveRecordCount > 1)
+                    {
+                        Console.WriteLine("RECEIVING {0} RECORDS IN {1} BYTE PACKET", m_receiveRecordCount, length);
+                    }
+                    Array.Copy(m_receiveBuf, 0, buf, off, System.Math.Min(len, m_receiveBuf.Length));
+                    ResetReceiveBuf();
+                    return length;
+                }
+
+                long now = DateTimeUtilities.CurrentUnixMs();
+                if (now >= endMillis)
+                    return -1;
+
+                waitMillis = (int)(endMillis - now);
+            }
+        }
+
+//#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        public virtual int Receive(Span<byte> buffer, int waitMillis)
+        {
+            long endMillis = DateTimeUtilities.CurrentUnixMs() + waitMillis;
+            for (;;)
+            {
+                int length = m_transport.Receive(buffer, waitMillis);
+                if (length < 0 || !m_aggregateReceiving)
+                    return length;
+
+                AddToReceiveBuf(buffer);
+
+                if (FlushAfterThisPacket(buffer))
+                {
+                    if (m_receiveRecordCount > 1)
+                    {
+                        Console.WriteLine("RECEIVING {0} RECORDS IN {1} BYTE PACKET", m_receiveRecordCount, length);
+                    }
+                    int resultLength = System.Math.Min(buffer.Length, m_receiveBuf.Length);
+                    m_receiveBuf.AsSpan(0, resultLength).CopyTo(buffer);
+                    ResetReceiveBuf();
+                    return resultLength;
+                }
+
+                long now = DateTimeUtilities.CurrentUnixMs();
+                if (now >= endMillis)
+                    return -1;
+
+                waitMillis = (int)(endMillis - now);
+            }
+        }
+#endif
+
+        public virtual void Send(byte[] buf, int off, int len)
+        {
+            if (!m_aggregateSending)
+            {
+                m_transport.Send(buf, off, len);
+                return;
+            }
+            AddToSendBuf(buf, off, len);
+
+            if (FlushAfterThisPacket(buf, off, len))
+            {
+                if (m_sendRecordCount > 1)
+                {
+                    Console.WriteLine("SENDING {0} RECORDS IN {1} BYTE PACKET", m_sendRecordCount, m_sendBuf.Length);
+                }
+                m_transport.Send(m_sendBuf, 0, m_sendBuf.Length);
+                ResetSendBuf();
+            }
+        }
+
+        //#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+#if NET6_0_OR_GREATER
+        public virtual void Send(ReadOnlySpan<byte> buffer)
+        {
+            if (!m_aggregateSending)
+            {
+                m_transport.Send(buffer);
+                return;
+            }
+            AddToSendBuf(buffer);
+
+            if (FlushAfterThisPacket(buffer))
+            {
+                if (m_sendRecordCount > 1)
+                {
+                    Console.WriteLine("SENDING {0} RECORDS IN {1} BYTE PACKET", m_sendRecordCount, m_sendBuf.Length);
+                }
+                m_transport.Send(m_sendBuf, 0, m_sendBuf.Length);
+                ResetSendBuf();
+            }
+        }
+#endif
+
+        public virtual void Close() => m_transport.Close();
+    }
+}
diff --git a/crypto/test/src/tls/test/MockDtlsClient.cs b/crypto/test/src/tls/test/MockDtlsClient.cs
index 9898acd08..2e7ab9450 100644
--- a/crypto/test/src/tls/test/MockDtlsClient.cs
+++ b/crypto/test/src/tls/test/MockDtlsClient.cs
@@ -15,6 +15,8 @@ namespace Org.BouncyCastle.Tls.Tests
     {
         internal TlsSession m_session;
 
+        private int m_handshakeTimeoutMillis = 0;
+
         internal MockDtlsClient(TlsSession session)
             : base(new BcTlsCrypto())
         {
@@ -26,6 +28,10 @@ namespace Org.BouncyCastle.Tls.Tests
             return this.m_session;
         }
 
+        public override int GetHandshakeTimeoutMillis() => m_handshakeTimeoutMillis;
+
+        public void SetHandshakeTimeoutMillis(int millis) => m_handshakeTimeoutMillis = millis;
+
         public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message,
             Exception cause)
         {
diff --git a/crypto/test/src/tls/test/ServerHandshakeDropper.cs b/crypto/test/src/tls/test/ServerHandshakeDropper.cs
new file mode 100644
index 000000000..89c752333
--- /dev/null
+++ b/crypto/test/src/tls/test/ServerHandshakeDropper.cs
@@ -0,0 +1,63 @@
+using System;
+
+namespace Org.BouncyCastle.Tls.Tests
+{
+    /** This is a [Transport] wrapper which causes the first retransmission of the second flight of a server
+     * handshake to be dropped. */
+    public class ServerHandshakeDropper
+        : FilteredDatagramTransport
+    {
+        private static FilterPredicate Choose(bool condition, FilterPredicate left, FilterPredicate right)
+        {
+            if (condition) { return left; } else { return right; }
+        }
+
+        public ServerHandshakeDropper(DatagramTransport transport, bool dropOnReceive)
+            : base(transport,
+                Choose(dropOnReceive, new DropFirstServerFinalFlight().AllowPacket, AlwaysAllow),
+                Choose(dropOnReceive, AlwaysAllow, new DropFirstServerFinalFlight().AllowPacket))
+        {
+        }
+
+        /** This drops the first instance of DTLS packets that either begin with a ChangeCipherSpec, or handshake in
+         * epoch 1.  This is the server's final flight of the handshake.  It will test whether the client properly
+         * retransmits its second flight, and the server properly retransmits the dropped flight.
+         */
+        private class DropFirstServerFinalFlight
+        {
+            private bool m_sawChangeCipherSpec = false;
+            private bool m_sawEpoch1Handshake = false;
+
+            private bool IsChangeCipherSpec(byte[] buf, int off, int len)
+            {
+                short contentType = TlsUtilities.ReadUint8(buf, off);
+                return ContentType.change_cipher_spec == contentType;
+            }
+
+            private bool IsEpoch1Handshake(byte[] buf, int off, int len)
+            {
+                short contentType = TlsUtilities.ReadUint8(buf, off);
+                if (ContentType.handshake != contentType)
+                    return false;
+
+                int epoch = TlsUtilities.ReadUint16(buf, off + 3);
+                return 1 == epoch;
+            }
+
+            public bool AllowPacket(byte[] buf, int off, int len)
+            {
+                if (!m_sawChangeCipherSpec && IsChangeCipherSpec(buf, off, len))
+                {
+                    m_sawChangeCipherSpec = true;
+                    return false;
+                }
+                if (!m_sawEpoch1Handshake && IsEpoch1Handshake(buf, off, len))
+                {
+                    m_sawEpoch1Handshake = true;
+                    return false;
+                }
+                return true;
+            }
+        }
+    }
+}
diff --git a/crypto/test/src/tls/test/UnreliableDatagramTransport.cs b/crypto/test/src/tls/test/UnreliableDatagramTransport.cs
index 7769db9d1..0aed2d7a1 100644
--- a/crypto/test/src/tls/test/UnreliableDatagramTransport.cs
+++ b/crypto/test/src/tls/test/UnreliableDatagramTransport.cs
@@ -25,15 +25,9 @@ namespace Org.BouncyCastle.Tls.Tests
             this.m_percentPacketLossSending = percentPacketLossSending;
         }
 
-        public virtual int GetReceiveLimit()
-        {
-            return m_transport.GetReceiveLimit();
-        }
+        public virtual int GetReceiveLimit() => m_transport.GetReceiveLimit();
 
-        public virtual int GetSendLimit()
-        {
-            return m_transport.GetSendLimit();
-        }
+        public virtual int GetSendLimit() => m_transport.GetSendLimit();
 
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
@@ -48,7 +42,7 @@ namespace Org.BouncyCastle.Tls.Tests
                 if (length < 0 || !LostPacket(m_percentPacketLossReceiving))
                     return length;
 
-                Console.WriteLine("PACKET LOSS (" + length + " byte packet not received)");
+                Console.WriteLine("PACKET LOSS ({0} byte packet not received)", length);
 
                 long now = DateTimeUtilities.CurrentUnixMs();
                 if (now >= endMillis)
@@ -70,7 +64,7 @@ namespace Org.BouncyCastle.Tls.Tests
                 if (length < 0 || !LostPacket(m_percentPacketLossReceiving))
                     return length;
 
-                Console.WriteLine("PACKET LOSS (" + length + " byte packet not received)");
+                Console.WriteLine("PACKET LOSS ({0} byte packet not received)", length);
 
                 long now = DateTimeUtilities.CurrentUnixMs();
                 if (now >= endMillis)
@@ -89,7 +83,7 @@ namespace Org.BouncyCastle.Tls.Tests
 #else
             if (LostPacket(m_percentPacketLossSending))
             {
-                Console.WriteLine("PACKET LOSS (" + len + " byte packet not sent)");
+                Console.WriteLine("PACKET LOSS ({0} byte packet not sent)", len);
             }
             else
             {
@@ -104,7 +98,7 @@ namespace Org.BouncyCastle.Tls.Tests
         {
             if (LostPacket(m_percentPacketLossSending))
             {
-                Console.WriteLine("PACKET LOSS (" + buffer.Length + " byte packet not sent)");
+                Console.WriteLine("PACKET LOSS ({0} byte packet not sent)", buffer.Length);
             }
             else
             {
@@ -113,10 +107,7 @@ namespace Org.BouncyCastle.Tls.Tests
         }
 #endif
 
-        public virtual void Close()
-        {
-            m_transport.Close();
-        }
+        public virtual void Close() => m_transport.Close();
 
         private bool LostPacket(int percentPacketLoss)
         {