From bd2fe5262f97293908320481e0eeefb0a92b364c Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Sun, 17 Oct 2021 21:46:19 +0700 Subject: Server-side PSK selection --- crypto/src/tls/OfferedPsks.cs | 36 +++++++++++++++------ crypto/src/tls/TlsUtilities.cs | 71 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 10 deletions(-) diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs index 9eddd2e23..1cc8a2a68 100644 --- a/crypto/src/tls/OfferedPsks.cs +++ b/crypto/src/tls/OfferedPsks.cs @@ -26,6 +26,22 @@ namespace Org.BouncyCastle.Tls } } + internal class SelectedConfig + { + internal readonly int m_index; + internal readonly TlsPsk m_psk; + internal readonly short[] m_pskKeyExchangeModes; + internal readonly TlsSecret m_earlySecret; + + internal SelectedConfig(int index, TlsPsk psk, short[] pskKeyExchangeModes, TlsSecret earlySecret) + { + this.m_index = index; + this.m_psk = psk; + this.m_pskKeyExchangeModes = pskKeyExchangeModes; + this.m_earlySecret = earlySecret; + } + } + private readonly IList m_identities; private readonly IList m_binders; private readonly int m_bindersSize; @@ -49,16 +65,6 @@ namespace Org.BouncyCastle.Tls this.m_bindersSize = bindersSize; } - internal byte[] GetBinderForIdentity(PskIdentity matchIdentity) - { - for (int i = 0, count = m_identities.Count; i < count; ++i) - { - if (matchIdentity.Equals(m_identities[i])) - return (byte[])m_binders[i]; - } - return null; - } - public IList Binders { get { return m_binders; } @@ -74,6 +80,16 @@ namespace Org.BouncyCastle.Tls get { return m_identities; } } + public int GetIndexOfIdentity(PskIdentity pskIdentity) + { + for (int i = 0, count = m_identities.Count; i < count; ++i) + { + if (pskIdentity.Equals(m_identities[i])) + return i; + } + return -1; + } + /// public void Encode(Stream output) { diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs index b03548398..d1e046965 100644 --- a/crypto/src/tls/TlsUtilities.cs +++ b/crypto/src/tls/TlsUtilities.cs @@ -5526,6 +5526,77 @@ namespace Org.BouncyCastle.Tls return result; } + internal static OfferedPsks.SelectedConfig SelectPreSharedKey(TlsServerContext serverContext, TlsServer server, + IDictionary clientHelloExtensions, HandshakeMessageInput clientHelloMessage, TlsHandshakeHash handshakeHash, + bool afterHelloRetryRequest) + { + bool handshakeHashUpdated = false; + + OfferedPsks offeredPsks = TlsExtensionsUtilities.GetPreSharedKeyClientHello(clientHelloExtensions); + if (null != offeredPsks) + { + short[] pskKeyExchangeModes = TlsExtensionsUtilities.GetPskKeyExchangeModesExtension( + clientHelloExtensions); + if (IsNullOrEmpty(pskKeyExchangeModes)) + throw new TlsFatalAlert(AlertDescription.missing_extension); + + // TODO[tls13] Add support for psk_ke? + if (Arrays.Contains(pskKeyExchangeModes, PskKeyExchangeMode.psk_dhe_ke)) + { + // TODO[tls13] Prefer to get the exact index from the server? + TlsPskExternal psk = server.GetExternalPsk(offeredPsks.Identities); + if (null != psk) + { + int index = offeredPsks.GetIndexOfIdentity(new PskIdentity(psk.Identity, 0L)); + if (index >= 0) + { + byte[] binder = (byte[])offeredPsks.Binders[index]; + + TlsCrypto crypto = serverContext.Crypto; + TlsSecret earlySecret = GetPskEarlySecret(crypto, psk); + + // TODO[tls13-psk] Handle resumption PSKs + bool isExternalPsk = true; + int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm); + + byte[] transcriptHash; + { + handshakeHashUpdated = true; + int bindersSize = offeredPsks.BindersSize; + clientHelloMessage.UpdateHashPrefix(handshakeHash, bindersSize); + + if (afterHelloRetryRequest) + { + transcriptHash = handshakeHash.GetFinalHash(pskCryptoHashAlgorithm); + } + else + { + TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm); + handshakeHash.CopyBufferTo(new TlsHashSink(hash)); + transcriptHash = hash.CalculateHash(); + } + + clientHelloMessage.UpdateHashSuffix(handshakeHash, bindersSize); + } + + byte[] calculatedBinder = CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm, + earlySecret, transcriptHash); + + if (Arrays.ConstantTimeAreEqual(calculatedBinder, binder)) + return new OfferedPsks.SelectedConfig(index, psk, pskKeyExchangeModes, earlySecret); + } + } + } + } + + if (!handshakeHashUpdated) + { + clientHelloMessage.UpdateHash(handshakeHash); + } + + return null; + } + internal static TlsSecret GetPskEarlySecret(TlsCrypto crypto, TlsPsk psk) { int cryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm); -- cgit 1.4.1