summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2024-01-08 23:08:15 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2024-01-08 23:08:15 +0700
commitb6730581b87743c4b0bcbcd5163cb577d22f2f31 (patch)
tree603bee70ba2ffcc4058e804b599b1a41e87cd2a5
parentAdd some convenience methods to BigInteger (diff)
downloadBouncyCastle.NET-ed25519-b6730581b87743c4b0bcbcd5163cb577d22f2f31.tar.xz
Fix ordering changes in Pkcs12Store
-rw-r--r--crypto/src/pkcs/PKCS12StoreBuilder.cs8
-rw-r--r--crypto/src/pkcs/Pkcs12Store.cs197
2 files changed, 115 insertions, 90 deletions
diff --git a/crypto/src/pkcs/PKCS12StoreBuilder.cs b/crypto/src/pkcs/PKCS12StoreBuilder.cs
index 63d7fb56a..be8f29886 100644
--- a/crypto/src/pkcs/PKCS12StoreBuilder.cs
+++ b/crypto/src/pkcs/PKCS12StoreBuilder.cs
@@ -11,7 +11,7 @@ namespace Org.BouncyCastle.Pkcs
 		private DerObjectIdentifier	certAlgorithm = PkcsObjectIdentifiers.PbewithShaAnd40BitRC2Cbc;
 		private DerObjectIdentifier keyPrfAlgorithm = null;
 		private bool useDerEncoding = false;
-		private bool reverseCertificate = false;
+		private bool reverseCertificates = false;
 
 		public Pkcs12StoreBuilder()
 		{
@@ -19,12 +19,12 @@ namespace Org.BouncyCastle.Pkcs
 
 		public Pkcs12Store Build()
 		{
-			return new Pkcs12Store(keyAlgorithm, keyPrfAlgorithm, certAlgorithm, useDerEncoding, reverseCertificate);
+			return new Pkcs12Store(keyAlgorithm, keyPrfAlgorithm, certAlgorithm, useDerEncoding, reverseCertificates);
 		}
 
-		public Pkcs12StoreBuilder SetReverseCertificates(bool reverseCertificate)
+		public Pkcs12StoreBuilder SetReverseCertificates(bool reverseCertificates)
 		{
-			this.reverseCertificate = reverseCertificate;
+			this.reverseCertificates = reverseCertificates;
 			return this;
 		}
 
diff --git a/crypto/src/pkcs/Pkcs12Store.cs b/crypto/src/pkcs/Pkcs12Store.cs
index 27cef8090..98247e2a8 100644
--- a/crypto/src/pkcs/Pkcs12Store.cs
+++ b/crypto/src/pkcs/Pkcs12Store.cs
@@ -22,24 +22,26 @@ namespace Org.BouncyCastle.Pkcs
 
         private readonly Dictionary<string, AsymmetricKeyEntry> m_keys =
             new Dictionary<string, AsymmetricKeyEntry>(StringComparer.OrdinalIgnoreCase);
+        private readonly List<string> m_keysOrder = new List<string>();
+
         private readonly Dictionary<string, string> m_localIds = new Dictionary<string, string>();
+
         private readonly Dictionary<string, X509CertificateEntry> m_certs =
             new Dictionary<string, X509CertificateEntry>(StringComparer.OrdinalIgnoreCase);
+        private readonly List<string> m_certsOrder = new List<string>();
+
         private readonly Dictionary<CertID, X509CertificateEntry> m_chainCerts =
             new Dictionary<CertID, X509CertificateEntry>();
+        private readonly List<CertID> m_chainCertsOrder = new List<CertID>();
+
         private readonly Dictionary<string, X509CertificateEntry> m_keyCerts =
             new Dictionary<string, X509CertificateEntry>();
-        private readonly List<string> m_keysOrder =
-            new List<string>();
-        private readonly List<string> m_certsOrder =
-            new List<string>();
-        private readonly List<CertID> m_chainCertOrder =
-            new List<CertID>();
+
         private readonly DerObjectIdentifier keyAlgorithm;
         private readonly DerObjectIdentifier keyPrfAlgorithm;
         private readonly DerObjectIdentifier certAlgorithm;
         private readonly bool useDerEncoding;
-        private readonly bool isReverse;
+        private readonly bool reverseCertificates;
 
         private AsymmetricKeyEntry unmarkedKeyEntry = null;
 
@@ -82,13 +84,13 @@ namespace Org.BouncyCastle.Pkcs
         }
 
         internal Pkcs12Store(DerObjectIdentifier keyAlgorithm, DerObjectIdentifier keyPrfAlgorithm,
-            DerObjectIdentifier certAlgorithm, bool useDerEncoding, bool isReverse)
+            DerObjectIdentifier certAlgorithm, bool useDerEncoding, bool reverseCertificates)
         {
             this.keyAlgorithm = keyAlgorithm;
             this.keyPrfAlgorithm = keyPrfAlgorithm;
             this.certAlgorithm = certAlgorithm;
             this.useDerEncoding = useDerEncoding;
-            this.isReverse = isReverse;
+            this.reverseCertificates = reverseCertificates;
         }
 
         protected virtual void LoadKeyBag(PrivateKeyInfo privKeyInfo, Asn1Set bagAttributes)
@@ -127,14 +129,13 @@ namespace Org.BouncyCastle.Pkcs
                             attributes[aOid] = attr;
                         }
 
-                        if (aOid.Equals(PkcsObjectIdentifiers.Pkcs9AtFriendlyName))
+                        if (PkcsObjectIdentifiers.Pkcs9AtFriendlyName.Equals(aOid))
                         {
                             alias = ((DerBmpString)attr).GetString();
                             // TODO Do these in a separate loop, just collect aliases here
-                            m_keys[alias] = keyEntry;
-                            m_keysOrder.Add(alias);
+                            Map(m_keys, m_keysOrder, alias, keyEntry);
                         }
-                        else if (aOid.Equals(PkcsObjectIdentifiers.Pkcs9AtLocalKeyID))
+                        else if (PkcsObjectIdentifiers.Pkcs9AtLocalKeyID.Equals(aOid))
                         {
                             localId = (Asn1OctetString)attr;
                         }
@@ -148,8 +149,7 @@ namespace Org.BouncyCastle.Pkcs
 
                 if (alias == null)
                 {
-                    m_keys[name] = keyEntry;
-                    m_keysOrder.Add(name);
+                    Map(m_keys, m_keysOrder, name, keyEntry);
                 }
                 else
                 {
@@ -225,14 +225,13 @@ namespace Org.BouncyCastle.Pkcs
                 }
             }
 
-            m_keys.Clear();
-            m_keysOrder.Clear();
+            Clear(m_keys, m_keysOrder);
             m_localIds.Clear();
             unmarkedKeyEntry = null;
 
             var certBags = new List<SafeBag>();
 
-            if (info.ContentType.Equals(PkcsObjectIdentifiers.Data))
+            if (PkcsObjectIdentifiers.Data.Equals(info.ContentType))
             {
                 Asn1OctetString content = Asn1OctetString.GetInstance(info.Content);
                 AuthenticatedSafe authSafe = AuthenticatedSafe.GetInstance(content.GetOctets());
@@ -243,11 +242,11 @@ namespace Org.BouncyCastle.Pkcs
                     DerObjectIdentifier oid = ci.ContentType;
 
                     byte[] octets = null;
-                    if (oid.Equals(PkcsObjectIdentifiers.Data))
+                    if (PkcsObjectIdentifiers.Data.Equals(oid))
                     {
                         octets = Asn1OctetString.GetInstance(ci.Content).GetOctets();
                     }
-                    else if (oid.Equals(PkcsObjectIdentifiers.EncryptedData))
+                    else if (PkcsObjectIdentifiers.EncryptedData.Equals(oid))
                     {
                         if (password != null)
                         {
@@ -269,16 +268,16 @@ namespace Org.BouncyCastle.Pkcs
                         {
                             SafeBag b = SafeBag.GetInstance(subSeq);
 
-                            if (b.BagID.Equals(PkcsObjectIdentifiers.CertBag))
+                            if (PkcsObjectIdentifiers.CertBag.Equals(b.BagID))
                             {
                                 certBags.Add(b);
                             }
-                            else if (b.BagID.Equals(PkcsObjectIdentifiers.Pkcs8ShroudedKeyBag))
+                            else if (PkcsObjectIdentifiers.Pkcs8ShroudedKeyBag.Equals(b.BagID))
                             {
                                 LoadPkcs8ShroudedKeyBag(EncryptedPrivateKeyInfo.GetInstance(b.BagValue),
                                     b.BagAttributes, password, wrongPkcs12Zero);
                             }
-                            else if (b.BagID.Equals(PkcsObjectIdentifiers.KeyBag))
+                            else if (PkcsObjectIdentifiers.KeyBag.Equals(b.BagID))
                             {
                                 LoadKeyBag(PrivateKeyInfo.GetInstance(b.BagValue), b.BagAttributes);
                             }
@@ -291,12 +290,10 @@ namespace Org.BouncyCastle.Pkcs
                 }
             }
 
-            m_certs.Clear();
-            m_chainCerts.Clear();
+            Clear(m_certs, m_certsOrder);
+            Clear(m_chainCerts, m_chainCertsOrder);
             m_keyCerts.Clear();
-            m_certsOrder.Clear();
-            m_chainCertOrder.Clear();
-            
+
             foreach (SafeBag b in certBags)
             {
                 CertBag certBag = CertBag.GetInstance(b.BagValue);
@@ -345,11 +342,11 @@ namespace Org.BouncyCastle.Pkcs
                                 attributes[aOid] = attr;
                             }
 
-                            if (aOid.Equals(PkcsObjectIdentifiers.Pkcs9AtFriendlyName))
+                            if (PkcsObjectIdentifiers.Pkcs9AtFriendlyName.Equals(aOid))
                             {
                                 alias = ((DerBmpString)attr).GetString();
                             }
-                            else if (aOid.Equals(PkcsObjectIdentifiers.Pkcs9AtLocalKeyID))
+                            else if (PkcsObjectIdentifiers.Pkcs9AtLocalKeyID.Equals(aOid))
                             {
                                 localId = (Asn1OctetString)attr;
                             }
@@ -359,10 +356,7 @@ namespace Org.BouncyCastle.Pkcs
 
                 CertID certID = new CertID(cert);
                 X509CertificateEntry certEntry = new X509CertificateEntry(cert, attributes);
-
-                m_chainCerts[certID] = certEntry;
-                m_chainCertOrder.Add(certID);
-                // m_certOrder.Add(certID);
+                Map(m_chainCerts, m_chainCertsOrder, certID, certEntry);
 
                 if (unmarkedKeyEntry != null)
                 {
@@ -371,11 +365,11 @@ namespace Org.BouncyCastle.Pkcs
                         string name = Hex.ToHexString(certID.ID);
 
                         m_keyCerts[name] = certEntry;
-                        m_keys[name] = unmarkedKeyEntry;
+                        Map(m_keys, m_keysOrder, name, unmarkedKeyEntry);
                     }
                     else
                     {
-                        m_keys["unmarked"] = unmarkedKeyEntry;
+                        Map(m_keys, m_keysOrder, "unmarked", unmarkedKeyEntry);
                     }
                 }
                 else
@@ -390,8 +384,7 @@ namespace Org.BouncyCastle.Pkcs
                     if (alias != null)
                     {
                         // TODO There may have been more than one alias
-                        m_certs[alias] = certEntry;
-                        m_certsOrder.Add(alias);
+                        Map(m_certs, m_certsOrder, alias, certEntry);
                     }
                 }
             }
@@ -565,8 +558,8 @@ namespace Org.BouncyCastle.Pkcs
             if (m_keys.ContainsKey(alias))
                 throw new ArgumentException("There is a key entry with the name " + alias + ".");
 
-            m_certs[alias] = certEntry;
-            m_chainCerts[new CertID(certEntry)] = certEntry;
+            Map(m_certs, m_certsOrder, alias, certEntry);
+            Map(m_chainCerts, m_chainCertsOrder, new CertID(certEntry), certEntry);
         }
 
         public void SetKeyEntry(string alias, AsymmetricKeyEntry keyEntry, X509CertificateEntry[] chain)
@@ -583,18 +576,15 @@ namespace Org.BouncyCastle.Pkcs
                 DeleteEntry(alias);
             }
 
-            m_keys[alias] = keyEntry;
-            m_keysOrder.Add(alias);
+            Map(m_keys, m_keysOrder, alias, keyEntry);
 
             if (chain.Length > 0)
             {
-                m_certs[alias] = chain[0];
-                m_certsOrder.Add(alias);
+                Map(m_certs, m_certsOrder, alias, chain[0]);
+
                 foreach (var certificateEntry in chain)
                 {
-                    CertID certId = new CertID(certificateEntry);
-                    m_chainCerts[certId] = certificateEntry;
-                    m_chainCertOrder.Add(certId);
+                    Map(m_chainCerts, m_chainCertsOrder, new CertID(certificateEntry), certificateEntry);
                 }
             }
         }
@@ -604,24 +594,18 @@ namespace Org.BouncyCastle.Pkcs
             if (alias == null)
                 throw new ArgumentNullException(nameof(alias));
 
-            if (CollectionUtilities.Remove(m_certs, alias, out var certEntry))
+            if (Remove(m_certs, m_certsOrder, alias, out var certEntry))
             {
-                CertID certId = new CertID(certEntry);
-                m_chainCerts.Remove(certId);
-                m_chainCertOrder.Remove(certId);
-                m_certsOrder.Remove(alias);
+                Remove(m_chainCerts, m_chainCertsOrder, new CertID(certEntry));
             }
 
-            if (m_keys.Remove(alias))
+            if (Remove(m_keys, m_keysOrder, alias))
             {
-                m_keys.Remove(alias);
                 if (CollectionUtilities.Remove(m_localIds, alias, out var id))
                 {
                     if (CollectionUtilities.Remove(m_keyCerts, id, out var keyCertEntry))
                     {
-                        CertID certId = new CertID(certEntry);
-                        m_chainCertOrder.Remove(certId);
-                        m_chainCerts.Remove(certId);
+                        Remove(m_chainCerts, m_chainCertsOrder, new CertID(keyCertEntry));
                     }
                 }
             }
@@ -667,9 +651,9 @@ namespace Org.BouncyCastle.Pkcs
             // handle the keys
             //
             Asn1EncodableVector keyBags = new Asn1EncodableVector(m_keys.Count);
-            for (uint i = isReverse ? (uint)m_keysOrder.Count-1 : 0;
+            for (uint i = reverseCertificates ? (uint)m_keysOrder.Count-1 : 0;
                  i < m_keysOrder.Count;
-                 i = isReverse ? i-1 : i+1)
+                 i = reverseCertificates ? i-1 : i+1)
             {
                 var name = m_keysOrder[(int)i];
                 var privKey = m_keys[name];
@@ -756,11 +740,11 @@ namespace Org.BouncyCastle.Pkcs
             AlgorithmIdentifier cAlgId = new AlgorithmIdentifier(certAlgorithm, cParams.ToAsn1Object());
             var doneCerts = new HashSet<X509Certificate>();
 
-            for (uint i = isReverse ? (uint)m_keysOrder.Count-1 : 0;
+            for (uint i = reverseCertificates ? (uint)m_keysOrder.Count-1 : 0;
                  i < m_keysOrder.Count;
-                 i = isReverse ? i-1 : i+1)
+                 i = reverseCertificates ? i-1 : i+1)
             {
-                String name = m_keysOrder[(int)i];
+                string name = m_keysOrder[(int)i];
                 X509CertificateEntry certEntry = GetCertificate(name);
                 CertBag cBag = new CertBag(
                     PkcsObjectIdentifiers.X509Certificate,
@@ -807,18 +791,15 @@ namespace Org.BouncyCastle.Pkcs
 
                 doneCerts.Add(certEntry.Certificate);
             }
-            
-            // foreach (var certEntry in m_certs)
-            for (uint j = isReverse ? (uint)m_certsOrder.Count-1 : 0;
+
+            for (uint j = reverseCertificates ? (uint)m_certsOrder.Count-1 : 0;
                  j < m_certsOrder.Count;
-                 j = isReverse ? j-1 : j+1)
+                 j = reverseCertificates ? j-1 : j+1)
             {
-                var certId = m_certsOrder[(int)j];
-                var cert = m_certs[certId];
-                // var certId = certEntry.Key;
-                // var cert = certEntry.Value;
+                var alias = m_certsOrder[(int)j];
+                var cert = m_certs[alias];
 
-                if (m_keys.ContainsKey(certId))
+                if (m_keys.ContainsKey(alias))
                     continue;
 
                 CertBag cBag = new CertBag(
@@ -852,7 +833,7 @@ namespace Org.BouncyCastle.Pkcs
                     fName.Add(
                         new DerSequence(
                             PkcsObjectIdentifiers.Pkcs9AtFriendlyName,
-                            new DerSet(new DerBmpString(certId))));
+                            new DerSet(new DerBmpString(alias))));
                 }
 
                 // the Oracle PKCS12 parser looks for a trusted key usage for named certificates as well
@@ -888,27 +869,25 @@ namespace Org.BouncyCastle.Pkcs
 
                 doneCerts.Add(cert.Certificate);
             }
-            
-            // foreach (var chainCertEntry in m_chainCerts)
-            for (uint i = isReverse ? (uint)m_chainCertOrder.Count-1 : 0;
-                 i < m_chainCertOrder.Count;
-                 i = isReverse ? i-1 : i+1)
+
+            for (uint i = reverseCertificates ? (uint)m_chainCertsOrder.Count-1 : 0;
+                 i < m_chainCertsOrder.Count;
+                 i = reverseCertificates ? i-1 : i+1)
             {
-                var certId = m_chainCertOrder[(int)i];
-                var cert = m_chainCerts[certId];
-                // var certId = chainCertEntry.Key;
-                // var cert = chainCertEntry.Value;
+                CertID certID = m_chainCertsOrder[(int)i];
+                X509CertificateEntry certEntry = m_chainCerts[certID];
+                X509Certificate cert = certEntry.Certificate;
 
-                if (doneCerts.Contains(cert.Certificate))
+                if (doneCerts.Contains(cert))
                     continue;
 
                 CertBag cBag = new CertBag(
                     PkcsObjectIdentifiers.X509Certificate,
-                    new DerOctetString(cert.Certificate.GetEncoded()));
+                    new DerOctetString(cert.GetEncoded()));
 
                 Asn1EncodableVector fName = new Asn1EncodableVector();
 
-                foreach (var oid in cert.BagAttributeKeys)
+                foreach (var oid in certEntry.BagAttributeKeys)
                 {
                     // a certificate not immediately linked to a key doesn't require
                     // a localKeyID and will confuse some PKCS12 implementations.
@@ -917,7 +896,7 @@ namespace Org.BouncyCastle.Pkcs
                     if (PkcsObjectIdentifiers.Pkcs9AtLocalKeyID.Equals(oid))
                         continue;
 
-                    fName.Add(new DerSequence(oid, new DerSet(cert[oid])));
+                    fName.Add(new DerSequence(oid, new DerSet(certEntry[oid])));
                 }
 
                 certBags.Add(new SafeBag(PkcsObjectIdentifiers.CertBag, cBag.ToAsn1Object(), DerSet.FromVector(fName)));
@@ -1001,7 +980,7 @@ namespace Org.BouncyCastle.Pkcs
             if (cipher == null)
                 throw new Exception("Unknown encryption algorithm: " + algId.Algorithm);
 
-            if (algId.Algorithm.Equals(PkcsObjectIdentifiers.IdPbeS2))
+            if (PkcsObjectIdentifiers.IdPbeS2.Equals(algId.Algorithm))
             {
                 PbeS2Parameters pbeParameters = PbeS2Parameters.GetInstance(algId.Parameters);
                 ICipherParameters cipherParams = PbeUtilities.GenerateCipherParameters(
@@ -1018,5 +997,51 @@ namespace Org.BouncyCastle.Pkcs
                 return cipher.DoFinal(data);
             }
         }
+
+        private static void Clear<K, V>(Dictionary<K, V> d, List<K> o)
+        {
+            d.Clear();
+            o.Clear();
+        }
+
+        private static void Map<K, V>(Dictionary<K, V> d, List<K> o, K k, V v)
+        {
+            if (d.ContainsKey(k))
+            {
+                RemoveOrdering(d.Comparer, o, k);
+            }
+
+            o.Add(k);
+            d[k] = v;
+        }
+
+        private static bool Remove<K, V>(Dictionary<K, V> d, List<K> o, K k)
+        {
+            bool result = d.Remove(k);
+            if (result)
+            {
+                RemoveOrdering(d.Comparer, o, k);
+            }
+            return result;
+        }
+
+        private static bool Remove<K, V>(Dictionary<K, V> d, List<K> o, K k, out V v)
+        {
+            bool result = CollectionUtilities.Remove(d, k, out v);
+            if (result)
+            {
+                RemoveOrdering(d.Comparer, o, k);
+            }
+            return result;
+        }
+
+        private static void RemoveOrdering<K>(IEqualityComparer<K> c, List<K> o, K k)
+        {
+            int index = o.FindIndex(e => c.Equals(k, e));
+            if (index >= 0)
+            {
+                o.RemoveAt(index);
+            }
+        }
     }
 }