diff options
Diffstat (limited to '')
-rw-r--r-- | crypto/src/tls/OfferedPsks.cs | 79 |
1 files changed, 68 insertions, 11 deletions
diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs index 597ec195c..5419a19d1 100644 --- a/crypto/src/tls/OfferedPsks.cs +++ b/crypto/src/tls/OfferedPsks.cs @@ -2,6 +2,7 @@ using System.Collections; using System.IO; +using Org.BouncyCastle.Tls.Crypto; using Org.BouncyCastle.Utilities; namespace Org.BouncyCastle.Tls @@ -11,12 +12,17 @@ namespace Org.BouncyCastle.Tls private readonly IList m_identities; private readonly IList m_binders; - public OfferedPsks(IList identities, IList binders) + public OfferedPsks(IList identities) + : this(identities, null) + { + } + + private OfferedPsks(IList identities, IList binders) { if (null == identities || identities.Count < 1) throw new ArgumentException("cannot be null or empty", "identities"); - if (null == binders || identities.Count != binders.Count) - throw new ArgumentException("must be non-null and the same length as 'identities'", "binders"); + if (null != binders && identities.Count != binders.Count) + throw new ArgumentException("must be the same length as 'identities' (or null)", "binders"); this.m_identities = identities; this.m_binders = binders; @@ -37,14 +43,14 @@ namespace Org.BouncyCastle.Tls { // identities { - int totalLengthIdentities = 0; + int lengthOfIdentitiesList = 0; foreach (PskIdentity identity in m_identities) { - totalLengthIdentities += 2 + identity.Identity.Length + 4; + lengthOfIdentitiesList += identity.GetEncodedLength(); } - TlsUtilities.CheckUint16(totalLengthIdentities); - TlsUtilities.WriteUint16(totalLengthIdentities, output); + TlsUtilities.CheckUint16(lengthOfIdentitiesList); + TlsUtilities.WriteUint16(lengthOfIdentitiesList, output); foreach (PskIdentity identity in m_identities) { @@ -53,15 +59,16 @@ namespace Org.BouncyCastle.Tls } // binders + if (null != m_binders) { - int totalLengthBinders = 0; + int lengthOfBindersList = 0; foreach (byte[] binder in m_binders) { - totalLengthBinders += 1 + binder.Length; + lengthOfBindersList += 1 + binder.Length; } - TlsUtilities.CheckUint16(totalLengthBinders); - TlsUtilities.WriteUint16(totalLengthBinders, output); + TlsUtilities.CheckUint16(lengthOfBindersList); + TlsUtilities.WriteUint16(lengthOfBindersList, output); foreach (byte[] binder in m_binders) { @@ -71,6 +78,56 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> + internal static void EncodeBinders(Stream output, TlsCrypto crypto, TlsHandshakeHash handshakeHash, + TlsPsk[] psks, TlsSecret[] earlySecrets, int expectedLengthOfBindersList) + { + TlsUtilities.CheckUint16(expectedLengthOfBindersList); + TlsUtilities.WriteUint16(expectedLengthOfBindersList, output); + + int lengthOfBindersList = 0; + for (int i = 0; i < psks.Length; ++i) + { + TlsPsk psk = psks[i]; + TlsSecret earlySecret = earlySecrets[i]; + + // TODO[tls13-psk] Handle resumption PSKs + bool isExternalPsk = true; + int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm); + + // TODO[tls13-psk] Cache the transcript hashes per algorithm to avoid duplicates for multiple PSKs + TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm); + handshakeHash.CopyBufferTo(new TlsHashSink(hash)); + byte[] transcriptHash = hash.CalculateHash(); + + byte[] binder = TlsUtilities.CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm, + earlySecret, transcriptHash); + + lengthOfBindersList += 1 + binder.Length; + TlsUtilities.WriteOpaque8(binder, output); + } + + if (expectedLengthOfBindersList != lengthOfBindersList) + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + /// <exception cref="IOException"/> + internal static int GetLengthOfBindersList(TlsPsk[] psks) + { + int lengthOfBindersList = 0; + for (int i = 0; i < psks.Length; ++i) + { + TlsPsk psk = psks[i]; + + int prfAlgorithm = psk.PrfAlgorithm; + int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(prfAlgorithm); + + lengthOfBindersList += 1 + TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm); + } + TlsUtilities.CheckUint16(lengthOfBindersList); + return lengthOfBindersList; + } + + /// <exception cref="IOException"/> public static OfferedPsks Parse(Stream input) { IList identities = Platform.CreateArrayList(); |