summary refs log tree commit diff
path: root/crypto/src/tls/DeferredHash.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/DeferredHash.cs')
-rw-r--r--crypto/src/tls/DeferredHash.cs249
1 files changed, 249 insertions, 0 deletions
diff --git a/crypto/src/tls/DeferredHash.cs b/crypto/src/tls/DeferredHash.cs
new file mode 100644
index 000000000..43d60d07c
--- /dev/null
+++ b/crypto/src/tls/DeferredHash.cs
@@ -0,0 +1,249 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Tls.Crypto;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Tls
+{
+    /// <summary>Buffers input until the hash algorithm is determined.</summary>
+    internal sealed class DeferredHash
+        : TlsHandshakeHash
+    {
+        private const int BufferingHashLimit = 4;
+
+        private readonly TlsContext m_context;
+
+        private DigestInputBuffer m_buf;
+        private readonly IDictionary m_hashes;
+        private bool m_forceBuffering;
+        private bool m_sealed;
+
+        internal DeferredHash(TlsContext context)
+        {
+            this.m_context = context;
+            this.m_buf = new DigestInputBuffer();
+            this.m_hashes = Platform.CreateHashtable();
+            this.m_forceBuffering = false;
+            this.m_sealed = false;
+        }
+
+        private DeferredHash(TlsContext context, IDictionary hashes)
+        {
+            this.m_context = context;
+            this.m_buf = null;
+            this.m_hashes = hashes;
+            this.m_forceBuffering = false;
+            this.m_sealed = true;
+        }
+
+        /// <exception cref="IOException"/>
+        public void CopyBufferTo(Stream output)
+        {
+            if (m_buf == null)
+            {
+                // If you see this, you need to call forceBuffering() before SealHashAlgorithms()
+                throw new InvalidOperationException("Not buffering");
+            }
+
+            m_buf.CopyTo(output);
+        }
+
+        public void ForceBuffering()
+        {
+            if (m_sealed)
+                throw new InvalidOperationException("Too late to force buffering");
+
+            this.m_forceBuffering = true;
+        }
+
+        public void NotifyPrfDetermined()
+        {
+            SecurityParameters securityParameters = m_context.SecurityParameters;
+
+            switch (securityParameters.PrfAlgorithm)
+            {
+            case PrfAlgorithm.ssl_prf_legacy:
+            case PrfAlgorithm.tls_prf_legacy:
+            {
+                CheckTrackingHash(CryptoHashAlgorithm.md5);
+                CheckTrackingHash(CryptoHashAlgorithm.sha1);
+                break;
+            }
+            default:
+            {
+                CheckTrackingHash(securityParameters.PrfHashAlgorithm);
+                if (TlsUtilities.IsTlsV13(securityParameters.NegotiatedVersion))
+                {
+                    SealHashAlgorithms();
+                }
+                break;
+            }
+            }
+        }
+
+        public void TrackHashAlgorithm(int cryptoHashAlgorithm)
+        {
+            if (m_sealed)
+                throw new InvalidOperationException("Too late to track more hash algorithms");
+
+            CheckTrackingHash(cryptoHashAlgorithm);
+        }
+
+        public void SealHashAlgorithms()
+        {
+            if (m_sealed)
+                throw new InvalidOperationException("Already sealed");
+
+            this.m_sealed = true;
+            CheckStopBuffering();
+        }
+
+        public TlsHandshakeHash StopTracking()
+        {
+            SecurityParameters securityParameters = m_context.SecurityParameters;
+
+            IDictionary newHashes = Platform.CreateHashtable();
+            switch (securityParameters.PrfAlgorithm)
+            {
+            case PrfAlgorithm.ssl_prf_legacy:
+            case PrfAlgorithm.tls_prf_legacy:
+            {
+                CloneHash(newHashes, HashAlgorithm.md5);
+                CloneHash(newHashes, HashAlgorithm.sha1);
+                break;
+            }
+            default:
+            {
+                CloneHash(newHashes, securityParameters.PrfHashAlgorithm);
+                break;
+            }
+            }
+            return new DeferredHash(m_context, newHashes);
+        }
+
+        public TlsHash ForkPrfHash()
+        {
+            CheckStopBuffering();
+
+            SecurityParameters securityParameters = m_context.SecurityParameters;
+
+            TlsHash prfHash;
+            switch (securityParameters.PrfAlgorithm)
+            {
+            case PrfAlgorithm.ssl_prf_legacy:
+            case PrfAlgorithm.tls_prf_legacy:
+            {
+                prfHash = new CombinedHash(m_context, CloneHash(HashAlgorithm.md5), CloneHash(HashAlgorithm.sha1));
+                break;
+            }
+            default:
+            {
+                prfHash = CloneHash(securityParameters.PrfHashAlgorithm);
+                break;
+            }
+            }
+
+            if (m_buf != null)
+            {
+                m_buf.UpdateDigest(prfHash);
+            }
+
+            return prfHash;
+        }
+
+        public byte[] GetFinalHash(int cryptoHashAlgorithm)
+        {
+            TlsHash d = (TlsHash)m_hashes[cryptoHashAlgorithm];
+            if (d == null)
+                throw new InvalidOperationException("CryptoHashAlgorithm." + cryptoHashAlgorithm
+                    + " is not being tracked");
+
+            CheckStopBuffering();
+
+            d = d.CloneHash();
+            if (m_buf != null)
+            {
+                m_buf.UpdateDigest(d);
+            }
+
+            return d.CalculateHash();
+        }
+
+        public void Update(byte[] input, int inOff, int len)
+        {
+            if (m_buf != null)
+            {
+                m_buf.Write(input, inOff, len);
+                return;
+            }
+
+            foreach (TlsHash hash in m_hashes.Values)
+            {
+                hash.Update(input, inOff, len);
+            }
+        }
+
+        public byte[] CalculateHash()
+        {
+            throw new InvalidOperationException("Use 'ForkPrfHash' to get a definite hash");
+        }
+
+        public TlsHash CloneHash()
+        {
+            throw new InvalidOperationException("attempt to clone a DeferredHash");
+        }
+
+        public void Reset()
+        {
+            if (m_buf != null)
+            {
+                m_buf.SetLength(0);
+                return;
+            }
+
+            foreach (TlsHash hash in m_hashes.Values)
+            {
+                hash.Reset();
+            }
+        }
+
+        private void CheckStopBuffering()
+        {
+            if (!m_forceBuffering && m_sealed && m_buf != null && m_hashes.Count <= BufferingHashLimit)
+            {
+                foreach (TlsHash hash in m_hashes.Values)
+                {
+                    m_buf.UpdateDigest(hash);
+                }
+
+                this.m_buf = null;
+            }
+        }
+
+        private void CheckTrackingHash(int cryptoHashAlgorithm)
+        {
+            if (!m_hashes.Contains(cryptoHashAlgorithm))
+            {
+                TlsHash hash = m_context.Crypto.CreateHash(cryptoHashAlgorithm);
+                m_hashes[cryptoHashAlgorithm] = hash;
+            }
+        }
+
+        private TlsHash CloneHash(int cryptoHashAlgorithm)
+        {
+            return ((TlsHash)m_hashes[cryptoHashAlgorithm]).CloneHash();
+        }
+
+        private void CloneHash(IDictionary newHashes, int cryptoHashAlgorithm)
+        {
+            TlsHash hash = CloneHash(cryptoHashAlgorithm);
+            if (m_buf != null)
+            {
+                m_buf.UpdateDigest(hash);
+            }
+            newHashes[cryptoHashAlgorithm] = hash;
+        }
+    }
+}