summary refs log tree commit diff
path: root/crypto/src
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2017-08-13 13:13:18 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2017-08-13 13:13:18 +0700
commit3207ab5de93280623f50bb320577f8bbe3d38354 (patch)
treeac880c300e8e93e9099733f71b1b40e672efc33c /crypto/src
parentUse ffdhe2048 from RFC 7919 as TLS default DH group (diff)
downloadBouncyCastle.NET-ed25519-3207ab5de93280623f50bb320577f8bbe3d38354.tar.xz
Support receiving DTLS records containing multiple handshake messages
- see https://github.com/bcgit/bc-csharp/issues/85
Diffstat (limited to 'crypto/src')
-rw-r--r--crypto/src/crypto/tls/DtlsRecordLayer.cs5
-rw-r--r--crypto/src/crypto/tls/DtlsReliableHandshake.cs289
2 files changed, 141 insertions, 153 deletions
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 3c3e1821f..39e018810 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -52,6 +52,11 @@ namespace Org.BouncyCastle.Crypto.Tls
             this.mPlaintextLimit = plaintextLimit;
         }
 
+        internal virtual int ReadEpoch
+        {
+            get { return mReadEpoch.Epoch; }
+        }
+
         internal virtual ProtocolVersion ReadVersion
         {
             get { return mReadVersion; }
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
index 18a41769a..396ea7483 100644
--- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
@@ -8,7 +8,8 @@ namespace Org.BouncyCastle.Crypto.Tls
 {
     internal class DtlsReliableHandshake
     {
-        private const int MAX_RECEIVE_AHEAD = 10;
+        private const int MaxReceiveAhead = 16;
+        private const int MessageHeaderLength = 12;
 
         private readonly DtlsRecordLayer mRecordLayer;
 
@@ -78,21 +79,7 @@ namespace Org.BouncyCastle.Crypto.Tls
             if (mSending)
             {
                 mSending = false;
-                PrepareInboundFlight();
-            }
-
-            // Check if we already have the next message waiting
-            {
-                DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq];
-                if (next != null)
-                {
-                    byte[] body = next.GetBodyIfComplete();
-                    if (body != null)
-                    {
-                        mPreviousInboundFlight = null;
-                        return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body));
-                    }
-                }
+                PrepareInboundFlight(Platform.CreateHashtable());
             }
 
             byte[] buf = null;
@@ -102,110 +89,38 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             for (;;)
             {
-                int receiveLimit = mRecordLayer.GetReceiveLimit();
-                if (buf == null || buf.Length < receiveLimit)
-                {
-                    buf = new byte[receiveLimit];
-                }
-
-                // TODO Handle records containing multiple handshake messages
-
                 try
                 {
-                    for (; ; )
+                    for (;;)
                     {
+                        Message pending = GetPendingMessage();
+                        if (pending != null)
+                            return pending;
+
+                        int receiveLimit = mRecordLayer.GetReceiveLimit();
+                        if (buf == null || buf.Length < receiveLimit)
+                        {
+                            buf = new byte[receiveLimit];
+                        }
+
                         int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
                         if (received < 0)
-                        {
                             break;
-                        }
-                        if (received < 12)
-                        {
-                            continue;
-                        }
-                        int fragment_length = TlsUtilities.ReadUint24(buf, 9);
-                        if (received != (fragment_length + 12))
-                        {
-                            continue;
-                        }
-                        int seq = TlsUtilities.ReadUint16(buf, 4);
-                        if (seq > (mNextReceiveSeq + MAX_RECEIVE_AHEAD))
-                        {
-                            continue;
-                        }
-                        byte msg_type = TlsUtilities.ReadUint8(buf, 0);
-                        int length = TlsUtilities.ReadUint24(buf, 1);
-                        int fragment_offset = TlsUtilities.ReadUint24(buf, 6);
-                        if (fragment_offset + fragment_length > length)
-                        {
-                            continue;
-                        }
 
-                        if (seq < mNextReceiveSeq)
-                        {
-                            /*
-                             * NOTE: If we Receive the previous flight of incoming messages in full
-                             * again, retransmit our last flight
-                             */
-                            if (mPreviousInboundFlight != null)
-                            {
-                                DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[seq];
-                                if (reassembler != null)
-                                {
-                                    reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset,
-                                        fragment_length);
-
-                                    if (CheckAll(mPreviousInboundFlight))
-                                    {
-                                        ResendOutboundFlight();
-
-                                        /*
-                                         * TODO[DTLS] implementations SHOULD back off handshake packet
-                                         * size during the retransmit backoff.
-                                         */
-                                        readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000);
-
-                                        ResetAll(mPreviousInboundFlight);
-                                    }
-                                }
-                            }
-                        }
-                        else
+                        bool resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
+                        if (resentOutbound)
                         {
-                            DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq];
-                            if (reassembler == null)
-                            {
-                                reassembler = new DtlsReassembler(msg_type, length);
-                                mCurrentInboundFlight[seq] = reassembler;
-                            }
-
-                            reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset, fragment_length);
-
-                            if (seq == mNextReceiveSeq)
-                            {
-                                byte[] body = reassembler.GetBodyIfComplete();
-                                if (body != null)
-                                {
-                                    mPreviousInboundFlight = null;
-                                    return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++,
-                                        reassembler.MsgType, body));
-                                }
-                            }
+                            readTimeoutMillis = BackOff(readTimeoutMillis);
                         }
                     }
                 }
-                catch (IOException)
+                catch (IOException e)
                 {
                     // NOTE: Assume this is a timeout for the moment
                 }
 
                 ResendOutboundFlight();
-
-                /*
-                 * TODO[DTLS] implementations SHOULD back off handshake packet size during the
-                 * retransmit backoff.
-                 */
-                readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000);
+                readTimeoutMillis = BackOff(readTimeoutMillis);
             }
         }
 
@@ -216,15 +131,20 @@ namespace Org.BouncyCastle.Crypto.Tls
             {
                 CheckInboundFlight();
             }
-            else if (mCurrentInboundFlight != null)
+            else
             {
-                /*
-                 * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
-                 * when in the FINISHED state, the node that transmits the last flight (the server in an
-                 * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
-                 * of the peer's last flight with a retransmit of the last flight.
-                 */
-                retransmit = new Retransmit(this);
+                PrepareInboundFlight(null);
+
+                if (mPreviousInboundFlight != null)
+                {
+                    /*
+                     * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
+                     * when in the FINISHED state, the node that transmits the last flight (the server in an
+                     * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
+                     * of the peer's last flight with a retransmit of the last flight.
+                     */
+                    retransmit = new Retransmit(this);
+                }
             }
 
             mRecordLayer.HandshakeSuccessful(retransmit);
@@ -235,44 +155,13 @@ namespace Org.BouncyCastle.Crypto.Tls
             mHandshakeHash.Reset();
         }
 
-        private void HandleRetransmittedHandshakeRecord(int epoch, byte[] buf, int off, int len)
+        private int BackOff(int timeoutMillis)
         {
             /*
-             * TODO Need to handle the case where the previous inbound flight contains
-             * messages from two epochs.
+             * TODO[DTLS] implementations SHOULD back off handshake packet size during the
+             * retransmit backoff.
              */
-            if (len < 12)
-                return;
-            int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
-            if (len != (fragment_length + 12))
-                return;
-            int seq = TlsUtilities.ReadUint16(buf, off + 4);
-            if (seq >= mNextReceiveSeq)
-                return;
-
-            byte msg_type = TlsUtilities.ReadUint8(buf, off);
-
-            // TODO This is a hack that only works until we try to support renegotiation
-            int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
-            if (epoch != expectedEpoch)
-                return;
-
-            int length = TlsUtilities.ReadUint24(buf, off + 1);
-            int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
-            if (fragment_offset + fragment_length > length)
-                return;
-
-            DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq];
-            if (reassembler != null)
-            {
-                reassembler.ContributeFragment(msg_type, length, buf, off + 12, fragment_offset,
-                    fragment_length);
-                if (CheckAll(mCurrentInboundFlight))
-                {
-                    ResendOutboundFlight();
-                    ResetAll(mCurrentInboundFlight);
-                }
-            }
+            return System.Math.Min(timeoutMillis * 2, 60000);
         }
 
         /**
@@ -289,11 +178,105 @@ namespace Org.BouncyCastle.Crypto.Tls
             }
         }
 
-        private void PrepareInboundFlight()
+        private Message GetPendingMessage()
+        {
+            DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq];
+            if (next != null)
+            {
+                byte[] body = next.GetBodyIfComplete();
+                if (body != null)
+                {
+                    mPreviousInboundFlight = null;
+                    return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body));
+                }
+            }
+            return null;
+        }
+
+        private void PrepareInboundFlight(IDictionary nextFlight)
         {
             ResetAll(mCurrentInboundFlight);
             mPreviousInboundFlight = mCurrentInboundFlight;
-            mCurrentInboundFlight = Platform.CreateHashtable();
+            mCurrentInboundFlight = nextFlight;
+        }
+
+        private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
+        {
+            bool checkPreviousFlight = false;
+
+            while (len >= MessageHeaderLength)
+            {
+                int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
+                int message_length = fragment_length + MessageHeaderLength;
+                if (len < message_length)
+                {
+                    // NOTE: Truncated message - ignore it
+                    break;
+                }
+
+                int length = TlsUtilities.ReadUint24(buf, off + 1);
+                int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
+                if (fragment_offset + fragment_length > length)
+                {
+                    // NOTE: Malformed fragment - ignore it and the rest of the record
+                    break;
+                }
+
+                /*
+                 * NOTE: This very simple epoch check will only work until we want to support
+                 * renegotiation (and we're not likely to do that anyway).
+                 */
+                byte msg_type = TlsUtilities.ReadUint8(buf, off + 0);
+                int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
+                if (epoch != expectedEpoch)
+                {
+                    break;
+                }
+
+                int message_seq = TlsUtilities.ReadUint16(buf, off + 4);
+                if (message_seq >= (mNextReceiveSeq + windowSize))
+                {
+                    // NOTE: Too far ahead - ignore
+                }
+                else if (message_seq >= mNextReceiveSeq)
+                {
+                    DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[message_seq];
+                    if (reassembler == null)
+                    {
+                        reassembler = new DtlsReassembler(msg_type, length);
+                        mCurrentInboundFlight[message_seq] = reassembler;
+                    }
+
+                    reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
+                        fragment_length);
+                }
+                else if (mPreviousInboundFlight != null)
+                {
+                    /*
+                     * NOTE: If we receive the previous flight of incoming messages in full again,
+                     * retransmit our last flight
+                     */
+
+                    DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[message_seq];
+                    if (reassembler != null)
+                    {
+                        reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
+                            fragment_length);
+                        checkPreviousFlight = true;
+                    }
+                }
+
+                off += message_length;
+                len -= message_length;
+            }
+
+            bool result = checkPreviousFlight && CheckAll(mPreviousInboundFlight);
+            if (result)
+            {
+                ResendOutboundFlight();
+                ResetAll(mPreviousInboundFlight);
+            }
+            return result;
         }
 
         private void ResendOutboundFlight()
@@ -310,7 +293,7 @@ namespace Org.BouncyCastle.Crypto.Tls
             if (message.Type != HandshakeType.hello_request)
             {
                 byte[] body = message.Body;
-                byte[] buf = new byte[12];
+                byte[] buf = new byte[MessageHeaderLength];
                 TlsUtilities.WriteUint8(message.Type, buf, 0);
                 TlsUtilities.WriteUint24(body.Length, buf, 1);
                 TlsUtilities.WriteUint16(message.Seq, buf, 4);
@@ -325,7 +308,7 @@ namespace Org.BouncyCastle.Crypto.Tls
         private void WriteMessage(Message message)
         {
             int sendLimit = mRecordLayer.GetSendLimit();
-            int fragmentLimit = sendLimit - 12;
+            int fragmentLimit = sendLimit - MessageHeaderLength;
 
             // TODO Support a higher minimum fragment size?
             if (fragmentLimit < 1)
@@ -349,7 +332,7 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
         {
-            RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length);
+            RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length);
             TlsUtilities.WriteUint8(message.Type, fragment);
             TlsUtilities.WriteUint24(message.Body.Length, fragment);
             TlsUtilities.WriteUint16(message.Seq, fragment);
@@ -444,7 +427,7 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len)
             {
-                mOuter.HandleRetransmittedHandshakeRecord(epoch, buf, off, len);
+                mOuter.ProcessRecord(0, epoch, buf, off, len);
             }
         }
     }