summary refs log tree commit diff
path: root/crypto/src/tls/TlsUtilities.cs
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2021-07-25 00:12:25 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2021-07-25 00:12:25 +0700
commit934e1a1047ef192cf2fb5b5b5822561118a10c20 (patch)
tree47dfadf8dd6e7a2675a16a347b1c6691bb2812e1 /crypto/src/tls/TlsUtilities.cs
parentCalculate HMAC without extracting TlsSecret (diff)
downloadBouncyCastle.NET-ed25519-934e1a1047ef192cf2fb5b5b5822561118a10c20.tar.xz
Refactoring around TLS HKDF
Diffstat (limited to 'crypto/src/tls/TlsUtilities.cs')
-rw-r--r--crypto/src/tls/TlsUtilities.cs93
1 files changed, 67 insertions, 26 deletions
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index 52b554801..adead624b 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -1459,6 +1459,23 @@ namespace Org.BouncyCastle.Tls
             return Arrays.ConcatenateAll(cr, sr, contextLength, context);
         }
 
+        private static byte[] CalculateFinishedHmac(SecurityParameters securityParameters, TlsSecret baseKey,
+            byte[] transcriptHash)
+        {
+            int cryptoHashAlgorithm = TlsCryptoUtilities.GetHash(securityParameters.PrfHashAlgorithm);
+            TlsSecret finishedKey = TlsCryptoUtilities.HkdfExpandLabel(baseKey, cryptoHashAlgorithm, "finished",
+                EmptyBytes, securityParameters.PrfHashLength);
+
+            try
+            {
+                return finishedKey.CalculateHmac(cryptoHashAlgorithm, transcriptHash, 0, transcriptHash.Length);
+            }
+            finally
+            {
+                finishedKey.Destroy();
+            }
+        }
+
         internal static TlsSecret CalculateMasterSecret(TlsContext context, TlsSecret preMasterSecret)
         {
             SecurityParameters sp = context.SecurityParameters;
@@ -1479,6 +1496,28 @@ namespace Org.BouncyCastle.Tls
             return Prf(sp, preMasterSecret, asciiLabel, seed, 48);
         }
 
+        internal static byte[] CalculatePskBinder(TlsContext context, bool isExternalPsk, TlsSecret earlySecret,
+            byte[] transcriptHash)
+        {
+            TlsCrypto crypto = context.Crypto;
+            SecurityParameters securityParameters = context.SecurityParameters;
+            int cryptoHashAlgorithm = TlsCryptoUtilities.GetHash(securityParameters.PrfHashAlgorithm);
+
+            string label = isExternalPsk ? "ext binder" : "res binder";
+            byte[] emptyTranscriptHash = crypto.CreateHash(cryptoHashAlgorithm).CalculateHash();
+
+            TlsSecret baseKey = DeriveSecret(securityParameters, earlySecret, label, emptyTranscriptHash);
+
+            try
+            {
+                return CalculateFinishedHmac(securityParameters, baseKey, transcriptHash);
+            }
+            finally
+            {
+                baseKey.Destroy();
+            }
+        }
+
         internal static byte[] CalculateVerifyData(TlsContext context, TlsHandshakeHash handshakeHash, bool isServer)
         {
             SecurityParameters securityParameters = context.SecurityParameters;
@@ -1489,12 +1528,9 @@ namespace Org.BouncyCastle.Tls
                 TlsSecret baseKey = isServer
                     ?   securityParameters.BaseKeyServer
                     :   securityParameters.BaseKeyClient;
-
-                TlsSecret finishedKey = DeriveSecret(securityParameters, baseKey, "finished", EmptyBytes);
-                int cryptoHashAlgorithm = TlsCryptoUtilities.GetHash(securityParameters.PrfHashAlgorithm);
                 byte[] transcriptHash = GetCurrentPrfHash(handshakeHash);
 
-                return finishedKey.CalculateHmac(cryptoHashAlgorithm, transcriptHash, 0, transcriptHash.Length);
+                return CalculateFinishedHmac(securityParameters, baseKey, transcriptHash);
             }
 
             if (negotiatedVersion.IsSsl)
@@ -1513,43 +1549,43 @@ namespace Org.BouncyCastle.Tls
 
         internal static void Establish13PhaseSecrets(TlsContext context)
         {
+            TlsCrypto crypto = context.Crypto;
             SecurityParameters securityParameters = context.SecurityParameters;
             int cryptoHashAlgorithm = TlsCryptoUtilities.GetHash(securityParameters.PrfHashAlgorithm);
-            int hashLen = securityParameters.PrfHashLength;
-            byte[] zeroes = new byte[hashLen];
+            TlsSecret zeros = crypto.HkdfInit(cryptoHashAlgorithm);
+            byte[] emptyTranscriptHash = crypto.CreateHash(cryptoHashAlgorithm).CalculateHash();
 
-            byte[] psk = securityParameters.Psk;
-            if (null == psk)
-            {
-                psk = zeroes;
-            }
-            else
+            TlsSecret preSharedKey = securityParameters.PreSharedKey;
+            if (null == preSharedKey)
             {
-                securityParameters.m_psk = null;
+                preSharedKey = zeros;
             }
 
-            byte[] ecdhe = zeroes;
+            TlsSecret earlySecret = crypto.HkdfInit(cryptoHashAlgorithm)
+                .HkdfExtract(cryptoHashAlgorithm, preSharedKey);
+
             TlsSecret sharedSecret = securityParameters.SharedSecret;
-            if (null != sharedSecret)
+            if (null == sharedSecret)
             {
-                securityParameters.m_sharedSecret = null;
-                ecdhe = sharedSecret.Extract();
+                sharedSecret = zeros;
             }
 
-            TlsCrypto crypto = context.Crypto;
+            TlsSecret handshakeSecret = DeriveSecret(securityParameters, earlySecret, "derived", emptyTranscriptHash)
+                .HkdfExtract(cryptoHashAlgorithm, sharedSecret);
 
-            byte[] emptyTranscriptHash = crypto.CreateHash(cryptoHashAlgorithm).CalculateHash();
+            if (sharedSecret != zeros)
+            {
+                sharedSecret.Destroy();
+            }
 
-            TlsSecret earlySecret = crypto.HkdfInit(cryptoHashAlgorithm)
-                .HkdfExtract(cryptoHashAlgorithm, psk);
-            TlsSecret handshakeSecret = DeriveSecret(securityParameters, earlySecret, "derived", emptyTranscriptHash)
-                .HkdfExtract(cryptoHashAlgorithm, ecdhe);
             TlsSecret masterSecret = DeriveSecret(securityParameters, handshakeSecret, "derived", emptyTranscriptHash)
-                .HkdfExtract(cryptoHashAlgorithm, zeroes);
+                .HkdfExtract(cryptoHashAlgorithm, zeros);
 
             securityParameters.m_earlySecret = earlySecret;
             securityParameters.m_handshakeSecret = handshakeSecret;
             securityParameters.m_masterSecret = masterSecret;
+            securityParameters.m_preSharedKey = null;
+            securityParameters.m_sharedSecret = null;
         }
 
         private static void Establish13TrafficSecrets(TlsContext context, byte[] transcriptHash, TlsSecret phaseSecret,
@@ -5170,8 +5206,13 @@ namespace Org.BouncyCastle.Tls
         internal static TlsSecret DeriveSecret(SecurityParameters securityParameters, TlsSecret secret, string label,
             byte[] transcriptHash)
         {
-            return TlsCryptoUtilities.HkdfExpandLabel(secret, securityParameters.PrfHashAlgorithm, label,
-                transcriptHash, securityParameters.PrfHashLength);
+            short prfHashAlgorithm = securityParameters.PrfHashAlgorithm;
+            int prfHashLength = securityParameters.PrfHashLength;
+
+            if (transcriptHash.Length != prfHashLength)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+
+            return TlsCryptoUtilities.HkdfExpandLabel(secret, prfHashAlgorithm, label, transcriptHash, prfHashLength);
         }
 
         internal static TlsSecret GetSessionMasterSecret(TlsCrypto crypto, TlsSecret masterSecret)