summary refs log tree commit diff
path: root/crypto/src/pkcs/Pkcs12Store.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/pkcs/Pkcs12Store.cs')
-rw-r--r--crypto/src/pkcs/Pkcs12Store.cs72
1 files changed, 32 insertions, 40 deletions
diff --git a/crypto/src/pkcs/Pkcs12Store.cs b/crypto/src/pkcs/Pkcs12Store.cs
index aede1653a..e05805b88 100644
--- a/crypto/src/pkcs/Pkcs12Store.cs
+++ b/crypto/src/pkcs/Pkcs12Store.cs
@@ -25,8 +25,8 @@ namespace Org.BouncyCastle.Pkcs
         private readonly Dictionary<string, string> m_localIds = new Dictionary<string, string>();
         private readonly Dictionary<string, X509CertificateEntry> m_certs =
             new Dictionary<string, X509CertificateEntry>(StringComparer.OrdinalIgnoreCase);
-        private readonly Dictionary<CertId, X509CertificateEntry> m_chainCerts =
-            new Dictionary<CertId, X509CertificateEntry>();
+        private readonly Dictionary<CertID, X509CertificateEntry> m_chainCerts =
+            new Dictionary<CertID, X509CertificateEntry>();
         private readonly Dictionary<string, X509CertificateEntry> m_keyCerts =
             new Dictionary<string, X509CertificateEntry>();
         private readonly DerObjectIdentifier keyAlgorithm;
@@ -45,45 +45,33 @@ namespace Org.BouncyCastle.Pkcs
                 SubjectPublicKeyInfoFactory.CreateSubjectPublicKeyInfo(pubKey));
         }
 
-        internal class CertId
+        internal struct CertID
+            : IEquatable<CertID>
         {
-            private readonly byte[] id;
+            private readonly byte[] m_id;
 
-            internal CertId(
-                AsymmetricKeyParameter pubKey)
+            internal CertID(X509CertificateEntry certEntry)
+                : this(certEntry.Certificate)
             {
-                this.id = CreateSubjectKeyID(pubKey).GetKeyIdentifier();
             }
 
-            internal CertId(
-                byte[] id)
+            internal CertID(X509Certificate cert)
+                : this(CreateSubjectKeyID(cert.GetPublicKey()).GetKeyIdentifier())
             {
-                this.id = id;
             }
 
-            internal byte[] Id
+            internal CertID(byte[] id)
             {
-                get { return id; }
+                m_id = id;
             }
 
-            public override int GetHashCode()
-            {
-                return Arrays.GetHashCode(id);
-            }
-
-            public override bool Equals(
-                object obj)
-            {
-                if (obj == this)
-                    return true;
+            internal byte[] ID => m_id;
 
-                CertId other = obj as CertId;
+            public bool Equals(CertID other) => Arrays.AreEqual(m_id, other.m_id);
 
-                if (other == null)
-                    return false;
+            public override bool Equals(object obj) => obj is CertID other && Equals(other);
 
-                return Arrays.AreEqual(id, other.id);
-            }
+            public override int GetHashCode() => Arrays.GetHashCode(m_id);
         }
 
         internal Pkcs12Store(DerObjectIdentifier keyAlgorithm, DerObjectIdentifier keyPrfAlgorithm,
@@ -356,16 +344,16 @@ namespace Org.BouncyCastle.Pkcs
                     }
                 }
 
-                CertId certId = new CertId(cert.GetPublicKey());
+                CertID certID = new CertID(cert);
                 X509CertificateEntry certEntry = new X509CertificateEntry(cert, attributes);
 
-                m_chainCerts[certId] = certEntry;
+                m_chainCerts[certID] = certEntry;
 
                 if (unmarkedKeyEntry != null)
                 {
                     if (m_keyCerts.Count == 0)
                     {
-                        string name = Hex.ToHexString(certId.Id);
+                        string name = Hex.ToHexString(certID.ID);
 
                         m_keyCerts[name] = certEntry;
                         m_keys[name] = unmarkedKeyEntry;
@@ -502,7 +490,7 @@ namespace Org.BouncyCastle.Pkcs
                     byte[] keyID = aki.GetKeyIdentifier();
                     if (keyID != null)
                     {
-                        nextC = CollectionUtilities.GetValueOrNull(m_chainCerts, new CertId(keyID));
+                        nextC = CollectionUtilities.GetValueOrNull(m_chainCerts, new CertID(keyID));
                     }
                 }
 
@@ -562,7 +550,7 @@ namespace Org.BouncyCastle.Pkcs
                 throw new ArgumentException("There is a key entry with the name " + alias + ".");
 
             m_certs[alias] = certEntry;
-            m_chainCerts[new CertId(certEntry.Certificate.GetPublicKey())] = certEntry;
+            m_chainCerts[new CertID(certEntry)] = certEntry;
         }
 
         public void SetKeyEntry(string alias, AsymmetricKeyEntry keyEntry, X509CertificateEntry[] chain)
@@ -571,7 +559,7 @@ namespace Org.BouncyCastle.Pkcs
                 throw new ArgumentNullException(nameof(alias));
             if (keyEntry == null)
                 throw new ArgumentNullException(nameof(keyEntry));
-            if (keyEntry.Key.IsPrivate && chain == null)
+            if (keyEntry.Key.IsPrivate && Arrays.IsNullOrEmpty(chain))
                 throw new ArgumentException("No certificate chain for private key");
 
             if (m_keys.ContainsKey(alias))
@@ -580,11 +568,15 @@ namespace Org.BouncyCastle.Pkcs
             }
 
             m_keys[alias] = keyEntry;
-            m_certs[alias] = chain[0];
 
-            for (int i = 0; i != chain.Length; i++)
+            if (chain.Length > 0)
             {
-                m_chainCerts[new CertId(chain[i].Certificate.GetPublicKey())] = chain[i];
+                m_certs[alias] = chain[0];
+
+                foreach (var certificateEntry in chain)
+                {
+                    m_chainCerts[new CertID(certificateEntry)] = certificateEntry;
+                }
             }
         }
 
@@ -593,18 +585,18 @@ namespace Org.BouncyCastle.Pkcs
             if (alias == null)
                 throw new ArgumentNullException(nameof(alias));
 
-            if (CollectionUtilities.Remove(m_certs, alias, out var cert))
+            if (CollectionUtilities.Remove(m_certs, alias, out var certEntry))
             {
-                m_chainCerts.Remove(new CertId(cert.Certificate.GetPublicKey()));
+                m_chainCerts.Remove(new CertID(certEntry));
             }
 
             if (m_keys.Remove(alias))
             {
                 if (CollectionUtilities.Remove(m_localIds, alias, out var id))
                 {
-                    if (CollectionUtilities.Remove(m_keyCerts, id, out var keyCert))
+                    if (CollectionUtilities.Remove(m_keyCerts, id, out var keyCertEntry))
                     {
-                        m_chainCerts.Remove(new CertId(keyCert.Certificate.GetPublicKey()));
+                        m_chainCerts.Remove(new CertID(keyCertEntry));
                     }
                 }
             }