summary refs log tree commit diff
path: root/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'crypto')
-rw-r--r--crypto/src/tls/DeferredHash.cs21
-rw-r--r--crypto/src/tls/DtlsReliableHandshake.cs185
-rw-r--r--crypto/src/tls/DtlsServerProtocol.cs13
-rw-r--r--crypto/src/tls/TlsClientProtocol.cs2
-rw-r--r--crypto/src/tls/TlsHandshakeHash.cs2
-rw-r--r--crypto/src/tls/TlsServerProtocol.cs4
6 files changed, 122 insertions, 105 deletions
diff --git a/crypto/src/tls/DeferredHash.cs b/crypto/src/tls/DeferredHash.cs
index bba3019a1..f97e9c088 100644
--- a/crypto/src/tls/DeferredHash.cs
+++ b/crypto/src/tls/DeferredHash.cs
@@ -16,7 +16,7 @@ namespace Org.BouncyCastle.Tls
         private readonly TlsContext m_context;
 
         private DigestInputBuffer m_buf;
-        private readonly IDictionary m_hashes;
+        private IDictionary m_hashes;
         private bool m_forceBuffering;
         private bool m_sealed;
 
@@ -29,21 +29,12 @@ namespace Org.BouncyCastle.Tls
             this.m_sealed = false;
         }
 
-        private DeferredHash(TlsContext context, IDictionary hashes)
-        {
-            this.m_context = context;
-            this.m_buf = null;
-            this.m_hashes = hashes;
-            this.m_forceBuffering = false;
-            this.m_sealed = true;
-        }
-
         /// <exception cref="IOException"/>
         public void CopyBufferTo(Stream output)
         {
             if (m_buf == null)
             {
-                // If you see this, you need to call forceBuffering() before SealHashAlgorithms()
+                // If you see this, you need to call ForceBuffering() before SealHashAlgorithms()
                 throw new InvalidOperationException("Not buffering");
             }
 
@@ -96,7 +87,7 @@ namespace Org.BouncyCastle.Tls
             CheckStopBuffering();
         }
 
-        public TlsHandshakeHash StopTracking()
+        public void StopTracking()
         {
             SecurityParameters securityParameters = m_context.SecurityParameters;
 
@@ -116,7 +107,11 @@ namespace Org.BouncyCastle.Tls
                 break;
             }
             }
-            return new DeferredHash(m_context, newHashes);
+
+            this.m_buf = null;
+            this.m_hashes = newHashes;
+            this.m_forceBuffering = false;
+            this.m_sealed = true;
         }
 
         public TlsHash ForkPrfHash()
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs
index e27d72762..58b9301fd 100644
--- a/crypto/src/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/tls/DtlsReliableHandshake.cs
@@ -142,11 +142,9 @@ namespace Org.BouncyCastle.Tls
             get { return m_handshakeHash; }
         }
 
-        internal TlsHandshakeHash PrepareToFinish()
+        internal void PrepareToFinish()
         {
-            TlsHandshakeHash result = m_handshakeHash;
-            this.m_handshakeHash = m_handshakeHash.StopTracking();
-            return result;
+            m_handshakeHash.StopTracking();
         }
 
         /// <exception cref="IOException"/>
@@ -173,69 +171,63 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
+        internal Message ReceiveMessage()
+        {
+            Message message = ImplReceiveMessage();
+            UpdateHandshakeMessagesDigest(message);
+            return message;
+        }
+
+        /// <exception cref="IOException"/>
         internal byte[] ReceiveMessageBody(short msg_type)
         {
-            Message message = ReceiveMessage();
+            Message message = ImplReceiveMessage();
             if (message.Type != msg_type)
                 throw new TlsFatalAlert(AlertDescription.unexpected_message);
 
+            UpdateHandshakeMessagesDigest(message);
             return message.Body;
         }
 
         /// <exception cref="IOException"/>
-        internal Message ReceiveMessage()
+        internal Message ReceiveMessageDelayedDigest(short msg_type)
         {
-            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
-
-            if (null == m_resendTimeout)
-            {
-                m_resendMillis = INITIAL_RESEND_MILLIS;
-                m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
-
-                PrepareInboundFlight(Platform.CreateHashtable());
-            }
+            Message message = ImplReceiveMessage();
+            if (message.Type != msg_type)
+                throw new TlsFatalAlert(AlertDescription.unexpected_message);
 
-            byte[] buf = null;
+            return message;
+        }
 
-            for (;;)
+        /// <exception cref="IOException"/>
+        internal Message UpdateHandshakeMessagesDigest(Message message)
+        {
+            short msg_type = message.Type;
+            switch (msg_type)
             {
-                if (m_recordLayer.IsClosed)
-                    throw new TlsFatalAlert(AlertDescription.user_canceled);
-
-                Message pending = GetPendingMessage();
-                if (pending != null)
-                    return pending;
-
-                if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
-                    throw new TlsTimeoutException("Handshake timed out");
-
-                int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
-                waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
-
-                // NOTE: Ensure a finite wait, of at least 1ms
-                if (waitMillis < 1)
-                {
-                    waitMillis = 1;
-                }
-
-                int receiveLimit = m_recordLayer.GetReceiveLimit();
-                if (buf == null || buf.Length < receiveLimit)
-                {
-                    buf = new byte[receiveLimit];
-                }
-
-                int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
-                if (received < 0)
-                {
-                    ResendOutboundFlight();
-                }
-                else
-                {
-                    ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
-                }
+            case HandshakeType.hello_request:
+            case HandshakeType.hello_verify_request:
+            case HandshakeType.key_update:
+                break;
 
-                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+            // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
+            case HandshakeType.new_session_ticket:
+            default:
+            {
+                byte[] body = message.Body;
+                byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
+                TlsUtilities.WriteUint8(msg_type, buf, 0);
+                TlsUtilities.WriteUint24(body.Length, buf, 1);
+                TlsUtilities.WriteUint16(message.Seq, buf, 4);
+                TlsUtilities.WriteUint24(0, buf, 6);
+                TlsUtilities.WriteUint24(body.Length, buf, 9);
+                m_handshakeHash.Update(buf, 0, buf.Length);
+                m_handshakeHash.Update(body, 0, body.Length);
+                break;
+            }
             }
+
+            return message;
         }
 
         internal void Finish()
@@ -297,12 +289,68 @@ namespace Org.BouncyCastle.Tls
                 if (body != null)
                 {
                     m_previousInboundFlight = null;
-                    return UpdateHandshakeMessagesDigest(new Message(m_next_receive_seq++, next.MsgType, body));
+                    return new Message(m_next_receive_seq++, next.MsgType, body);
                 }
             }
             return null;
         }
 
+        /// <exception cref="IOException"/>
+        private Message ImplReceiveMessage()
+        {
+            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+
+            if (null == m_resendTimeout)
+            {
+                m_resendMillis = INITIAL_RESEND_MILLIS;
+                m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
+
+                PrepareInboundFlight(Platform.CreateHashtable());
+            }
+
+            byte[] buf = null;
+
+            for (; ; )
+            {
+                if (m_recordLayer.IsClosed)
+                    throw new TlsFatalAlert(AlertDescription.user_canceled);
+
+                Message pending = GetPendingMessage();
+                if (pending != null)
+                    return pending;
+
+                if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
+                    throw new TlsTimeoutException("Handshake timed out");
+
+                int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
+                waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
+
+                // NOTE: Ensure a finite wait, of at least 1ms
+                if (waitMillis < 1)
+                {
+                    waitMillis = 1;
+                }
+
+                int receiveLimit = m_recordLayer.GetReceiveLimit();
+                if (buf == null || buf.Length < receiveLimit)
+                {
+                    buf = new byte[receiveLimit];
+                }
+
+                int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
+                if (received < 0)
+                {
+                    ResendOutboundFlight();
+                }
+                else
+                {
+                    ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
+                }
+
+                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+            }
+        }
+
         private void PrepareInboundFlight(IDictionary nextFlight)
         {
             ResetAll(m_currentInboundFlight);
@@ -400,37 +448,6 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        private Message UpdateHandshakeMessagesDigest(Message message)
-        {
-            short msg_type = message.Type;
-            switch (msg_type)
-            {
-            case HandshakeType.hello_request:
-            case HandshakeType.hello_verify_request:
-            case HandshakeType.key_update:
-                break;
-
-            // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
-            case HandshakeType.new_session_ticket:
-            default:
-            {
-                byte[] body = message.Body;
-                byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
-                TlsUtilities.WriteUint8(msg_type, buf, 0);
-                TlsUtilities.WriteUint24(body.Length, buf, 1);
-                TlsUtilities.WriteUint16(message.Seq, buf, 4);
-                TlsUtilities.WriteUint24(0, buf, 6);
-                TlsUtilities.WriteUint24(body.Length, buf, 9);
-                m_handshakeHash.Update(buf, 0, buf.Length);
-                m_handshakeHash.Update(body, 0, body.Length);
-                break;
-            }
-            }
-
-            return message;
-        }
-
-        /// <exception cref="IOException"/>
         private void WriteMessage(Message message)
         {
             int sendLimit = m_recordLayer.GetSendLimit();
diff --git a/crypto/src/tls/DtlsServerProtocol.cs b/crypto/src/tls/DtlsServerProtocol.cs
index 99c47ba1b..b49122423 100644
--- a/crypto/src/tls/DtlsServerProtocol.cs
+++ b/crypto/src/tls/DtlsServerProtocol.cs
@@ -297,12 +297,17 @@ namespace Org.BouncyCastle.Tls
              * parameters).
              */
             {
-                TlsHandshakeHash certificateVerifyHash = handshake.PrepareToFinish();
-
                 if (ExpectCertificateVerifyMessage(state))
                 {
-                    byte[] certificateVerifyBody = handshake.ReceiveMessageBody(HandshakeType.certificate_verify);
-                    ProcessCertificateVerify(state, certificateVerifyBody, certificateVerifyHash);
+                    clientMessage = handshake.ReceiveMessageDelayedDigest(HandshakeType.certificate_verify);
+                    byte[] certificateVerifyBody = clientMessage.Body;
+                    ProcessCertificateVerify(state, certificateVerifyBody, handshake.HandshakeHash);
+                    handshake.PrepareToFinish();
+                    handshake.UpdateHandshakeMessagesDigest(clientMessage);
+                }
+                else
+                {
+                    handshake.PrepareToFinish();
                 }
             }
 
diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs
index 19e2eda3d..cb59289ae 100644
--- a/crypto/src/tls/TlsClientProtocol.cs
+++ b/crypto/src/tls/TlsClientProtocol.cs
@@ -609,7 +609,7 @@ namespace Org.BouncyCastle.Tls
                         this.m_connectionState = CS_CLIENT_CERTIFICATE_VERIFY;
                     }
 
-                    this.m_handshakeHash = m_handshakeHash.StopTracking();
+                    m_handshakeHash.StopTracking();
 
                     SendChangeCipherSpec();
                     SendFinishedMessage();
diff --git a/crypto/src/tls/TlsHandshakeHash.cs b/crypto/src/tls/TlsHandshakeHash.cs
index aa33c680d..88aeaaa32 100644
--- a/crypto/src/tls/TlsHandshakeHash.cs
+++ b/crypto/src/tls/TlsHandshakeHash.cs
@@ -20,7 +20,7 @@ namespace Org.BouncyCastle.Tls
 
         void SealHashAlgorithms();
 
-        TlsHandshakeHash StopTracking();
+        void StopTracking();
 
         TlsHash ForkPrfHash();
 
diff --git a/crypto/src/tls/TlsServerProtocol.cs b/crypto/src/tls/TlsServerProtocol.cs
index 22700a277..0ab8a7a98 100644
--- a/crypto/src/tls/TlsServerProtocol.cs
+++ b/crypto/src/tls/TlsServerProtocol.cs
@@ -1322,7 +1322,7 @@ namespace Org.BouncyCastle.Tls
             TlsUtilities.VerifyCertificateVerifyClient(m_tlsServerContext, m_certificateRequest, certificateVerify,
                 m_handshakeHash);
 
-            this.m_handshakeHash = m_handshakeHash.StopTracking();
+            m_handshakeHash.StopTracking();
         }
 
         /// <exception cref="IOException"/>
@@ -1357,7 +1357,7 @@ namespace Org.BouncyCastle.Tls
 
             if (!ExpectCertificateVerifyMessage())
             {
-                this.m_handshakeHash = m_handshakeHash.StopTracking();
+                m_handshakeHash.StopTracking();
             }
         }