diff --git a/crypto/src/tls/ByteQueue.cs b/crypto/src/tls/ByteQueue.cs
index e06ad6346..a92f79baf 100644
--- a/crypto/src/tls/ByteQueue.cs
+++ b/crypto/src/tls/ByteQueue.cs
@@ -193,6 +193,14 @@ namespace Org.BouncyCastle.Tls
return TlsUtilities.ReadInt32(m_databuf, m_skipped);
}
+ public short ReadUint8(int skip)
+ {
+ if (m_available < skip + 1)
+ throw new InvalidOperationException("Not enough data to read");
+
+ return TlsUtilities.ReadUint8(m_databuf, m_skipped + skip);
+ }
+
public int ReadUint16(int skip)
{
if (m_available < skip + 2)
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index 5d8c217b0..860c2dc31 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -3,6 +3,7 @@ using System.IO;
using System.Net.Sockets;
using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Tls.Crypto.Impl;
using Org.BouncyCastle.Utilities;
using Org.BouncyCastle.Utilities.Date;
@@ -234,15 +235,39 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
public virtual int GetReceiveLimit()
{
- return System.Math.Min(m_plaintextLimit,
- m_readEpoch.Cipher.GetPlaintextLimit(m_transport.GetReceiveLimit() - RECORD_HEADER_LENGTH));
+ int ciphertextLimit = m_transport.GetReceiveLimit() - m_readEpoch.RecordHeaderLengthRead;
+ var cipher = m_readEpoch.Cipher;
+
+ int plaintextDecodeLimit;
+ if (cipher is AbstractTlsCipher abstractTlsCipher)
+ {
+ plaintextDecodeLimit = abstractTlsCipher.GetPlaintextDecodeLimit(ciphertextLimit);
+ }
+ else
+ {
+ plaintextDecodeLimit = cipher.GetPlaintextLimit(ciphertextLimit);
+ }
+
+ return System.Math.Min(m_plaintextLimit, plaintextDecodeLimit);
}
/// <exception cref="IOException"/>
public virtual int GetSendLimit()
{
- return System.Math.Min(m_plaintextLimit,
- m_writeEpoch.Cipher.GetPlaintextLimit(m_transport.GetSendLimit() - RECORD_HEADER_LENGTH));
+ var cipher = m_writeEpoch.Cipher;
+ int ciphertextLimit = m_transport.GetSendLimit() - m_writeEpoch.RecordHeaderLengthWrite;
+
+ int plaintextEncodeLimit;
+ if (cipher is AbstractTlsCipher abstractTlsCipher)
+ {
+ plaintextEncodeLimit = abstractTlsCipher.GetPlaintextEncodeLimit(ciphertextLimit);
+ }
+ else
+ {
+ plaintextEncodeLimit = cipher.GetPlaintextLimit(ciphertextLimit);
+ }
+
+ return System.Math.Min(m_plaintextLimit, plaintextEncodeLimit);
}
/// <exception cref="IOException"/>
@@ -296,18 +321,16 @@ namespace Org.BouncyCastle.Tls
waitMillis = 1;
}
- int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
+ int receiveLimit = m_transport.GetReceiveLimit();
if (null == record || record.Length < receiveLimit)
{
record = new byte[receiveLimit];
}
int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
- int processed = ProcessRecord(received, record, buf, off);
+ int processed = ProcessRecord(received, record, buf, off, len);
if (processed >= 0)
- {
return processed;
- }
currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis);
@@ -366,7 +389,7 @@ namespace Org.BouncyCastle.Tls
waitMillis = 1;
}
- int receiveLimit = System.Math.Min(buffer.Length, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
+ int receiveLimit = m_transport.GetReceiveLimit();
if (null == record || record.Length < receiveLimit)
{
record = new byte[receiveLimit];
@@ -375,9 +398,7 @@ namespace Org.BouncyCastle.Tls
int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
int processed = ProcessRecord(received, record, buffer);
if (processed >= 0)
- {
return processed;
- }
currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis);
@@ -599,17 +620,13 @@ namespace Org.BouncyCastle.Tls
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
private int ProcessRecord(int received, byte[] record, Span<byte> buffer)
#else
- private int ProcessRecord(int received, byte[] record, byte[] buf, int off)
+ private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len)
#endif
{
// NOTE: received < 0 (timeout) is covered by this first case
if (received < RECORD_HEADER_LENGTH)
return -1;
- int length = TlsUtilities.ReadUint16(record, 11);
- if (received != (length + RECORD_HEADER_LENGTH))
- return -1;
-
// TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers
short recordType = TlsUtilities.ReadUint8(record, 0);
@@ -620,11 +637,16 @@ namespace Org.BouncyCastle.Tls
case ContentType.change_cipher_spec:
case ContentType.handshake:
case ContentType.heartbeat:
+ case ContentType.tls12_cid:
break;
default:
return -1;
}
+ ProtocolVersion recordVersion = TlsUtilities.ReadVersion(record, 1);
+ if (!recordVersion.IsDtls)
+ return -1;
+
int epoch = TlsUtilities.ReadUint16(record, 3);
DtlsEpoch recordEpoch = null;
@@ -645,8 +667,23 @@ namespace Org.BouncyCastle.Tls
if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
return -1;
- ProtocolVersion recordVersion = TlsUtilities.ReadVersion(record, 1);
- if (!recordVersion.IsDtls)
+
+ int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
+ if (recordHeaderLength > RECORD_HEADER_LENGTH)
+ {
+ if (ContentType.tls12_cid != recordType)
+ return -1;
+
+ if (received < recordHeaderLength)
+ return -1;
+
+ byte[] connectionID = m_context.SecurityParameters.ConnectionIDPeer;
+ if (!Arrays.FixedTimeEquals(connectionID.Length, connectionID, 0, record, 11))
+ return -1;
+ }
+
+ int length = TlsUtilities.ReadUint16(record, recordHeaderLength - 2);
+ if (received != (length + recordHeaderLength))
return -1;
if (null != m_readVersion && !m_readVersion.Equals(recordVersion))
@@ -660,7 +697,7 @@ namespace Org.BouncyCastle.Tls
ReadEpoch == 0
&& length > 0
&& ContentType.handshake == recordType
- && HandshakeType.client_hello == TlsUtilities.ReadUint8(record, RECORD_HEADER_LENGTH);
+ && HandshakeType.client_hello == TlsUtilities.ReadUint8(record, recordHeaderLength);
if (!isClientHelloFragment)
return -1;
@@ -668,8 +705,20 @@ namespace Org.BouncyCastle.Tls
long macSeqNo = GetMacSequenceNumber(recordEpoch.Epoch, seq);
- TlsDecodeResult decoded = recordEpoch.Cipher.DecodeCiphertext(macSeqNo, recordType, recordVersion, record,
- RECORD_HEADER_LENGTH, length);
+ TlsDecodeResult decoded;
+ try
+ {
+ decoded = recordEpoch.Cipher.DecodeCiphertext(macSeqNo, recordType, recordVersion, record,
+ recordHeaderLength, length);
+ }
+ catch (TlsFatalAlert fatalAlert) when (AlertDescription.bad_record_mac == fatalAlert.AlertDescription)
+ {
+ /*
+ * RFC 9146 6. DTLS implementations MUST silently discard records with bad MACs or that are otherwise
+ * invalid.
+ */
+ return -1;
+ }
recordEpoch.ReplayWindow.ReportAuthenticated(seq);
@@ -685,7 +734,7 @@ namespace Org.BouncyCastle.Tls
ReadEpoch == 0
&& length > 0
&& ContentType.handshake == recordType
- && HandshakeType.hello_verify_request == TlsUtilities.ReadUint8(record, RECORD_HEADER_LENGTH);
+ && HandshakeType.hello_verify_request == TlsUtilities.ReadUint8(record, recordHeaderLength);
if (isHelloVerifyRequest)
{
@@ -818,6 +867,7 @@ namespace Org.BouncyCastle.Tls
return -1;
}
+ case ContentType.tls12_cid:
default:
return -1;
}
@@ -833,11 +883,19 @@ namespace Org.BouncyCastle.Tls
this.m_retransmitTimeout = null;
}
+ // NOTE: Internal error implies GetReceiveLimit() was not used to allocate result space
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ if (decoded.len > buffer.Length)
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+
decoded.buf.AsSpan(decoded.off, decoded.len).CopyTo(buffer);
#else
+ if (decoded.len > len)
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+
Array.Copy(decoded.buf, decoded.off, buf, off, decoded.len);
#endif
+
return decoded.len;
}
@@ -846,13 +904,38 @@ namespace Org.BouncyCastle.Tls
{
if (m_recordQueue.Available > 0)
{
- int length = 0;
- if (m_recordQueue.Available >= RECORD_HEADER_LENGTH)
+ int recordLength = RECORD_HEADER_LENGTH;
+ if (m_recordQueue.Available >= recordLength)
{
- length = m_recordQueue.ReadUint16(11);
+ short recordType = m_recordQueue.ReadUint8(0);
+ int epoch = m_recordQueue.ReadUint16(3);
+
+ DtlsEpoch recordEpoch = null;
+ if (epoch == m_readEpoch.Epoch)
+ {
+ recordEpoch = m_readEpoch;
+ }
+ else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+ && epoch == m_retransmitEpoch.Epoch)
+ {
+ recordEpoch = m_retransmitEpoch;
+ }
+
+ if (null == recordEpoch)
+ {
+ m_recordQueue.RemoveData(m_recordQueue.Available);
+ return -1;
+ }
+
+ recordLength = recordEpoch.RecordHeaderLengthRead;
+ if (m_recordQueue.Available >= recordLength)
+ {
+ int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2);
+ recordLength += fragmentLength;
+ }
}
- int received = System.Math.Min(m_recordQueue.Available, RECORD_HEADER_LENGTH + length);
+ int received = System.Math.Min(m_recordQueue.Available, recordLength);
m_recordQueue.RemoveData(buf, off, received, 0);
return received;
}
@@ -863,12 +946,33 @@ namespace Org.BouncyCastle.Tls
{
this.m_inConnection = true;
- int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11);
- int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
- if (received > recordLength)
+ short recordType = TlsUtilities.ReadUint8(buf, off);
+ int epoch = TlsUtilities.ReadUint16(buf, off + 3);
+
+ DtlsEpoch recordEpoch = null;
+ if (epoch == m_readEpoch.Epoch)
+ {
+ recordEpoch = m_readEpoch;
+ }
+ else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+ && epoch == m_retransmitEpoch.Epoch)
+ {
+ recordEpoch = m_retransmitEpoch;
+ }
+
+ if (null == recordEpoch)
+ return -1;
+
+ int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
+ if (received >= recordHeaderLength)
{
- m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
- received = recordLength;
+ int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2);
+ int recordLength = recordHeaderLength + fragmentLength;
+ if (received > recordLength)
+ {
+ m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
+ received = recordLength;
+ }
}
}
@@ -939,22 +1043,31 @@ namespace Org.BouncyCastle.Tls
long macSequenceNumber = GetMacSequenceNumber(recordEpoch, recordSequenceNumber);
ProtocolVersion recordVersion = m_writeVersion;
+ int recordHeaderLength = m_writeEpoch.RecordHeaderLengthWrite;
+
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType,
- recordVersion, RECORD_HEADER_LENGTH, buffer);
+ recordVersion, recordHeaderLength, buffer);
#else
TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType,
- recordVersion, RECORD_HEADER_LENGTH, buf, off, len);
+ recordVersion, recordHeaderLength, buf, off, len);
#endif
- int ciphertextLength = encoded.len - RECORD_HEADER_LENGTH;
+ int ciphertextLength = encoded.len - recordHeaderLength;
TlsUtilities.CheckUint16(ciphertextLength);
TlsUtilities.WriteUint8(encoded.recordType, encoded.buf, encoded.off + 0);
TlsUtilities.WriteVersion(recordVersion, encoded.buf, encoded.off + 1);
TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3);
TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5);
- TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + 11);
+
+ if (recordHeaderLength > RECORD_HEADER_LENGTH)
+ {
+ byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal;
+ Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length);
+ }
+
+ TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + (recordHeaderLength - 2));
SendDatagram(m_transport, encoded.buf, encoded.off, encoded.len);
}
|