diff options
Diffstat (limited to 'crypto/src/security/AgreementUtilities.cs')
-rw-r--r-- | crypto/src/security/AgreementUtilities.cs | 193 |
1 files changed, 147 insertions, 46 deletions
diff --git a/crypto/src/security/AgreementUtilities.cs b/crypto/src/security/AgreementUtilities.cs index 041aeeed2..41dcb7435 100644 --- a/crypto/src/security/AgreementUtilities.cs +++ b/crypto/src/security/AgreementUtilities.cs @@ -12,35 +12,86 @@ using Org.BouncyCastle.Utilities.Collections; namespace Org.BouncyCastle.Security { - /// <remarks> - /// Utility class for creating IBasicAgreement objects from their names/Oids - /// </remarks> - public static class AgreementUtilities + /// <remarks> + /// Utility class for creating IBasicAgreement objects from their names/Oids + /// </remarks> + public static class AgreementUtilities { - private static readonly IDictionary<string, string> Algorithms = - new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase); + private static readonly Dictionary<DerObjectIdentifier, string> AlgorithmOidMap = + new Dictionary<DerObjectIdentifier, string>(); static AgreementUtilities() { - Algorithms[X9ObjectIdentifiers.DHSinglePassCofactorDHSha1KdfScheme.Id] = "ECCDHWITHSHA1KDF"; - Algorithms[X9ObjectIdentifiers.DHSinglePassStdDHSha1KdfScheme.Id] = "ECDHWITHSHA1KDF"; - Algorithms[X9ObjectIdentifiers.MqvSinglePassSha1KdfScheme.Id] = "ECMQVWITHSHA1KDF"; + AlgorithmOidMap[X9ObjectIdentifiers.DHSinglePassCofactorDHSha1KdfScheme] = "ECCDHWITHSHA1KDF"; + AlgorithmOidMap[X9ObjectIdentifiers.DHSinglePassStdDHSha1KdfScheme] = "ECDHWITHSHA1KDF"; + AlgorithmOidMap[X9ObjectIdentifiers.MqvSinglePassSha1KdfScheme] = "ECMQVWITHSHA1KDF"; + + AlgorithmOidMap[EdECObjectIdentifiers.id_X25519] = "X25519"; + AlgorithmOidMap[EdECObjectIdentifiers.id_X448] = "X448"; + +#if DEBUG + //foreach (var key in AlgorithmMap.Keys) + //{ + // if (DerObjectIdentifier.TryFromID(key, out var ignore)) + // throw new Exception("OID mapping belongs in AlgorithmOidMap: " + key); + //} + + //var mechanisms = new HashSet<string>(AlgorithmMap.Values); + var mechanisms = new HashSet<string>(); + mechanisms.UnionWith(AlgorithmOidMap.Values); + + foreach (var mechanism in mechanisms) + { + //if (AlgorithmMap.TryGetValue(mechanism, out var check)) + //{ + // if (mechanism != check) + // throw new Exception("Mechanism mapping MUST be to self: " + mechanism); + //} + //else + { + if (!mechanism.Equals(mechanism.ToUpperInvariant())) + throw new Exception("Unmapped mechanism MUST be uppercase: " + mechanism); + } + } +#endif + } - Algorithms[EdECObjectIdentifiers.id_X25519.Id] = "X25519"; - Algorithms[EdECObjectIdentifiers.id_X448.Id] = "X448"; + public static string GetAlgorithmName(DerObjectIdentifier oid) + { + return CollectionUtilities.GetValueOrNull(AlgorithmOidMap, oid); } - public static IBasicAgreement GetBasicAgreement( - DerObjectIdentifier oid) + public static IBasicAgreement GetBasicAgreement(DerObjectIdentifier oid) { - return GetBasicAgreement(oid.Id); - } + if (oid == null) + throw new ArgumentNullException(nameof(oid)); - public static IBasicAgreement GetBasicAgreement( - string algorithm) + if (AlgorithmOidMap.TryGetValue(oid, out var mechanism)) + { + var basicAgreement = GetBasicAgreementForMechanism(mechanism); + if (basicAgreement != null) + return basicAgreement; + } + + throw new SecurityUtilityException("Basic Agreement OID not recognised."); + } + + public static IBasicAgreement GetBasicAgreement(string algorithm) { - string mechanism = GetMechanism(algorithm); + if (algorithm == null) + throw new ArgumentNullException(nameof(algorithm)); + + string mechanism = GetMechanism(algorithm) ?? algorithm.ToUpperInvariant(); + var basicAgreement = GetBasicAgreementForMechanism(mechanism); + if (basicAgreement != null) + return basicAgreement; + + throw new SecurityUtilityException("Basic Agreement " + algorithm + " not recognised."); + } + + private static IBasicAgreement GetBasicAgreementForMechanism(string mechanism) + { if (mechanism == "DH" || mechanism == "DIFFIEHELLMAN") return new DHBasicAgreement(); @@ -48,71 +99,121 @@ namespace Org.BouncyCastle.Security return new ECDHBasicAgreement(); if (mechanism == "ECDHC" || mechanism == "ECCDH") - return new ECDHCBasicAgreement(); + return new ECDHCBasicAgreement(); if (mechanism == "ECMQV") return new ECMqvBasicAgreement(); - throw new SecurityUtilityException("Basic Agreement " + algorithm + " not recognised."); + return null; } public static IBasicAgreement GetBasicAgreementWithKdf(DerObjectIdentifier agreeAlgOid, DerObjectIdentifier wrapAlgOid) { - return GetBasicAgreementWithKdf(agreeAlgOid.Id, wrapAlgOid.Id); + return GetBasicAgreementWithKdf(agreeAlgOid, wrapAlgOid?.Id); } + // TODO[api] Change parameter name to 'agreeAlgOid' public static IBasicAgreement GetBasicAgreementWithKdf(DerObjectIdentifier oid, string wrapAlgorithm) { - return GetBasicAgreementWithKdf(oid.Id, wrapAlgorithm); - } + if (oid == null) + throw new ArgumentNullException(nameof(oid)); + if (wrapAlgorithm == null) + throw new ArgumentNullException(nameof(wrapAlgorithm)); + + if (AlgorithmOidMap.TryGetValue(oid, out var mechanism)) + { + var basicAgreement = GetBasicAgreementWithKdfForMechanism(mechanism, wrapAlgorithm); + if (basicAgreement != null) + return basicAgreement; + } + + throw new SecurityUtilityException("Basic Agreement (with KDF) OID not recognised."); + } - public static IBasicAgreement GetBasicAgreementWithKdf(string agreeAlgorithm, string wrapAlgorithm) + public static IBasicAgreement GetBasicAgreementWithKdf(string agreeAlgorithm, string wrapAlgorithm) { - string mechanism = GetMechanism(agreeAlgorithm); + if (agreeAlgorithm == null) + throw new ArgumentNullException(nameof(agreeAlgorithm)); + if (wrapAlgorithm == null) + throw new ArgumentNullException(nameof(wrapAlgorithm)); + + string mechanism = GetMechanism(agreeAlgorithm) ?? agreeAlgorithm.ToUpperInvariant(); + + var basicAgreement = GetBasicAgreementWithKdfForMechanism(mechanism, wrapAlgorithm); + if (basicAgreement != null) + return basicAgreement; + + throw new SecurityUtilityException("Basic Agreement (with KDF) " + agreeAlgorithm + " not recognised."); + } + private static IBasicAgreement GetBasicAgreementWithKdfForMechanism(string mechanism, string wrapAlgorithm) + { // 'DHWITHSHA1KDF' retained for backward compatibility - if (mechanism == "DHWITHSHA1KDF" || mechanism == "ECDHWITHSHA1KDF") - return new ECDHWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest())); + if (mechanism == "DHWITHSHA1KDF" || mechanism == "ECDHWITHSHA1KDF") + return new ECDHWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest())); - if (mechanism == "ECCDHWITHSHA1KDF") - return new ECDHCWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest())); + if (mechanism == "ECCDHWITHSHA1KDF") + return new ECDHCWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest())); - if (mechanism == "ECMQVWITHSHA1KDF") - return new ECMqvWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest())); + if (mechanism == "ECMQVWITHSHA1KDF") + return new ECMqvWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest())); - throw new SecurityUtilityException("Basic Agreement (with KDF) " + agreeAlgorithm + " not recognised."); - } + return null; + } - public static IRawAgreement GetRawAgreement( - DerObjectIdentifier oid) + public static IRawAgreement GetRawAgreement(DerObjectIdentifier oid) { - return GetRawAgreement(oid.Id); + if (oid == null) + throw new ArgumentNullException(nameof(oid)); + + if (AlgorithmOidMap.TryGetValue(oid, out var mechanism)) + { + var rawAgreement = GetRawAgreementForMechanism(mechanism); + if (rawAgreement != null) + return rawAgreement; + } + + throw new SecurityUtilityException("Raw Agreement OID not recognised."); } public static IRawAgreement GetRawAgreement(string algorithm) { - string mechanism = GetMechanism(algorithm); + if (algorithm == null) + throw new ArgumentNullException(nameof(algorithm)); + + string mechanism = GetMechanism(algorithm) ?? algorithm.ToUpperInvariant(); + + var rawAgreement = GetRawAgreementForMechanism(mechanism); + if (rawAgreement != null) + return rawAgreement; + + throw new SecurityUtilityException("Raw Agreement " + algorithm + " not recognised."); + } + private static IRawAgreement GetRawAgreementForMechanism(string mechanism) + { if (mechanism == "X25519") return new X25519Agreement(); if (mechanism == "X448") return new X448Agreement(); - throw new SecurityUtilityException("Raw Agreement " + algorithm + " not recognised."); + return null; } - public static string GetAlgorithmName(DerObjectIdentifier oid) - { - return CollectionUtilities.GetValueOrNull(Algorithms, oid.Id); - } - - private static string GetMechanism(string algorithm) + private static string GetMechanism(string algorithm) { - var mechanism = CollectionUtilities.GetValueOrKey(Algorithms, algorithm); + //if (AlgorithmMap.TryGetValue(algorithm, out var mechanism1)) + // return mechanism1; + + if (DerObjectIdentifier.TryFromID(algorithm, out var oid)) + { + if (AlgorithmOidMap.TryGetValue(oid, out var mechanism2)) + return mechanism2; + } - return mechanism.ToUpperInvariant(); + return null; } } } |