diff options
Diffstat (limited to 'crypto/src/tls/TlsUtilities.cs')
-rw-r--r-- | crypto/src/tls/TlsUtilities.cs | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs index b03548398..d1e046965 100644 --- a/crypto/src/tls/TlsUtilities.cs +++ b/crypto/src/tls/TlsUtilities.cs @@ -5526,6 +5526,77 @@ namespace Org.BouncyCastle.Tls return result; } + internal static OfferedPsks.SelectedConfig SelectPreSharedKey(TlsServerContext serverContext, TlsServer server, + IDictionary clientHelloExtensions, HandshakeMessageInput clientHelloMessage, TlsHandshakeHash handshakeHash, + bool afterHelloRetryRequest) + { + bool handshakeHashUpdated = false; + + OfferedPsks offeredPsks = TlsExtensionsUtilities.GetPreSharedKeyClientHello(clientHelloExtensions); + if (null != offeredPsks) + { + short[] pskKeyExchangeModes = TlsExtensionsUtilities.GetPskKeyExchangeModesExtension( + clientHelloExtensions); + if (IsNullOrEmpty(pskKeyExchangeModes)) + throw new TlsFatalAlert(AlertDescription.missing_extension); + + // TODO[tls13] Add support for psk_ke? + if (Arrays.Contains(pskKeyExchangeModes, PskKeyExchangeMode.psk_dhe_ke)) + { + // TODO[tls13] Prefer to get the exact index from the server? + TlsPskExternal psk = server.GetExternalPsk(offeredPsks.Identities); + if (null != psk) + { + int index = offeredPsks.GetIndexOfIdentity(new PskIdentity(psk.Identity, 0L)); + if (index >= 0) + { + byte[] binder = (byte[])offeredPsks.Binders[index]; + + TlsCrypto crypto = serverContext.Crypto; + TlsSecret earlySecret = GetPskEarlySecret(crypto, psk); + + // TODO[tls13-psk] Handle resumption PSKs + bool isExternalPsk = true; + int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm); + + byte[] transcriptHash; + { + handshakeHashUpdated = true; + int bindersSize = offeredPsks.BindersSize; + clientHelloMessage.UpdateHashPrefix(handshakeHash, bindersSize); + + if (afterHelloRetryRequest) + { + transcriptHash = handshakeHash.GetFinalHash(pskCryptoHashAlgorithm); + } + else + { + TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm); + handshakeHash.CopyBufferTo(new TlsHashSink(hash)); + transcriptHash = hash.CalculateHash(); + } + + clientHelloMessage.UpdateHashSuffix(handshakeHash, bindersSize); + } + + byte[] calculatedBinder = CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm, + earlySecret, transcriptHash); + + if (Arrays.ConstantTimeAreEqual(calculatedBinder, binder)) + return new OfferedPsks.SelectedConfig(index, psk, pskKeyExchangeModes, earlySecret); + } + } + } + } + + if (!handshakeHashUpdated) + { + clientHelloMessage.UpdateHash(handshakeHash); + } + + return null; + } + internal static TlsSecret GetPskEarlySecret(TlsCrypto crypto, TlsPsk psk) { int cryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm); |