summary refs log tree commit diff
path: root/crypto/src/tls/DtlsReliableHandshake.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/DtlsReliableHandshake.cs')
-rw-r--r--crypto/src/tls/DtlsReliableHandshake.cs185
1 files changed, 101 insertions, 84 deletions
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs
index e27d72762..58b9301fd 100644
--- a/crypto/src/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/tls/DtlsReliableHandshake.cs
@@ -142,11 +142,9 @@ namespace Org.BouncyCastle.Tls
             get { return m_handshakeHash; }
         }
 
-        internal TlsHandshakeHash PrepareToFinish()
+        internal void PrepareToFinish()
         {
-            TlsHandshakeHash result = m_handshakeHash;
-            this.m_handshakeHash = m_handshakeHash.StopTracking();
-            return result;
+            m_handshakeHash.StopTracking();
         }
 
         /// <exception cref="IOException"/>
@@ -173,69 +171,63 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
+        internal Message ReceiveMessage()
+        {
+            Message message = ImplReceiveMessage();
+            UpdateHandshakeMessagesDigest(message);
+            return message;
+        }
+
+        /// <exception cref="IOException"/>
         internal byte[] ReceiveMessageBody(short msg_type)
         {
-            Message message = ReceiveMessage();
+            Message message = ImplReceiveMessage();
             if (message.Type != msg_type)
                 throw new TlsFatalAlert(AlertDescription.unexpected_message);
 
+            UpdateHandshakeMessagesDigest(message);
             return message.Body;
         }
 
         /// <exception cref="IOException"/>
-        internal Message ReceiveMessage()
+        internal Message ReceiveMessageDelayedDigest(short msg_type)
         {
-            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
-
-            if (null == m_resendTimeout)
-            {
-                m_resendMillis = INITIAL_RESEND_MILLIS;
-                m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
-
-                PrepareInboundFlight(Platform.CreateHashtable());
-            }
+            Message message = ImplReceiveMessage();
+            if (message.Type != msg_type)
+                throw new TlsFatalAlert(AlertDescription.unexpected_message);
 
-            byte[] buf = null;
+            return message;
+        }
 
-            for (;;)
+        /// <exception cref="IOException"/>
+        internal Message UpdateHandshakeMessagesDigest(Message message)
+        {
+            short msg_type = message.Type;
+            switch (msg_type)
             {
-                if (m_recordLayer.IsClosed)
-                    throw new TlsFatalAlert(AlertDescription.user_canceled);
-
-                Message pending = GetPendingMessage();
-                if (pending != null)
-                    return pending;
-
-                if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
-                    throw new TlsTimeoutException("Handshake timed out");
-
-                int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
-                waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
-
-                // NOTE: Ensure a finite wait, of at least 1ms
-                if (waitMillis < 1)
-                {
-                    waitMillis = 1;
-                }
-
-                int receiveLimit = m_recordLayer.GetReceiveLimit();
-                if (buf == null || buf.Length < receiveLimit)
-                {
-                    buf = new byte[receiveLimit];
-                }
-
-                int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
-                if (received < 0)
-                {
-                    ResendOutboundFlight();
-                }
-                else
-                {
-                    ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
-                }
+            case HandshakeType.hello_request:
+            case HandshakeType.hello_verify_request:
+            case HandshakeType.key_update:
+                break;
 
-                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+            // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
+            case HandshakeType.new_session_ticket:
+            default:
+            {
+                byte[] body = message.Body;
+                byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
+                TlsUtilities.WriteUint8(msg_type, buf, 0);
+                TlsUtilities.WriteUint24(body.Length, buf, 1);
+                TlsUtilities.WriteUint16(message.Seq, buf, 4);
+                TlsUtilities.WriteUint24(0, buf, 6);
+                TlsUtilities.WriteUint24(body.Length, buf, 9);
+                m_handshakeHash.Update(buf, 0, buf.Length);
+                m_handshakeHash.Update(body, 0, body.Length);
+                break;
+            }
             }
+
+            return message;
         }
 
         internal void Finish()
@@ -297,12 +289,68 @@ namespace Org.BouncyCastle.Tls
                 if (body != null)
                 {
                     m_previousInboundFlight = null;
-                    return UpdateHandshakeMessagesDigest(new Message(m_next_receive_seq++, next.MsgType, body));
+                    return new Message(m_next_receive_seq++, next.MsgType, body);
                 }
             }
             return null;
         }
 
+        /// <exception cref="IOException"/>
+        private Message ImplReceiveMessage()
+        {
+            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+
+            if (null == m_resendTimeout)
+            {
+                m_resendMillis = INITIAL_RESEND_MILLIS;
+                m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
+
+                PrepareInboundFlight(Platform.CreateHashtable());
+            }
+
+            byte[] buf = null;
+
+            for (; ; )
+            {
+                if (m_recordLayer.IsClosed)
+                    throw new TlsFatalAlert(AlertDescription.user_canceled);
+
+                Message pending = GetPendingMessage();
+                if (pending != null)
+                    return pending;
+
+                if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
+                    throw new TlsTimeoutException("Handshake timed out");
+
+                int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
+                waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
+
+                // NOTE: Ensure a finite wait, of at least 1ms
+                if (waitMillis < 1)
+                {
+                    waitMillis = 1;
+                }
+
+                int receiveLimit = m_recordLayer.GetReceiveLimit();
+                if (buf == null || buf.Length < receiveLimit)
+                {
+                    buf = new byte[receiveLimit];
+                }
+
+                int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
+                if (received < 0)
+                {
+                    ResendOutboundFlight();
+                }
+                else
+                {
+                    ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
+                }
+
+                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+            }
+        }
+
         private void PrepareInboundFlight(IDictionary nextFlight)
         {
             ResetAll(m_currentInboundFlight);
@@ -400,37 +448,6 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        private Message UpdateHandshakeMessagesDigest(Message message)
-        {
-            short msg_type = message.Type;
-            switch (msg_type)
-            {
-            case HandshakeType.hello_request:
-            case HandshakeType.hello_verify_request:
-            case HandshakeType.key_update:
-                break;
-
-            // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
-            case HandshakeType.new_session_ticket:
-            default:
-            {
-                byte[] body = message.Body;
-                byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
-                TlsUtilities.WriteUint8(msg_type, buf, 0);
-                TlsUtilities.WriteUint24(body.Length, buf, 1);
-                TlsUtilities.WriteUint16(message.Seq, buf, 4);
-                TlsUtilities.WriteUint24(0, buf, 6);
-                TlsUtilities.WriteUint24(body.Length, buf, 9);
-                m_handshakeHash.Update(buf, 0, buf.Length);
-                m_handshakeHash.Update(body, 0, body.Length);
-                break;
-            }
-            }
-
-            return message;
-        }
-
-        /// <exception cref="IOException"/>
         private void WriteMessage(Message message)
         {
             int sendLimit = m_recordLayer.GetSendLimit();