diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index b03548398..d1e046965 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -5526,6 +5526,77 @@ namespace Org.BouncyCastle.Tls
return result;
}
+ internal static OfferedPsks.SelectedConfig SelectPreSharedKey(TlsServerContext serverContext, TlsServer server,
+ IDictionary clientHelloExtensions, HandshakeMessageInput clientHelloMessage, TlsHandshakeHash handshakeHash,
+ bool afterHelloRetryRequest)
+ {
+ bool handshakeHashUpdated = false;
+
+ OfferedPsks offeredPsks = TlsExtensionsUtilities.GetPreSharedKeyClientHello(clientHelloExtensions);
+ if (null != offeredPsks)
+ {
+ short[] pskKeyExchangeModes = TlsExtensionsUtilities.GetPskKeyExchangeModesExtension(
+ clientHelloExtensions);
+ if (IsNullOrEmpty(pskKeyExchangeModes))
+ throw new TlsFatalAlert(AlertDescription.missing_extension);
+
+ // TODO[tls13] Add support for psk_ke?
+ if (Arrays.Contains(pskKeyExchangeModes, PskKeyExchangeMode.psk_dhe_ke))
+ {
+ // TODO[tls13] Prefer to get the exact index from the server?
+ TlsPskExternal psk = server.GetExternalPsk(offeredPsks.Identities);
+ if (null != psk)
+ {
+ int index = offeredPsks.GetIndexOfIdentity(new PskIdentity(psk.Identity, 0L));
+ if (index >= 0)
+ {
+ byte[] binder = (byte[])offeredPsks.Binders[index];
+
+ TlsCrypto crypto = serverContext.Crypto;
+ TlsSecret earlySecret = GetPskEarlySecret(crypto, psk);
+
+ // TODO[tls13-psk] Handle resumption PSKs
+ bool isExternalPsk = true;
+ int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+ byte[] transcriptHash;
+ {
+ handshakeHashUpdated = true;
+ int bindersSize = offeredPsks.BindersSize;
+ clientHelloMessage.UpdateHashPrefix(handshakeHash, bindersSize);
+
+ if (afterHelloRetryRequest)
+ {
+ transcriptHash = handshakeHash.GetFinalHash(pskCryptoHashAlgorithm);
+ }
+ else
+ {
+ TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm);
+ handshakeHash.CopyBufferTo(new TlsHashSink(hash));
+ transcriptHash = hash.CalculateHash();
+ }
+
+ clientHelloMessage.UpdateHashSuffix(handshakeHash, bindersSize);
+ }
+
+ byte[] calculatedBinder = CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm,
+ earlySecret, transcriptHash);
+
+ if (Arrays.ConstantTimeAreEqual(calculatedBinder, binder))
+ return new OfferedPsks.SelectedConfig(index, psk, pskKeyExchangeModes, earlySecret);
+ }
+ }
+ }
+ }
+
+ if (!handshakeHashUpdated)
+ {
+ clientHelloMessage.UpdateHash(handshakeHash);
+ }
+
+ return null;
+ }
+
internal static TlsSecret GetPskEarlySecret(TlsCrypto crypto, TlsPsk psk)
{
int cryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
|