summary refs log tree commit diff
path: root/crypto/src/pqc/crypto/lms/HSSSignature.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/pqc/crypto/lms/HSSSignature.cs')
-rw-r--r--crypto/src/pqc/crypto/lms/HSSSignature.cs112
1 files changed, 53 insertions, 59 deletions
diff --git a/crypto/src/pqc/crypto/lms/HSSSignature.cs b/crypto/src/pqc/crypto/lms/HSSSignature.cs
index 21f0397c8..946d0ef89 100644
--- a/crypto/src/pqc/crypto/lms/HSSSignature.cs
+++ b/crypto/src/pqc/crypto/lms/HSSSignature.cs
@@ -1,22 +1,25 @@
 using System;
 using System.IO;
+using System.Text;
 
 using Org.BouncyCastle.Utilities;
 using Org.BouncyCastle.Utilities.IO;
 
 namespace Org.BouncyCastle.Pqc.Crypto.Lms
 {
+    // TODO[api] Make internal
     public sealed class HssSignature
         : IEncodable
     {
         private readonly int m_lMinus1;
-        private readonly LmsSignedPubKey[] m_signedPubKey;
+        private readonly LmsSignedPubKey[] m_signedPubKeys;
         private readonly LmsSignature m_signature;
 
+        // TODO[api] signedPubKeys
         public HssSignature(int lMinus1, LmsSignedPubKey[] signedPubKey, LmsSignature signature)
         {
             m_lMinus1 = lMinus1;
-            m_signedPubKey = signedPubKey;
+            m_signedPubKeys = signedPubKey;
             m_signature = signature;
         }
 
@@ -29,90 +32,81 @@ namespace Org.BouncyCastle.Pqc.Crypto.Lms
         public static HssSignature GetInstance(object src, int L)
         {
             if (src is HssSignature hssSignature)
-            {
                 return hssSignature;
-            }
-            else if (src is BinaryReader binaryReader)
-            {
-                int lminus = BinaryReaders.ReadInt32BigEndian(binaryReader);
-                if (lminus != L - 1)
-                    throw new Exception("nspk exceeded maxNspk");
 
-                LmsSignedPubKey[] signedPubKeys = new LmsSignedPubKey[lminus];
-                if (lminus != 0)
-                {
-                    for (int t = 0; t < signedPubKeys.Length; t++)
-                    {
-                        signedPubKeys[t] = new LmsSignedPubKey(LmsSignature.GetInstance(src),
-                            LmsPublicKeyParameters.GetInstance(src));
-                    }
-                }
+            if (src is BinaryReader binaryReader)
+                return Parse(L, binaryReader);
 
-                LmsSignature sig = LmsSignature.GetInstance(src);
+            if (src is Stream stream)
+                return Parse(L, stream, leaveOpen: true);
 
-                return new HssSignature(lminus, signedPubKeys, sig);
-            }
-            else if (src is byte[] bytes)
+            if (src is byte[] bytes)
+                return Parse(L, new MemoryStream(bytes, false), leaveOpen: false);
+
+            throw new ArgumentException($"cannot parse {src}");
+        }
+
+        internal static HssSignature Parse(int L, BinaryReader binaryReader)
+        {
+            int lMinus1 = BinaryReaders.ReadInt32BigEndian(binaryReader);
+            if (lMinus1 != L - 1)
+                throw new Exception("nspk exceeded maxNspk");
+
+            var signedPubKeys = new LmsSignedPubKey[lMinus1];
+            for (int t = 0; t < lMinus1; t++)
             {
-                BinaryReader input = null;
-                try // 1.5 / 1.6 compatibility
-                {
-                    input = new BinaryReader(new MemoryStream(bytes));
-                    return GetInstance(input, L);
-                }
-                finally
-                {
-                    if (input != null) input.Close();
-                }
+                var signature = LmsSignature.Parse(binaryReader);
+                var publicKey = LmsPublicKeyParameters.Parse(binaryReader);
+
+                signedPubKeys[t] = new LmsSignedPubKey(signature, publicKey);
             }
-            else if (src is MemoryStream memoryStream)
+
             {
-                return GetInstance(Streams.ReadAll(memoryStream), L);
+                var signature = LmsSignature.Parse(binaryReader);
+
+                return new HssSignature(lMinus1, signedPubKeys, signature);
             }
+        }
 
-            throw new ArgumentException($"cannot parse {src}");
+        private static HssSignature Parse(int L, Stream stream, bool leaveOpen)
+        {
+            using (var binaryReader = new BinaryReader(stream, Encoding.UTF8, leaveOpen))
+            {
+                return Parse(L, binaryReader);
+            }
         }
 
+        [Obsolete("Use 'LMinus1' instead")]
         public int GetLMinus1()
         {
             return m_lMinus1;
         }
 
-        // FIXME
-        public LmsSignedPubKey[] GetSignedPubKeys()
-        {
-            return m_signedPubKey;
-        }
+        public int LMinus1 => m_lMinus1;
+
+        public LmsSignedPubKey[] GetSignedPubKeys() => (LmsSignedPubKey[])m_signedPubKeys?.Clone();
+
+        internal LmsSignedPubKey[] SignedPubKeys => m_signedPubKeys;
 
         public LmsSignature Signature => m_signature;
 
+        // TODO[api] Fix parameter name
         public override bool Equals(object other)
         {
             if (this == other)
                 return true;
-            if (!(other is HssSignature that))
-                return false;
-
-            if (this.m_lMinus1 != that.m_lMinus1)
-                return false;
-
-            if (this.m_signedPubKey.Length != that.m_signedPubKey.Length)
-                return false;
-
-            for (int t = 0; t < m_signedPubKey.Length; t++)
-            {
-                if (!this.m_signedPubKey[t].Equals(that.m_signedPubKey[t]))
-                    return false;
-            }
 
-            return Equals(this.m_signature, that.m_signature);
+            return other is HssSignature that
+                && this.m_lMinus1 == that.m_lMinus1
+                && Arrays.AreEqual(this.m_signedPubKeys, that.m_signedPubKeys)
+                && Objects.Equals(this.m_signature, that.m_signature);
         }
 
         public override int GetHashCode()
         {
             int result = m_lMinus1;
-            result = 31 * result + m_signedPubKey.GetHashCode();
-            result = 31 * result + (m_signature != null ? m_signature.GetHashCode() : 0);
+            result = 31 * result + Arrays.GetHashCode(m_signedPubKeys);
+            result = 31 * result + Objects.GetHashCode(m_signature);
             return result;
         }
 
@@ -120,9 +114,9 @@ namespace Org.BouncyCastle.Pqc.Crypto.Lms
         {
             Composer composer = Composer.Compose();
             composer.U32Str(m_lMinus1);
-            if (m_signedPubKey != null)
+            if (m_signedPubKeys != null)
             {
-                foreach (LmsSignedPubKey sigPub in m_signedPubKey)
+                foreach (LmsSignedPubKey sigPub in m_signedPubKeys)
                 {
                     composer.Bytes(sigPub);
                 }