summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/tls/ByteQueue.cs8
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs183
2 files changed, 156 insertions, 35 deletions
diff --git a/crypto/src/tls/ByteQueue.cs b/crypto/src/tls/ByteQueue.cs
index e06ad6346..a92f79baf 100644
--- a/crypto/src/tls/ByteQueue.cs
+++ b/crypto/src/tls/ByteQueue.cs
@@ -193,6 +193,14 @@ namespace Org.BouncyCastle.Tls
             return TlsUtilities.ReadInt32(m_databuf, m_skipped);
         }
 
+        public short ReadUint8(int skip)
+        {
+            if (m_available < skip + 1)
+                throw new InvalidOperationException("Not enough data to read");
+
+            return TlsUtilities.ReadUint8(m_databuf, m_skipped + skip);
+        }
+
         public int ReadUint16(int skip)
         {
             if (m_available < skip + 2)
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index 5d8c217b0..860c2dc31 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -3,6 +3,7 @@ using System.IO;
 using System.Net.Sockets;
 
 using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl;
 using Org.BouncyCastle.Utilities;
 using Org.BouncyCastle.Utilities.Date;
 
@@ -234,15 +235,39 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int GetReceiveLimit()
         {
-            return System.Math.Min(m_plaintextLimit,
-                m_readEpoch.Cipher.GetPlaintextLimit(m_transport.GetReceiveLimit() - RECORD_HEADER_LENGTH));
+            int ciphertextLimit = m_transport.GetReceiveLimit() - m_readEpoch.RecordHeaderLengthRead;
+            var cipher = m_readEpoch.Cipher;
+
+            int plaintextDecodeLimit;
+            if (cipher is AbstractTlsCipher abstractTlsCipher)
+            {
+                plaintextDecodeLimit = abstractTlsCipher.GetPlaintextDecodeLimit(ciphertextLimit);
+            }
+            else
+            {
+                plaintextDecodeLimit = cipher.GetPlaintextLimit(ciphertextLimit);
+            }
+
+            return System.Math.Min(m_plaintextLimit, plaintextDecodeLimit);
         }
 
         /// <exception cref="IOException"/>
         public virtual int GetSendLimit()
         {
-            return System.Math.Min(m_plaintextLimit,
-                m_writeEpoch.Cipher.GetPlaintextLimit(m_transport.GetSendLimit() - RECORD_HEADER_LENGTH));
+            var cipher = m_writeEpoch.Cipher;
+            int ciphertextLimit = m_transport.GetSendLimit() - m_writeEpoch.RecordHeaderLengthWrite;
+
+            int plaintextEncodeLimit;
+            if (cipher is AbstractTlsCipher abstractTlsCipher)
+            {
+                plaintextEncodeLimit = abstractTlsCipher.GetPlaintextEncodeLimit(ciphertextLimit);
+            }
+            else
+            {
+                plaintextEncodeLimit = cipher.GetPlaintextLimit(ciphertextLimit);
+            }
+
+            return System.Math.Min(m_plaintextLimit, plaintextEncodeLimit);
         }
 
         /// <exception cref="IOException"/>
@@ -296,18 +321,16 @@ namespace Org.BouncyCastle.Tls
                     waitMillis = 1;
                 }
 
-                int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
+                int receiveLimit = m_transport.GetReceiveLimit();
                 if (null == record || record.Length < receiveLimit)
                 {
                     record = new byte[receiveLimit];
                 }
 
                 int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
-                int processed = ProcessRecord(received, record, buf, off);
+                int processed = ProcessRecord(received, record, buf, off, len);
                 if (processed >= 0)
-                {
                     return processed;
-                }
 
                 currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
                 waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis);
@@ -366,7 +389,7 @@ namespace Org.BouncyCastle.Tls
                     waitMillis = 1;
                 }
 
-                int receiveLimit = System.Math.Min(buffer.Length, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
+                int receiveLimit = m_transport.GetReceiveLimit();
                 if (null == record || record.Length < receiveLimit)
                 {
                     record = new byte[receiveLimit];
@@ -375,9 +398,7 @@ namespace Org.BouncyCastle.Tls
                 int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
                 int processed = ProcessRecord(received, record, buffer);
                 if (processed >= 0)
-                {
                     return processed;
-                }
 
                 currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
                 waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis);
@@ -599,17 +620,13 @@ namespace Org.BouncyCastle.Tls
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
         private int ProcessRecord(int received, byte[] record, Span<byte> buffer)
 #else
-        private int ProcessRecord(int received, byte[] record, byte[] buf, int off)
+        private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len)
 #endif
         {
             // NOTE: received < 0 (timeout) is covered by this first case
             if (received < RECORD_HEADER_LENGTH)
                 return -1;
 
-            int length = TlsUtilities.ReadUint16(record, 11);
-            if (received != (length + RECORD_HEADER_LENGTH))
-                return -1;
-
             // TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers
             short recordType = TlsUtilities.ReadUint8(record, 0);
 
@@ -620,11 +637,16 @@ namespace Org.BouncyCastle.Tls
             case ContentType.change_cipher_spec:
             case ContentType.handshake:
             case ContentType.heartbeat:
+            case ContentType.tls12_cid:
                 break;
             default:
                 return -1;
             }
 
+            ProtocolVersion recordVersion = TlsUtilities.ReadVersion(record, 1);
+            if (!recordVersion.IsDtls)
+                return -1;
+
             int epoch = TlsUtilities.ReadUint16(record, 3);
 
             DtlsEpoch recordEpoch = null;
@@ -645,8 +667,23 @@ namespace Org.BouncyCastle.Tls
             if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
                 return -1;
 
-            ProtocolVersion recordVersion = TlsUtilities.ReadVersion(record, 1);
-            if (!recordVersion.IsDtls)
+
+            int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
+            if (recordHeaderLength > RECORD_HEADER_LENGTH)
+            {
+                if (ContentType.tls12_cid != recordType)
+                    return -1;
+
+                if (received < recordHeaderLength)
+                    return -1;
+
+                byte[] connectionID = m_context.SecurityParameters.ConnectionIDPeer;
+                if (!Arrays.FixedTimeEquals(connectionID.Length, connectionID, 0, record, 11))
+                    return -1;
+            }
+
+            int length = TlsUtilities.ReadUint16(record, recordHeaderLength - 2);
+            if (received != (length + recordHeaderLength))
                 return -1;
 
             if (null != m_readVersion && !m_readVersion.Equals(recordVersion))
@@ -660,7 +697,7 @@ namespace Org.BouncyCastle.Tls
                         ReadEpoch == 0
                     &&  length > 0
                     &&  ContentType.handshake == recordType
-                    &&  HandshakeType.client_hello == TlsUtilities.ReadUint8(record, RECORD_HEADER_LENGTH);
+                    &&  HandshakeType.client_hello == TlsUtilities.ReadUint8(record, recordHeaderLength);
 
                 if (!isClientHelloFragment)
                     return -1;
@@ -668,8 +705,20 @@ namespace Org.BouncyCastle.Tls
 
             long macSeqNo = GetMacSequenceNumber(recordEpoch.Epoch, seq);
 
-            TlsDecodeResult decoded = recordEpoch.Cipher.DecodeCiphertext(macSeqNo, recordType, recordVersion, record,
-                RECORD_HEADER_LENGTH, length);
+            TlsDecodeResult decoded;
+            try
+            {
+                decoded = recordEpoch.Cipher.DecodeCiphertext(macSeqNo, recordType, recordVersion, record,
+                    recordHeaderLength, length);
+            }
+            catch (TlsFatalAlert fatalAlert) when (AlertDescription.bad_record_mac == fatalAlert.AlertDescription)
+            {
+                /*
+                 * RFC 9146 6. DTLS implementations MUST silently discard records with bad MACs or that are otherwise
+                 * invalid.
+                 */
+                return -1;
+            }
 
             recordEpoch.ReplayWindow.ReportAuthenticated(seq);
 
@@ -685,7 +734,7 @@ namespace Org.BouncyCastle.Tls
                         ReadEpoch == 0
                     &&  length > 0
                     &&  ContentType.handshake == recordType
-                    &&  HandshakeType.hello_verify_request == TlsUtilities.ReadUint8(record, RECORD_HEADER_LENGTH);
+                    &&  HandshakeType.hello_verify_request == TlsUtilities.ReadUint8(record, recordHeaderLength);
 
                 if (isHelloVerifyRequest)
                 {
@@ -818,6 +867,7 @@ namespace Org.BouncyCastle.Tls
 
                 return -1;
             }
+            case ContentType.tls12_cid:
             default:
                 return -1;
             }
@@ -833,11 +883,19 @@ namespace Org.BouncyCastle.Tls
                 this.m_retransmitTimeout = null;
             }
 
+            // NOTE: Internal error implies GetReceiveLimit() was not used to allocate result space
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            if (decoded.len > buffer.Length)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+
             decoded.buf.AsSpan(decoded.off, decoded.len).CopyTo(buffer);
 #else
+            if (decoded.len > len)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+
             Array.Copy(decoded.buf, decoded.off, buf, off, decoded.len);
 #endif
+
             return decoded.len;
         }
 
@@ -846,13 +904,38 @@ namespace Org.BouncyCastle.Tls
         {
             if (m_recordQueue.Available > 0)
             {
-                int length = 0;
-                if (m_recordQueue.Available >= RECORD_HEADER_LENGTH)
+                int recordLength = RECORD_HEADER_LENGTH;
+                if (m_recordQueue.Available >= recordLength)
                 {
-                    length = m_recordQueue.ReadUint16(11);
+                    short recordType = m_recordQueue.ReadUint8(0);
+                    int epoch = m_recordQueue.ReadUint16(3);
+
+                    DtlsEpoch recordEpoch = null;
+                    if (epoch == m_readEpoch.Epoch)
+                    {
+                        recordEpoch = m_readEpoch;
+                    }
+                    else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+                        && epoch == m_retransmitEpoch.Epoch)
+                    {
+                        recordEpoch = m_retransmitEpoch;
+                    }
+
+                    if (null == recordEpoch)
+                    {
+                        m_recordQueue.RemoveData(m_recordQueue.Available);
+                        return -1;
+                    }
+
+                    recordLength = recordEpoch.RecordHeaderLengthRead;
+                    if (m_recordQueue.Available >= recordLength)
+                    {
+                        int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2);
+                        recordLength += fragmentLength;
+                    }
                 }
 
-                int received = System.Math.Min(m_recordQueue.Available, RECORD_HEADER_LENGTH + length);
+                int received = System.Math.Min(m_recordQueue.Available, recordLength);
                 m_recordQueue.RemoveData(buf, off, received, 0);
                 return received;
             }
@@ -863,12 +946,33 @@ namespace Org.BouncyCastle.Tls
                 {
                     this.m_inConnection = true;
 
-                    int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11);
-                    int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
-                    if (received > recordLength)
+                    short recordType = TlsUtilities.ReadUint8(buf, off);
+                    int epoch = TlsUtilities.ReadUint16(buf, off + 3);
+
+                    DtlsEpoch recordEpoch = null;
+                    if (epoch == m_readEpoch.Epoch)
+                    {
+                        recordEpoch = m_readEpoch;
+                    }
+                    else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+                        && epoch == m_retransmitEpoch.Epoch)
+                    {
+                        recordEpoch = m_retransmitEpoch;
+                    }
+
+                    if (null == recordEpoch)
+                        return -1;
+
+                    int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
+                    if (received >= recordHeaderLength)
                     {
-                        m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
-                        received = recordLength;
+                        int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2);
+                        int recordLength = recordHeaderLength + fragmentLength;
+                        if (received > recordLength)
+                        {
+                            m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
+                            received = recordLength;
+                        }
                     }
                 }
 
@@ -939,22 +1043,31 @@ namespace Org.BouncyCastle.Tls
                 long macSequenceNumber = GetMacSequenceNumber(recordEpoch, recordSequenceNumber);
                 ProtocolVersion recordVersion = m_writeVersion;
 
+                int recordHeaderLength = m_writeEpoch.RecordHeaderLengthWrite;
+
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
                 TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType,
-                    recordVersion, RECORD_HEADER_LENGTH, buffer);
+                    recordVersion, recordHeaderLength, buffer);
 #else
                 TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType,
-                    recordVersion, RECORD_HEADER_LENGTH, buf, off, len);
+                    recordVersion, recordHeaderLength, buf, off, len);
 #endif
 
-                int ciphertextLength = encoded.len - RECORD_HEADER_LENGTH;
+                int ciphertextLength = encoded.len - recordHeaderLength;
                 TlsUtilities.CheckUint16(ciphertextLength);
 
                 TlsUtilities.WriteUint8(encoded.recordType, encoded.buf, encoded.off + 0);
                 TlsUtilities.WriteVersion(recordVersion, encoded.buf, encoded.off + 1);
                 TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3);
                 TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5);
-                TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + 11);
+
+                if (recordHeaderLength > RECORD_HEADER_LENGTH)
+                {
+                    byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal;
+                    Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length);
+                }
+
+                TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + (recordHeaderLength - 2));
 
                 SendDatagram(m_transport, encoded.buf, encoded.off, encoded.len);
             }