diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs
index 9eddd2e23..1cc8a2a68 100644
--- a/crypto/src/tls/OfferedPsks.cs
+++ b/crypto/src/tls/OfferedPsks.cs
@@ -26,6 +26,22 @@ namespace Org.BouncyCastle.Tls
}
}
+ internal class SelectedConfig
+ {
+ internal readonly int m_index;
+ internal readonly TlsPsk m_psk;
+ internal readonly short[] m_pskKeyExchangeModes;
+ internal readonly TlsSecret m_earlySecret;
+
+ internal SelectedConfig(int index, TlsPsk psk, short[] pskKeyExchangeModes, TlsSecret earlySecret)
+ {
+ this.m_index = index;
+ this.m_psk = psk;
+ this.m_pskKeyExchangeModes = pskKeyExchangeModes;
+ this.m_earlySecret = earlySecret;
+ }
+ }
+
private readonly IList m_identities;
private readonly IList m_binders;
private readonly int m_bindersSize;
@@ -49,16 +65,6 @@ namespace Org.BouncyCastle.Tls
this.m_bindersSize = bindersSize;
}
- internal byte[] GetBinderForIdentity(PskIdentity matchIdentity)
- {
- for (int i = 0, count = m_identities.Count; i < count; ++i)
- {
- if (matchIdentity.Equals(m_identities[i]))
- return (byte[])m_binders[i];
- }
- return null;
- }
-
public IList Binders
{
get { return m_binders; }
@@ -74,6 +80,16 @@ namespace Org.BouncyCastle.Tls
get { return m_identities; }
}
+ public int GetIndexOfIdentity(PskIdentity pskIdentity)
+ {
+ for (int i = 0, count = m_identities.Count; i < count; ++i)
+ {
+ if (pskIdentity.Equals(m_identities[i]))
+ return i;
+ }
+ return -1;
+ }
+
/// <exception cref="IOException"/>
public void Encode(Stream output)
{
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index b03548398..d1e046965 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -5526,6 +5526,77 @@ namespace Org.BouncyCastle.Tls
return result;
}
+ internal static OfferedPsks.SelectedConfig SelectPreSharedKey(TlsServerContext serverContext, TlsServer server,
+ IDictionary clientHelloExtensions, HandshakeMessageInput clientHelloMessage, TlsHandshakeHash handshakeHash,
+ bool afterHelloRetryRequest)
+ {
+ bool handshakeHashUpdated = false;
+
+ OfferedPsks offeredPsks = TlsExtensionsUtilities.GetPreSharedKeyClientHello(clientHelloExtensions);
+ if (null != offeredPsks)
+ {
+ short[] pskKeyExchangeModes = TlsExtensionsUtilities.GetPskKeyExchangeModesExtension(
+ clientHelloExtensions);
+ if (IsNullOrEmpty(pskKeyExchangeModes))
+ throw new TlsFatalAlert(AlertDescription.missing_extension);
+
+ // TODO[tls13] Add support for psk_ke?
+ if (Arrays.Contains(pskKeyExchangeModes, PskKeyExchangeMode.psk_dhe_ke))
+ {
+ // TODO[tls13] Prefer to get the exact index from the server?
+ TlsPskExternal psk = server.GetExternalPsk(offeredPsks.Identities);
+ if (null != psk)
+ {
+ int index = offeredPsks.GetIndexOfIdentity(new PskIdentity(psk.Identity, 0L));
+ if (index >= 0)
+ {
+ byte[] binder = (byte[])offeredPsks.Binders[index];
+
+ TlsCrypto crypto = serverContext.Crypto;
+ TlsSecret earlySecret = GetPskEarlySecret(crypto, psk);
+
+ // TODO[tls13-psk] Handle resumption PSKs
+ bool isExternalPsk = true;
+ int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+ byte[] transcriptHash;
+ {
+ handshakeHashUpdated = true;
+ int bindersSize = offeredPsks.BindersSize;
+ clientHelloMessage.UpdateHashPrefix(handshakeHash, bindersSize);
+
+ if (afterHelloRetryRequest)
+ {
+ transcriptHash = handshakeHash.GetFinalHash(pskCryptoHashAlgorithm);
+ }
+ else
+ {
+ TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm);
+ handshakeHash.CopyBufferTo(new TlsHashSink(hash));
+ transcriptHash = hash.CalculateHash();
+ }
+
+ clientHelloMessage.UpdateHashSuffix(handshakeHash, bindersSize);
+ }
+
+ byte[] calculatedBinder = CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm,
+ earlySecret, transcriptHash);
+
+ if (Arrays.ConstantTimeAreEqual(calculatedBinder, binder))
+ return new OfferedPsks.SelectedConfig(index, psk, pskKeyExchangeModes, earlySecret);
+ }
+ }
+ }
+ }
+
+ if (!handshakeHashUpdated)
+ {
+ clientHelloMessage.UpdateHash(handshakeHash);
+ }
+
+ return null;
+ }
+
internal static TlsSecret GetPskEarlySecret(TlsCrypto crypto, TlsPsk psk)
{
int cryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
|