From b2e186669793e61ec36a50ec35c00f781fa5d3c8 Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Mon, 26 Jul 2021 21:30:34 +0700 Subject: More work on PSK binders --- crypto/src/tls/ClientHello.cs | 13 ++++-- crypto/src/tls/DtlsClientProtocol.cs | 2 +- crypto/src/tls/HandshakeMessageOutput.cs | 13 ++---- crypto/src/tls/OfferedPsks.cs | 79 +++++++++++++++++++++++++++----- crypto/src/tls/PskIdentity.cs | 5 ++ crypto/src/tls/TlsClientProtocol.cs | 15 +++--- crypto/src/tls/TlsProtocol.cs | 49 ++++++++++++++++++-- crypto/src/tls/TlsUtilities.cs | 11 ++--- 8 files changed, 149 insertions(+), 38 deletions(-) (limited to 'crypto/src') diff --git a/crypto/src/tls/ClientHello.cs b/crypto/src/tls/ClientHello.cs index 50a33ac39..700d424cd 100644 --- a/crypto/src/tls/ClientHello.cs +++ b/crypto/src/tls/ClientHello.cs @@ -15,9 +15,10 @@ namespace Org.BouncyCastle.Tls private readonly byte[] m_cookie; private readonly int[] m_cipherSuites; private readonly IDictionary m_extensions; + private readonly int m_bindersSize; public ClientHello(ProtocolVersion version, byte[] random, byte[] sessionID, byte[] cookie, - int[] cipherSuites, IDictionary extensions) + int[] cipherSuites, IDictionary extensions, int bindersSize) { this.m_version = version; this.m_random = random; @@ -25,6 +26,12 @@ namespace Org.BouncyCastle.Tls this.m_cookie = cookie; this.m_cipherSuites = cipherSuites; this.m_extensions = extensions; + this.m_bindersSize = bindersSize; + } + + public int BindersSize + { + get { return m_bindersSize; } } public int[] CipherSuites @@ -78,7 +85,7 @@ namespace Org.BouncyCastle.Tls TlsUtilities.WriteUint8ArrayWithUint8Length(new short[]{ CompressionMethod.cls_null }, output); - TlsProtocol.WriteExtensions(output, m_extensions); + TlsProtocol.WriteExtensions(output, m_extensions, m_bindersSize); } /// Parse a from a . @@ -161,7 +168,7 @@ namespace Org.BouncyCastle.Tls extensions = TlsProtocol.ReadExtensionsDataClientHello(extBytes); } - return new ClientHello(clientVersion, random, sessionID, cookie, cipherSuites, extensions); + return new ClientHello(clientVersion, random, sessionID, cookie, cipherSuites, extensions, 0); } } } diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs index dea35a28b..cd2fff709 100644 --- a/crypto/src/tls/DtlsClientProtocol.cs +++ b/crypto/src/tls/DtlsClientProtocol.cs @@ -513,7 +513,7 @@ namespace Org.BouncyCastle.Tls ClientHello clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, session_id, - TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions); + TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0); MemoryStream buf = new MemoryStream(); clientHello.Encode(state.clientContext, buf); diff --git a/crypto/src/tls/HandshakeMessageOutput.cs b/crypto/src/tls/HandshakeMessageOutput.cs index 97e9a84af..ff45ce6f3 100644 --- a/crypto/src/tls/HandshakeMessageOutput.cs +++ b/crypto/src/tls/HandshakeMessageOutput.cs @@ -59,12 +59,10 @@ namespace Org.BouncyCastle.Tls Platform.Dispose(this); } - internal void PrepareClientHello(TlsHandshakeHash handshakeHash, int totalBindersLength) + internal void PrepareClientHello(TlsHandshakeHash handshakeHash, int bindersSize) { - TlsUtilities.CheckUint16(totalBindersLength); - // Patch actual length back in - int bodyLength = (int)Length - 4 + totalBindersLength; + int bodyLength = (int)Length - 4 + bindersSize; TlsUtilities.CheckUint24(bodyLength); Seek(1L, SeekOrigin.Begin); @@ -83,8 +81,7 @@ namespace Org.BouncyCastle.Tls Seek(0L, SeekOrigin.End); } - internal void SendClientHello(TlsClientProtocol clientProtocol, TlsHandshakeHash handshakeHash, - int totalBindersLength) + internal void SendClientHello(TlsClientProtocol clientProtocol, TlsHandshakeHash handshakeHash, int bindersSize) { #if PORTABLE byte[] buf = ToArray(); @@ -94,9 +91,9 @@ namespace Org.BouncyCastle.Tls int count = (int)Length; #endif - if (totalBindersLength > 0) + if (bindersSize > 0) { - handshakeHash.Update(buf, count - totalBindersLength, totalBindersLength); + handshakeHash.Update(buf, count - bindersSize, bindersSize); } clientProtocol.WriteHandshakeMessage(buf, 0, count); diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs index 597ec195c..5419a19d1 100644 --- a/crypto/src/tls/OfferedPsks.cs +++ b/crypto/src/tls/OfferedPsks.cs @@ -2,6 +2,7 @@ using System.Collections; using System.IO; +using Org.BouncyCastle.Tls.Crypto; using Org.BouncyCastle.Utilities; namespace Org.BouncyCastle.Tls @@ -11,12 +12,17 @@ namespace Org.BouncyCastle.Tls private readonly IList m_identities; private readonly IList m_binders; - public OfferedPsks(IList identities, IList binders) + public OfferedPsks(IList identities) + : this(identities, null) + { + } + + private OfferedPsks(IList identities, IList binders) { if (null == identities || identities.Count < 1) throw new ArgumentException("cannot be null or empty", "identities"); - if (null == binders || identities.Count != binders.Count) - throw new ArgumentException("must be non-null and the same length as 'identities'", "binders"); + if (null != binders && identities.Count != binders.Count) + throw new ArgumentException("must be the same length as 'identities' (or null)", "binders"); this.m_identities = identities; this.m_binders = binders; @@ -37,14 +43,14 @@ namespace Org.BouncyCastle.Tls { // identities { - int totalLengthIdentities = 0; + int lengthOfIdentitiesList = 0; foreach (PskIdentity identity in m_identities) { - totalLengthIdentities += 2 + identity.Identity.Length + 4; + lengthOfIdentitiesList += identity.GetEncodedLength(); } - TlsUtilities.CheckUint16(totalLengthIdentities); - TlsUtilities.WriteUint16(totalLengthIdentities, output); + TlsUtilities.CheckUint16(lengthOfIdentitiesList); + TlsUtilities.WriteUint16(lengthOfIdentitiesList, output); foreach (PskIdentity identity in m_identities) { @@ -53,15 +59,16 @@ namespace Org.BouncyCastle.Tls } // binders + if (null != m_binders) { - int totalLengthBinders = 0; + int lengthOfBindersList = 0; foreach (byte[] binder in m_binders) { - totalLengthBinders += 1 + binder.Length; + lengthOfBindersList += 1 + binder.Length; } - TlsUtilities.CheckUint16(totalLengthBinders); - TlsUtilities.WriteUint16(totalLengthBinders, output); + TlsUtilities.CheckUint16(lengthOfBindersList); + TlsUtilities.WriteUint16(lengthOfBindersList, output); foreach (byte[] binder in m_binders) { @@ -70,6 +77,56 @@ namespace Org.BouncyCastle.Tls } } + /// + internal static void EncodeBinders(Stream output, TlsCrypto crypto, TlsHandshakeHash handshakeHash, + TlsPsk[] psks, TlsSecret[] earlySecrets, int expectedLengthOfBindersList) + { + TlsUtilities.CheckUint16(expectedLengthOfBindersList); + TlsUtilities.WriteUint16(expectedLengthOfBindersList, output); + + int lengthOfBindersList = 0; + for (int i = 0; i < psks.Length; ++i) + { + TlsPsk psk = psks[i]; + TlsSecret earlySecret = earlySecrets[i]; + + // TODO[tls13-psk] Handle resumption PSKs + bool isExternalPsk = true; + int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm); + + // TODO[tls13-psk] Cache the transcript hashes per algorithm to avoid duplicates for multiple PSKs + TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm); + handshakeHash.CopyBufferTo(new TlsHashSink(hash)); + byte[] transcriptHash = hash.CalculateHash(); + + byte[] binder = TlsUtilities.CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm, + earlySecret, transcriptHash); + + lengthOfBindersList += 1 + binder.Length; + TlsUtilities.WriteOpaque8(binder, output); + } + + if (expectedLengthOfBindersList != lengthOfBindersList) + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + /// + internal static int GetLengthOfBindersList(TlsPsk[] psks) + { + int lengthOfBindersList = 0; + for (int i = 0; i < psks.Length; ++i) + { + TlsPsk psk = psks[i]; + + int prfAlgorithm = psk.PrfAlgorithm; + int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(prfAlgorithm); + + lengthOfBindersList += 1 + TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm); + } + TlsUtilities.CheckUint16(lengthOfBindersList); + return lengthOfBindersList; + } + /// public static OfferedPsks Parse(Stream input) { diff --git a/crypto/src/tls/PskIdentity.cs b/crypto/src/tls/PskIdentity.cs index 9b24527bb..082907419 100644 --- a/crypto/src/tls/PskIdentity.cs +++ b/crypto/src/tls/PskIdentity.cs @@ -21,6 +21,11 @@ namespace Org.BouncyCastle.Tls this.m_obfuscatedTicketAge = obfuscatedTicketAge; } + public int GetEncodedLength() + { + return 6 + m_identity.Length; + } + public byte[] Identity { get { return m_identity; } diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs index 7a92220dc..8fb1a39b7 100644 --- a/crypto/src/tls/TlsClientProtocol.cs +++ b/crypto/src/tls/TlsClientProtocol.cs @@ -1474,6 +1474,8 @@ namespace Org.BouncyCastle.Tls /// protected virtual void Send13ClientHelloRetry() { + // TODO[tls13-psk] Create a new ClientHello object and handle any changes to the bindersSize + IDictionary clientHelloExtensions = m_clientHello.Extensions; clientHelloExtensions.Remove(ExtensionType.cookie); @@ -1679,8 +1681,12 @@ namespace Org.BouncyCastle.Tls + // TODO[tls13-psk] Calculate the total length of the binders that will be added. + int bindersSize = 0; + //int bindersSize = 2 + lengthOfBindersList; + this.m_clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, legacy_session_id, - null, offeredCipherSuites, m_clientExtensions); + null, offeredCipherSuites, m_clientExtensions, bindersSize); SendClientHelloMessage(); } @@ -1691,14 +1697,11 @@ namespace Org.BouncyCastle.Tls HandshakeMessageOutput message = new HandshakeMessageOutput(HandshakeType.client_hello); m_clientHello.Encode(m_tlsClientContext, message); - // TODO[tls13-psk] Calculate the total length of the binders that will be added. - int totalBindersLength = 0; - - message.PrepareClientHello(m_handshakeHash, totalBindersLength); + message.PrepareClientHello(m_handshakeHash, m_clientHello.BindersSize); // TODO[tls13-psk] Calculate any PSK binders and write them to 'message' here. - message.SendClientHello(this, m_handshakeHash, totalBindersLength); + message.SendClientHello(this, m_handshakeHash, m_clientHello.BindersSize); } /// diff --git a/crypto/src/tls/TlsProtocol.cs b/crypto/src/tls/TlsProtocol.cs index d4960e3c8..f05c09a1b 100644 --- a/crypto/src/tls/TlsProtocol.cs +++ b/crypto/src/tls/TlsProtocol.cs @@ -1826,25 +1826,46 @@ namespace Org.BouncyCastle.Tls /// internal static void WriteExtensions(Stream output, IDictionary extensions) + { + WriteExtensions(output, extensions, 0); + } + + /// + internal static void WriteExtensions(Stream output, IDictionary extensions, int bindersSize) { if (null == extensions || extensions.Count < 1) return; - byte[] extBytes = WriteExtensionsData(extensions); + byte[] extBytes = WriteExtensionsData(extensions, bindersSize); - TlsUtilities.WriteOpaque16(extBytes, output); + int lengthWithBinders = extBytes.Length + bindersSize; + TlsUtilities.CheckUint16(lengthWithBinders); + TlsUtilities.WriteUint16(lengthWithBinders, output); + output.Write(extBytes, 0, extBytes.Length); } /// internal static byte[] WriteExtensionsData(IDictionary extensions) + { + return WriteExtensionsData(extensions, 0); + } + + /// + internal static byte[] WriteExtensionsData(IDictionary extensions, int bindersSize) { MemoryStream buf = new MemoryStream(); - WriteExtensionsData(extensions, buf); + WriteExtensionsData(extensions, buf, bindersSize); return buf.ToArray(); } /// internal static void WriteExtensionsData(IDictionary extensions, MemoryStream buf) + { + WriteExtensionsData(extensions, buf, 0); + } + + /// + internal static void WriteExtensionsData(IDictionary extensions, MemoryStream buf, int bindersSize) { /* * NOTE: There are reports of servers that don't accept a zero-length extension as the last @@ -1852,6 +1873,23 @@ namespace Org.BouncyCastle.Tls */ WriteSelectedExtensions(buf, extensions, true); WriteSelectedExtensions(buf, extensions, false); + WritePreSharedKeyExtension(buf, extensions, bindersSize); + } + + /// + internal static void WritePreSharedKeyExtension(MemoryStream buf, IDictionary extensions, int bindersSize) + { + byte[] extension_data = (byte[])extensions[ExtensionType.pre_shared_key]; + if (null != extension_data) + { + TlsUtilities.CheckUint16(ExtensionType.pre_shared_key); + TlsUtilities.WriteUint16(ExtensionType.pre_shared_key, buf); + + int lengthWithBinders = extension_data.Length + bindersSize; + TlsUtilities.CheckUint16(lengthWithBinders); + TlsUtilities.WriteUint16(lengthWithBinders, buf); + buf.Write(extension_data, 0, extension_data.Length); + } } /// @@ -1860,6 +1898,11 @@ namespace Org.BouncyCastle.Tls foreach (Int32 key in extensions.Keys) { int extension_type = key; + + // NOTE: Must be last; handled by 'WritePreSharedKeyExtension' + if (ExtensionType.pre_shared_key == extension_type) + continue; + byte[] extension_data = (byte[])extensions[key]; if (selectEmpty == (extension_data.Length == 0)) diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs index c0ccfe9be..72c41ef05 100644 --- a/crypto/src/tls/TlsUtilities.cs +++ b/crypto/src/tls/TlsUtilities.cs @@ -1504,21 +1504,20 @@ namespace Org.BouncyCastle.Tls return Prf(sp, preMasterSecret, asciiLabel, seed, 48); } - internal static byte[] CalculatePskBinder(TlsCrypto crypto, bool isExternalPsk, int pskPRFAlgorithm, + internal static byte[] CalculatePskBinder(TlsCrypto crypto, bool isExternalPsk, int pskCryptoHashAlgorithm, TlsSecret earlySecret, byte[] transcriptHash) { - int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(pskPRFAlgorithm); - int prfHashLength = TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm); + int prfHashLength = TlsCryptoUtilities.GetHashOutputSize(pskCryptoHashAlgorithm); string label = isExternalPsk ? "ext binder" : "res binder"; - byte[] emptyTranscriptHash = crypto.CreateHash(prfCryptoHashAlgorithm).CalculateHash(); + byte[] emptyTranscriptHash = crypto.CreateHash(pskCryptoHashAlgorithm).CalculateHash(); - TlsSecret binderKey = DeriveSecret(prfCryptoHashAlgorithm, prfHashLength, earlySecret, label, + TlsSecret binderKey = DeriveSecret(pskCryptoHashAlgorithm, prfHashLength, earlySecret, label, emptyTranscriptHash); try { - return CalculateFinishedHmac(prfCryptoHashAlgorithm, prfHashLength, binderKey, transcriptHash); + return CalculateFinishedHmac(pskCryptoHashAlgorithm, prfHashLength, binderKey, transcriptHash); } finally { -- cgit 1.4.1