diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs
index 5419a19d1..dfa2be034 100644
--- a/crypto/src/tls/OfferedPsks.cs
+++ b/crypto/src/tls/OfferedPsks.cs
@@ -9,6 +9,20 @@ namespace Org.BouncyCastle.Tls
{
public sealed class OfferedPsks
{
+ internal class Config
+ {
+ internal readonly TlsPsk[] m_psks;
+ internal readonly TlsSecret[] m_earlySecrets;
+ internal int m_bindersSize;
+
+ internal Config(TlsPsk[] psks, TlsSecret[] earlySecrets, int bindersSize)
+ {
+ this.m_psks = psks;
+ this.m_earlySecrets = earlySecrets;
+ this.m_bindersSize = bindersSize;
+ }
+ }
+
private readonly IList m_identities;
private readonly IList m_binders;
@@ -79,8 +93,12 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
internal static void EncodeBinders(Stream output, TlsCrypto crypto, TlsHandshakeHash handshakeHash,
- TlsPsk[] psks, TlsSecret[] earlySecrets, int expectedLengthOfBindersList)
+ Config config)
{
+ TlsPsk[] psks = config.m_psks;
+ TlsSecret[] earlySecrets = config.m_earlySecrets;
+ int expectedLengthOfBindersList = config.m_bindersSize - 2;
+
TlsUtilities.CheckUint16(expectedLengthOfBindersList);
TlsUtilities.WriteUint16(expectedLengthOfBindersList, output);
@@ -111,7 +129,7 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- internal static int GetLengthOfBindersList(TlsPsk[] psks)
+ internal static int GetBindersSize(TlsPsk[] psks)
{
int lengthOfBindersList = 0;
for (int i = 0; i < psks.Length; ++i)
@@ -124,7 +142,7 @@ namespace Org.BouncyCastle.Tls
lengthOfBindersList += 1 + TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm);
}
TlsUtilities.CheckUint16(lengthOfBindersList);
- return lengthOfBindersList;
+ return 2 + lengthOfBindersList;
}
/// <exception cref="IOException"/>
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index 72c41ef05..e48a44452 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -1127,6 +1127,11 @@ namespace Org.BouncyCastle.Tls
return null == s || s.Length < 1;
}
+ public static bool IsNullOrEmpty(IList v)
+ {
+ return null == v || v.Count < 1;
+ }
+
public static bool IsSignatureAlgorithmsExtensionAllowed(ProtocolVersion version)
{
return null != version
@@ -1992,6 +1997,46 @@ namespace Org.BouncyCastle.Tls
}
}
+ internal static int GetPrfAlgorithm13(int cipherSuite)
+ {
+ // NOTE: GetPrfAlgorithms13 relies on the number of distinct return values
+ switch (cipherSuite)
+ {
+ case CipherSuite.TLS_AES_128_CCM_SHA256:
+ case CipherSuite.TLS_AES_128_CCM_8_SHA256:
+ case CipherSuite.TLS_AES_128_GCM_SHA256:
+ case CipherSuite.TLS_CHACHA20_POLY1305_SHA256:
+ return PrfAlgorithm.tls13_hkdf_sha256;
+
+ case CipherSuite.TLS_AES_256_GCM_SHA384:
+ return PrfAlgorithm.tls13_hkdf_sha384;
+
+ case CipherSuite.TLS_SM4_CCM_SM3:
+ case CipherSuite.TLS_SM4_GCM_SM3:
+ return PrfAlgorithm.tls13_hkdf_sm3;
+
+ default:
+ return -1;
+ }
+ }
+
+ internal static int[] GetPrfAlgorithms13(int[] cipherSuites)
+ {
+ int[] result = new int[System.Math.Min(3, cipherSuites.Length)];
+
+ int count = 0;
+ for (int i = 0; i < cipherSuites.Length; ++i)
+ {
+ int prfAlgorithm = GetPrfAlgorithm13(cipherSuites[i]);
+ if (prfAlgorithm >= 0 && !Arrays.Contains(result, prfAlgorithm))
+ {
+ result[count++] = prfAlgorithm;
+ }
+ }
+
+ return Truncate(result, count);
+ }
+
internal static byte[] CalculateSignatureHash(TlsContext context, SignatureAndHashAlgorithm algorithm,
byte[] extraSignatureInput, DigestInputBuffer buf)
{
@@ -4641,6 +4686,16 @@ namespace Org.BouncyCastle.Tls
return t;
}
+ internal static int[] Truncate(int[] a, int n)
+ {
+ if (n >= a.Length)
+ return a;
+
+ int[] t = new int[n];
+ Array.Copy(a, 0, t, 0, n);
+ return t;
+ }
+
/// <exception cref="IOException"/>
internal static TlsCredentialedAgreement RequireAgreementCredentials(TlsCredentials credentials)
{
@@ -5380,5 +5435,63 @@ namespace Org.BouncyCastle.Tls
#endif
}
#endif
+
+ /// <exception cref="IOException"/>
+ internal static OfferedPsks.Config GetOfferedPsksConfig(TlsClientContext clientContext, TlsClient client)
+ {
+ TlsPskExternal[] pskExternals = GetPskExternalsClient(client);
+ if (null == pskExternals)
+ return null;
+
+ TlsSecret[] pskEarlySecrets = GetPskEarlySecrets(clientContext.Crypto, pskExternals);
+
+ int bindersSize = OfferedPsks.GetBindersSize(pskExternals);
+
+ return new OfferedPsks.Config(pskExternals, pskEarlySecrets, bindersSize);
+ }
+
+ internal static TlsSecret GetPskEarlySecret(TlsCrypto crypto, TlsPsk psk)
+ {
+ int cryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+ return crypto
+ .HkdfInit(cryptoHashAlgorithm)
+ .HkdfExtract(cryptoHashAlgorithm, psk.Key);
+ }
+
+ internal static TlsSecret[] GetPskEarlySecrets(TlsCrypto crypto, TlsPsk[] psks)
+ {
+ int count = psks.Length;
+ TlsSecret[] earlySecrets = new TlsSecret[count];
+ for (int i = 0; i < count; ++i)
+ {
+ earlySecrets[i] = GetPskEarlySecret(crypto, psks[i]);
+ }
+ return earlySecrets;
+ }
+
+ /// <exception cref="IOException"/>
+ internal static TlsPskExternal[] GetPskExternalsClient(TlsClient client)
+ {
+ // TODO[tl13-psk] Ensure PSK hash algorithms are supported by cipher suites
+
+ IList externalPsks = client.GetExternalPsks();
+ if (IsNullOrEmpty(externalPsks))
+ return null;
+
+ int count = externalPsks.Count;
+ TlsPskExternal[] result = new TlsPskExternal[count];
+
+ for (int i = 0; i < count; ++i)
+ {
+ TlsPskExternal element = externalPsks[i] as TlsPskExternal;
+ if (null == element)
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+
+ result[i] = element;
+ }
+
+ return result;
+ }
}
}
|