diff options
Diffstat (limited to 'crypto/src/tls/DtlsRecordLayer.cs')
-rw-r--r-- | crypto/src/tls/DtlsRecordLayer.cs | 55 |
1 files changed, 28 insertions, 27 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index efe9e7312..e3567aa46 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -4,7 +4,6 @@ using System.IO; using System.Net.Sockets; using Org.BouncyCastle.Tls.Crypto; -using Org.BouncyCastle.Tls.Crypto.Impl; using Org.BouncyCastle.Utilities; using Org.BouncyCastle.Utilities.Date; @@ -13,43 +12,45 @@ namespace Org.BouncyCastle.Tls internal class DtlsRecordLayer : DatagramTransport { - private const int RECORD_HEADER_LENGTH = 13; + internal const int RecordHeaderLength = 13; + private const int MAX_FRAGMENT_LENGTH = 1 << 14; private const long TCP_MSL = 1000L * 60 * 2; private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2; /// <exception cref="IOException"/> - internal static byte[] ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen) + internal static int ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen) { - if (dataLen < RECORD_HEADER_LENGTH) - { - return null; - } + if (dataLen < RecordHeaderLength) + return -1; short contentType = TlsUtilities.ReadUint8(data, dataOff + 0); if (ContentType.handshake != contentType) - return null; + return -1; ProtocolVersion version = TlsUtilities.ReadVersion(data, dataOff + 1); if (!ProtocolVersion.DTLSv10.IsEqualOrEarlierVersionOf(version)) - return null; + return -1; int epoch = TlsUtilities.ReadUint16(data, dataOff + 3); if (0 != epoch) - return null; + return -1; //long sequenceNumber = TlsUtilities.ReadUint48(data, dataOff + 5); int length = TlsUtilities.ReadUint16(data, dataOff + 11); - if (dataLen < RECORD_HEADER_LENGTH + length) - return null; + if (length < 1 || length > MAX_FRAGMENT_LENGTH) + return -1; - if (length > MAX_FRAGMENT_LENGTH) - return null; + if (dataLen < RecordHeaderLength + length) + return -1; + + short msgType = TlsUtilities.ReadUint8(data, dataOff + RecordHeaderLength); + if (HandshakeType.client_hello != msgType) + return -1; // NOTE: We ignore/drop any data after the first record - return TlsUtilities.CopyOfRangeExact(data, dataOff + RECORD_HEADER_LENGTH, - dataOff + RECORD_HEADER_LENGTH + length); + return length; } /// <exception cref="IOException"/> @@ -57,14 +58,14 @@ namespace Org.BouncyCastle.Tls { TlsUtilities.CheckUint16(message.Length); - byte[] record = new byte[RECORD_HEADER_LENGTH + message.Length]; + byte[] record = new byte[RecordHeaderLength + message.Length]; TlsUtilities.WriteUint8(ContentType.handshake, record, 0); TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, record, 1); TlsUtilities.WriteUint16(0, record, 3); TlsUtilities.WriteUint48(recordSeq, record, 5); TlsUtilities.WriteUint16(message.Length, record, 11); - Array.Copy(message, 0, record, RECORD_HEADER_LENGTH, message.Length); + Array.Copy(message, 0, record, RecordHeaderLength, message.Length); SendDatagram(sender, record, 0, record.Length); } @@ -124,8 +125,8 @@ namespace Org.BouncyCastle.Tls this.m_inHandshake = true; - this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RECORD_HEADER_LENGTH, - RECORD_HEADER_LENGTH); + this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RecordHeaderLength, + RecordHeaderLength); this.m_pendingEpoch = null; this.m_readEpoch = m_currentEpoch; this.m_writeEpoch = m_currentEpoch; @@ -179,8 +180,8 @@ namespace Org.BouncyCastle.Tls */ var securityParameters = m_context.SecurityParameters; - int recordHeaderLengthRead = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDPeer?.Length ?? 0); - int recordHeaderLengthWrite = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDLocal?.Length ?? 0); + int recordHeaderLengthRead = RecordHeaderLength + (securityParameters.ConnectionIDPeer?.Length ?? 0); + int recordHeaderLengthWrite = RecordHeaderLength + (securityParameters.ConnectionIDLocal?.Length ?? 0); // TODO Check for overflow this.m_pendingEpoch = new DtlsEpoch(m_writeEpoch.Epoch + 1, pendingCipher, recordHeaderLengthRead, @@ -684,7 +685,7 @@ namespace Org.BouncyCastle.Tls #endif { // NOTE: received < 0 (timeout) is covered by this first case - if (received < RECORD_HEADER_LENGTH) + if (received < RecordHeaderLength) return -1; // TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers @@ -729,7 +730,7 @@ namespace Org.BouncyCastle.Tls int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; - if (recordHeaderLength > RECORD_HEADER_LENGTH) + if (recordHeaderLength > RecordHeaderLength) { if (ContentType.tls12_cid != recordType) return -1; @@ -990,7 +991,7 @@ namespace Org.BouncyCastle.Tls { Debug.Assert(m_recordQueue.Available > 0); - int recordLength = RECORD_HEADER_LENGTH; + int recordLength = RecordHeaderLength; if (m_recordQueue.Available >= recordLength) { short recordType = m_recordQueue.ReadUint8(0); @@ -1033,7 +1034,7 @@ namespace Org.BouncyCastle.Tls return ReceivePendingRecord(buf, off, len); int received = ReceiveDatagram(buf, off, len, waitMillis); - if (received >= RECORD_HEADER_LENGTH) + if (received >= RecordHeaderLength) { this.m_inConnection = true; @@ -1151,7 +1152,7 @@ namespace Org.BouncyCastle.Tls TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3); TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5); - if (recordHeaderLength > RECORD_HEADER_LENGTH) + if (recordHeaderLength > RecordHeaderLength) { byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal; Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length); |