summary refs log tree commit diff
path: root/crypto/src/pqc/crypto/sphincsplus/SPHINCSPlusKeyPairGenerator.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/pqc/crypto/sphincsplus/SPHINCSPlusKeyPairGenerator.cs')
-rw-r--r--crypto/src/pqc/crypto/sphincsplus/SPHINCSPlusKeyPairGenerator.cs28
1 files changed, 24 insertions, 4 deletions
diff --git a/crypto/src/pqc/crypto/sphincsplus/SPHINCSPlusKeyPairGenerator.cs b/crypto/src/pqc/crypto/sphincsplus/SPHINCSPlusKeyPairGenerator.cs
index 9e5724027..dbb93a812 100644
--- a/crypto/src/pqc/crypto/sphincsplus/SPHINCSPlusKeyPairGenerator.cs
+++ b/crypto/src/pqc/crypto/sphincsplus/SPHINCSPlusKeyPairGenerator.cs
@@ -1,6 +1,8 @@
 
+using System;
 using Org.BouncyCastle.Crypto;
 using Org.BouncyCastle.Security;
+using static Org.BouncyCastle.Pqc.Crypto.SphincsPlus.SPHINCSPlusEngine;
 
 namespace Org.BouncyCastle.Pqc.Crypto.SphincsPlus
 {
@@ -13,15 +15,33 @@ namespace Org.BouncyCastle.Pqc.Crypto.SphincsPlus
         public void Init(KeyGenerationParameters param)
         {
             random = param.Random;
-            parameters = ((SPHINCSPlusKeyGenerationParameters) param).Parameters;
+            parameters = ((SPHINCSPlusKeyGenerationParameters)param).Parameters;
         }
 
         public AsymmetricCipherKeyPair GenerateKeyPair()
         {
             SPHINCSPlusEngine engine = parameters.GetEngine();
-
-            SK sk = new SK(SecRand(engine.N), SecRand(engine.N));
-            byte[] pkSeed = SecRand(engine.N);
+            byte[] pkSeed;
+            SK sk;
+
+            if (engine is SPHINCSPlusEngine.HarakaSEngine)
+            {
+                // required to pass kat tests
+                byte[] tmparray = SecRand(engine.N * 3);
+                byte[] skseed = new byte[engine.N];
+                byte[] skprf = new byte[engine.N];
+                pkSeed = new byte[engine.N];
+                Array.Copy(tmparray, 0, skseed, 0, engine.N);
+                Array.Copy(tmparray, engine.N, skprf, 0, engine.N);
+                Array.Copy(tmparray, engine.N << 1, pkSeed, 0, engine.N);
+                sk = new SK(skseed, skprf);
+            }
+            else
+            {
+                sk = new SK(SecRand(engine.N), SecRand(engine.N));
+                pkSeed = SecRand(engine.N);
+            }
+            engine.init(pkSeed);
             // TODO
             PK pk = new PK(pkSeed, new HT(engine, sk.seed, pkSeed).HTPubKey);