diff --git a/crypto/src/crypto/tls/TlsUtilities.cs b/crypto/src/crypto/tls/TlsUtilities.cs
index ffb2fc3e6..462ec4074 100644
--- a/crypto/src/crypto/tls/TlsUtilities.cs
+++ b/crypto/src/crypto/tls/TlsUtilities.cs
@@ -4,6 +4,8 @@ using System.IO;
using System.Text;
using Org.BouncyCastle.Asn1;
+using Org.BouncyCastle.Asn1.Nist;
+using Org.BouncyCastle.Asn1.Pkcs;
using Org.BouncyCastle.Asn1.X509;
using Org.BouncyCastle.Crypto.Digests;
using Org.BouncyCastle.Crypto.Macs;
@@ -660,37 +662,54 @@ namespace Org.BouncyCastle.Crypto.Tls
}
}
- internal static byte[] PRF(byte[] secret, string asciiLabel, byte[] seed, int size)
+ public static byte[] PRF(TlsContext context, byte[] secret, string asciiLabel, byte[] seed, int size)
{
- byte[] label = Strings.ToAsciiByteArray(asciiLabel);
+ ProtocolVersion version = context.ServerVersion;
- int s_half = (secret.Length + 1) / 2;
- byte[] s1 = new byte[s_half];
- byte[] s2 = new byte[s_half];
- Array.Copy(secret, 0, s1, 0, s_half);
- Array.Copy(secret, secret.Length - s_half, s2, 0, s_half);
+ if (version.IsSsl)
+ throw new InvalidOperationException("No PRF available for SSLv3 session");
+
+ byte[] label = Strings.ToByteArray(asciiLabel);
+ byte[] labelSeed = Concat(label, seed);
- byte[] ls = Concat(label, seed);
+ int prfAlgorithm = context.SecurityParameters.PrfAlgorithm;
- byte[] buf = new byte[size];
- byte[] prf = new byte[size];
- HMacHash(new MD5Digest(), s1, ls, prf);
- HMacHash(new Sha1Digest(), s2, ls, buf);
- for (int i = 0; i < size; i++)
+ if (prfAlgorithm == PrfAlgorithm.tls_prf_legacy)
{
- buf[i] ^= prf[i];
+ return PRF_legacy(secret, label, labelSeed, size);
}
+
+ IDigest prfDigest = CreatePrfHash(prfAlgorithm);
+ byte[] buf = new byte[size];
+ HMacHash(prfDigest, secret, labelSeed, buf);
return buf;
}
- internal static byte[] PRF_1_2(IDigest digest, byte[] secret, string asciiLabel, byte[] seed, int size)
+ public static byte[] PRF_legacy(byte[] secret, string asciiLabel, byte[] seed, int size)
{
- byte[] label = Strings.ToAsciiByteArray(asciiLabel);
+ byte[] label = Strings.ToByteArray(asciiLabel);
byte[] labelSeed = Concat(label, seed);
- byte[] buf = new byte[size];
- HMacHash(digest, secret, labelSeed, buf);
- return buf;
+ return PRF_legacy(secret, label, labelSeed, size);
+ }
+
+ internal static byte[] PRF_legacy(byte[] secret, byte[] label, byte[] labelSeed, int size)
+ {
+ int s_half = (secret.Length + 1) / 2;
+ byte[] s1 = new byte[s_half];
+ byte[] s2 = new byte[s_half];
+ Array.Copy(secret, 0, s1, 0, s_half);
+ Array.Copy(secret, secret.Length - s_half, s2, 0, s_half);
+
+ byte[] b1 = new byte[size];
+ byte[] b2 = new byte[size];
+ HMacHash(CreateHash(HashAlgorithm.md5), s1, labelSeed, b1);
+ HMacHash(CreateHash(HashAlgorithm.sha1), s2, labelSeed, b2);
+ for (int i = 0; i < size; i++)
+ {
+ b1[i] ^= b2[i];
+ }
+ return b1;
}
internal static byte[] Concat(byte[] a, byte[] b)
@@ -782,6 +801,64 @@ namespace Org.BouncyCastle.Crypto.Tls
}
}
+ public static IDigest CreatePrfHash(int prfAlgorithm)
+ {
+ switch (prfAlgorithm)
+ {
+ case PrfAlgorithm.tls_prf_legacy:
+ return new CombinedHash();
+ default:
+ return CreateHash(GetHashAlgorithmForPrfAlgorithm(prfAlgorithm));
+ }
+ }
+
+ public static IDigest ClonePrfHash(int prfAlgorithm, IDigest hash)
+ {
+ switch (prfAlgorithm)
+ {
+ case PrfAlgorithm.tls_prf_legacy:
+ return new CombinedHash((CombinedHash)hash);
+ default:
+ return CloneHash(GetHashAlgorithmForPrfAlgorithm(prfAlgorithm), hash);
+ }
+ }
+
+ public static byte GetHashAlgorithmForPrfAlgorithm(int prfAlgorithm)
+ {
+ switch (prfAlgorithm)
+ {
+ case PrfAlgorithm.tls_prf_legacy:
+ throw new ArgumentException("legacy PRF not a valid algorithm", "prfAlgorithm");
+ case PrfAlgorithm.tls_prf_sha256:
+ return HashAlgorithm.sha256;
+ case PrfAlgorithm.tls_prf_sha384:
+ return HashAlgorithm.sha384;
+ default:
+ throw new ArgumentException("unknown PrfAlgorithm", "prfAlgorithm");
+ }
+ }
+
+ public static DerObjectIdentifier GetOidForHashAlgorithm(byte hashAlgorithm)
+ {
+ switch (hashAlgorithm)
+ {
+ case HashAlgorithm.md5:
+ return PkcsObjectIdentifiers.MD5;
+ case HashAlgorithm.sha1:
+ return X509ObjectIdentifiers.IdSha1;
+ case HashAlgorithm.sha224:
+ return NistObjectIdentifiers.IdSha224;
+ case HashAlgorithm.sha256:
+ return NistObjectIdentifiers.IdSha256;
+ case HashAlgorithm.sha384:
+ return NistObjectIdentifiers.IdSha384;
+ case HashAlgorithm.sha512:
+ return NistObjectIdentifiers.IdSha512;
+ default:
+ throw new ArgumentException("unknown HashAlgorithm", "hashAlgorithm");
+ }
+ }
+
private static IList VectorOfOne(object obj)
{
IList v = Platform.CreateArrayList(1);
|