summary refs log tree commit diff
path: root/crypto/src/tls/OfferedPsks.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/OfferedPsks.cs')
-rw-r--r--crypto/src/tls/OfferedPsks.cs79
1 files changed, 68 insertions, 11 deletions
diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs
index 597ec195c..5419a19d1 100644
--- a/crypto/src/tls/OfferedPsks.cs
+++ b/crypto/src/tls/OfferedPsks.cs
@@ -2,6 +2,7 @@
 using System.Collections;
 using System.IO;
 
+using Org.BouncyCastle.Tls.Crypto;
 using Org.BouncyCastle.Utilities;
 
 namespace Org.BouncyCastle.Tls
@@ -11,12 +12,17 @@ namespace Org.BouncyCastle.Tls
         private readonly IList m_identities;
         private readonly IList m_binders;
 
-        public OfferedPsks(IList identities, IList binders)
+        public OfferedPsks(IList identities)
+            : this(identities, null)
+        {
+        }
+
+        private OfferedPsks(IList identities, IList binders)
         {
             if (null == identities || identities.Count < 1)
                 throw new ArgumentException("cannot be null or empty", "identities");
-            if (null == binders || identities.Count != binders.Count)
-                throw new ArgumentException("must be non-null and the same length as 'identities'", "binders");
+            if (null != binders && identities.Count != binders.Count)
+                throw new ArgumentException("must be the same length as 'identities' (or null)", "binders");
 
             this.m_identities = identities;
             this.m_binders = binders;
@@ -37,14 +43,14 @@ namespace Org.BouncyCastle.Tls
         {
             // identities
             {
-                int totalLengthIdentities = 0;
+                int lengthOfIdentitiesList = 0;
                 foreach (PskIdentity identity in m_identities)
                 {
-                    totalLengthIdentities += 2 + identity.Identity.Length + 4;
+                    lengthOfIdentitiesList += identity.GetEncodedLength();
                 }
 
-                TlsUtilities.CheckUint16(totalLengthIdentities);
-                TlsUtilities.WriteUint16(totalLengthIdentities, output);
+                TlsUtilities.CheckUint16(lengthOfIdentitiesList);
+                TlsUtilities.WriteUint16(lengthOfIdentitiesList, output);
 
                 foreach (PskIdentity identity in m_identities)
                 {
@@ -53,15 +59,16 @@ namespace Org.BouncyCastle.Tls
             }
 
             // binders
+            if (null != m_binders)
             {
-                int totalLengthBinders = 0;
+                int lengthOfBindersList = 0;
                 foreach (byte[] binder in m_binders)
                 {
-                    totalLengthBinders += 1 + binder.Length;
+                    lengthOfBindersList += 1 + binder.Length;
                 }
 
-                TlsUtilities.CheckUint16(totalLengthBinders);
-                TlsUtilities.WriteUint16(totalLengthBinders, output);
+                TlsUtilities.CheckUint16(lengthOfBindersList);
+                TlsUtilities.WriteUint16(lengthOfBindersList, output);
 
                 foreach (byte[] binder in m_binders)
                 {
@@ -71,6 +78,56 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
+        internal static void EncodeBinders(Stream output, TlsCrypto crypto, TlsHandshakeHash handshakeHash,
+            TlsPsk[] psks, TlsSecret[] earlySecrets, int expectedLengthOfBindersList)
+        {
+            TlsUtilities.CheckUint16(expectedLengthOfBindersList);
+            TlsUtilities.WriteUint16(expectedLengthOfBindersList, output);
+
+            int lengthOfBindersList = 0;
+            for (int i = 0; i < psks.Length; ++i)
+            {
+                TlsPsk psk = psks[i];
+                TlsSecret earlySecret = earlySecrets[i];
+
+                // TODO[tls13-psk] Handle resumption PSKs
+                bool isExternalPsk = true;
+                int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+                // TODO[tls13-psk] Cache the transcript hashes per algorithm to avoid duplicates for multiple PSKs
+                TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm);
+                handshakeHash.CopyBufferTo(new TlsHashSink(hash));
+                byte[] transcriptHash = hash.CalculateHash();
+
+                byte[] binder = TlsUtilities.CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm,
+                    earlySecret, transcriptHash);
+
+                lengthOfBindersList += 1 + binder.Length;
+                TlsUtilities.WriteOpaque8(binder, output);
+            }
+
+            if (expectedLengthOfBindersList != lengthOfBindersList)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+        }
+
+        /// <exception cref="IOException"/>
+        internal static int GetLengthOfBindersList(TlsPsk[] psks)
+        {
+            int lengthOfBindersList = 0;
+            for (int i = 0; i < psks.Length; ++i)
+            {
+                TlsPsk psk = psks[i];
+
+                int prfAlgorithm = psk.PrfAlgorithm;
+                int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(prfAlgorithm);
+
+                lengthOfBindersList += 1 + TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm);
+            }
+            TlsUtilities.CheckUint16(lengthOfBindersList);
+            return lengthOfBindersList;
+        }
+
+        /// <exception cref="IOException"/>
         public static OfferedPsks Parse(Stream input)
         {
             IList identities = Platform.CreateArrayList();