diff --git a/crypto/src/tls/ClientHello.cs b/crypto/src/tls/ClientHello.cs
index 50a33ac39..700d424cd 100644
--- a/crypto/src/tls/ClientHello.cs
+++ b/crypto/src/tls/ClientHello.cs
@@ -15,9 +15,10 @@ namespace Org.BouncyCastle.Tls
private readonly byte[] m_cookie;
private readonly int[] m_cipherSuites;
private readonly IDictionary m_extensions;
+ private readonly int m_bindersSize;
public ClientHello(ProtocolVersion version, byte[] random, byte[] sessionID, byte[] cookie,
- int[] cipherSuites, IDictionary extensions)
+ int[] cipherSuites, IDictionary extensions, int bindersSize)
{
this.m_version = version;
this.m_random = random;
@@ -25,6 +26,12 @@ namespace Org.BouncyCastle.Tls
this.m_cookie = cookie;
this.m_cipherSuites = cipherSuites;
this.m_extensions = extensions;
+ this.m_bindersSize = bindersSize;
+ }
+
+ public int BindersSize
+ {
+ get { return m_bindersSize; }
}
public int[] CipherSuites
@@ -78,7 +85,7 @@ namespace Org.BouncyCastle.Tls
TlsUtilities.WriteUint8ArrayWithUint8Length(new short[]{ CompressionMethod.cls_null }, output);
- TlsProtocol.WriteExtensions(output, m_extensions);
+ TlsProtocol.WriteExtensions(output, m_extensions, m_bindersSize);
}
/// <summary>Parse a <see cref="ClientHello"/> from a <see cref="MemoryStream"/>.</summary>
@@ -161,7 +168,7 @@ namespace Org.BouncyCastle.Tls
extensions = TlsProtocol.ReadExtensionsDataClientHello(extBytes);
}
- return new ClientHello(clientVersion, random, sessionID, cookie, cipherSuites, extensions);
+ return new ClientHello(clientVersion, random, sessionID, cookie, cipherSuites, extensions, 0);
}
}
}
diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs
index dea35a28b..cd2fff709 100644
--- a/crypto/src/tls/DtlsClientProtocol.cs
+++ b/crypto/src/tls/DtlsClientProtocol.cs
@@ -513,7 +513,7 @@ namespace Org.BouncyCastle.Tls
ClientHello clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, session_id,
- TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions);
+ TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0);
MemoryStream buf = new MemoryStream();
clientHello.Encode(state.clientContext, buf);
diff --git a/crypto/src/tls/HandshakeMessageOutput.cs b/crypto/src/tls/HandshakeMessageOutput.cs
index 97e9a84af..ff45ce6f3 100644
--- a/crypto/src/tls/HandshakeMessageOutput.cs
+++ b/crypto/src/tls/HandshakeMessageOutput.cs
@@ -59,12 +59,10 @@ namespace Org.BouncyCastle.Tls
Platform.Dispose(this);
}
- internal void PrepareClientHello(TlsHandshakeHash handshakeHash, int totalBindersLength)
+ internal void PrepareClientHello(TlsHandshakeHash handshakeHash, int bindersSize)
{
- TlsUtilities.CheckUint16(totalBindersLength);
-
// Patch actual length back in
- int bodyLength = (int)Length - 4 + totalBindersLength;
+ int bodyLength = (int)Length - 4 + bindersSize;
TlsUtilities.CheckUint24(bodyLength);
Seek(1L, SeekOrigin.Begin);
@@ -83,8 +81,7 @@ namespace Org.BouncyCastle.Tls
Seek(0L, SeekOrigin.End);
}
- internal void SendClientHello(TlsClientProtocol clientProtocol, TlsHandshakeHash handshakeHash,
- int totalBindersLength)
+ internal void SendClientHello(TlsClientProtocol clientProtocol, TlsHandshakeHash handshakeHash, int bindersSize)
{
#if PORTABLE
byte[] buf = ToArray();
@@ -94,9 +91,9 @@ namespace Org.BouncyCastle.Tls
int count = (int)Length;
#endif
- if (totalBindersLength > 0)
+ if (bindersSize > 0)
{
- handshakeHash.Update(buf, count - totalBindersLength, totalBindersLength);
+ handshakeHash.Update(buf, count - bindersSize, bindersSize);
}
clientProtocol.WriteHandshakeMessage(buf, 0, count);
diff --git a/crypto/src/tls/OfferedPsks.cs b/crypto/src/tls/OfferedPsks.cs
index 597ec195c..5419a19d1 100644
--- a/crypto/src/tls/OfferedPsks.cs
+++ b/crypto/src/tls/OfferedPsks.cs
@@ -2,6 +2,7 @@
using System.Collections;
using System.IO;
+using Org.BouncyCastle.Tls.Crypto;
using Org.BouncyCastle.Utilities;
namespace Org.BouncyCastle.Tls
@@ -11,12 +12,17 @@ namespace Org.BouncyCastle.Tls
private readonly IList m_identities;
private readonly IList m_binders;
- public OfferedPsks(IList identities, IList binders)
+ public OfferedPsks(IList identities)
+ : this(identities, null)
+ {
+ }
+
+ private OfferedPsks(IList identities, IList binders)
{
if (null == identities || identities.Count < 1)
throw new ArgumentException("cannot be null or empty", "identities");
- if (null == binders || identities.Count != binders.Count)
- throw new ArgumentException("must be non-null and the same length as 'identities'", "binders");
+ if (null != binders && identities.Count != binders.Count)
+ throw new ArgumentException("must be the same length as 'identities' (or null)", "binders");
this.m_identities = identities;
this.m_binders = binders;
@@ -37,14 +43,14 @@ namespace Org.BouncyCastle.Tls
{
// identities
{
- int totalLengthIdentities = 0;
+ int lengthOfIdentitiesList = 0;
foreach (PskIdentity identity in m_identities)
{
- totalLengthIdentities += 2 + identity.Identity.Length + 4;
+ lengthOfIdentitiesList += identity.GetEncodedLength();
}
- TlsUtilities.CheckUint16(totalLengthIdentities);
- TlsUtilities.WriteUint16(totalLengthIdentities, output);
+ TlsUtilities.CheckUint16(lengthOfIdentitiesList);
+ TlsUtilities.WriteUint16(lengthOfIdentitiesList, output);
foreach (PskIdentity identity in m_identities)
{
@@ -53,15 +59,16 @@ namespace Org.BouncyCastle.Tls
}
// binders
+ if (null != m_binders)
{
- int totalLengthBinders = 0;
+ int lengthOfBindersList = 0;
foreach (byte[] binder in m_binders)
{
- totalLengthBinders += 1 + binder.Length;
+ lengthOfBindersList += 1 + binder.Length;
}
- TlsUtilities.CheckUint16(totalLengthBinders);
- TlsUtilities.WriteUint16(totalLengthBinders, output);
+ TlsUtilities.CheckUint16(lengthOfBindersList);
+ TlsUtilities.WriteUint16(lengthOfBindersList, output);
foreach (byte[] binder in m_binders)
{
@@ -71,6 +78,56 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
+ internal static void EncodeBinders(Stream output, TlsCrypto crypto, TlsHandshakeHash handshakeHash,
+ TlsPsk[] psks, TlsSecret[] earlySecrets, int expectedLengthOfBindersList)
+ {
+ TlsUtilities.CheckUint16(expectedLengthOfBindersList);
+ TlsUtilities.WriteUint16(expectedLengthOfBindersList, output);
+
+ int lengthOfBindersList = 0;
+ for (int i = 0; i < psks.Length; ++i)
+ {
+ TlsPsk psk = psks[i];
+ TlsSecret earlySecret = earlySecrets[i];
+
+ // TODO[tls13-psk] Handle resumption PSKs
+ bool isExternalPsk = true;
+ int pskCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(psk.PrfAlgorithm);
+
+ // TODO[tls13-psk] Cache the transcript hashes per algorithm to avoid duplicates for multiple PSKs
+ TlsHash hash = crypto.CreateHash(pskCryptoHashAlgorithm);
+ handshakeHash.CopyBufferTo(new TlsHashSink(hash));
+ byte[] transcriptHash = hash.CalculateHash();
+
+ byte[] binder = TlsUtilities.CalculatePskBinder(crypto, isExternalPsk, pskCryptoHashAlgorithm,
+ earlySecret, transcriptHash);
+
+ lengthOfBindersList += 1 + binder.Length;
+ TlsUtilities.WriteOpaque8(binder, output);
+ }
+
+ if (expectedLengthOfBindersList != lengthOfBindersList)
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+ }
+
+ /// <exception cref="IOException"/>
+ internal static int GetLengthOfBindersList(TlsPsk[] psks)
+ {
+ int lengthOfBindersList = 0;
+ for (int i = 0; i < psks.Length; ++i)
+ {
+ TlsPsk psk = psks[i];
+
+ int prfAlgorithm = psk.PrfAlgorithm;
+ int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(prfAlgorithm);
+
+ lengthOfBindersList += 1 + TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm);
+ }
+ TlsUtilities.CheckUint16(lengthOfBindersList);
+ return lengthOfBindersList;
+ }
+
+ /// <exception cref="IOException"/>
public static OfferedPsks Parse(Stream input)
{
IList identities = Platform.CreateArrayList();
diff --git a/crypto/src/tls/PskIdentity.cs b/crypto/src/tls/PskIdentity.cs
index 9b24527bb..082907419 100644
--- a/crypto/src/tls/PskIdentity.cs
+++ b/crypto/src/tls/PskIdentity.cs
@@ -21,6 +21,11 @@ namespace Org.BouncyCastle.Tls
this.m_obfuscatedTicketAge = obfuscatedTicketAge;
}
+ public int GetEncodedLength()
+ {
+ return 6 + m_identity.Length;
+ }
+
public byte[] Identity
{
get { return m_identity; }
diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs
index 7a92220dc..8fb1a39b7 100644
--- a/crypto/src/tls/TlsClientProtocol.cs
+++ b/crypto/src/tls/TlsClientProtocol.cs
@@ -1474,6 +1474,8 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
protected virtual void Send13ClientHelloRetry()
{
+ // TODO[tls13-psk] Create a new ClientHello object and handle any changes to the bindersSize
+
IDictionary clientHelloExtensions = m_clientHello.Extensions;
clientHelloExtensions.Remove(ExtensionType.cookie);
@@ -1679,8 +1681,12 @@ namespace Org.BouncyCastle.Tls
+ // TODO[tls13-psk] Calculate the total length of the binders that will be added.
+ int bindersSize = 0;
+ //int bindersSize = 2 + lengthOfBindersList;
+
this.m_clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, legacy_session_id,
- null, offeredCipherSuites, m_clientExtensions);
+ null, offeredCipherSuites, m_clientExtensions, bindersSize);
SendClientHelloMessage();
}
@@ -1691,14 +1697,11 @@ namespace Org.BouncyCastle.Tls
HandshakeMessageOutput message = new HandshakeMessageOutput(HandshakeType.client_hello);
m_clientHello.Encode(m_tlsClientContext, message);
- // TODO[tls13-psk] Calculate the total length of the binders that will be added.
- int totalBindersLength = 0;
-
- message.PrepareClientHello(m_handshakeHash, totalBindersLength);
+ message.PrepareClientHello(m_handshakeHash, m_clientHello.BindersSize);
// TODO[tls13-psk] Calculate any PSK binders and write them to 'message' here.
- message.SendClientHello(this, m_handshakeHash, totalBindersLength);
+ message.SendClientHello(this, m_handshakeHash, m_clientHello.BindersSize);
}
/// <exception cref="IOException"/>
diff --git a/crypto/src/tls/TlsProtocol.cs b/crypto/src/tls/TlsProtocol.cs
index d4960e3c8..f05c09a1b 100644
--- a/crypto/src/tls/TlsProtocol.cs
+++ b/crypto/src/tls/TlsProtocol.cs
@@ -1827,31 +1827,69 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
internal static void WriteExtensions(Stream output, IDictionary extensions)
{
+ WriteExtensions(output, extensions, 0);
+ }
+
+ /// <exception cref="IOException"/>
+ internal static void WriteExtensions(Stream output, IDictionary extensions, int bindersSize)
+ {
if (null == extensions || extensions.Count < 1)
return;
- byte[] extBytes = WriteExtensionsData(extensions);
+ byte[] extBytes = WriteExtensionsData(extensions, bindersSize);
- TlsUtilities.WriteOpaque16(extBytes, output);
+ int lengthWithBinders = extBytes.Length + bindersSize;
+ TlsUtilities.CheckUint16(lengthWithBinders);
+ TlsUtilities.WriteUint16(lengthWithBinders, output);
+ output.Write(extBytes, 0, extBytes.Length);
}
/// <exception cref="IOException"/>
internal static byte[] WriteExtensionsData(IDictionary extensions)
{
+ return WriteExtensionsData(extensions, 0);
+ }
+
+ /// <exception cref="IOException"/>
+ internal static byte[] WriteExtensionsData(IDictionary extensions, int bindersSize)
+ {
MemoryStream buf = new MemoryStream();
- WriteExtensionsData(extensions, buf);
+ WriteExtensionsData(extensions, buf, bindersSize);
return buf.ToArray();
}
/// <exception cref="IOException"/>
internal static void WriteExtensionsData(IDictionary extensions, MemoryStream buf)
{
+ WriteExtensionsData(extensions, buf, 0);
+ }
+
+ /// <exception cref="IOException"/>
+ internal static void WriteExtensionsData(IDictionary extensions, MemoryStream buf, int bindersSize)
+ {
/*
* NOTE: There are reports of servers that don't accept a zero-length extension as the last
* one, so we write out any zero-length ones first as a best-effort workaround.
*/
WriteSelectedExtensions(buf, extensions, true);
WriteSelectedExtensions(buf, extensions, false);
+ WritePreSharedKeyExtension(buf, extensions, bindersSize);
+ }
+
+ /// <exception cref="IOException"/>
+ internal static void WritePreSharedKeyExtension(MemoryStream buf, IDictionary extensions, int bindersSize)
+ {
+ byte[] extension_data = (byte[])extensions[ExtensionType.pre_shared_key];
+ if (null != extension_data)
+ {
+ TlsUtilities.CheckUint16(ExtensionType.pre_shared_key);
+ TlsUtilities.WriteUint16(ExtensionType.pre_shared_key, buf);
+
+ int lengthWithBinders = extension_data.Length + bindersSize;
+ TlsUtilities.CheckUint16(lengthWithBinders);
+ TlsUtilities.WriteUint16(lengthWithBinders, buf);
+ buf.Write(extension_data, 0, extension_data.Length);
+ }
}
/// <exception cref="IOException"/>
@@ -1860,6 +1898,11 @@ namespace Org.BouncyCastle.Tls
foreach (Int32 key in extensions.Keys)
{
int extension_type = key;
+
+ // NOTE: Must be last; handled by 'WritePreSharedKeyExtension'
+ if (ExtensionType.pre_shared_key == extension_type)
+ continue;
+
byte[] extension_data = (byte[])extensions[key];
if (selectEmpty == (extension_data.Length == 0))
diff --git a/crypto/src/tls/TlsUtilities.cs b/crypto/src/tls/TlsUtilities.cs
index c0ccfe9be..72c41ef05 100644
--- a/crypto/src/tls/TlsUtilities.cs
+++ b/crypto/src/tls/TlsUtilities.cs
@@ -1504,21 +1504,20 @@ namespace Org.BouncyCastle.Tls
return Prf(sp, preMasterSecret, asciiLabel, seed, 48);
}
- internal static byte[] CalculatePskBinder(TlsCrypto crypto, bool isExternalPsk, int pskPRFAlgorithm,
+ internal static byte[] CalculatePskBinder(TlsCrypto crypto, bool isExternalPsk, int pskCryptoHashAlgorithm,
TlsSecret earlySecret, byte[] transcriptHash)
{
- int prfCryptoHashAlgorithm = TlsCryptoUtilities.GetHashForPrf(pskPRFAlgorithm);
- int prfHashLength = TlsCryptoUtilities.GetHashOutputSize(prfCryptoHashAlgorithm);
+ int prfHashLength = TlsCryptoUtilities.GetHashOutputSize(pskCryptoHashAlgorithm);
string label = isExternalPsk ? "ext binder" : "res binder";
- byte[] emptyTranscriptHash = crypto.CreateHash(prfCryptoHashAlgorithm).CalculateHash();
+ byte[] emptyTranscriptHash = crypto.CreateHash(pskCryptoHashAlgorithm).CalculateHash();
- TlsSecret binderKey = DeriveSecret(prfCryptoHashAlgorithm, prfHashLength, earlySecret, label,
+ TlsSecret binderKey = DeriveSecret(pskCryptoHashAlgorithm, prfHashLength, earlySecret, label,
emptyTranscriptHash);
try
{
- return CalculateFinishedHmac(prfCryptoHashAlgorithm, prfHashLength, binderKey, transcriptHash);
+ return CalculateFinishedHmac(pskCryptoHashAlgorithm, prfHashLength, binderKey, transcriptHash);
}
finally
{
|