summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2020-07-30 02:24:54 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2020-07-30 02:24:54 +0700
commitd7b5df9df2099487c62342a9bfbc30e40711788b (patch)
treec188933e1cb1c8d3dbced0c266f5edb285debc23
parentDTLS: Exceptions properly abort handshake (diff)
downloadBouncyCastle.NET-ed25519-d7b5df9df2099487c62342a9bfbc30e40711788b.tar.xz
DTLS: Improved retransmission timer
-rw-r--r--crypto/src/crypto/tls/DtlsRecordLayer.cs369
-rw-r--r--crypto/src/crypto/tls/DtlsReliableHandshake.cs53
2 files changed, 222 insertions, 200 deletions
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 266893df0..c1a26b14f 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -45,7 +45,7 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         private DtlsHandshakeRetransmit mRetransmit = null;
         private DtlsEpoch mRetransmitEpoch = null;
-        private long mRetransmitExpiry = 0;
+        private Timeout mRetransmitTimeout = null;
 
         internal DtlsRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, byte contentType)
         {
@@ -116,7 +116,7 @@ namespace Org.BouncyCastle.Crypto.Tls
             {
                 this.mRetransmit = retransmit;
                 this.mRetransmitEpoch = mCurrentEpoch;
-                this.mRetransmitExpiry = DateTimeUtilities.CurrentUnixMs() + RETRANSMIT_TIMEOUT;
+                this.mRetransmitTimeout = new Timeout(RETRANSMIT_TIMEOUT);
             }
 
             this.mInHandshake = false;
@@ -150,196 +150,43 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
-            // TODO Avoid returning -1 (timeout) until 'waitMillis' has definitely elapsed  
+            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+
+            Timeout timeout = null;
+            if (waitMillis > 0)
+            {
+                timeout = new Timeout(waitMillis, currentTimeMillis);
+            }
 
             byte[] record = null;
 
-            for (;;)
+            while (waitMillis >= 0)
             {
-                int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
-                if (record == null || record.Length < receiveLimit)
-                {
-                    record = new byte[receiveLimit];
-                }
-
-                if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
+                if (mRetransmitTimeout != null && mRetransmitTimeout.RemainingMillis(currentTimeMillis) < 1)
                 {
                     mRetransmit = null;
                     mRetransmitEpoch = null;
+                    mRetransmitTimeout = null;
                 }
 
-                int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
-                if (received < 0)
-                {
-                    return received;
-                }
-                if (received < RECORD_HEADER_LENGTH)
-                {
-                    continue;
-                }
-                int length = TlsUtilities.ReadUint16(record, 11);
-                if (received != (length + RECORD_HEADER_LENGTH))
-                {
-                    continue;
-                }
-
-                byte type = TlsUtilities.ReadUint8(record, 0);
-
-                // TODO Support user-specified custom protocols?
-                switch (type)
-                {
-                case ContentType.alert:
-                case ContentType.application_data:
-                case ContentType.change_cipher_spec:
-                case ContentType.handshake:
-                case ContentType.heartbeat:
-                    break;
-                default:
-                    // TODO Exception?
-                    continue;
-                }
-
-                int epoch = TlsUtilities.ReadUint16(record, 3);
-
-                DtlsEpoch recordEpoch = null;
-                if (epoch == mReadEpoch.Epoch)
-                {
-                    recordEpoch = mReadEpoch;
-                }
-                else if (type == ContentType.handshake && mRetransmitEpoch != null
-                    && epoch == mRetransmitEpoch.Epoch)
-                {
-                    recordEpoch = mRetransmitEpoch;
-                }
-
-                if (recordEpoch == null)
-                {
-                    continue;
-                }
-
-                long seq = TlsUtilities.ReadUint48(record, 5);
-                if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
-                {
-                    continue;
-                }
-
-                ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
-                if (!version.IsDtls)
-                {
-                    continue;
-                }
-
-                if (mReadVersion != null && !mReadVersion.Equals(version))
-                {
-                    continue;
-                }
-
-                byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
-                    GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
-                    received - RECORD_HEADER_LENGTH);
-
-                recordEpoch.ReplayWindow.ReportAuthenticated(seq);
-
-                if (plaintext.Length > this.mPlaintextLimit)
-                {
-                    continue;
-                }
-
-                if (mReadVersion == null)
-                {
-                    mReadVersion = version;
-                }
-
-                switch (type)
-                {
-                case ContentType.alert:
-                {
-                    if (plaintext.Length == 2)
-                    {
-                        byte alertLevel = plaintext[0];
-                        byte alertDescription = plaintext[1];
-
-                        mPeer.NotifyAlertReceived(alertLevel, alertDescription);
-
-                        if (alertLevel == AlertLevel.fatal)
-                        {
-                            Failed();
-                            throw new TlsFatalAlert(alertDescription);
-                        }
-
-                        // TODO Can close_notify be a fatal alert?
-                        if (alertDescription == AlertDescription.close_notify)
-                        {
-                            CloseTransport();
-                        }
-                    }
-
-                    continue;
-                }
-                case ContentType.application_data:
-                {
-                    if (mInHandshake)
-                    {
-                        // TODO Consider buffering application data for new epoch that arrives
-                        // out-of-order with the Finished message
-                        continue;
-                    }
-                    break;
-                }
-                case ContentType.change_cipher_spec:
-                {
-                    // Implicitly receive change_cipher_spec and change to pending cipher state
-
-                    for (int i = 0; i < plaintext.Length; ++i)
-                    {
-                        byte message = TlsUtilities.ReadUint8(plaintext, i);
-                        if (message != ChangeCipherSpec.change_cipher_spec)
-                        {
-                            continue;
-                        }
-
-                        if (mPendingEpoch != null)
-                        {
-                            mReadEpoch = mPendingEpoch;
-                        }
-                    }
-
-                    continue;
-                }
-                case ContentType.handshake:
-                {
-                    if (!mInHandshake)
-                    {
-                        if (mRetransmit != null)
-                        {
-                            mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
-                        }
-
-                        // TODO Consider support for HelloRequest
-                        continue;
-                    }
-                    break;
-                }
-                case ContentType.heartbeat:
+                int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
+                if (record == null || record.Length < receiveLimit)
                 {
-                    // TODO[RFC 6520]
-                    continue;
-                }
+                    record = new byte[receiveLimit];
                 }
 
-                /*
-                    * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
-                    * received our final flight.
-                    */
-                if (!mInHandshake && mRetransmit != null)
+                int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
+                int processed = ProcessRecord(received, record, buf, off);
+                if (processed >= 0)
                 {
-                    this.mRetransmit = null;
-                    this.mRetransmitEpoch = null;
+                    return processed;
                 }
 
-                Array.Copy(plaintext, 0, buf, off, plaintext.Length);
-                return plaintext.Length;
+                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+                waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis);
             }
+
+            return -1;
         }
 
         /// <exception cref="IOException"/>
@@ -497,6 +344,176 @@ namespace Org.BouncyCastle.Crypto.Tls
             }
         }
 
+        private int ProcessRecord(int received, byte[] record, byte[] buf, int off)
+        {
+            // NOTE: received < 0 (timeout) is covered by this first case
+            if (received < RECORD_HEADER_LENGTH)
+            {
+                return -1;
+            }
+            int length = TlsUtilities.ReadUint16(record, 11);
+            if (received != (length + RECORD_HEADER_LENGTH))
+            {
+                return -1;
+            }
+
+            byte type = TlsUtilities.ReadUint8(record, 0);
+
+            switch (type)
+            {
+            case ContentType.alert:
+            case ContentType.application_data:
+            case ContentType.change_cipher_spec:
+            case ContentType.handshake:
+            case ContentType.heartbeat:
+                break;
+            default:
+                return -1;
+            }
+
+            int epoch = TlsUtilities.ReadUint16(record, 3);
+
+            DtlsEpoch recordEpoch = null;
+            if (epoch == mReadEpoch.Epoch)
+            {
+                recordEpoch = mReadEpoch;
+            }
+            else if (type == ContentType.handshake && mRetransmitEpoch != null
+                && epoch == mRetransmitEpoch.Epoch)
+            {
+                recordEpoch = mRetransmitEpoch;
+            }
+
+            if (recordEpoch == null)
+            {
+                return -1;
+            }
+
+            long seq = TlsUtilities.ReadUint48(record, 5);
+            if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
+            {
+                return -1;
+            }
+
+            ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
+            if (!version.IsDtls)
+            {
+                return -1;
+            }
+
+            if (mReadVersion != null && !mReadVersion.Equals(version))
+            {
+                return -1;
+            }
+
+            byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
+                GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
+                received - RECORD_HEADER_LENGTH);
+
+            recordEpoch.ReplayWindow.ReportAuthenticated(seq);
+
+            if (plaintext.Length > this.mPlaintextLimit)
+            {
+                return -1;
+            }
+
+            if (mReadVersion == null)
+            {
+                mReadVersion = version;
+            }
+
+            switch (type)
+            {
+            case ContentType.alert:
+            {
+                if (plaintext.Length == 2)
+                {
+                    byte alertLevel = plaintext[0];
+                    byte alertDescription = plaintext[1];
+
+                    mPeer.NotifyAlertReceived(alertLevel, alertDescription);
+
+                    if (alertLevel == AlertLevel.fatal)
+                    {
+                        Failed();
+                        throw new TlsFatalAlert(alertDescription);
+                    }
+
+                    // TODO Can close_notify be a fatal alert?
+                    if (alertDescription == AlertDescription.close_notify)
+                    {
+                        CloseTransport();
+                    }
+                }
+
+                return -1;
+            }
+            case ContentType.application_data:
+            {
+                if (mInHandshake)
+                {
+                    // TODO Consider buffering application data for new epoch that arrives
+                    // out-of-order with the Finished message
+                    return -1;
+                }
+                break;
+            }
+            case ContentType.change_cipher_spec:
+            {
+                // Implicitly receive change_cipher_spec and change to pending cipher state
+
+                for (int i = 0; i < plaintext.Length; ++i)
+                {
+                    byte message = TlsUtilities.ReadUint8(plaintext, i);
+                    if (message != ChangeCipherSpec.change_cipher_spec)
+                    {
+                        continue;
+                    }
+
+                    if (mPendingEpoch != null)
+                    {
+                        mReadEpoch = mPendingEpoch;
+                    }
+                }
+
+                return -1;
+            }
+            case ContentType.handshake:
+            {
+                if (!mInHandshake)
+                {
+                    if (mRetransmit != null)
+                    {
+                        mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
+                    }
+
+                    // TODO Consider support for HelloRequest
+                    return -1;
+                }
+                break;
+            }
+            case ContentType.heartbeat:
+            {
+                // TODO[RFC 6520]
+                return -1;
+            }
+            }
+
+            /*
+             * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
+             * received our final flight.
+             */
+            if (!mInHandshake && mRetransmit != null)
+            {
+                this.mRetransmit = null;
+                this.mRetransmitEpoch = null;
+                this.mRetransmitTimeout = null;
+            }
+
+            Array.Copy(plaintext, 0, buf, off, plaintext.Length);
+            return plaintext.Length;
+        }
+
         private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
         {
             if (mRecordQueue.Available > 0)
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
index 92c222e70..3eeb8a61e 100644
--- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
@@ -3,6 +3,7 @@ using System.Collections;
 using System.IO;
 
 using Org.BouncyCastle.Utilities;
+using Org.BouncyCastle.Utilities.Date;
 
 namespace Org.BouncyCastle.Crypto.Tls
 {
@@ -11,6 +12,9 @@ namespace Org.BouncyCastle.Crypto.Tls
         private const int MaxReceiveAhead = 16;
         private const int MessageHeaderLength = 12;
 
+        private const int InitialResendMillis = 1000;
+        private const int MaxResendMillis = 60000;
+
         private readonly DtlsRecordLayer mRecordLayer;
 
         private TlsHandshakeHash mHandshakeHash;
@@ -18,7 +22,9 @@ namespace Org.BouncyCastle.Crypto.Tls
         private IDictionary mCurrentInboundFlight = Platform.CreateHashtable();
         private IDictionary mPreviousInboundFlight = null;
         private IList mOutboundFlight = Platform.CreateArrayList();
-        private bool mSending = true;
+
+        private int mResendMillis = -1;
+        private Timeout mResendTimeout = null;
 
         private int mMessageSeq = 0, mNextReceiveSeq = 0;
 
@@ -50,10 +56,13 @@ namespace Org.BouncyCastle.Crypto.Tls
         {
             TlsUtilities.CheckUint24(body.Length);
 
-            if (!mSending)
+            if (mResendTimeout != null)
             {
                 CheckInboundFlight();
-                mSending = true;
+
+                mResendMillis = -1;
+                mResendTimeout = null;
+
                 mOutboundFlight.Clear();
             }
 
@@ -77,18 +86,18 @@ namespace Org.BouncyCastle.Crypto.Tls
         internal Message ReceiveMessage()
         {
             // TODO Add support for "overall" handshake timeout
+            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
-            if (mSending)
+            if (mResendTimeout == null)
             {
-                mSending = false;
+                mResendMillis = InitialResendMillis;
+                mResendTimeout = new Timeout(mResendMillis, currentTimeMillis);
+
                 PrepareInboundFlight(Platform.CreateHashtable());
             }
 
             byte[] buf = null;
 
-            // TODO Check the conditions under which we should reset this
-            int readTimeoutMillis = 1000;
-
             for (;;)
             {
                 if (mRecordLayer.IsClosed)
@@ -98,37 +107,32 @@ namespace Org.BouncyCastle.Crypto.Tls
                 if (pending != null)
                     return pending;
 
+                int waitMillis = System.Math.Max(1, Timeout.GetWaitMillis(mResendTimeout, currentTimeMillis));
+
                 int receiveLimit = mRecordLayer.GetReceiveLimit();
                 if (buf == null || buf.Length < receiveLimit)
                 {
                     buf = new byte[receiveLimit];
                 }
 
-                int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
-
-                bool resentOutbound;
+                int received = mRecordLayer.Receive(buf, 0, receiveLimit, waitMillis);
                 if (received < 0)
                 {
                     ResendOutboundFlight();
-                    resentOutbound = true;
                 }
                 else
                 {
-                    resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
+                    ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
                 }
 
-                // TODO Review conditions for resend/backoff  
-                if (resentOutbound)
-                {
-                    readTimeoutMillis = BackOff(readTimeoutMillis);
-                }
+                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
             }
         }
 
         internal void Finish()
         {
             DtlsHandshakeRetransmit retransmit = null;
-            if (!mSending)
+            if (mResendTimeout != null)
             {
                 CheckInboundFlight();
             }
@@ -162,7 +166,7 @@ namespace Org.BouncyCastle.Crypto.Tls
              * TODO[DTLS] implementations SHOULD back off handshake packet size during the
              * retransmit backoff.
              */
-            return System.Math.Min(timeoutMillis * 2, 60000);
+            return System.Math.Min(timeoutMillis * 2, MaxResendMillis);
         }
 
         /**
@@ -201,7 +205,7 @@ namespace Org.BouncyCastle.Crypto.Tls
             mCurrentInboundFlight = nextFlight;
         }
 
-        private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
+        private void ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
         {
             bool checkPreviousFlight = false;
 
@@ -271,13 +275,11 @@ namespace Org.BouncyCastle.Crypto.Tls
                 len -= message_length;
             }
 
-            bool result = checkPreviousFlight && CheckAll(mPreviousInboundFlight);
-            if (result)
+            if (checkPreviousFlight && CheckAll(mPreviousInboundFlight))
             {
                 ResendOutboundFlight();
                 ResetAll(mPreviousInboundFlight);
             }
-            return result;
         }
 
         private void ResendOutboundFlight()
@@ -287,6 +289,9 @@ namespace Org.BouncyCastle.Crypto.Tls
             {
                 WriteMessage((Message)mOutboundFlight[i]);
             }
+
+            mResendMillis = BackOff(mResendMillis);
+            mResendTimeout = new Timeout(mResendMillis);
         }
 
         private Message UpdateHandshakeMessagesDigest(Message message)