diff --git a/crypto/src/crypto/tls/TlsDHUtilities.cs b/crypto/src/crypto/tls/TlsDHUtilities.cs
index 7a44670fd..6df61cbed 100644
--- a/crypto/src/crypto/tls/TlsDHUtilities.cs
+++ b/crypto/src/crypto/tls/TlsDHUtilities.cs
@@ -204,36 +204,20 @@ namespace Org.BouncyCastle.Crypto.Tls
public static byte[] CreateNegotiatedDheGroupsServerExtension(byte dheGroup)
{
- return new byte[]{ dheGroup };
+ return TlsUtilities.EncodeUint8(dheGroup);
}
public static byte[] ReadNegotiatedDheGroupsClientExtension(byte[] extensionData)
{
- if (extensionData == null)
- throw new ArgumentNullException("extensionData");
-
- MemoryStream buf = new MemoryStream(extensionData, false);
-
- byte length = TlsUtilities.ReadUint8(buf);
- if (length < 1)
+ byte[] dheGroups = TlsUtilities.DecodeUint8ArrayWithUint8Length(extensionData);
+ if (dheGroups.Length < 1)
throw new TlsFatalAlert(AlertDescription.decode_error);
-
- byte[] dheGroups = TlsUtilities.ReadUint8Array(length, buf);
-
- TlsProtocol.AssertEmpty(buf);
-
return dheGroups;
}
public static byte ReadNegotiatedDheGroupsServerExtension(byte[] extensionData)
{
- if (extensionData == null)
- throw new ArgumentNullException("extensionData");
-
- if (extensionData.Length != 1)
- throw new TlsFatalAlert(AlertDescription.decode_error);
-
- return extensionData[0];
+ return TlsUtilities.DecodeUint8(extensionData);
}
public static DHParameters GetParametersForDHEGroup(short dheGroup)
diff --git a/crypto/src/crypto/tls/TlsEccUtilities.cs b/crypto/src/crypto/tls/TlsEccUtilities.cs
index a5c8fa910..fb31e1b07 100644
--- a/crypto/src/crypto/tls/TlsEccUtilities.cs
+++ b/crypto/src/crypto/tls/TlsEccUtilities.cs
@@ -90,19 +90,7 @@ namespace Org.BouncyCastle.Crypto.Tls
public static byte[] ReadSupportedPointFormatsExtension(byte[] extensionData)
{
- if (extensionData == null)
- throw new ArgumentNullException("extensionData");
-
- MemoryStream buf = new MemoryStream(extensionData, false);
-
- byte length = TlsUtilities.ReadUint8(buf);
- if (length < 1)
- throw new TlsFatalAlert(AlertDescription.decode_error);
-
- byte[] ecPointFormats = TlsUtilities.ReadUint8Array(length, buf);
-
- TlsProtocol.AssertEmpty(buf);
-
+ byte[] ecPointFormats = TlsUtilities.DecodeUint8ArrayWithUint8Length(extensionData);
if (!Arrays.Contains(ecPointFormats, ECPointFormat.uncompressed))
{
/*
@@ -111,7 +99,6 @@ namespace Org.BouncyCastle.Crypto.Tls
*/
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
}
-
return ecPointFormats;
}
diff --git a/crypto/src/crypto/tls/TlsExtensionsUtilities.cs b/crypto/src/crypto/tls/TlsExtensionsUtilities.cs
index 7f6a26ef2..4b3d9e0c5 100644
--- a/crypto/src/crypto/tls/TlsExtensionsUtilities.cs
+++ b/crypto/src/crypto/tls/TlsExtensionsUtilities.cs
@@ -13,6 +13,18 @@ namespace Org.BouncyCastle.Crypto.Tls
return extensions == null ? Platform.CreateHashtable() : extensions;
}
+ /// <exception cref="IOException"></exception>
+ public static void AddClientCertificateTypeExtensionClient(IDictionary extensions, byte[] certificateTypes)
+ {
+ extensions[ExtensionType.client_certificate_type] = CreateCertificateTypeExtensionClient(certificateTypes);
+ }
+
+ /// <exception cref="IOException"></exception>
+ public static void AddClientCertificateTypeExtensionServer(IDictionary extensions, byte certificateType)
+ {
+ extensions[ExtensionType.client_certificate_type] = CreateCertificateTypeExtensionServer(certificateType);
+ }
+
public static void AddEncryptThenMacExtension(IDictionary extensions)
{
extensions[ExtensionType.encrypt_then_mac] = CreateEncryptThenMacExtension();
@@ -42,6 +54,18 @@ namespace Org.BouncyCastle.Crypto.Tls
}
/// <exception cref="IOException"></exception>
+ public static void AddServerCertificateTypeExtensionClient(IDictionary extensions, byte[] certificateTypes)
+ {
+ extensions[ExtensionType.server_certificate_type] = CreateCertificateTypeExtensionClient(certificateTypes);
+ }
+
+ /// <exception cref="IOException"></exception>
+ public static void AddServerCertificateTypeExtensionServer(IDictionary extensions, byte certificateType)
+ {
+ extensions[ExtensionType.server_certificate_type] = CreateCertificateTypeExtensionServer(certificateType);
+ }
+
+ /// <exception cref="IOException"></exception>
public static void AddServerNameExtension(IDictionary extensions, ServerNameList serverNameList)
{
extensions[ExtensionType.server_name] = CreateServerNameExtension(serverNameList);
@@ -59,6 +83,20 @@ namespace Org.BouncyCastle.Crypto.Tls
}
/// <exception cref="IOException"></exception>
+ public static byte[] GetClientCertificateTypeExtensionClient(IDictionary extensions)
+ {
+ byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.client_certificate_type);
+ return extensionData == null ? null : ReadCertificateTypeExtensionClient(extensionData);
+ }
+
+ /// <exception cref="IOException"></exception>
+ public static short GetClientCertificateTypeExtensionServer(IDictionary extensions)
+ {
+ byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.client_certificate_type);
+ return extensionData == null ? (short)-1 : (short)ReadCertificateTypeExtensionServer(extensionData);
+ }
+
+ /// <exception cref="IOException"></exception>
public static HeartbeatExtension GetHeartbeatExtension(IDictionary extensions)
{
byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.heartbeat);
@@ -80,6 +118,20 @@ namespace Org.BouncyCastle.Crypto.Tls
}
/// <exception cref="IOException"></exception>
+ public static byte[] GetServerCertificateTypeExtensionClient(IDictionary extensions)
+ {
+ byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.server_certificate_type);
+ return extensionData == null ? null : ReadCertificateTypeExtensionClient(extensionData);
+ }
+
+ /// <exception cref="IOException"></exception>
+ public static short GetServerCertificateTypeExtensionServer(IDictionary extensions)
+ {
+ byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.server_certificate_type);
+ return extensionData == null ? (short)-1 : (short)ReadCertificateTypeExtensionServer(extensionData);
+ }
+
+ /// <exception cref="IOException"></exception>
public static ServerNameList GetServerNameExtension(IDictionary extensions)
{
byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.server_name);
@@ -114,6 +166,21 @@ namespace Org.BouncyCastle.Crypto.Tls
return extensionData == null ? false : ReadTruncatedHMacExtension(extensionData);
}
+ /// <exception cref="IOException"></exception>
+ public static byte[] CreateCertificateTypeExtensionClient(byte[] certificateTypes)
+ {
+ if (certificateTypes == null || certificateTypes.Length < 1 || certificateTypes.Length > 255)
+ throw new TlsFatalAlert(AlertDescription.internal_error);
+
+ return TlsUtilities.EncodeUint8ArrayWithUint8Length(certificateTypes);
+ }
+
+ /// <exception cref="IOException"></exception>
+ public static byte[] CreateCertificateTypeExtensionServer(byte certificateType)
+ {
+ return TlsUtilities.EncodeUint8(certificateType);
+ }
+
public static byte[] CreateEmptyExtensionData()
{
return TlsUtilities.EmptyBytes;
@@ -145,7 +212,7 @@ namespace Org.BouncyCastle.Crypto.Tls
/// <exception cref="IOException"></exception>
public static byte[] CreateMaxFragmentLengthExtension(byte maxFragmentLength)
{
- return new byte[]{ maxFragmentLength };
+ return TlsUtilities.EncodeUint8(maxFragmentLength);
}
/// <exception cref="IOException"></exception>
@@ -201,6 +268,21 @@ namespace Org.BouncyCastle.Crypto.Tls
}
/// <exception cref="IOException"></exception>
+ public static byte[] ReadCertificateTypeExtensionClient(byte[] extensionData)
+ {
+ byte[] certificateTypes = TlsUtilities.DecodeUint8ArrayWithUint8Length(extensionData);
+ if (certificateTypes.Length < 1)
+ throw new TlsFatalAlert(AlertDescription.decode_error);
+ return certificateTypes;
+ }
+
+ /// <exception cref="IOException"></exception>
+ public static byte ReadCertificateTypeExtensionServer(byte[] extensionData)
+ {
+ return TlsUtilities.DecodeUint8(extensionData);
+ }
+
+ /// <exception cref="IOException"></exception>
public static bool ReadEncryptThenMacExtension(byte[] extensionData)
{
return ReadEmptyExtensionData(extensionData);
@@ -228,15 +310,9 @@ namespace Org.BouncyCastle.Crypto.Tls
}
/// <exception cref="IOException"></exception>
- public static short ReadMaxFragmentLengthExtension(byte[] extensionData)
+ public static byte ReadMaxFragmentLengthExtension(byte[] extensionData)
{
- if (extensionData == null)
- throw new ArgumentNullException("extensionData");
-
- if (extensionData.Length != 1)
- throw new TlsFatalAlert(AlertDescription.decode_error);
-
- return extensionData[0];
+ return TlsUtilities.DecodeUint8(extensionData);
}
/// <exception cref="IOException"></exception>
diff --git a/crypto/src/crypto/tls/TlsUtilities.cs b/crypto/src/crypto/tls/TlsUtilities.cs
index 48e51a7b6..48eb9d375 100644
--- a/crypto/src/crypto/tls/TlsUtilities.cs
+++ b/crypto/src/crypto/tls/TlsUtilities.cs
@@ -324,12 +324,47 @@ namespace Org.BouncyCastle.Crypto.Tls
WriteUint16Array(uints, buf, offset + 2);
}
+ public static byte DecodeUint8(byte[] buf)
+ {
+ if (buf == null)
+ throw new ArgumentNullException("buf");
+ if (buf.Length != 1)
+ throw new TlsFatalAlert(AlertDescription.decode_error);
+ return ReadUint8(buf, 0);
+ }
+
+ public static byte[] DecodeUint8ArrayWithUint8Length(byte[] buf)
+ {
+ if (buf == null)
+ throw new ArgumentNullException("buf");
+
+ int count = ReadUint8(buf, 0);
+ if (buf.Length != (count + 1))
+ throw new TlsFatalAlert(AlertDescription.decode_error);
+
+ byte[] uints = new byte[count];
+ for (int i = 0; i < count; ++i)
+ {
+ uints[i] = ReadUint8(buf, i + 1);
+ }
+ return uints;
+ }
+
public static byte[] EncodeOpaque8(byte[] buf)
{
CheckUint8(buf.Length);
return Arrays.Prepend(buf, (byte)buf.Length);
}
+ public static byte[] EncodeUint8(byte val)
+ {
+ CheckUint8(val);
+
+ byte[] extensionData = new byte[1];
+ WriteUint8(val, extensionData, 0);
+ return extensionData;
+ }
+
public static byte[] EncodeUint8ArrayWithUint8Length(byte[] uints)
{
byte[] result = new byte[1 + uints.Length];
|