summary refs log tree commit diff
path: root/crypto/src/asn1/pkcs/PrivateKeyInfo.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/asn1/pkcs/PrivateKeyInfo.cs')
-rw-r--r--crypto/src/asn1/pkcs/PrivateKeyInfo.cs181
1 files changed, 64 insertions, 117 deletions
diff --git a/crypto/src/asn1/pkcs/PrivateKeyInfo.cs b/crypto/src/asn1/pkcs/PrivateKeyInfo.cs
index 9535dbcae..7397b7061 100644
--- a/crypto/src/asn1/pkcs/PrivateKeyInfo.cs
+++ b/crypto/src/asn1/pkcs/PrivateKeyInfo.cs
@@ -1,8 +1,6 @@
 using System;
 
 using Org.BouncyCastle.Asn1.X509;
-using Org.BouncyCastle.Math;
-using Org.BouncyCastle.Utilities.Collections;
 
 namespace Org.BouncyCastle.Asn1.Pkcs
 {
@@ -46,166 +44,115 @@ namespace Org.BouncyCastle.Asn1.Pkcs
     public class PrivateKeyInfo
         : Asn1Encodable
     {
-        private readonly DerInteger version;
-        private readonly AlgorithmIdentifier privateKeyAlgorithm;
-        private readonly Asn1OctetString privateKey;
-        private readonly Asn1Set attributes;
-        private readonly DerBitString publicKey;
-
-        public static PrivateKeyInfo GetInstance(Asn1TaggedObject obj, bool explicitly)
-        {
-            return GetInstance(Asn1Sequence.GetInstance(obj, explicitly));
-        }
-
-        public static PrivateKeyInfo GetInstance(
-            object obj)
+        public static PrivateKeyInfo GetInstance(object obj)
         {
             if (obj == null)
                 return null;
-            if (obj is PrivateKeyInfo)
-                return (PrivateKeyInfo)obj;
+            if (obj is PrivateKeyInfo privateKeyInfo)
+                return privateKeyInfo;
             return new PrivateKeyInfo(Asn1Sequence.GetInstance(obj));
         }
 
-        private static int GetVersionValue(DerInteger version)
-        {
-            BigInteger bigValue = version.Value;
-            if (bigValue.CompareTo(BigInteger.Zero) < 0 || bigValue.CompareTo(BigInteger.One) > 0)
-                throw new ArgumentException("invalid version for private key info", "version");
-
-            return bigValue.IntValue;
-        }
-
-        public PrivateKeyInfo(
-            AlgorithmIdentifier privateKeyAlgorithm,
-            Asn1Encodable privateKey)
-            : this(privateKeyAlgorithm, privateKey, null, null)
-        {
-        }
-
-        public PrivateKeyInfo(
-            AlgorithmIdentifier privateKeyAlgorithm,
-            Asn1Encodable privateKey,
-            Asn1Set attributes)
-            : this(privateKeyAlgorithm, privateKey, attributes, null)
+        public static PrivateKeyInfo GetInstance(Asn1TaggedObject obj, bool explicitly)
         {
+            return new PrivateKeyInfo(Asn1Sequence.GetInstance(obj, explicitly));
         }
 
-        public PrivateKeyInfo(
-            AlgorithmIdentifier privateKeyAlgorithm,
-            Asn1Encodable privateKey,
-            Asn1Set attributes,
-            byte[] publicKey)
-        {
-            this.version = new DerInteger(publicKey != null ? BigInteger.One : BigInteger.Zero);
-            this.privateKeyAlgorithm = privateKeyAlgorithm;
-            this.privateKey = new DerOctetString(privateKey);
-            this.attributes = attributes;
-            this.publicKey = publicKey == null ? null : new DerBitString(publicKey);
-        }
+        private readonly DerInteger m_version;
+        private readonly AlgorithmIdentifier m_privateKeyAlgorithm;
+        private readonly Asn1OctetString m_privateKey;
+        private readonly Asn1Set m_attributes;
+        private readonly DerBitString m_publicKey;
 
         private PrivateKeyInfo(Asn1Sequence seq)
         {
-            var e = seq.GetEnumerator();
+            int count = seq.Count, pos = 0;
+            if (count < 3 || count > 5)
+                throw new ArgumentException("Bad sequence size: " + count, nameof(seq));
 
-            this.version = DerInteger.GetInstance(CollectionUtilities.RequireNext(e));
+            m_version = DerInteger.GetInstance(seq[pos++]);
+            m_privateKeyAlgorithm = AlgorithmIdentifier.GetInstance(seq[pos++]);
+            m_privateKey = Asn1OctetString.GetInstance(seq[pos++]);
+            m_attributes = Asn1Utilities.ReadOptionalContextTagged(seq, ref pos, 0, false, Asn1Set.GetInstance);
+            m_publicKey = Asn1Utilities.ReadOptionalContextTagged(seq, ref pos, 1, false, DerBitString.GetInstance);
 
-            int versionValue = GetVersionValue(version);
+            if (pos != count)
+                throw new ArgumentException("Unexpected elements in sequence", nameof(seq));
 
-            this.privateKeyAlgorithm = AlgorithmIdentifier.GetInstance(CollectionUtilities.RequireNext(e));
-            this.privateKey = Asn1OctetString.GetInstance(CollectionUtilities.RequireNext(e));
+            int versionValue = GetVersionValue(m_version);
 
-            int lastTag = -1;
-            while (e.MoveNext())
-            {
-                Asn1TaggedObject tagged = (Asn1TaggedObject)e.Current;
-
-                int tag = tagged.TagNo;
-                if (tag <= lastTag)
-                    throw new ArgumentException("invalid optional field in private key info", "seq");
-
-                lastTag = tag;
-
-                switch (tag)
-                {
-                case 0:
-                {
-                    this.attributes = Asn1Set.GetInstance(tagged, false);
-                    break;
-                }
-                case 1:
-                {
-                    if (versionValue < 1)
-                        throw new ArgumentException("'publicKey' requires version v2(1) or later", "seq");
-
-                    this.publicKey = DerBitString.GetInstance(tagged, false);
-                    break;
-                }
-                default:
-                {
-                    throw new ArgumentException("unknown optional field in private key info", "seq");
-                }
-                }
-            }
+            if (m_publicKey != null && versionValue < 1)
+                throw new ArgumentException("'publicKey' requires version v2(1) or later", nameof(seq));
         }
 
-        public virtual DerInteger Version
+        public PrivateKeyInfo(AlgorithmIdentifier privateKeyAlgorithm, Asn1Encodable privateKey)
+            : this(privateKeyAlgorithm, privateKey, null, null)
         {
-            get { return version; }
         }
 
-        public virtual Asn1Set Attributes
+        public PrivateKeyInfo(AlgorithmIdentifier privateKeyAlgorithm, Asn1Encodable privateKey, Asn1Set attributes)
+            : this(privateKeyAlgorithm, privateKey, attributes, null)
         {
-            get { return attributes; }
         }
 
-        /// <summary>Return true if a public key is present, false otherwise.</summary>
-        public virtual bool HasPublicKey
+        public PrivateKeyInfo(AlgorithmIdentifier privateKeyAlgorithm, Asn1Encodable privateKey, Asn1Set attributes,
+            byte[] publicKey)
         {
-            get { return publicKey != null; }
+            m_version = new DerInteger(publicKey != null ? 1 : 0);
+            m_privateKeyAlgorithm = privateKeyAlgorithm ?? throw new ArgumentNullException(nameof(privateKeyAlgorithm));
+            m_privateKey = new DerOctetString(privateKey);
+            m_attributes = attributes;
+            m_publicKey = publicKey == null ? null : new DerBitString(publicKey);
         }
 
-        public virtual AlgorithmIdentifier PrivateKeyAlgorithm
-        {
-            get { return privateKeyAlgorithm; }
-        }
+        public virtual DerInteger Version => m_version;
 
-        public virtual Asn1OctetString PrivateKey => privateKey;
+        public virtual Asn1Set Attributes => m_attributes;
+
+        /// <summary>Return true if a public key is present, false otherwise.</summary>
+        public virtual bool HasPublicKey => m_publicKey != null;
+
+        public virtual AlgorithmIdentifier PrivateKeyAlgorithm => m_privateKeyAlgorithm;
+
+        public virtual Asn1OctetString PrivateKey => m_privateKey;
 
         [Obsolete("Use 'PrivateKey' instead")]
-        public virtual Asn1OctetString PrivateKeyData
-        {
-            get { return privateKey; }
-        }
+        public virtual Asn1OctetString PrivateKeyData => m_privateKey;
 
-        public virtual int PrivateKeyLength => privateKey.GetOctetsLength();
+        public virtual int PrivateKeyLength => m_privateKey.GetOctetsLength();
 
-        public virtual Asn1Object ParsePrivateKey()
-        {
-            return Asn1Object.FromByteArray(privateKey.GetOctets());
-        }
+        public virtual Asn1Object ParsePrivateKey() => Asn1Object.FromByteArray(m_privateKey.GetOctets());
 
         /// <summary>For when the public key is an ASN.1 encoding.</summary>
         public virtual Asn1Object ParsePublicKey()
         {
-            return publicKey == null ? null : Asn1Object.FromByteArray(publicKey.GetOctets());
+            return m_publicKey == null ? null : Asn1Object.FromByteArray(m_publicKey.GetOctets());
         }
 
-        public virtual DerBitString PublicKey => publicKey;
+        public virtual DerBitString PublicKey => m_publicKey;
 
         /// <summary>Return the public key as a raw bit string.</summary>
         [Obsolete("Use 'PublicKey' instead")]
-        public virtual DerBitString PublicKeyData
-        {
-            get { return publicKey; }
-        }
+        public virtual DerBitString PublicKeyData => m_publicKey;
 
         public override Asn1Object ToAsn1Object()
         {
-            Asn1EncodableVector v = new Asn1EncodableVector(version, privateKeyAlgorithm, privateKey);
-            v.AddOptionalTagged(false, 0, attributes);
-            v.AddOptionalTagged(false, 1, publicKey);
+            Asn1EncodableVector v = new Asn1EncodableVector(5);
+            v.Add(m_version, m_privateKeyAlgorithm, m_privateKey);
+            v.AddOptionalTagged(false, 0, m_attributes);
+            v.AddOptionalTagged(false, 1, m_publicKey);
             return new DerSequence(v);
         }
+
+        private static int GetVersionValue(DerInteger version)
+        {
+            if (version.TryGetIntPositiveValueExact(out int value))
+            {
+                if (value >= 0 && value <= 1)
+                    return value;
+            }
+
+            throw new ArgumentException("Invalid version for PrivateKeyInfo", nameof(version));
+        }
     }
 }