summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/tls/DtlsClientProtocol.cs2
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs55
-rw-r--r--crypto/src/tls/DtlsReliableHandshake.cs50
-rw-r--r--crypto/src/tls/DtlsVerifier.cs108
-rw-r--r--crypto/src/tls/TlsClientProtocol.cs2
-rw-r--r--crypto/test/src/tls/test/DtlsProtocolTest.cs32
6 files changed, 132 insertions, 117 deletions
diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs
index 72484e178..c1bad2e6f 100644
--- a/crypto/src/tls/DtlsClientProtocol.cs
+++ b/crypto/src/tls/DtlsClientProtocol.cs
@@ -525,7 +525,7 @@ namespace Org.BouncyCastle.Tls
 
 
             ClientHello clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, session_id,
-                TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0);
+                cookie: TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0);
 
             MemoryStream buf = new MemoryStream();
             clientHello.Encode(state.clientContext, buf);
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index efe9e7312..e3567aa46 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -4,7 +4,6 @@ using System.IO;
 using System.Net.Sockets;
 
 using Org.BouncyCastle.Tls.Crypto;
-using Org.BouncyCastle.Tls.Crypto.Impl;
 using Org.BouncyCastle.Utilities;
 using Org.BouncyCastle.Utilities.Date;
 
@@ -13,43 +12,45 @@ namespace Org.BouncyCastle.Tls
     internal class DtlsRecordLayer
         : DatagramTransport
     {
-        private const int RECORD_HEADER_LENGTH = 13;
+        internal const int RecordHeaderLength = 13;
+
         private const int MAX_FRAGMENT_LENGTH = 1 << 14;
         private const long TCP_MSL = 1000L * 60 * 2;
         private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
 
         /// <exception cref="IOException"/>
-        internal static byte[] ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen)
+        internal static int ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen)
         {
-            if (dataLen < RECORD_HEADER_LENGTH)
-            {
-                return null;
-            }
+            if (dataLen < RecordHeaderLength)
+                return -1;
 
             short contentType = TlsUtilities.ReadUint8(data, dataOff + 0);
             if (ContentType.handshake != contentType)
-                return null;
+                return -1;
 
             ProtocolVersion version = TlsUtilities.ReadVersion(data, dataOff + 1);
             if (!ProtocolVersion.DTLSv10.IsEqualOrEarlierVersionOf(version))
-                return null;
+                return -1;
 
             int epoch = TlsUtilities.ReadUint16(data, dataOff + 3);
             if (0 != epoch)
-                return null;
+                return -1;
 
             //long sequenceNumber = TlsUtilities.ReadUint48(data, dataOff + 5);
 
             int length = TlsUtilities.ReadUint16(data, dataOff + 11);
-            if (dataLen < RECORD_HEADER_LENGTH + length)
-                return null;
+            if (length < 1 || length > MAX_FRAGMENT_LENGTH)
+                return -1;
 
-            if (length > MAX_FRAGMENT_LENGTH)
-                return null;
+            if (dataLen < RecordHeaderLength + length)
+                return -1;
+
+            short msgType = TlsUtilities.ReadUint8(data, dataOff + RecordHeaderLength);
+            if (HandshakeType.client_hello != msgType)
+                return -1;
 
             // NOTE: We ignore/drop any data after the first record 
-            return TlsUtilities.CopyOfRangeExact(data, dataOff + RECORD_HEADER_LENGTH,
-                dataOff + RECORD_HEADER_LENGTH + length);
+            return length;
         }
 
         /// <exception cref="IOException"/>
@@ -57,14 +58,14 @@ namespace Org.BouncyCastle.Tls
         {
             TlsUtilities.CheckUint16(message.Length);
 
-            byte[] record = new byte[RECORD_HEADER_LENGTH + message.Length];
+            byte[] record = new byte[RecordHeaderLength + message.Length];
             TlsUtilities.WriteUint8(ContentType.handshake, record, 0);
             TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, record, 1);
             TlsUtilities.WriteUint16(0, record, 3);
             TlsUtilities.WriteUint48(recordSeq, record, 5);
             TlsUtilities.WriteUint16(message.Length, record, 11);
 
-            Array.Copy(message, 0, record, RECORD_HEADER_LENGTH, message.Length);
+            Array.Copy(message, 0, record, RecordHeaderLength, message.Length);
 
             SendDatagram(sender, record, 0, record.Length);
         }
@@ -124,8 +125,8 @@ namespace Org.BouncyCastle.Tls
 
             this.m_inHandshake = true;
 
-            this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RECORD_HEADER_LENGTH,
-                RECORD_HEADER_LENGTH);
+            this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RecordHeaderLength,
+                RecordHeaderLength);
             this.m_pendingEpoch = null;
             this.m_readEpoch = m_currentEpoch;
             this.m_writeEpoch = m_currentEpoch;
@@ -179,8 +180,8 @@ namespace Org.BouncyCastle.Tls
              */
 
             var securityParameters = m_context.SecurityParameters;
-            int recordHeaderLengthRead = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDPeer?.Length ?? 0);
-            int recordHeaderLengthWrite = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDLocal?.Length ?? 0);
+            int recordHeaderLengthRead = RecordHeaderLength + (securityParameters.ConnectionIDPeer?.Length ?? 0);
+            int recordHeaderLengthWrite = RecordHeaderLength + (securityParameters.ConnectionIDLocal?.Length ?? 0);
 
             // TODO Check for overflow
             this.m_pendingEpoch = new DtlsEpoch(m_writeEpoch.Epoch + 1, pendingCipher, recordHeaderLengthRead,
@@ -684,7 +685,7 @@ namespace Org.BouncyCastle.Tls
 #endif
         {
             // NOTE: received < 0 (timeout) is covered by this first case
-            if (received < RECORD_HEADER_LENGTH)
+            if (received < RecordHeaderLength)
                 return -1;
 
             // TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers
@@ -729,7 +730,7 @@ namespace Org.BouncyCastle.Tls
 
 
             int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
-            if (recordHeaderLength > RECORD_HEADER_LENGTH)
+            if (recordHeaderLength > RecordHeaderLength)
             {
                 if (ContentType.tls12_cid != recordType)
                     return -1;
@@ -990,7 +991,7 @@ namespace Org.BouncyCastle.Tls
         {
             Debug.Assert(m_recordQueue.Available > 0);
 
-            int recordLength = RECORD_HEADER_LENGTH;
+            int recordLength = RecordHeaderLength;
             if (m_recordQueue.Available >= recordLength)
             {
                 short recordType = m_recordQueue.ReadUint8(0);
@@ -1033,7 +1034,7 @@ namespace Org.BouncyCastle.Tls
                 return ReceivePendingRecord(buf, off, len);
 
             int received = ReceiveDatagram(buf, off, len, waitMillis);
-            if (received >= RECORD_HEADER_LENGTH)
+            if (received >= RecordHeaderLength)
             {
                 this.m_inConnection = true;
 
@@ -1151,7 +1152,7 @@ namespace Org.BouncyCastle.Tls
                 TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3);
                 TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5);
 
-                if (recordHeaderLength > RECORD_HEADER_LENGTH)
+                if (recordHeaderLength > RecordHeaderLength)
                 {
                     byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal;
                     Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length);
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs
index 42a98a991..b1107f7a1 100644
--- a/crypto/src/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/tls/DtlsReliableHandshake.cs
@@ -8,47 +8,41 @@ namespace Org.BouncyCastle.Tls
 {
     internal class DtlsReliableHandshake
     {
-        private const int MAX_RECEIVE_AHEAD = 16;
-        private const int MESSAGE_HEADER_LENGTH = 12;
+        internal const int MessageHeaderLength = 12;
 
+        private const int MAX_RECEIVE_AHEAD = 16;
         private const int MAX_RESEND_MILLIS = 60000;
 
         /// <exception cref="IOException"/>
-        internal static DtlsRequest ReadClientRequest(byte[] data, int dataOff, int dataLen, Stream dtlsOutput)
+        internal static MemoryStream ReceiveClientHelloMessage(byte[] msg, int msgOff, int msgLen)
         {
             // TODO Support the possibility of a fragmented ClientHello datagram
 
-            byte[] message = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen);
-            if (null == message || message.Length < MESSAGE_HEADER_LENGTH)
+            if (msgLen < MessageHeaderLength)
                 return null;
 
-            long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5);
-
-            short msgType = TlsUtilities.ReadUint8(message, 0);
+            short msgType = TlsUtilities.ReadUint8(msg, msgOff);
             if (HandshakeType.client_hello != msgType)
                 return null;
 
-            int length = TlsUtilities.ReadUint24(message, 1);
-            if (message.Length != MESSAGE_HEADER_LENGTH + length)
+            int length = TlsUtilities.ReadUint24(msg, msgOff + 1);
+            if (msgLen != MessageHeaderLength + length)
                 return null;
 
             // TODO Consider stricter HelloVerifyRequest-related checks
-            //int messageSeq = TlsUtilities.ReadUint16(message, 4);
+            //int messageSeq = TlsUtilities.ReadUint16(msg, msgOff + 4);
             //if (messageSeq > 1)
             //    return null;
 
-            int fragmentOffset = TlsUtilities.ReadUint24(message, 6);
+            int fragmentOffset = TlsUtilities.ReadUint24(msg, msgOff + 6);
             if (0 != fragmentOffset)
                 return null;
 
-            int fragmentLength = TlsUtilities.ReadUint24(message, 9);
+            int fragmentLength = TlsUtilities.ReadUint24(msg, msgOff + 9);
             if (length != fragmentLength)
                 return null;
 
-            ClientHello clientHello = ClientHello.Parse(
-                new MemoryStream(message, MESSAGE_HEADER_LENGTH, length, false), dtlsOutput);
-
-            return new DtlsRequest(recordSeq, message, clientHello);
+            return new MemoryStream(msg, msgOff + MessageHeaderLength, length, false);
         }
 
         /// <exception cref="IOException"/>
@@ -58,7 +52,7 @@ namespace Org.BouncyCastle.Tls
 
             int length = 3 + cookie.Length;
 
-            byte[] message = new byte[MESSAGE_HEADER_LENGTH + length];
+            byte[] message = new byte[MessageHeaderLength + length];
             TlsUtilities.WriteUint8(HandshakeType.hello_verify_request, message, 0);
             TlsUtilities.WriteUint24(length, message, 1);
             //TlsUtilities.WriteUint16(0, message, 4);
@@ -66,8 +60,8 @@ namespace Org.BouncyCastle.Tls
             TlsUtilities.WriteUint24(length, message, 9);
 
             // HelloVerifyRequest fields
-            TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MESSAGE_HEADER_LENGTH + 0);
-            TlsUtilities.WriteOpaque8(cookie, message, MESSAGE_HEADER_LENGTH + 2);
+            TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MessageHeaderLength + 0);
+            TlsUtilities.WriteOpaque8(cookie, message, MessageHeaderLength + 2);
 
             DtlsRecordLayer.SendHelloVerifyRequestRecord(sender, recordSeq, message);
         }
@@ -111,7 +105,7 @@ namespace Org.BouncyCastle.Tls
 
                 // Simulate a previous flight consisting of the request ClientHello
                 DtlsReassembler reassembler = new DtlsReassembler(HandshakeType.client_hello,
-                    message.Length - MESSAGE_HEADER_LENGTH);
+                    message.Length - MessageHeaderLength);
                 m_currentInboundFlight[messageSeq] = reassembler;
 
                 // We sent HelloVerifyRequest with (message) sequence number 0
@@ -215,7 +209,7 @@ namespace Org.BouncyCastle.Tls
             default:
             {
                 byte[] body = message.Body;
-                byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
+                byte[] buf = new byte[MessageHeaderLength];
                 TlsUtilities.WriteUint8(msg_type, buf, 0);
                 TlsUtilities.WriteUint24(body.Length, buf, 1);
                 TlsUtilities.WriteUint16(message.Seq, buf, 4);
@@ -360,10 +354,10 @@ namespace Org.BouncyCastle.Tls
         {
             bool checkPreviousFlight = false;
 
-            while (len >= MESSAGE_HEADER_LENGTH)
+            while (len >= MessageHeaderLength)
             {
                 int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
-                int message_length = fragment_length + MESSAGE_HEADER_LENGTH;
+                int message_length = fragment_length + MessageHeaderLength;
                 if (len < message_length)
                 {
                     // NOTE: Truncated message - ignore it
@@ -400,7 +394,7 @@ namespace Org.BouncyCastle.Tls
                         m_currentInboundFlight[message_seq] = reassembler;
                     }
 
-                    reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, fragment_offset,
+                    reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
                         fragment_length);
                 }
                 else if (m_previousInboundFlight != null)
@@ -412,7 +406,7 @@ namespace Org.BouncyCastle.Tls
 
                     if (m_previousInboundFlight.TryGetValue(message_seq, out var reassembler))
                     {
-                        reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH,
+                        reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength,
                             fragment_offset, fragment_length);
                         checkPreviousFlight = true;
                     }
@@ -446,7 +440,7 @@ namespace Org.BouncyCastle.Tls
         private void WriteMessage(Message message)
         {
             int sendLimit = m_recordLayer.GetSendLimit();
-            int fragmentLimit = sendLimit - MESSAGE_HEADER_LENGTH;
+            int fragmentLimit = sendLimit - MessageHeaderLength;
 
             // TODO Support a higher minimum fragment size?
             if (fragmentLimit < 1)
@@ -471,7 +465,7 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
         {
-            RecordLayerBuffer fragment = new RecordLayerBuffer(MESSAGE_HEADER_LENGTH + fragment_length);
+            RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length);
             TlsUtilities.WriteUint8(message.Type, fragment);
             TlsUtilities.WriteUint24(message.Body.Length, fragment);
             TlsUtilities.WriteUint16(message.Seq, fragment);
diff --git a/crypto/src/tls/DtlsVerifier.cs b/crypto/src/tls/DtlsVerifier.cs
index e691685e6..01437d648 100644
--- a/crypto/src/tls/DtlsVerifier.cs
+++ b/crypto/src/tls/DtlsVerifier.cs
@@ -1,89 +1,79 @@
-using System;
-using System.IO;
+using System.IO;
 
+using Org.BouncyCastle.Security;
 using Org.BouncyCastle.Tls.Crypto;
 using Org.BouncyCastle.Utilities;
 
 namespace Org.BouncyCastle.Tls
 {
+    /// <summary>
+    /// Implements cookie generation/verification for a DTLS server as described in RFC 4347,
+    /// 4.2.1. Denial of Service Countermeasures.
+    /// </summary>
+    /// <remarks>
+    /// RFC 4347 4.2.1 additionally recommends changing the secret frequently. This class does not handle that
+    /// internally, so the instance should be replaced instead.
+    /// </remarks>
     public class DtlsVerifier
     {
-        private static TlsMac CreateCookieMac(TlsCrypto crypto)
-        {
-            TlsMac mac = crypto.CreateHmac(MacAlgorithm.hmac_sha256);
-
-            byte[] secret = new byte[mac.MacLength];
-            crypto.SecureRandom.NextBytes(secret);
-
-            mac.SetKey(secret, 0, secret.Length);
-
-            return mac;
-        }
-
-        private readonly TlsMac m_cookieMac;
-        private readonly TlsMacSink m_cookieMacSink;
+        private readonly TlsCrypto m_crypto;
+        private readonly byte[] m_macKey;
 
         public DtlsVerifier(TlsCrypto crypto)
         {
-            this.m_cookieMac = CreateCookieMac(crypto);
-            this.m_cookieMacSink = new TlsMacSink(m_cookieMac);
+            m_crypto = crypto;
+            m_macKey = SecureRandom.GetNextBytes(crypto.SecureRandom, 32);
         }
 
         public virtual DtlsRequest VerifyRequest(byte[] clientID, byte[] data, int dataOff, int dataLen,
             DatagramSender sender)
         {
-            lock (this)
+            try
             {
-                bool resetCookieMac = true;
+                int msgLen = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen);
+                if (msgLen < 0)
+                    return null;
 
-                try
-                {
-                    m_cookieMac.Update(clientID, 0, clientID.Length);
+                int bodyLength = msgLen - DtlsReliableHandshake.MessageHeaderLength;
+                if (bodyLength < 39) // Minimum (syntactically) valid DTLS ClientHello length
+                    return null;
 
-                    DtlsRequest request = DtlsReliableHandshake.ReadClientRequest(data, dataOff, dataLen,
-                        m_cookieMacSink);
-                    if (null != request)
-                    {
-                        byte[] expectedCookie = m_cookieMac.CalculateMac();
-                        resetCookieMac = false;
+                int msgOff = dataOff + DtlsRecordLayer.RecordHeaderLength;
 
-                        // TODO Consider stricter HelloVerifyRequest protocol
-                        //switch (request.MessageSeq)
-                        //{
-                        //case 0:
-                        //{
-                        //    DtlsReliableHandshake.SendHelloVerifyRequest(sender, request.RecordSeq, expectedCookie);
-                        //    break;
-                        //}
-                        //case 1:
-                        //{
-                        //    if (Arrays.FixedTimeEquals(expectedCookie, request.ClientHello.Cookie))
-                        //        return request;
+                var buf = DtlsReliableHandshake.ReceiveClientHelloMessage(msg: data, msgOff, msgLen);
+                if (buf == null)
+                    return null;
 
-                        //    break;
-                        //}
-                        //}
+                var macInput = new MemoryStream(bodyLength);
+                ClientHello clientHello = ClientHello.Parse(buf, dtlsOutput: macInput);
+                if (clientHello == null)
+                    return null;
 
-                        if (Arrays.FixedTimeEquals(expectedCookie, request.ClientHello.Cookie))
-                            return request;
+                long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5);
 
-                        DtlsReliableHandshake.SendHelloVerifyRequest(sender, request.RecordSeq, expectedCookie);
-                    }
-                }
-                catch (IOException)
-                {
-                    // Ignore
-                }
-                finally
+                byte[] cookie = clientHello.Cookie;
+
+                TlsMac mac = m_crypto.CreateHmac(MacAlgorithm.hmac_sha256);
+                mac.SetKey(m_macKey, 0, m_macKey.Length);
+                mac.Update(clientID, 0, clientID.Length);
+                macInput.WriteTo(new TlsMacSink(mac));
+                byte[] expectedCookie = mac.CalculateMac();
+
+                if (Arrays.FixedTimeEquals(expectedCookie, cookie))
                 {
-                    if (resetCookieMac)
-                    {
-                        m_cookieMac.Reset();
-                    }
+                    byte[] message = TlsUtilities.CopyOfRangeExact(data, msgOff, msgOff + msgLen);
+
+                    return new DtlsRequest(recordSeq, message, clientHello);
                 }
 
-                return null;
+                DtlsReliableHandshake.SendHelloVerifyRequest(sender, recordSeq, expectedCookie);
+            }
+            catch (IOException)
+            {
+                // Ignore
             }
+
+            return null;
         }
     }
 }
diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs
index 6aa1acf2f..d26f60ef1 100644
--- a/crypto/src/tls/TlsClientProtocol.cs
+++ b/crypto/src/tls/TlsClientProtocol.cs
@@ -1771,7 +1771,7 @@ namespace Org.BouncyCastle.Tls
             int bindersSize = null == m_clientBinders ? 0 : m_clientBinders.m_bindersSize;
 
             this.m_clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, legacy_session_id,
-                null, offeredCipherSuites, m_clientExtensions, bindersSize);
+                cookie: null, offeredCipherSuites, m_clientExtensions, bindersSize);
 
             SendClientHelloMessage();
         }
diff --git a/crypto/test/src/tls/test/DtlsProtocolTest.cs b/crypto/test/src/tls/test/DtlsProtocolTest.cs
index 388003666..7fc49fb51 100644
--- a/crypto/test/src/tls/test/DtlsProtocolTest.cs
+++ b/crypto/test/src/tls/test/DtlsProtocolTest.cs
@@ -1,4 +1,5 @@
 using System;
+using System.Text;
 using System.Threading;
 
 using NUnit.Framework;
@@ -70,7 +71,36 @@ namespace Org.BouncyCastle.Tls.Tests
                 try
                 {
                     MockDtlsServer server = new MockDtlsServer();
-                    DtlsTransport dtlsServer = m_serverProtocol.Accept(server, m_serverTransport);
+
+                    DtlsRequest request = null;
+
+                    // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting
+                    {
+                        DtlsVerifier verifier = new DtlsVerifier(server.Crypto);
+
+                        // NOTE: Test value only - would typically be the client IP address
+                        byte[] clientID = Encoding.UTF8.GetBytes("MockDtlsClient");
+
+                        int receiveLimit = m_serverTransport.GetReceiveLimit();
+                        int dummyOffset = server.Crypto.SecureRandom.Next(16) + 1;
+                        byte[] transportBuf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()];
+
+                        do
+                        {
+                            if (m_isShutdown)
+                                return;
+
+                            int length = m_serverTransport.Receive(transportBuf, dummyOffset, receiveLimit, 1000);
+                            if (length > 0)
+                            {
+                                request = verifier.VerifyRequest(clientID, transportBuf, dummyOffset, length,
+                                    m_serverTransport);
+                            }
+                        }
+                        while (request == null);
+                    }
+
+                    DtlsTransport dtlsServer = m_serverProtocol.Accept(server, m_serverTransport, request);
                     byte[] buf = new byte[dtlsServer.GetReceiveLimit()];
                     while (!m_isShutdown)
                     {