summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/pqc/crypto/bike/BikeRing.cs128
1 files changed, 69 insertions, 59 deletions
diff --git a/crypto/src/pqc/crypto/bike/BikeRing.cs b/crypto/src/pqc/crypto/bike/BikeRing.cs
index 9babe280e..e424c9c3d 100644
--- a/crypto/src/pqc/crypto/bike/BikeRing.cs
+++ b/crypto/src/pqc/crypto/bike/BikeRing.cs
@@ -1,6 +1,9 @@
 using System;
 using System.Collections.Generic;
 using System.Diagnostics;
+#if NETSTANDARD1_0_OR_GREATER || NETCOREAPP1_0_OR_GREATER
+using System.Runtime.CompilerServices;
+#endif
 #if NETCOREAPP3_0_OR_GREATER
 using System.Runtime.Intrinsics;
 using System.Runtime.Intrinsics.X86;
@@ -15,11 +18,12 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
 {
     internal sealed class BikeRing
     {
+        private const int PermutationCutoff = 64;
+
         private readonly int m_bits;
         private readonly int m_size;
         private readonly int m_sizeExt;
-        private readonly int m_permutationCutoff;
-        private readonly Dictionary<int, ushort[]> m_permutations = new Dictionary<int, ushort[]>();
+        private readonly Dictionary<int, int> m_halfPowers = new Dictionary<int, int>();
 
         internal BikeRing(int r)
         {
@@ -29,13 +33,12 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
             m_bits = r;
             m_size = (r + 63) >> 6;
             m_sizeExt = m_size * 2;
-            m_permutationCutoff = r >> 5;
 
             foreach (int n in EnumerateSquarePowersInv(r))
             {
-                if (n > m_permutationCutoff && !m_permutations.ContainsKey(n))
+                if (n >= PermutationCutoff && !m_halfPowers.ContainsKey(n))
                 {
-                    m_permutations[n] = GenerateSquarePowerPermutation(r, n);
+                    m_halfPowers[n] = GenerateHalfPower(r, n);
                 }
             }
         }
@@ -184,15 +187,15 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
 
         internal void SquareN(ulong[] x, int n, ulong[] z)
         {
-
             Debug.Assert(n > 0);
 
             /*
              * In these polynomial rings, 'SquareN' for some 'n' is equivalent to a fixed permutation of the
-             * coefficients. For 'SquareN' with 'n' above some cutoff value, this permutation is precomputed and then
-             * applied in place of explicit squaring for that 'n'. This is particularly relevant to calls during 'Inv'.
+             * coefficients. Calls to 'Inv' generate calls to 'SquareN' with a predictable sequence of 'n' values.
+             * For such 'n' above some cutoff value, we precalculate a small constant and then apply the permutation in
+             * place of explicit squaring for that 'n'.
              */
-            if (n > m_permutationCutoff)
+            if (n >= PermutationCutoff)
             {
                 ImplPermute(x, n, z);
                 return;
@@ -209,6 +212,24 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
             }
         }
 
+#if NETSTANDARD1_0_OR_GREATER || NETCOREAPP1_0_OR_GREATER
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+#endif
+        private static int ImplModAdd(int m, int x, int y)
+        {
+            int t = x + y - m;
+            return t + ((t >> 31) & m);
+        }
+
+#if NETSTANDARD1_0_OR_GREATER || NETCOREAPP1_0_OR_GREATER
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+#endif
+        private static int ImplModHalf(int m, int x)
+        {
+            int t = -(x & 1);
+            return (x + (m & t)) >> 1;
+        }
+
         private void ImplMultiplyAcc(ulong[] x, ulong[] y, ulong[] zz)
         {
 #if NETCOREAPP3_0_OR_GREATER
@@ -318,52 +339,51 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
 
         private void ImplPermute(ulong[] x, int n, ulong[] z)
         {
-            var permutation = m_permutations[n];
+            int r = m_bits;
+
+            var pow_1 = m_halfPowers[n];
+            var pow_2 = ImplModAdd(r, pow_1, pow_1);
+            var pow_4 = ImplModAdd(r, pow_2, pow_2);
+            var pow_8 = ImplModAdd(r, pow_4, pow_4);
+
+            int p0 = r - pow_8;
+            int p1 = ImplModAdd(r, p0, pow_1);
+            int p2 = ImplModAdd(r, p0, pow_2);
+            int p3 = ImplModAdd(r, p1, pow_2);
+            int p4 = ImplModAdd(r, p0, pow_4);
+            int p5 = ImplModAdd(r, p1, pow_4);
+            int p6 = ImplModAdd(r, p2, pow_4);
+            int p7 = ImplModAdd(r, p3, pow_4);
 
-            int i = 0, limit64 = m_bits - 64;
-            while (i < limit64)
+            for (int i = 0; i < Size; ++i)
             {
                 ulong z_i = 0UL;
 
                 for (int j = 0; j < 64; j += 8)
                 {
-                    int k = i + j;
-                    int p0 = permutation[k + 0];
-                    int p1 = permutation[k + 1];
-                    int p2 = permutation[k + 2];
-                    int p3 = permutation[k + 3];
-                    int p4 = permutation[k + 4];
-                    int p5 = permutation[k + 5];
-                    int p6 = permutation[k + 6];
-                    int p7 = permutation[k + 7];
-
-                    z_i |= ((x[p0 >> 6] >> (p0 & 63)) & 1) << (j + 0);
-                    z_i |= ((x[p1 >> 6] >> (p1 & 63)) & 1) << (j + 1);
-                    z_i |= ((x[p2 >> 6] >> (p2 & 63)) & 1) << (j + 2);
-                    z_i |= ((x[p3 >> 6] >> (p3 & 63)) & 1) << (j + 3);
-                    z_i |= ((x[p4 >> 6] >> (p4 & 63)) & 1) << (j + 4);
-                    z_i |= ((x[p5 >> 6] >> (p5 & 63)) & 1) << (j + 5);
-                    z_i |= ((x[p6 >> 6] >> (p6 & 63)) & 1) << (j + 6);
-                    z_i |= ((x[p7 >> 6] >> (p7 & 63)) & 1) << (j + 7);
+                    p0 = ImplModAdd(r, p0, pow_8);
+                    p1 = ImplModAdd(r, p1, pow_8);
+                    p2 = ImplModAdd(r, p2, pow_8);
+                    p3 = ImplModAdd(r, p3, pow_8);
+                    p4 = ImplModAdd(r, p4, pow_8);
+                    p5 = ImplModAdd(r, p5, pow_8);
+                    p6 = ImplModAdd(r, p6, pow_8);
+                    p7 = ImplModAdd(r, p7, pow_8);
+
+                    z_i |= ((x[p0 >> 6] >> p0) & 1) << (j + 0);
+                    z_i |= ((x[p1 >> 6] >> p1) & 1) << (j + 1);
+                    z_i |= ((x[p2 >> 6] >> p2) & 1) << (j + 2);
+                    z_i |= ((x[p3 >> 6] >> p3) & 1) << (j + 3);
+                    z_i |= ((x[p4 >> 6] >> p4) & 1) << (j + 4);
+                    z_i |= ((x[p5 >> 6] >> p5) & 1) << (j + 5);
+                    z_i |= ((x[p6 >> 6] >> p6) & 1) << (j + 6);
+                    z_i |= ((x[p7 >> 6] >> p7) & 1) << (j + 7);
                 }
 
-                z[i >> 6] = z_i;
-
-                i += 64;
+                z[i] = z_i;
             }
-            Debug.Assert(i < m_bits);
-            {
-                ulong z_i = 0UL;
 
-                for (int j = i; j < m_bits; ++j)
-                {
-                    int p = permutation[j];
-
-                    z_i |= ((x[p >> 6] >> (p & 63)) & 1) << (j & 63);
-                }
-
-                z[i >> 6] = z_i;
-            }
+            z[Size - 1] &= ulong.MaxValue >> -r;
         }
 
         private static IEnumerable<int> EnumerateSquarePowersInv(int r)
@@ -383,24 +403,14 @@ namespace Org.BouncyCastle.Pqc.Crypto.Bike
             }
         }
 
-        private static ushort[] GenerateSquarePowerPermutation(int r, int n)
+        private static int GenerateHalfPower(int r, int n)
         {
             int p = 1;
-            for (int i = 0; i < n; ++i)
-            {
-                int m = -(p & 1);
-                p += r & m;
-                p >>= 1;
-            }
-
-            var permutation = new ushort[r];
-            permutation[0] = 0;
-            permutation[1] = (ushort)p;
-            for (int i = 2; i < r; ++i)
+            for (int k = 0; k < n; ++k)
             {
-                permutation[i] = (ushort)(((uint)i * (uint)p) % (uint)r);
+                p = ImplModHalf(r, p);
             }
-            return permutation;
+            return p;
         }
 
         private static void ImplMulwAcc(ulong[] u, ulong x, ulong y, ulong[] z, int zOff)