summary refs log tree commit diff
path: root/crypto/src/security/AgreementUtilities.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/security/AgreementUtilities.cs')
-rw-r--r--crypto/src/security/AgreementUtilities.cs193
1 files changed, 147 insertions, 46 deletions
diff --git a/crypto/src/security/AgreementUtilities.cs b/crypto/src/security/AgreementUtilities.cs
index 041aeeed2..41dcb7435 100644
--- a/crypto/src/security/AgreementUtilities.cs
+++ b/crypto/src/security/AgreementUtilities.cs
@@ -12,35 +12,86 @@ using Org.BouncyCastle.Utilities.Collections;
 
 namespace Org.BouncyCastle.Security
 {
-	/// <remarks>
-	///  Utility class for creating IBasicAgreement objects from their names/Oids
-	/// </remarks>
-	public static class AgreementUtilities
+    /// <remarks>
+    ///  Utility class for creating IBasicAgreement objects from their names/Oids
+    /// </remarks>
+    public static class AgreementUtilities
 	{
-		private static readonly IDictionary<string, string> Algorithms =
-			new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
+        private static readonly Dictionary<DerObjectIdentifier, string> AlgorithmOidMap =
+            new Dictionary<DerObjectIdentifier, string>();
 
         static AgreementUtilities()
 		{
-            Algorithms[X9ObjectIdentifiers.DHSinglePassCofactorDHSha1KdfScheme.Id] = "ECCDHWITHSHA1KDF";
-			Algorithms[X9ObjectIdentifiers.DHSinglePassStdDHSha1KdfScheme.Id] = "ECDHWITHSHA1KDF";
-			Algorithms[X9ObjectIdentifiers.MqvSinglePassSha1KdfScheme.Id] = "ECMQVWITHSHA1KDF";
+            AlgorithmOidMap[X9ObjectIdentifiers.DHSinglePassCofactorDHSha1KdfScheme] = "ECCDHWITHSHA1KDF";
+            AlgorithmOidMap[X9ObjectIdentifiers.DHSinglePassStdDHSha1KdfScheme] = "ECDHWITHSHA1KDF";
+            AlgorithmOidMap[X9ObjectIdentifiers.MqvSinglePassSha1KdfScheme] = "ECMQVWITHSHA1KDF";
+
+            AlgorithmOidMap[EdECObjectIdentifiers.id_X25519] = "X25519";
+            AlgorithmOidMap[EdECObjectIdentifiers.id_X448] = "X448";
+
+#if DEBUG
+            //foreach (var key in AlgorithmMap.Keys)
+            //{
+            //    if (DerObjectIdentifier.TryFromID(key, out var ignore))
+            //        throw new Exception("OID mapping belongs in AlgorithmOidMap: " + key);
+            //}
+
+            //var mechanisms = new HashSet<string>(AlgorithmMap.Values);
+            var mechanisms = new HashSet<string>();
+            mechanisms.UnionWith(AlgorithmOidMap.Values);
+
+            foreach (var mechanism in mechanisms)
+            {
+                //if (AlgorithmMap.TryGetValue(mechanism, out var check))
+                //{
+                //    if (mechanism != check)
+                //        throw new Exception("Mechanism mapping MUST be to self: " + mechanism);
+                //}
+                //else
+                {
+                    if (!mechanism.Equals(mechanism.ToUpperInvariant()))
+                        throw new Exception("Unmapped mechanism MUST be uppercase: " + mechanism);
+                }
+            }
+#endif
+        }
 
-            Algorithms[EdECObjectIdentifiers.id_X25519.Id] = "X25519";
-            Algorithms[EdECObjectIdentifiers.id_X448.Id] = "X448";
+        public static string GetAlgorithmName(DerObjectIdentifier oid)
+        {
+            return CollectionUtilities.GetValueOrNull(AlgorithmOidMap, oid);
         }
 
-        public static IBasicAgreement GetBasicAgreement(
-			DerObjectIdentifier oid)
+        public static IBasicAgreement GetBasicAgreement(DerObjectIdentifier oid)
 		{
-			return GetBasicAgreement(oid.Id);
-		}
+            if (oid == null)
+                throw new ArgumentNullException(nameof(oid));
 
-		public static IBasicAgreement GetBasicAgreement(
-			string algorithm)
+            if (AlgorithmOidMap.TryGetValue(oid, out var mechanism))
+            {
+                var basicAgreement = GetBasicAgreementForMechanism(mechanism);
+                if (basicAgreement != null)
+                    return basicAgreement;
+            }
+
+            throw new SecurityUtilityException("Basic Agreement OID not recognised.");
+        }
+
+        public static IBasicAgreement GetBasicAgreement(string algorithm)
 		{
-            string mechanism = GetMechanism(algorithm);
+            if (algorithm == null)
+                throw new ArgumentNullException(nameof(algorithm));
+
+            string mechanism = GetMechanism(algorithm) ?? algorithm.ToUpperInvariant();
 
+            var basicAgreement = GetBasicAgreementForMechanism(mechanism);
+            if (basicAgreement != null)
+                return basicAgreement;
+
+            throw new SecurityUtilityException("Basic Agreement " + algorithm + " not recognised.");
+		}
+
+		private static IBasicAgreement GetBasicAgreementForMechanism(string mechanism)
+		{
             if (mechanism == "DH" || mechanism == "DIFFIEHELLMAN")
 				return new DHBasicAgreement();
 
@@ -48,71 +99,121 @@ namespace Org.BouncyCastle.Security
 				return new ECDHBasicAgreement();
 
             if (mechanism == "ECDHC" || mechanism == "ECCDH")
-                    return new ECDHCBasicAgreement();
+                return new ECDHCBasicAgreement();
 
 			if (mechanism == "ECMQV")
 				return new ECMqvBasicAgreement();
 
-			throw new SecurityUtilityException("Basic Agreement " + algorithm + " not recognised.");
+            return null;
 		}
 
         public static IBasicAgreement GetBasicAgreementWithKdf(DerObjectIdentifier agreeAlgOid,
 			DerObjectIdentifier wrapAlgOid)
         {
-            return GetBasicAgreementWithKdf(agreeAlgOid.Id, wrapAlgOid.Id);
+            return GetBasicAgreementWithKdf(agreeAlgOid, wrapAlgOid?.Id);
         }
 
+        // TODO[api] Change parameter name to 'agreeAlgOid'
         public static IBasicAgreement GetBasicAgreementWithKdf(DerObjectIdentifier oid, string wrapAlgorithm)
 		{
-			return GetBasicAgreementWithKdf(oid.Id, wrapAlgorithm);
-		}
+            if (oid == null)
+                throw new ArgumentNullException(nameof(oid));
+            if (wrapAlgorithm == null)
+                throw new ArgumentNullException(nameof(wrapAlgorithm));
+
+            if (AlgorithmOidMap.TryGetValue(oid, out var mechanism))
+            {
+                var basicAgreement = GetBasicAgreementWithKdfForMechanism(mechanism, wrapAlgorithm);
+                if (basicAgreement != null)
+                    return basicAgreement;
+            }
+
+            throw new SecurityUtilityException("Basic Agreement (with KDF) OID not recognised.");
+        }
 
-		public static IBasicAgreement GetBasicAgreementWithKdf(string agreeAlgorithm, string wrapAlgorithm)
+        public static IBasicAgreement GetBasicAgreementWithKdf(string agreeAlgorithm, string wrapAlgorithm)
 		{
-            string mechanism = GetMechanism(agreeAlgorithm);
+            if (agreeAlgorithm == null)
+                throw new ArgumentNullException(nameof(agreeAlgorithm));
+            if (wrapAlgorithm == null)
+                throw new ArgumentNullException(nameof(wrapAlgorithm));
+
+            string mechanism = GetMechanism(agreeAlgorithm) ?? agreeAlgorithm.ToUpperInvariant();
+
+            var basicAgreement = GetBasicAgreementWithKdfForMechanism(mechanism, wrapAlgorithm);
+            if (basicAgreement != null)
+                return basicAgreement;
+
+            throw new SecurityUtilityException("Basic Agreement (with KDF) " + agreeAlgorithm + " not recognised.");
+		}
 
+        private static IBasicAgreement GetBasicAgreementWithKdfForMechanism(string mechanism, string wrapAlgorithm)
+        {
             // 'DHWITHSHA1KDF' retained for backward compatibility
-			if (mechanism == "DHWITHSHA1KDF" || mechanism == "ECDHWITHSHA1KDF")
-				return new ECDHWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest()));
+            if (mechanism == "DHWITHSHA1KDF" || mechanism == "ECDHWITHSHA1KDF")
+                return new ECDHWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest()));
 
-			if (mechanism == "ECCDHWITHSHA1KDF")
-				return new ECDHCWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest()));
+            if (mechanism == "ECCDHWITHSHA1KDF")
+                return new ECDHCWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest()));
 
-			if (mechanism == "ECMQVWITHSHA1KDF")
-				return new ECMqvWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest()));
+            if (mechanism == "ECMQVWITHSHA1KDF")
+                return new ECMqvWithKdfBasicAgreement(wrapAlgorithm, new ECDHKekGenerator(new Sha1Digest()));
 
-			throw new SecurityUtilityException("Basic Agreement (with KDF) " + agreeAlgorithm + " not recognised.");
-		}
+            return null;
+        }
 
-        public static IRawAgreement GetRawAgreement(
-            DerObjectIdentifier oid)
+        public static IRawAgreement GetRawAgreement(DerObjectIdentifier oid)
         {
-            return GetRawAgreement(oid.Id);
+            if (oid == null)
+                throw new ArgumentNullException(nameof(oid));
+
+            if (AlgorithmOidMap.TryGetValue(oid, out var mechanism))
+            {
+                var rawAgreement = GetRawAgreementForMechanism(mechanism);
+                if (rawAgreement != null)
+                    return rawAgreement;
+            }
+
+            throw new SecurityUtilityException("Raw Agreement OID not recognised.");
         }
 
         public static IRawAgreement GetRawAgreement(string algorithm)
         {
-            string mechanism = GetMechanism(algorithm);
+            if (algorithm == null)
+                throw new ArgumentNullException(nameof(algorithm));
+
+            string mechanism = GetMechanism(algorithm) ?? algorithm.ToUpperInvariant();
+
+            var rawAgreement = GetRawAgreementForMechanism(mechanism);
+            if (rawAgreement != null)
+                return rawAgreement;
+
+            throw new SecurityUtilityException("Raw Agreement " + algorithm + " not recognised.");
+        }
 
+        private static IRawAgreement GetRawAgreementForMechanism(string mechanism)
+        {
             if (mechanism == "X25519")
                 return new X25519Agreement();
 
             if (mechanism == "X448")
                 return new X448Agreement();
 
-            throw new SecurityUtilityException("Raw Agreement " + algorithm + " not recognised.");
+            return null;
         }
 
-		public static string GetAlgorithmName(DerObjectIdentifier oid)
-		{
-			return CollectionUtilities.GetValueOrNull(Algorithms, oid.Id);
-		}
-
-		private static string GetMechanism(string algorithm)
+        private static string GetMechanism(string algorithm)
         {
-			var mechanism = CollectionUtilities.GetValueOrKey(Algorithms, algorithm);
+            //if (AlgorithmMap.TryGetValue(algorithm, out var mechanism1))
+            //    return mechanism1;
+
+            if (DerObjectIdentifier.TryFromID(algorithm, out var oid))
+            {
+                if (AlgorithmOidMap.TryGetValue(oid, out var mechanism2))
+                    return mechanism2;
+            }
 
-			return mechanism.ToUpperInvariant();
+            return null;
         }
 	}
 }