summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/tls/ClientHello.cs13
-rw-r--r--crypto/src/tls/DtlsClientProtocol.cs2
-rw-r--r--crypto/src/tls/HandshakeMessageOutput.cs13
-rw-r--r--crypto/src/tls/OfferedPsks.cs79
-rw-r--r--crypto/src/tls/PskIdentity.cs5
-rw-r--r--crypto/src/tls/TlsClientProtocol.cs15
-rw-r--r--crypto/src/tls/TlsProtocol.cs49
-rw-r--r--crypto/src/tls/TlsUtilities.cs11
8 files changed, 149 insertions, 38 deletions
diff --git a/crypto/src/tls/ClientHello.cs b/crypto/src/tls/ClientHello.cs
index 50a33ac39..700d424cd 100644
--- a/crypto/src/tls/ClientHello.cs
+++ b/crypto/src/tls/ClientHello.cs
@@ -15,9 +15,10 @@ namespace Org.BouncyCastle.Tls
         private readonly byte[] m_cookie;
         private readonly int[] m_cipherSuites;
         private readonly IDictionary m_extensions;
+        private readonly int m_bindersSize;
 
         public ClientHello(ProtocolVersion version, byte[] random, byte[] sessionID, byte[] cookie,
-            int[] cipherSuites, IDictionary extensions)
+            int[] cipherSuites, IDictionary extensions, int bindersSize)
         {
             this.m_version = version;
             this.m_random = random;
@@ -25,6 +26,12 @@ namespace Org.BouncyCastle.Tls
             this.m_cookie = cookie;
             this.m_cipherSuites = cipherSuites;
             this.m_extensions = extensions;
+            this.m_bindersSize = bindersSize;
+        }
+
+        public int BindersSize
+        {
+            get { return m_bindersSize; }
         }
 
         public int[] CipherSuites
@@ -78,7 +85,7 @@ namespace Org.BouncyCastle.Tls
 
             TlsUtilities.WriteUint8ArrayWithUint8Length(new short[]{ CompressionMethod.cls_null }, output);
 
-            TlsProtocol.WriteExtensions(output, m_extensions);
+            TlsProtocol.WriteExtensions(output, m_extensions, m_bindersSize);
         }
 
         /// <summary>Parse a <see cref="ClientHello"/> from a <see cref="MemoryStream"/>.</summary>
@@ -161,7 +168,7 @@ namespace Org.BouncyCastle.Tls
                 extensions = TlsProtocol.ReadExtensionsDataClientHello(extBytes);
             }
 
-            return new ClientHello(clientVersion, random, sessionID, cookie, cipherSuites, extensions);
+            return new ClientHello(clientVersion, random, sessionID, cookie, cipherSuites, extensions, 0);
         }
     }
 }
diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs
index dea35a28b..cd2fff709 100644
--- a/crypto/src/tls/DtlsClientProtocol.cs
+++ b/crypto/src/tls/DtlsClientProtocol.cs
@@ -513,7 +513,7 @@ namespace Org.BouncyCastle.Tls
 
 
             ClientHello clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, session_id,
-                TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions);
+                TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0);
 
             MemoryStream buf = new MemoryStream();
             clientHello.Encode(state.clientContext, buf);
diff --git a/crypto/src/tls/HandshakeMessageOutput.cs b/crypto/src/tls/HandshakeMessageOutput.cs
index 97e9a84af..ff45ce6f3 100644
--- a/crypto/src/tls/HandshakeMessageOutput.cs
+++ b/crypto/src/tls/HandshakeMessageOutput.cs
@@ -59,12 +59,10 @@ namespace Org.BouncyCastle.Tls
             Platform.Dispose(this);
         }
 
-        internal void PrepareClientHello(TlsHandshakeHash handshakeHash, int totalBindersLength)
+        internal void PrepareClientHello(TlsHandshakeHash handshakeHash, int bindersSize)
         {
-            TlsUtilities.CheckUint16(totalBindersLength);
-
             // Patch actual length back in
-            int bodyLength = (int)Length - 4 + totalBindersLength;
+            int bodyLength = (int)Length - 4 + bindersSize;
             TlsUtilities.CheckUint24(bodyLength);
 
             Seek(1L, SeekOrigin.Begin);
@@ -83,8 +81,7 @@ namespace Org.BouncyCastle.Tls
             Seek(0L, SeekOrigin.End);
         }
 
-        internal void SendClientHello(TlsClientProtocol clientProtocol, TlsHandshakeHash handshakeHash,
-            int totalBindersLength)
+        internal void SendClientHello(TlsClientProtocol clientProtocol, TlsHandshakeHash handshakeHash, int bindersSize)
         {
 #if PORTABLE
             byte[] buf = ToArray();
@@ -94,9 +91,9 @@ namespace Org.BouncyCastle.Tls
             int count = (int)Length;
 #endif
 
-            if (totalBindersLength > 0)
+            if (bindersSize > 0)
             {
-                handshakeHash.Update(buf, count - totalBindersLength, totalBindersLength);
+                handshakeHash.Update(buf, count - bindersSize, bindersSize);
             }
 
             clientProtocol.WriteHandshakeMessage(buf, 0, count);
diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs
index 597ec195c..5419a19d1 100644
--- a/crypto/src/tls/OfferedPsks.cs
+++ b/crypto/src/tls/OfferedPsks.cs
@@ -2,6 +2,7 @@
 using System.Collections;
 using System.IO;
 
+using Org.BouncyCastle.Tls.Crypto;
 using Org.BouncyCastle.Utilities;
 
 namespace Org.BouncyCastle.Tls
@@ -11,12 +12,17 @@ namespace Org.BouncyCastle.Tls
         private readonly IList m_identities;
         private readonly IList m_binders;
 
-        public OfferedPsks(IList identities, IList binders)
+        public OfferedPsks(IList identities)
+            : this(identities, null)
+        {
+        }
+
+        private OfferedPsks(IList identities, IList binders)
         {
             if (null == identities || identities.Count < 1)
                 throw new ArgumentException("cannot be null or empty", "identities");
-            if (null == binders || identities.Count != binders.Count)
-                throw new ArgumentException("must be non-null and the same length as 'identities'", "binders");
+            if (null != binders && identities.Count != binders.Count)
+                throw new ArgumentException("must be the same length as 'identities' (or null)", "binders");
 
             this.m_identities = identities;
             this.m_binders = binders;
@@ -37,14 +43,14 @@ namespace Org.BouncyCastle.Tls
         {
             // identities
             {
-                int totalLengthIdentities = 0;
+                int lengthOfIdentitiesList = 0;
                 foreach (PskIdentity identity in m_identities)
                 {
-                    totalLengthIdentities += 2 + identity.Identity.Length + 4;
+                    lengthOfIdentitiesList += identity.GetEncodedLength();
                 }
 
-                TlsUtilities.CheckUint16(totalLengthIdentities);
-                TlsUtilities.WriteUint16(totalLengthIdentities, output);
+                TlsUtilities.CheckUint16(lengthOfIdentitiesList);
+                TlsUtilities.WriteUint16(lengthOfIdentitiesList, output);
 
                 foreach (PskIdentity identity in m_identities)
                 {
@@ -53,15 +59,16 @@ namespace Org.BouncyCastle.Tls
             }
 
             // binders
+            if (null != m_binders)
             {
-                int totalLengthBinders = 0;
+                int lengthOfBindersList = 0;
                 foreach (byte[] binder in m_binders)
                 {
-                    totalLengthBinders += 1 + binder.Length;
+                    lengthOfBindersList += 1 + binder.Length;
                 }
 
-                TlsUtilities.CheckUint16(totalLengthBinders);
-                TlsUtilities.WriteUint16(totalLengthBinders, output);
+                TlsUtilities.CheckUint16(lengthOfBindersList);
+                TlsUtilities.WriteUint16(lengthOfBindersList, output);
 
                 foreach (byte[] binder in m_binders)
                 {
@@ -71,6 +78,56 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
+        internal static void EncodeBinders(Stream output, TlsCrypto crypto, TlsHandshakeHash handshakeHash,
+            TlsPsk[] psks, TlsSecret[] earlySecrets, int expectedLengthOfBindersList)
+        {
+            TlsUtilities.CheckUint16(expectedLengthOfBindersList);
+            TlsUtilities.WriteUint16(expectedLengthOfBindersList, output);
+
+            int lengthOfBindersList = 0;
+            for (int i = 0; i < psks.Length; ++i)
+            {
+                TlsPsk psk = psks[i];
+                TlsSecret earlySecret = earlySecrets[i];
+
+                // TODO[tls13-psk] Handle resumption PSKs
+                bool isExternalPsk = true;
+                int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+                // TODO[tls13-psk] Cache the transcript hashes per algorithm to avoid duplicates for multiple PSKs
+                TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm);
+                handshakeHash.CopyBufferTo(new TlsHashSink(hash));
+                byte[] transcriptHash = hash.CalculateHash();
+
+                byte[] binder = TlsUtilities.CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm,
+                    earlySecret, transcriptHash);
+
+                lengthOfBindersList += 1 + binder.Length;
+                TlsUtilities.WriteOpaque8(binder, output);
+            }
+
+            if (expectedLengthOfBindersList != lengthOfBindersList)
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+        }
+
+        /// <exception cref="IOException"/>
+        internal static int GetLengthOfBindersList(TlsPsk[] psks)
+        {
+            int lengthOfBindersList = 0;
+            for (int i = 0; i < psks.Length; ++i)
+            {
+                TlsPsk psk = psks[i];
+
+                int prfAlgorithm = psk.PrfAlgorithm;
+                int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(prfAlgorithm);
+
+                lengthOfBindersList += 1 + TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm);
+            }
+            TlsUtilities.CheckUint16(lengthOfBindersList);
+            return lengthOfBindersList;
+        }
+
+        /// <exception cref="IOException"/>
         public static OfferedPsks Parse(Stream input)
         {
             IList identities = Platform.CreateArrayList();
diff --git a/crypto/src/tls/PskIdentity.cs b/crypto/src/tls/PskIdentity.cs
index 9b24527bb..082907419 100644
--- a/crypto/src/tls/PskIdentity.cs
+++ b/crypto/src/tls/PskIdentity.cs
@@ -21,6 +21,11 @@ namespace Org.BouncyCastle.Tls
             this.m_obfuscatedTicketAge = obfuscatedTicketAge;
         }
 
+        public int GetEncodedLength()
+        {
+            return 6 + m_identity.Length;
+        }
+
         public byte[] Identity
         {
             get { return m_identity; }
diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs
index 7a92220dc..8fb1a39b7 100644
--- a/crypto/src/tls/TlsClientProtocol.cs
+++ b/crypto/src/tls/TlsClientProtocol.cs
@@ -1474,6 +1474,8 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         protected virtual void Send13ClientHelloRetry()
         {
+            // TODO[tls13-psk] Create a new ClientHello object and handle any changes to the bindersSize
+
             IDictionary clientHelloExtensions = m_clientHello.Extensions;
 
             clientHelloExtensions.Remove(ExtensionType.cookie);
@@ -1679,8 +1681,12 @@ namespace Org.BouncyCastle.Tls
 
 
 
+            // TODO[tls13-psk] Calculate the total length of the binders that will be added.
+            int bindersSize = 0;
+            //int bindersSize = 2 + lengthOfBindersList;
+
             this.m_clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, legacy_session_id,
-                null, offeredCipherSuites, m_clientExtensions);
+                null, offeredCipherSuites, m_clientExtensions, bindersSize);
 
             SendClientHelloMessage();
         }
@@ -1691,14 +1697,11 @@ namespace Org.BouncyCastle.Tls
             HandshakeMessageOutput message = new HandshakeMessageOutput(HandshakeType.client_hello);
             m_clientHello.Encode(m_tlsClientContext, message);
 
-            // TODO[tls13-psk] Calculate the total length of the binders that will be added.
-            int totalBindersLength = 0;
-
-            message.PrepareClientHello(m_handshakeHash, totalBindersLength);
+            message.PrepareClientHello(m_handshakeHash, m_clientHello.BindersSize);
 
             // TODO[tls13-psk] Calculate any PSK binders and write them to 'message' here. 
 
-            message.SendClientHello(this, m_handshakeHash, totalBindersLength);
+            message.SendClientHello(this, m_handshakeHash, m_clientHello.BindersSize);
         }
 
         /// <exception cref="IOException"/>
diff --git a/crypto/src/tls/TlsProtocol.cs b/crypto/src/tls/TlsProtocol.cs
index d4960e3c8..f05c09a1b 100644
--- a/crypto/src/tls/TlsProtocol.cs
+++ b/crypto/src/tls/TlsProtocol.cs
@@ -1827,31 +1827,69 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         internal static void WriteExtensions(Stream output, IDictionary extensions)
         {
+            WriteExtensions(output, extensions, 0);
+        }
+
+        /// <exception cref="IOException"/>
+        internal static void WriteExtensions(Stream output, IDictionary extensions, int bindersSize)
+        {
             if (null == extensions || extensions.Count < 1)
                 return;
 
-            byte[] extBytes = WriteExtensionsData(extensions);
+            byte[] extBytes = WriteExtensionsData(extensions, bindersSize);
 
-            TlsUtilities.WriteOpaque16(extBytes, output);
+            int lengthWithBinders = extBytes.Length + bindersSize;
+            TlsUtilities.CheckUint16(lengthWithBinders);
+            TlsUtilities.WriteUint16(lengthWithBinders, output);
+            output.Write(extBytes, 0, extBytes.Length);
         }
 
         /// <exception cref="IOException"/>
         internal static byte[] WriteExtensionsData(IDictionary extensions)
         {
+            return WriteExtensionsData(extensions, 0);
+        }
+
+        /// <exception cref="IOException"/>
+        internal static byte[] WriteExtensionsData(IDictionary extensions, int bindersSize)
+        {
             MemoryStream buf = new MemoryStream();
-            WriteExtensionsData(extensions, buf);
+            WriteExtensionsData(extensions, buf, bindersSize);
             return buf.ToArray();
         }
 
         /// <exception cref="IOException"/>
         internal static void WriteExtensionsData(IDictionary extensions, MemoryStream buf)
         {
+            WriteExtensionsData(extensions, buf, 0);
+        }
+
+        /// <exception cref="IOException"/>
+        internal static void WriteExtensionsData(IDictionary extensions, MemoryStream buf, int bindersSize)
+        {
             /*
              * NOTE: There are reports of servers that don't accept a zero-length extension as the last
              * one, so we write out any zero-length ones first as a best-effort workaround.
              */
             WriteSelectedExtensions(buf, extensions, true);
             WriteSelectedExtensions(buf, extensions, false);
+            WritePreSharedKeyExtension(buf, extensions, bindersSize);
+        }
+
+        /// <exception cref="IOException"/>
+        internal static void WritePreSharedKeyExtension(MemoryStream buf, IDictionary extensions, int bindersSize)
+        {
+            byte[] extension_data = (byte[])extensions[ExtensionType.pre_shared_key];
+            if (null != extension_data)
+            {
+                TlsUtilities.CheckUint16(ExtensionType.pre_shared_key);
+                TlsUtilities.WriteUint16(ExtensionType.pre_shared_key, buf);
+
+                int lengthWithBinders = extension_data.Length + bindersSize;
+                TlsUtilities.CheckUint16(lengthWithBinders);
+                TlsUtilities.WriteUint16(lengthWithBinders, buf);
+                buf.Write(extension_data, 0, extension_data.Length);
+            }
         }
 
         /// <exception cref="IOException"/>
@@ -1860,6 +1898,11 @@ namespace Org.BouncyCastle.Tls
             foreach (Int32 key in extensions.Keys)
             {
                 int extension_type = key;
+
+                // NOTE: Must be last; handled by 'WritePreSharedKeyExtension'
+                if (ExtensionType.pre_shared_key == extension_type)
+                    continue;
+
                 byte[] extension_data = (byte[])extensions[key];
 
                 if (selectEmpty == (extension_data.Length == 0))
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index c0ccfe9be..72c41ef05 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -1504,21 +1504,20 @@ namespace Org.BouncyCastle.Tls
             return Prf(sp, preMasterSecret, asciiLabel, seed, 48);
         }
 
-        internal static byte[] CalculatePskBinder(TlsCrypto crypto, bool isExternalPsk, int pskPRFAlgorithm,
+        internal static byte[] CalculatePskBinder(TlsCrypto crypto, bool isExternalPsk, int pskCryptoHashAlgorithm,
             TlsSecret earlySecret, byte[] transcriptHash)
         {
-            int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(pskPRFAlgorithm);
-            int prfHashLength = TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm);
+            int prfHashLength = TlsCryptoUtilities.GetHashOutputSize(pskCryptoHashAlgorithm);
 
             string label = isExternalPsk ? "ext binder" : "res binder";
-            byte[] emptyTranscriptHash = crypto.CreateHash(prfCryptoHashAlgorithm).CalculateHash();
+            byte[] emptyTranscriptHash = crypto.CreateHash(pskCryptoHashAlgorithm).CalculateHash();
 
-            TlsSecret binderKey = DeriveSecret(prfCryptoHashAlgorithm, prfHashLength, earlySecret, label,
+            TlsSecret binderKey = DeriveSecret(pskCryptoHashAlgorithm, prfHashLength, earlySecret, label,
                 emptyTranscriptHash);
 
             try
             {
-                return CalculateFinishedHmac(prfCryptoHashAlgorithm, prfHashLength, binderKey, transcriptHash);
+                return CalculateFinishedHmac(pskCryptoHashAlgorithm, prfHashLength, binderKey, transcriptHash);
             }
             finally
             {