summary refs log tree commit diff
path: root/crypto/src
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src')
-rw-r--r--crypto/src/tls/OfferedPsks.cs24
-rw-r--r--crypto/src/tls/TlsUtilities.cs113
2 files changed, 134 insertions, 3 deletions
diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs
index 5419a19d1..dfa2be034 100644
--- a/crypto/src/tls/OfferedPsks.cs
+++ b/crypto/src/tls/OfferedPsks.cs
@@ -9,6 +9,20 @@ namespace Org.BouncyCastle.Tls
 {
     public sealed class OfferedPsks
     {
+        internal class Config
+        {
+            internal readonly TlsPsk[] m_psks;
+            internal readonly TlsSecret[] m_earlySecrets;
+            internal int m_bindersSize;
+
+            internal Config(TlsPsk[] psks, TlsSecret[] earlySecrets, int bindersSize)
+            {
+                this.m_psks = psks;
+                this.m_earlySecrets = earlySecrets;
+                this.m_bindersSize = bindersSize;
+            }
+        }
+
         private readonly IList m_identities;
         private readonly IList m_binders;
 
@@ -79,8 +93,12 @@ namespace Org.BouncyCastle.Tls
 
         /// <exception cref="IOException"/>
         internal static void EncodeBinders(Stream output, TlsCrypto crypto, TlsHandshakeHash handshakeHash,
-            TlsPsk[] psks, TlsSecret[] earlySecrets, int expectedLengthOfBindersList)
+            Config config)
         {
+            TlsPsk[] psks = config.m_psks;
+            TlsSecret[] earlySecrets = config.m_earlySecrets;
+            int expectedLengthOfBindersList = config.m_bindersSize - 2;
+
             TlsUtilities.CheckUint16(expectedLengthOfBindersList);
             TlsUtilities.WriteUint16(expectedLengthOfBindersList, output);
 
@@ -111,7 +129,7 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        internal static int GetLengthOfBindersList(TlsPsk[] psks)
+        internal static int GetBindersSize(TlsPsk[] psks)
         {
             int lengthOfBindersList = 0;
             for (int i = 0; i < psks.Length; ++i)
@@ -124,7 +142,7 @@ namespace Org.BouncyCastle.Tls
                 lengthOfBindersList += 1 + TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm);
             }
             TlsUtilities.CheckUint16(lengthOfBindersList);
-            return lengthOfBindersList;
+            return 2 + lengthOfBindersList;
         }
 
         /// <exception cref="IOException"/>
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index 72c41ef05..e48a44452 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -1127,6 +1127,11 @@ namespace Org.BouncyCastle.Tls
             return null == s || s.Length < 1;
         }
 
+        public static bool IsNullOrEmpty(IList v)
+        {
+            return null == v || v.Count < 1;
+        }
+
         public static bool IsSignatureAlgorithmsExtensionAllowed(ProtocolVersion version)
         {
             return null != version
@@ -1992,6 +1997,46 @@ namespace Org.BouncyCastle.Tls
             }
         }
 
+        internal static int GetPrfAlgorithm13(int cipherSuite)
+        {
+            // NOTE: GetPrfAlgorithms13 relies on the number of distinct return values
+            switch (cipherSuite)
+            {
+            case CipherSuite.TLS_AES_128_CCM_SHA256:
+            case CipherSuite.TLS_AES_128_CCM_8_SHA256:
+            case CipherSuite.TLS_AES_128_GCM_SHA256:
+            case CipherSuite.TLS_CHACHA20_POLY1305_SHA256:
+                return PrfAlgorithm.tls13_hkdf_sha256;
+
+            case CipherSuite.TLS_AES_256_GCM_SHA384:
+                return PrfAlgorithm.tls13_hkdf_sha384;
+
+            case CipherSuite.TLS_SM4_CCM_SM3:
+            case CipherSuite.TLS_SM4_GCM_SM3:
+                return PrfAlgorithm.tls13_hkdf_sm3;
+
+            default:
+                return -1;
+            }
+        }
+
+        internal static int[] GetPrfAlgorithms13(int[] cipherSuites)
+        {
+            int[] result = new int[System.Math.Min(3, cipherSuites.Length)];
+
+            int count = 0;
+            for (int i = 0; i < cipherSuites.Length; ++i)
+            {
+                int prfAlgorithm = GetPrfAlgorithm13(cipherSuites[i]);
+                if (prfAlgorithm >= 0 && !Arrays.Contains(result, prfAlgorithm))
+                {
+                    result[count++] = prfAlgorithm;
+                }
+            }
+
+            return Truncate(result, count);
+        }
+
         internal static byte[] CalculateSignatureHash(TlsContext context, SignatureAndHashAlgorithm algorithm,
             byte[] extraSignatureInput, DigestInputBuffer buf)
         {
@@ -4641,6 +4686,16 @@ namespace Org.BouncyCastle.Tls
             return t;
         }
 
+        internal static int[] Truncate(int[] a, int n)
+        {
+            if (n >= a.Length)
+                return a;
+
+            int[] t = new int[n];
+            Array.Copy(a, 0, t, 0, n);
+            return t;
+        }
+
         /// <exception cref="IOException"/>
         internal static TlsCredentialedAgreement RequireAgreementCredentials(TlsCredentials credentials)
         {
@@ -5380,5 +5435,63 @@ namespace Org.BouncyCastle.Tls
 #endif
         }
 #endif
+
+        /// <exception cref="IOException"/>
+        internal static OfferedPsks.Config GetOfferedPsksConfig(TlsClientContext clientContext, TlsClient client)
+        {
+            TlsPskExternal[] pskExternals = GetPskExternalsClient(client);
+            if (null == pskExternals)
+                return null;
+
+            TlsSecret[] pskEarlySecrets = GetPskEarlySecrets(clientContext.Crypto, pskExternals);
+
+            int bindersSize = OfferedPsks.GetBindersSize(pskExternals);
+
+            return new OfferedPsks.Config(pskExternals, pskEarlySecrets, bindersSize);
+        }
+
+        internal static TlsSecret GetPskEarlySecret(TlsCrypto crypto, TlsPsk psk)
+        {
+            int cryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+            return crypto
+                .HkdfInit(cryptoHashAlgorithm)
+                .HkdfExtract(cryptoHashAlgorithm, psk.Key);
+        }
+
+        internal static TlsSecret[] GetPskEarlySecrets(TlsCrypto crypto, TlsPsk[] psks)
+        {
+            int count = psks.Length;
+            TlsSecret[] earlySecrets = new TlsSecret[count];
+            for (int i = 0; i < count; ++i)
+            {
+                earlySecrets[i] = GetPskEarlySecret(crypto, psks[i]);
+            }
+            return earlySecrets;
+        }
+
+        /// <exception cref="IOException"/>
+        internal static TlsPskExternal[] GetPskExternalsClient(TlsClient client)
+        {
+            // TODO[tl13-psk] Ensure PSK hash algorithms are supported by cipher suites
+
+            IList externalPsks = client.GetExternalPsks();
+            if (IsNullOrEmpty(externalPsks))
+                return null;
+
+            int count = externalPsks.Count;
+            TlsPskExternal[] result = new TlsPskExternal[count];
+
+            for (int i = 0; i < count; ++i)
+            {
+                TlsPskExternal element = externalPsks[i] as TlsPskExternal;
+                if (null == element)
+                    throw new TlsFatalAlert(AlertDescription.internal_error);
+
+                result[i] = element;
+            }
+
+            return result;
+        }
     }
 }