summary refs log tree commit diff
path: root/crypto/src/pqc/crypto/saber/Poly.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/pqc/crypto/saber/Poly.cs')
-rw-r--r--crypto/src/pqc/crypto/saber/Poly.cs25
1 files changed, 18 insertions, 7 deletions
diff --git a/crypto/src/pqc/crypto/saber/Poly.cs b/crypto/src/pqc/crypto/saber/Poly.cs
index eaae6c9a5..1a7312201 100644
--- a/crypto/src/pqc/crypto/saber/Poly.cs
+++ b/crypto/src/pqc/crypto/saber/Poly.cs
@@ -32,9 +32,8 @@ namespace Org.BouncyCastle.Pqc.Crypto.Saber
             byte[] buf = new byte[SABER_L * engine.PolyVecBytes];
             int i;
 
-            IXof digest = new ShakeDigest(128);
-            digest.BlockUpdate(seed, 0, engine.SeedBytes);
-            digest.OutputFinal(buf, 0, buf.Length);
+            engine.Symmetric.Prf(buf, seed, engine.SeedBytes, buf.Length);
+
 
             for (i = 0; i < SABER_L; i++)
             {
@@ -46,13 +45,25 @@ namespace Org.BouncyCastle.Pqc.Crypto.Saber
         {
             byte[] buf = new byte[SABER_L * engine.PolyCoinBytes];
 
-            IXof digest = new ShakeDigest(128);
-            digest.BlockUpdate(seed, 0, engine.NoiseSeedBytes);
-            digest.OutputFinal(buf, 0, buf.Length);
+            engine.Symmetric.Prf(buf, seed, engine.NoiseSeedBytes, buf.Length);
+
 
             for (int i = 0; i < SABER_L; i++)
             {
-                Cbd(s[i], buf, i * engine.PolyCoinBytes);
+                if (!engine.UsingEffectiveMasking)
+                {
+                    Cbd(s[i], buf, i * engine.PolyCoinBytes);
+                }
+                else
+                {
+                    for(int j = 0; j<SABER_N/4; j++)
+                    {
+                        s[i][4*j] = (short) ((((buf[j + i * engine.PolyCoinBytes]) & 0x03) ^ 2) - 2);
+                        s[i][4*j+1] = (short) ((((buf[j + i * engine.PolyCoinBytes] >> 2) & 0x03)  ^ 2) - 2);
+                        s[i][4*j+2] = (short) ((((buf[j + i * engine.PolyCoinBytes] >> 4) & 0x03)  ^ 2) - 2);
+                        s[i][4*j+3] = (short) ((((buf[j + i * engine.PolyCoinBytes] >> 6) & 0x03)  ^ 2) - 2);
+                    }
+                }
             }
         }