diff options
-rw-r--r-- | crypto/src/crypto/tls/TlsDHUtilities.cs | 24 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsEccUtilities.cs | 15 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsExtensionsUtilities.cs | 94 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsUtilities.cs | 35 |
4 files changed, 125 insertions, 43 deletions
diff --git a/crypto/src/crypto/tls/TlsDHUtilities.cs b/crypto/src/crypto/tls/TlsDHUtilities.cs index 7a44670fd..6df61cbed 100644 --- a/crypto/src/crypto/tls/TlsDHUtilities.cs +++ b/crypto/src/crypto/tls/TlsDHUtilities.cs @@ -204,36 +204,20 @@ namespace Org.BouncyCastle.Crypto.Tls public static byte[] CreateNegotiatedDheGroupsServerExtension(byte dheGroup) { - return new byte[]{ dheGroup }; + return TlsUtilities.EncodeUint8(dheGroup); } public static byte[] ReadNegotiatedDheGroupsClientExtension(byte[] extensionData) { - if (extensionData == null) - throw new ArgumentNullException("extensionData"); - - MemoryStream buf = new MemoryStream(extensionData, false); - - byte length = TlsUtilities.ReadUint8(buf); - if (length < 1) + byte[] dheGroups = TlsUtilities.DecodeUint8ArrayWithUint8Length(extensionData); + if (dheGroups.Length < 1) throw new TlsFatalAlert(AlertDescription.decode_error); - - byte[] dheGroups = TlsUtilities.ReadUint8Array(length, buf); - - TlsProtocol.AssertEmpty(buf); - return dheGroups; } public static byte ReadNegotiatedDheGroupsServerExtension(byte[] extensionData) { - if (extensionData == null) - throw new ArgumentNullException("extensionData"); - - if (extensionData.Length != 1) - throw new TlsFatalAlert(AlertDescription.decode_error); - - return extensionData[0]; + return TlsUtilities.DecodeUint8(extensionData); } public static DHParameters GetParametersForDHEGroup(short dheGroup) diff --git a/crypto/src/crypto/tls/TlsEccUtilities.cs b/crypto/src/crypto/tls/TlsEccUtilities.cs index a5c8fa910..fb31e1b07 100644 --- a/crypto/src/crypto/tls/TlsEccUtilities.cs +++ b/crypto/src/crypto/tls/TlsEccUtilities.cs @@ -90,19 +90,7 @@ namespace Org.BouncyCastle.Crypto.Tls public static byte[] ReadSupportedPointFormatsExtension(byte[] extensionData) { - if (extensionData == null) - throw new ArgumentNullException("extensionData"); - - MemoryStream buf = new MemoryStream(extensionData, false); - - byte length = TlsUtilities.ReadUint8(buf); - if (length < 1) - throw new TlsFatalAlert(AlertDescription.decode_error); - - byte[] ecPointFormats = TlsUtilities.ReadUint8Array(length, buf); - - TlsProtocol.AssertEmpty(buf); - + byte[] ecPointFormats = TlsUtilities.DecodeUint8ArrayWithUint8Length(extensionData); if (!Arrays.Contains(ecPointFormats, ECPointFormat.uncompressed)) { /* @@ -111,7 +99,6 @@ namespace Org.BouncyCastle.Crypto.Tls */ throw new TlsFatalAlert(AlertDescription.illegal_parameter); } - return ecPointFormats; } diff --git a/crypto/src/crypto/tls/TlsExtensionsUtilities.cs b/crypto/src/crypto/tls/TlsExtensionsUtilities.cs index 7f6a26ef2..4b3d9e0c5 100644 --- a/crypto/src/crypto/tls/TlsExtensionsUtilities.cs +++ b/crypto/src/crypto/tls/TlsExtensionsUtilities.cs @@ -13,6 +13,18 @@ namespace Org.BouncyCastle.Crypto.Tls return extensions == null ? Platform.CreateHashtable() : extensions; } + /// <exception cref="IOException"></exception> + public static void AddClientCertificateTypeExtensionClient(IDictionary extensions, byte[] certificateTypes) + { + extensions[ExtensionType.client_certificate_type] = CreateCertificateTypeExtensionClient(certificateTypes); + } + + /// <exception cref="IOException"></exception> + public static void AddClientCertificateTypeExtensionServer(IDictionary extensions, byte certificateType) + { + extensions[ExtensionType.client_certificate_type] = CreateCertificateTypeExtensionServer(certificateType); + } + public static void AddEncryptThenMacExtension(IDictionary extensions) { extensions[ExtensionType.encrypt_then_mac] = CreateEncryptThenMacExtension(); @@ -42,6 +54,18 @@ namespace Org.BouncyCastle.Crypto.Tls } /// <exception cref="IOException"></exception> + public static void AddServerCertificateTypeExtensionClient(IDictionary extensions, byte[] certificateTypes) + { + extensions[ExtensionType.server_certificate_type] = CreateCertificateTypeExtensionClient(certificateTypes); + } + + /// <exception cref="IOException"></exception> + public static void AddServerCertificateTypeExtensionServer(IDictionary extensions, byte certificateType) + { + extensions[ExtensionType.server_certificate_type] = CreateCertificateTypeExtensionServer(certificateType); + } + + /// <exception cref="IOException"></exception> public static void AddServerNameExtension(IDictionary extensions, ServerNameList serverNameList) { extensions[ExtensionType.server_name] = CreateServerNameExtension(serverNameList); @@ -59,6 +83,20 @@ namespace Org.BouncyCastle.Crypto.Tls } /// <exception cref="IOException"></exception> + public static byte[] GetClientCertificateTypeExtensionClient(IDictionary extensions) + { + byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.client_certificate_type); + return extensionData == null ? null : ReadCertificateTypeExtensionClient(extensionData); + } + + /// <exception cref="IOException"></exception> + public static short GetClientCertificateTypeExtensionServer(IDictionary extensions) + { + byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.client_certificate_type); + return extensionData == null ? (short)-1 : (short)ReadCertificateTypeExtensionServer(extensionData); + } + + /// <exception cref="IOException"></exception> public static HeartbeatExtension GetHeartbeatExtension(IDictionary extensions) { byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.heartbeat); @@ -80,6 +118,20 @@ namespace Org.BouncyCastle.Crypto.Tls } /// <exception cref="IOException"></exception> + public static byte[] GetServerCertificateTypeExtensionClient(IDictionary extensions) + { + byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.server_certificate_type); + return extensionData == null ? null : ReadCertificateTypeExtensionClient(extensionData); + } + + /// <exception cref="IOException"></exception> + public static short GetServerCertificateTypeExtensionServer(IDictionary extensions) + { + byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.server_certificate_type); + return extensionData == null ? (short)-1 : (short)ReadCertificateTypeExtensionServer(extensionData); + } + + /// <exception cref="IOException"></exception> public static ServerNameList GetServerNameExtension(IDictionary extensions) { byte[] extensionData = TlsUtilities.GetExtensionData(extensions, ExtensionType.server_name); @@ -114,6 +166,21 @@ namespace Org.BouncyCastle.Crypto.Tls return extensionData == null ? false : ReadTruncatedHMacExtension(extensionData); } + /// <exception cref="IOException"></exception> + public static byte[] CreateCertificateTypeExtensionClient(byte[] certificateTypes) + { + if (certificateTypes == null || certificateTypes.Length < 1 || certificateTypes.Length > 255) + throw new TlsFatalAlert(AlertDescription.internal_error); + + return TlsUtilities.EncodeUint8ArrayWithUint8Length(certificateTypes); + } + + /// <exception cref="IOException"></exception> + public static byte[] CreateCertificateTypeExtensionServer(byte certificateType) + { + return TlsUtilities.EncodeUint8(certificateType); + } + public static byte[] CreateEmptyExtensionData() { return TlsUtilities.EmptyBytes; @@ -145,7 +212,7 @@ namespace Org.BouncyCastle.Crypto.Tls /// <exception cref="IOException"></exception> public static byte[] CreateMaxFragmentLengthExtension(byte maxFragmentLength) { - return new byte[]{ maxFragmentLength }; + return TlsUtilities.EncodeUint8(maxFragmentLength); } /// <exception cref="IOException"></exception> @@ -201,6 +268,21 @@ namespace Org.BouncyCastle.Crypto.Tls } /// <exception cref="IOException"></exception> + public static byte[] ReadCertificateTypeExtensionClient(byte[] extensionData) + { + byte[] certificateTypes = TlsUtilities.DecodeUint8ArrayWithUint8Length(extensionData); + if (certificateTypes.Length < 1) + throw new TlsFatalAlert(AlertDescription.decode_error); + return certificateTypes; + } + + /// <exception cref="IOException"></exception> + public static byte ReadCertificateTypeExtensionServer(byte[] extensionData) + { + return TlsUtilities.DecodeUint8(extensionData); + } + + /// <exception cref="IOException"></exception> public static bool ReadEncryptThenMacExtension(byte[] extensionData) { return ReadEmptyExtensionData(extensionData); @@ -228,15 +310,9 @@ namespace Org.BouncyCastle.Crypto.Tls } /// <exception cref="IOException"></exception> - public static short ReadMaxFragmentLengthExtension(byte[] extensionData) + public static byte ReadMaxFragmentLengthExtension(byte[] extensionData) { - if (extensionData == null) - throw new ArgumentNullException("extensionData"); - - if (extensionData.Length != 1) - throw new TlsFatalAlert(AlertDescription.decode_error); - - return extensionData[0]; + return TlsUtilities.DecodeUint8(extensionData); } /// <exception cref="IOException"></exception> diff --git a/crypto/src/crypto/tls/TlsUtilities.cs b/crypto/src/crypto/tls/TlsUtilities.cs index 48e51a7b6..48eb9d375 100644 --- a/crypto/src/crypto/tls/TlsUtilities.cs +++ b/crypto/src/crypto/tls/TlsUtilities.cs @@ -324,12 +324,47 @@ namespace Org.BouncyCastle.Crypto.Tls WriteUint16Array(uints, buf, offset + 2); } + public static byte DecodeUint8(byte[] buf) + { + if (buf == null) + throw new ArgumentNullException("buf"); + if (buf.Length != 1) + throw new TlsFatalAlert(AlertDescription.decode_error); + return ReadUint8(buf, 0); + } + + public static byte[] DecodeUint8ArrayWithUint8Length(byte[] buf) + { + if (buf == null) + throw new ArgumentNullException("buf"); + + int count = ReadUint8(buf, 0); + if (buf.Length != (count + 1)) + throw new TlsFatalAlert(AlertDescription.decode_error); + + byte[] uints = new byte[count]; + for (int i = 0; i < count; ++i) + { + uints[i] = ReadUint8(buf, i + 1); + } + return uints; + } + public static byte[] EncodeOpaque8(byte[] buf) { CheckUint8(buf.Length); return Arrays.Prepend(buf, (byte)buf.Length); } + public static byte[] EncodeUint8(byte val) + { + CheckUint8(val); + + byte[] extensionData = new byte[1]; + WriteUint8(val, extensionData, 0); + return extensionData; + } + public static byte[] EncodeUint8ArrayWithUint8Length(byte[] uints) { byte[] result = new byte[1 + uints.Length]; |