diff options
-rw-r--r-- | crypto/src/tls/ByteQueue.cs | 8 | ||||
-rw-r--r-- | crypto/src/tls/DtlsRecordLayer.cs | 183 |
2 files changed, 156 insertions, 35 deletions
diff --git a/crypto/src/tls/ByteQueue.cs b/crypto/src/tls/ByteQueue.cs index e06ad6346..a92f79baf 100644 --- a/crypto/src/tls/ByteQueue.cs +++ b/crypto/src/tls/ByteQueue.cs @@ -193,6 +193,14 @@ namespace Org.BouncyCastle.Tls return TlsUtilities.ReadInt32(m_databuf, m_skipped); } + public short ReadUint8(int skip) + { + if (m_available < skip + 1) + throw new InvalidOperationException("Not enough data to read"); + + return TlsUtilities.ReadUint8(m_databuf, m_skipped + skip); + } + public int ReadUint16(int skip) { if (m_available < skip + 2) diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index 5d8c217b0..860c2dc31 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -3,6 +3,7 @@ using System.IO; using System.Net.Sockets; using Org.BouncyCastle.Tls.Crypto; +using Org.BouncyCastle.Tls.Crypto.Impl; using Org.BouncyCastle.Utilities; using Org.BouncyCastle.Utilities.Date; @@ -234,15 +235,39 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> public virtual int GetReceiveLimit() { - return System.Math.Min(m_plaintextLimit, - m_readEpoch.Cipher.GetPlaintextLimit(m_transport.GetReceiveLimit() - RECORD_HEADER_LENGTH)); + int ciphertextLimit = m_transport.GetReceiveLimit() - m_readEpoch.RecordHeaderLengthRead; + var cipher = m_readEpoch.Cipher; + + int plaintextDecodeLimit; + if (cipher is AbstractTlsCipher abstractTlsCipher) + { + plaintextDecodeLimit = abstractTlsCipher.GetPlaintextDecodeLimit(ciphertextLimit); + } + else + { + plaintextDecodeLimit = cipher.GetPlaintextLimit(ciphertextLimit); + } + + return System.Math.Min(m_plaintextLimit, plaintextDecodeLimit); } /// <exception cref="IOException"/> public virtual int GetSendLimit() { - return System.Math.Min(m_plaintextLimit, - m_writeEpoch.Cipher.GetPlaintextLimit(m_transport.GetSendLimit() - RECORD_HEADER_LENGTH)); + var cipher = m_writeEpoch.Cipher; + int ciphertextLimit = m_transport.GetSendLimit() - m_writeEpoch.RecordHeaderLengthWrite; + + int plaintextEncodeLimit; + if (cipher is AbstractTlsCipher abstractTlsCipher) + { + plaintextEncodeLimit = abstractTlsCipher.GetPlaintextEncodeLimit(ciphertextLimit); + } + else + { + plaintextEncodeLimit = cipher.GetPlaintextLimit(ciphertextLimit); + } + + return System.Math.Min(m_plaintextLimit, plaintextEncodeLimit); } /// <exception cref="IOException"/> @@ -296,18 +321,16 @@ namespace Org.BouncyCastle.Tls waitMillis = 1; } - int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH; + int receiveLimit = m_transport.GetReceiveLimit(); if (null == record || record.Length < receiveLimit) { record = new byte[receiveLimit]; } int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - int processed = ProcessRecord(received, record, buf, off); + int processed = ProcessRecord(received, record, buf, off, len); if (processed >= 0) - { return processed; - } currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis); @@ -366,7 +389,7 @@ namespace Org.BouncyCastle.Tls waitMillis = 1; } - int receiveLimit = System.Math.Min(buffer.Length, GetReceiveLimit()) + RECORD_HEADER_LENGTH; + int receiveLimit = m_transport.GetReceiveLimit(); if (null == record || record.Length < receiveLimit) { record = new byte[receiveLimit]; @@ -375,9 +398,7 @@ namespace Org.BouncyCastle.Tls int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); int processed = ProcessRecord(received, record, buffer); if (processed >= 0) - { return processed; - } currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis); @@ -599,17 +620,13 @@ namespace Org.BouncyCastle.Tls #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER private int ProcessRecord(int received, byte[] record, Span<byte> buffer) #else - private int ProcessRecord(int received, byte[] record, byte[] buf, int off) + private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len) #endif { // NOTE: received < 0 (timeout) is covered by this first case if (received < RECORD_HEADER_LENGTH) return -1; - int length = TlsUtilities.ReadUint16(record, 11); - if (received != (length + RECORD_HEADER_LENGTH)) - return -1; - // TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers short recordType = TlsUtilities.ReadUint8(record, 0); @@ -620,11 +637,16 @@ namespace Org.BouncyCastle.Tls case ContentType.change_cipher_spec: case ContentType.handshake: case ContentType.heartbeat: + case ContentType.tls12_cid: break; default: return -1; } + ProtocolVersion recordVersion = TlsUtilities.ReadVersion(record, 1); + if (!recordVersion.IsDtls) + return -1; + int epoch = TlsUtilities.ReadUint16(record, 3); DtlsEpoch recordEpoch = null; @@ -645,8 +667,23 @@ namespace Org.BouncyCastle.Tls if (recordEpoch.ReplayWindow.ShouldDiscard(seq)) return -1; - ProtocolVersion recordVersion = TlsUtilities.ReadVersion(record, 1); - if (!recordVersion.IsDtls) + + int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; + if (recordHeaderLength > RECORD_HEADER_LENGTH) + { + if (ContentType.tls12_cid != recordType) + return -1; + + if (received < recordHeaderLength) + return -1; + + byte[] connectionID = m_context.SecurityParameters.ConnectionIDPeer; + if (!Arrays.FixedTimeEquals(connectionID.Length, connectionID, 0, record, 11)) + return -1; + } + + int length = TlsUtilities.ReadUint16(record, recordHeaderLength - 2); + if (received != (length + recordHeaderLength)) return -1; if (null != m_readVersion && !m_readVersion.Equals(recordVersion)) @@ -660,7 +697,7 @@ namespace Org.BouncyCastle.Tls ReadEpoch == 0 && length > 0 && ContentType.handshake == recordType - && HandshakeType.client_hello == TlsUtilities.ReadUint8(record, RECORD_HEADER_LENGTH); + && HandshakeType.client_hello == TlsUtilities.ReadUint8(record, recordHeaderLength); if (!isClientHelloFragment) return -1; @@ -668,8 +705,20 @@ namespace Org.BouncyCastle.Tls long macSeqNo = GetMacSequenceNumber(recordEpoch.Epoch, seq); - TlsDecodeResult decoded = recordEpoch.Cipher.DecodeCiphertext(macSeqNo, recordType, recordVersion, record, - RECORD_HEADER_LENGTH, length); + TlsDecodeResult decoded; + try + { + decoded = recordEpoch.Cipher.DecodeCiphertext(macSeqNo, recordType, recordVersion, record, + recordHeaderLength, length); + } + catch (TlsFatalAlert fatalAlert) when (AlertDescription.bad_record_mac == fatalAlert.AlertDescription) + { + /* + * RFC 9146 6. DTLS implementations MUST silently discard records with bad MACs or that are otherwise + * invalid. + */ + return -1; + } recordEpoch.ReplayWindow.ReportAuthenticated(seq); @@ -685,7 +734,7 @@ namespace Org.BouncyCastle.Tls ReadEpoch == 0 && length > 0 && ContentType.handshake == recordType - && HandshakeType.hello_verify_request == TlsUtilities.ReadUint8(record, RECORD_HEADER_LENGTH); + && HandshakeType.hello_verify_request == TlsUtilities.ReadUint8(record, recordHeaderLength); if (isHelloVerifyRequest) { @@ -818,6 +867,7 @@ namespace Org.BouncyCastle.Tls return -1; } + case ContentType.tls12_cid: default: return -1; } @@ -833,11 +883,19 @@ namespace Org.BouncyCastle.Tls this.m_retransmitTimeout = null; } + // NOTE: Internal error implies GetReceiveLimit() was not used to allocate result space #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + if (decoded.len > buffer.Length) + throw new TlsFatalAlert(AlertDescription.internal_error); + decoded.buf.AsSpan(decoded.off, decoded.len).CopyTo(buffer); #else + if (decoded.len > len) + throw new TlsFatalAlert(AlertDescription.internal_error); + Array.Copy(decoded.buf, decoded.off, buf, off, decoded.len); #endif + return decoded.len; } @@ -846,13 +904,38 @@ namespace Org.BouncyCastle.Tls { if (m_recordQueue.Available > 0) { - int length = 0; - if (m_recordQueue.Available >= RECORD_HEADER_LENGTH) + int recordLength = RECORD_HEADER_LENGTH; + if (m_recordQueue.Available >= recordLength) { - length = m_recordQueue.ReadUint16(11); + short recordType = m_recordQueue.ReadUint8(0); + int epoch = m_recordQueue.ReadUint16(3); + + DtlsEpoch recordEpoch = null; + if (epoch == m_readEpoch.Epoch) + { + recordEpoch = m_readEpoch; + } + else if (recordType == ContentType.handshake && null != m_retransmitEpoch + && epoch == m_retransmitEpoch.Epoch) + { + recordEpoch = m_retransmitEpoch; + } + + if (null == recordEpoch) + { + m_recordQueue.RemoveData(m_recordQueue.Available); + return -1; + } + + recordLength = recordEpoch.RecordHeaderLengthRead; + if (m_recordQueue.Available >= recordLength) + { + int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2); + recordLength += fragmentLength; + } } - int received = System.Math.Min(m_recordQueue.Available, RECORD_HEADER_LENGTH + length); + int received = System.Math.Min(m_recordQueue.Available, recordLength); m_recordQueue.RemoveData(buf, off, received, 0); return received; } @@ -863,12 +946,33 @@ namespace Org.BouncyCastle.Tls { this.m_inConnection = true; - int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11); - int recordLength = RECORD_HEADER_LENGTH + fragmentLength; - if (received > recordLength) + short recordType = TlsUtilities.ReadUint8(buf, off); + int epoch = TlsUtilities.ReadUint16(buf, off + 3); + + DtlsEpoch recordEpoch = null; + if (epoch == m_readEpoch.Epoch) + { + recordEpoch = m_readEpoch; + } + else if (recordType == ContentType.handshake && null != m_retransmitEpoch + && epoch == m_retransmitEpoch.Epoch) + { + recordEpoch = m_retransmitEpoch; + } + + if (null == recordEpoch) + return -1; + + int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; + if (received >= recordHeaderLength) { - m_recordQueue.AddData(buf, off + recordLength, received - recordLength); - received = recordLength; + int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2); + int recordLength = recordHeaderLength + fragmentLength; + if (received > recordLength) + { + m_recordQueue.AddData(buf, off + recordLength, received - recordLength); + received = recordLength; + } } } @@ -939,22 +1043,31 @@ namespace Org.BouncyCastle.Tls long macSequenceNumber = GetMacSequenceNumber(recordEpoch, recordSequenceNumber); ProtocolVersion recordVersion = m_writeVersion; + int recordHeaderLength = m_writeEpoch.RecordHeaderLengthWrite; + #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType, - recordVersion, RECORD_HEADER_LENGTH, buffer); + recordVersion, recordHeaderLength, buffer); #else TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType, - recordVersion, RECORD_HEADER_LENGTH, buf, off, len); + recordVersion, recordHeaderLength, buf, off, len); #endif - int ciphertextLength = encoded.len - RECORD_HEADER_LENGTH; + int ciphertextLength = encoded.len - recordHeaderLength; TlsUtilities.CheckUint16(ciphertextLength); TlsUtilities.WriteUint8(encoded.recordType, encoded.buf, encoded.off + 0); TlsUtilities.WriteVersion(recordVersion, encoded.buf, encoded.off + 1); TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3); TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5); - TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + 11); + + if (recordHeaderLength > RECORD_HEADER_LENGTH) + { + byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal; + Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length); + } + + TlsUtilities.WriteUint16(ciphertextLength, encoded.buf, encoded.off + (recordHeaderLength - 2)); SendDatagram(m_transport, encoded.buf, encoded.off, encoded.len); } |