summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2022-11-10 22:06:36 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2022-11-10 22:06:36 +0700
commit6c27cbbbe403ee97c8aa9ae17bc6e5861dd7ebd5 (patch)
treebcb26eedbdfe37d6811044458c4917643a32e87e
parentBIKE perf. opts. (diff)
downloadBouncyCastle.NET-ed25519-6c27cbbbe403ee97c8aa9ae17bc6e5861dd7ebd5.tar.xz
BIKE perf. opts.
- CtrAll with vectorization when available
-rw-r--r--crypto/src/pqc/crypto/bike/BikeEngine.cs280
1 files changed, 207 insertions, 73 deletions
diff --git a/crypto/src/pqc/crypto/bike/BikeEngine.cs b/crypto/src/pqc/crypto/bike/BikeEngine.cs
index 2a8882901..d523e71ab 100644
--- a/crypto/src/pqc/crypto/bike/BikeEngine.cs
+++ b/crypto/src/pqc/crypto/bike/BikeEngine.cs
@@ -1,5 +1,8 @@
 using System;
 using System.Diagnostics;
+#if NETCOREAPP1_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+using System.Numerics;
+#endif
 
 using Org.BouncyCastle.Crypto;
 using Org.BouncyCastle.Crypto.Digests;
@@ -252,20 +255,24 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
             int[] h0CompactCol = GetColumnFromCompactVersion(h0Compact);
             int[] h1CompactCol = GetColumnFromCompactVersion(h1Compact);
 
-            for (int i = 1; i <= nbIter; i++)
+            uint[] black = new uint[(2 * r + 31) >> 5];
+
             {
-                byte[] black = new byte[2 * r];
-                byte[] gray = new byte[2 * r];
+                uint[] gray = new uint[(2 * r + 31) >> 5];
 
                 int T = Threshold(BikeUtilities.GetHammingWeight(s), r);
 
                 BFIter(s, e, T, h0Compact, h1Compact, h0CompactCol, h1CompactCol, black, gray);
+                BFMaskedIter(s, e, black, (hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
+                BFMaskedIter(s, e, gray, (hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
+            }
+            for (int i = 1; i < nbIter; i++)
+            {
+                Array.Clear(black, 0, black.Length);
 
-                if (i == 1)
-                {
-                    BFMaskedIter(s, e, black, (hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
-                    BFMaskedIter(s, e, gray, (hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
-                }
+                int T = Threshold(BikeUtilities.GetHammingWeight(s), r);
+
+                BFIter2(s, e, T, h0Compact, h1Compact, h0CompactCol, h1CompactCol, black);
             }
 
             if (BikeUtilities.GetHammingWeight(s) == 0)
@@ -286,86 +293,170 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
         }
 
         private void BFIter(byte[] s, byte[] e, int T, int[] h0Compact, int[] h1Compact, int[] h0CompactCol,
-            int[] h1CompactCol, byte[] black, byte[] gray)
+            int[] h1CompactCol, uint[] black, uint[] gray)
         {
-            int[] updatedIndices = new int[2 * r];
+            byte[] ctrs = new byte[r];
 
             // calculate for h0compact
-            for (int j = 0; j < r; j++)
             {
-                int ctr = Ctr(h0CompactCol, s, j);
-                if (ctr >= T)
+                CtrAll(h0CompactCol, s, ctrs);
+
                 {
-                    UpdateNewErrorIndex(e, j);
-                    updatedIndices[j] = 1;
-                    black[j] = 1;
+                    int ctrBit1 = ((ctrs[0] - T) >> 31) + 1;
+                    int ctrBit2 = ((ctrs[0] - (T - tau)) >> 31) + 1;
+                    e[0] ^= (byte)ctrBit1;
+                    black[0] |= (uint)ctrBit1;
+                    gray[0] |= (uint)ctrBit2;
                 }
-                else if (ctr >= T - tau)
+                for (int j = 1; j < r; j++)
                 {
-                    gray[j] = 1;
+                    int ctrBit1 = ((ctrs[j] - T) >> 31) + 1;
+                    int ctrBit2 = ((ctrs[j] - (T - tau)) >> 31) + 1;
+                    e[r - j] ^= (byte)ctrBit1;
+                    black[j >> 5] |= (uint)ctrBit1 << (j & 31);
+                    gray[j >> 5] |= (uint)ctrBit2 << (j & 31);
                 }
             }
 
+            Array.Clear(ctrs, 0, r);
+
             // calculate for h1Compact
-            for (int j = 0; j < r; j++)
             {
-                int ctr = Ctr(h1CompactCol, s, j);
-                if (ctr >= T)
+                CtrAll(h1CompactCol, s, ctrs);
+
                 {
-                    UpdateNewErrorIndex(e, r + j);
-                    updatedIndices[r + j] = 1;
-                    black[r + j] = 1;
+                    int ctrBit1 = ((ctrs[0] - T) >> 31) + 1;
+                    int ctrBit2 = ((ctrs[0] - (T - tau)) >> 31) + 1;
+                    e[r] ^= (byte)ctrBit1;
+                    black[r >> 5] |= (uint)ctrBit1 << (r & 31);
+                    gray[r >> 5] |= (uint)ctrBit2 << (r & 31);
                 }
-                else if (ctr >= T - tau)
+                for (int j = 1; j < r; j++)
                 {
-                    gray[r + j] = 1;
+                    int ctrBit1 = ((ctrs[j] - T) >> 31) + 1;
+                    int ctrBit2 = ((ctrs[j] - (T - tau)) >> 31) + 1;
+                    e[r + r - j] ^= (byte)ctrBit1;
+                    black[(r + j) >> 5] |= (uint)ctrBit1 << ((r + j) & 31);
+                    gray[(r + j) >> 5] |= (uint)ctrBit2 << ((r + j) & 31);
                 }
             }
 
             // recompute syndrome
-            for (int i = 0; i < 2 * r; i++)
+            for (int i = 0; i < black.Length; ++i)
             {
-                if (updatedIndices[i] == 1)
+                uint bits = black[i];
+                while (bits != 0)
                 {
-                    RecomputeSyndrome(s, i, h0Compact, h1Compact);
+                    int tz = Integers.NumberOfTrailingZeros((int)bits);
+                    RecomputeSyndrome(s, (i << 5) + tz, h0Compact, h1Compact);
+                    bits ^= 1U << tz;
                 }
             }
         }
 
-        private void BFMaskedIter(byte[] s, byte[] e, byte[] mask, int T, int[] h0Compact, int[] h1Compact,
+        private void BFIter2(byte[] s, byte[] e, int T, int[] h0Compact, int[] h1Compact, int[] h0CompactCol,
+            int[] h1CompactCol, uint[] black)
+        {
+            byte[] ctrs = new byte[r];
+
+            // calculate for h0compact
+            {
+                CtrAll(h0CompactCol, s, ctrs);
+
+                {
+                    int ctrBit1 = ((ctrs[0] - T) >> 31) + 1;
+                    e[0] ^= (byte)ctrBit1;
+                    black[0] |= (uint)ctrBit1;
+                }
+                for (int j = 1; j < r; j++)
+                {
+                    int ctrBit1 = ((ctrs[j] - T) >> 31) + 1;
+                    e[r - j] ^= (byte)ctrBit1;
+                    black[j >> 5] |= (uint)ctrBit1 << (j & 31);
+                }
+            }
+
+            Array.Clear(ctrs, 0, r);
+
+            // calculate for h1compact
+            {
+                CtrAll(h1CompactCol, s, ctrs);
+
+                {
+                    int ctrBit1 = ((ctrs[0] - T) >> 31) + 1;
+                    e[r] ^= (byte)ctrBit1;
+                    black[r >> 5] |= (uint)ctrBit1 << (r & 31);
+                }
+                for (int j = 1; j < r; j++)
+                {
+                    int ctrBit1 = ((ctrs[j] - T) >> 31) + 1;
+                    e[r + r - j] ^= (byte)ctrBit1;
+                    black[(r + j) >> 5] |= (uint)ctrBit1 << ((r + j) & 31);
+                }
+            }
+
+            // recompute syndrome
+            for (int i = 0; i < black.Length; ++i)
+            {
+                uint bits = black[i];
+                while (bits != 0)
+                {
+                    int tz = Integers.NumberOfTrailingZeros((int)bits);
+                    RecomputeSyndrome(s, (i << 5) + tz, h0Compact, h1Compact);
+                    bits ^= 1U << tz;
+                }
+            }
+        }
+
+        private void BFMaskedIter(byte[] s, byte[] e, uint[] mask, int T, int[] h0Compact, int[] h1Compact,
             int[] h0CompactCol, int[] h1CompactCol)
         {
-            int[] updatedIndices = new int[2 * r];
+            uint[] updatedIndices = new uint[(2 * r + 31) >> 5];
 
             for (int j = 0; j < r; j++)
             {
-                if (mask[j] == 1 && Ctr(h0CompactCol, s, j) >= T)
+                if ((mask[j >> 5] & (1U << (j & 31))) != 0)
                 {
-                    UpdateNewErrorIndex(e, j);
-                    updatedIndices[j] = 1;
+                    int ctr = Ctr(h0CompactCol, s, j);
+                    int ctrBit1 = ((ctr - T) >> 31) + 1;
+
+                    int k = -j;
+                    k += (k >> 31) & r;
+                    e[k] ^= (byte)ctrBit1;
+
+                    updatedIndices[j >> 5] |= (uint)ctrBit1 << (j & 31);
                 }
             }
 
             for (int j = 0; j < r; j++)
             {
-                if (mask[r + j] == 1 && Ctr(h1CompactCol, s, j) >= T)
+                if ((mask[(r + j) >> 5] & (1U << ((r + j) & 31))) != 0)
                 {
-                    UpdateNewErrorIndex(e, r + j);
-                    updatedIndices[r + j] = 1;
+                    int ctr = Ctr(h1CompactCol, s, j);
+                    int ctrBit1 = ((ctr - T) >> 31) + 1;
+
+                    int k = -j;
+                    k += (k >> 31) & r;
+                    e[r + k] ^= (byte)ctrBit1;
+
+                    updatedIndices[(r + j) >> 5] |= (uint)ctrBit1 << ((r + j) & 31);
                 }
             }
 
             // recompute syndrome
-            for (int i = 0; i < 2 * r; i++)
+            for (int i = 0; i < updatedIndices.Length; ++i)
             {
-                if (updatedIndices[i] == 1)
+                uint bits = updatedIndices[i];
+                while (bits != 0)
                 {
-                    RecomputeSyndrome(s, i, h0Compact, h1Compact);
+                    int tz = Integers.NumberOfTrailingZeros((int)bits);
+                    RecomputeSyndrome(s, (i << 5) + tz, h0Compact, h1Compact);
+                    bits ^= 1U << tz;
                 }
             }
         }
 
-        private int Threshold(int hammingWeight, int r)
+        private static int Threshold(int hammingWeight, int r)
         {
             double d;
             int floorD;
@@ -397,37 +488,25 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
 
             int count = 0;
 
-            int i = 0, limit8 = hw - 8;
-            while (i < limit8)
+            int i = 0, limit = hw - 4;
+            while (i < limit)
             {
                 int sPos0 = hCompactCol[i + 0] + j - r;
                 int sPos1 = hCompactCol[i + 1] + j - r;
                 int sPos2 = hCompactCol[i + 2] + j - r;
                 int sPos3 = hCompactCol[i + 3] + j - r;
-                int sPos4 = hCompactCol[i + 4] + j - r;
-                int sPos5 = hCompactCol[i + 5] + j - r;
-                int sPos6 = hCompactCol[i + 6] + j - r;
-                int sPos7 = hCompactCol[i + 7] + j - r;
 
                 sPos0 += (sPos0 >> 31) & r;
                 sPos1 += (sPos1 >> 31) & r;
                 sPos2 += (sPos2 >> 31) & r;
                 sPos3 += (sPos3 >> 31) & r;
-                sPos4 += (sPos4 >> 31) & r;
-                sPos5 += (sPos5 >> 31) & r;
-                sPos6 += (sPos6 >> 31) & r;
-                sPos7 += (sPos7 >> 31) & r;
 
                 count += s[sPos0];
                 count += s[sPos1];
                 count += s[sPos2];
                 count += s[sPos3];
-                count += s[sPos4];
-                count += s[sPos5];
-                count += s[sPos6];
-                count += s[sPos7];
 
-                i += 8;
+                i += 4;
             }
             while (i < hw)
             {
@@ -439,6 +518,78 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
             return count;
         }
 
+        private void CtrAll(int[] hCompactCol, byte[] s, byte[] ctrs)
+        {
+            for (int i = 0; i < hw; ++i)
+            {
+                int col = hCompactCol[i], neg = r - col;
+
+                int j = 0;
+#if NETCOREAPP1_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                if (Vector.IsHardwareAccelerated)
+                {
+                    int jLimit = neg - Vector<byte>.Count;
+                    while (j < jLimit)
+                    {
+                        var vc = new Vector<byte>(ctrs, j);
+                        var vs = new Vector<byte>(s, col + j);
+                        (vc + vs).CopyTo(ctrs, j);
+                        j += Vector<byte>.Count;
+                    }
+                }
+                else
+#endif
+                {
+                    int jLimit = neg - 4;
+                    while (j < jLimit)
+                    {
+                        ctrs[j + 0] += s[col + j + 0];
+                        ctrs[j + 1] += s[col + j + 1];
+                        ctrs[j + 2] += s[col + j + 2];
+                        ctrs[j + 3] += s[col + j + 3];
+                        j += 4;
+                    }
+                }
+
+                while (j < neg)
+                {
+                    ctrs[j] += s[col + j];
+                    ++j;
+                }
+                int k = neg;
+#if NETCOREAPP1_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                if (Vector.IsHardwareAccelerated)
+                {
+                    int kLimit = r - Vector<byte>.Count;
+                    while (k < kLimit)
+                    {
+                        var vc = new Vector<byte>(ctrs, k);
+                        var vs = new Vector<byte>(s, k - neg);
+                        (vc + vs).CopyTo(ctrs, k);
+                        k += Vector<byte>.Count;
+                    }
+                }
+                else
+#endif
+                {
+                    int kLimit = r - 4;
+                    while (k < kLimit)
+                    {
+                        ctrs[k + 0] += s[k + 0 - neg];
+                        ctrs[k + 1] += s[k + 1 - neg];
+                        ctrs[k + 2] += s[k + 2 - neg];
+                        ctrs[k + 3] += s[k + 3 - neg];
+                        k += 4;
+                    }
+                }
+                while (k < r)
+                {
+                    ctrs[k] += s[k - neg];
+                    ++k;
+                }
+            }
+        }
+
         // Convert a polynomial in GF2 to an array of positions of which the coefficients of the polynomial are equals to 1
         private void ConvertToCompact(int[] compactVersion, byte[] h)
         {
@@ -511,22 +662,5 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
                 }
             }
         }
-
-        private void UpdateNewErrorIndex(byte[] e, int index)
-        {
-            int newIndex = index;
-            if (index != 0 && index != r)
-            {
-                if (index > r)
-                {
-                    newIndex = 2 * r - index + r;
-                }
-                else
-                {
-                    newIndex = r - index;
-                }
-            }
-            e[newIndex] ^= 1;
-        }
     }
 }