summary refs log tree commit diff
path: root/crypto/src/tls/DtlsRecordLayer.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/DtlsRecordLayer.cs')
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs55
1 files changed, 28 insertions, 27 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index efe9e7312..e3567aa46 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -4,7 +4,6 @@ using System.IO;
 using System.Net.Sockets;
 
 using Org.BouncyCastle.Tls.Crypto;
-using Org.BouncyCastle.Tls.Crypto.Impl;
 using Org.BouncyCastle.Utilities;
 using Org.BouncyCastle.Utilities.Date;
 
@@ -13,43 +12,45 @@ namespace Org.BouncyCastle.Tls
     internal class DtlsRecordLayer
         : DatagramTransport
     {
-        private const int RECORD_HEADER_LENGTH = 13;
+        internal const int RecordHeaderLength = 13;
+
         private const int MAX_FRAGMENT_LENGTH = 1 << 14;
         private const long TCP_MSL = 1000L * 60 * 2;
         private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
 
         /// <exception cref="IOException"/>
-        internal static byte[] ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen)
+        internal static int ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen)
         {
-            if (dataLen < RECORD_HEADER_LENGTH)
-            {
-                return null;
-            }
+            if (dataLen < RecordHeaderLength)
+                return -1;
 
             short contentType = TlsUtilities.ReadUint8(data, dataOff + 0);
             if (ContentType.handshake != contentType)
-                return null;
+                return -1;
 
             ProtocolVersion version = TlsUtilities.ReadVersion(data, dataOff + 1);
             if (!ProtocolVersion.DTLSv10.IsEqualOrEarlierVersionOf(version))
-                return null;
+                return -1;
 
             int epoch = TlsUtilities.ReadUint16(data, dataOff + 3);
             if (0 != epoch)
-                return null;
+                return -1;
 
             //long sequenceNumber = TlsUtilities.ReadUint48(data, dataOff + 5);
 
             int length = TlsUtilities.ReadUint16(data, dataOff + 11);
-            if (dataLen < RECORD_HEADER_LENGTH + length)
-                return null;
+            if (length < 1 || length > MAX_FRAGMENT_LENGTH)
+                return -1;
 
-            if (length > MAX_FRAGMENT_LENGTH)
-                return null;
+            if (dataLen < RecordHeaderLength + length)
+                return -1;
+
+            short msgType = TlsUtilities.ReadUint8(data, dataOff + RecordHeaderLength);
+            if (HandshakeType.client_hello != msgType)
+                return -1;
 
             // NOTE: We ignore/drop any data after the first record 
-            return TlsUtilities.CopyOfRangeExact(data, dataOff + RECORD_HEADER_LENGTH,
-                dataOff + RECORD_HEADER_LENGTH + length);
+            return length;
         }
 
         /// <exception cref="IOException"/>
@@ -57,14 +58,14 @@ namespace Org.BouncyCastle.Tls
         {
             TlsUtilities.CheckUint16(message.Length);
 
-            byte[] record = new byte[RECORD_HEADER_LENGTH + message.Length];
+            byte[] record = new byte[RecordHeaderLength + message.Length];
             TlsUtilities.WriteUint8(ContentType.handshake, record, 0);
             TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, record, 1);
             TlsUtilities.WriteUint16(0, record, 3);
             TlsUtilities.WriteUint48(recordSeq, record, 5);
             TlsUtilities.WriteUint16(message.Length, record, 11);
 
-            Array.Copy(message, 0, record, RECORD_HEADER_LENGTH, message.Length);
+            Array.Copy(message, 0, record, RecordHeaderLength, message.Length);
 
             SendDatagram(sender, record, 0, record.Length);
         }
@@ -124,8 +125,8 @@ namespace Org.BouncyCastle.Tls
 
             this.m_inHandshake = true;
 
-            this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RECORD_HEADER_LENGTH,
-                RECORD_HEADER_LENGTH);
+            this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RecordHeaderLength,
+                RecordHeaderLength);
             this.m_pendingEpoch = null;
             this.m_readEpoch = m_currentEpoch;
             this.m_writeEpoch = m_currentEpoch;
@@ -179,8 +180,8 @@ namespace Org.BouncyCastle.Tls
              */
 
             var securityParameters = m_context.SecurityParameters;
-            int recordHeaderLengthRead = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDPeer?.Length ?? 0);
-            int recordHeaderLengthWrite = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDLocal?.Length ?? 0);
+            int recordHeaderLengthRead = RecordHeaderLength + (securityParameters.ConnectionIDPeer?.Length ?? 0);
+            int recordHeaderLengthWrite = RecordHeaderLength + (securityParameters.ConnectionIDLocal?.Length ?? 0);
 
             // TODO Check for overflow
             this.m_pendingEpoch = new DtlsEpoch(m_writeEpoch.Epoch + 1, pendingCipher, recordHeaderLengthRead,
@@ -684,7 +685,7 @@ namespace Org.BouncyCastle.Tls
 #endif
         {
             // NOTE: received < 0 (timeout) is covered by this first case
-            if (received < RECORD_HEADER_LENGTH)
+            if (received < RecordHeaderLength)
                 return -1;
 
             // TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers
@@ -729,7 +730,7 @@ namespace Org.BouncyCastle.Tls
 
 
             int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
-            if (recordHeaderLength > RECORD_HEADER_LENGTH)
+            if (recordHeaderLength > RecordHeaderLength)
             {
                 if (ContentType.tls12_cid != recordType)
                     return -1;
@@ -990,7 +991,7 @@ namespace Org.BouncyCastle.Tls
         {
             Debug.Assert(m_recordQueue.Available > 0);
 
-            int recordLength = RECORD_HEADER_LENGTH;
+            int recordLength = RecordHeaderLength;
             if (m_recordQueue.Available >= recordLength)
             {
                 short recordType = m_recordQueue.ReadUint8(0);
@@ -1033,7 +1034,7 @@ namespace Org.BouncyCastle.Tls
                 return ReceivePendingRecord(buf, off, len);
 
             int received = ReceiveDatagram(buf, off, len, waitMillis);
-            if (received >= RECORD_HEADER_LENGTH)
+            if (received >= RecordHeaderLength)
             {
                 this.m_inConnection = true;
 
@@ -1151,7 +1152,7 @@ namespace Org.BouncyCastle.Tls
                 TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3);
                 TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5);
 
-                if (recordHeaderLength > RECORD_HEADER_LENGTH)
+                if (recordHeaderLength > RecordHeaderLength)
                 {
                     byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal;
                     Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length);