summary refs log tree commit diff
path: root/crypto/src/math/ec/rfc8032/Ed25519.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/math/ec/rfc8032/Ed25519.cs')
-rw-r--r--crypto/src/math/ec/rfc8032/Ed25519.cs314
1 files changed, 178 insertions, 136 deletions
diff --git a/crypto/src/math/ec/rfc8032/Ed25519.cs b/crypto/src/math/ec/rfc8032/Ed25519.cs
index c1820d00f..49f7b23a9 100644
--- a/crypto/src/math/ec/rfc8032/Ed25519.cs
+++ b/crypto/src/math/ec/rfc8032/Ed25519.cs
@@ -73,10 +73,9 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
         private const int PrecompPoints = 1 << (PrecompTeeth - 1);
         private const int PrecompMask = PrecompPoints - 1;
 
-        private static readonly object precompLock = new object();
-        // TODO[ed25519] Convert to PointPrecomp
-        private static PointExtended[] precompBaseTable = null;
-        private static int[] precompBase = null;
+        private static readonly object PrecompLock = new object();
+        private static PointPrecomp[] PrecompBaseWnaf = null;
+        private static int[] PrecompBaseComb = null;
 
         private ref struct PointAccum
         {
@@ -93,7 +92,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             internal int[] x, y, z, t;
         }
 
-        private ref struct PointPrecomp
+        private struct PointPrecomp
         {
             internal int[] ypx_h, ymx_h, xyd;
         }
@@ -466,7 +465,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             if (!CheckScalarVar(S, nS))
                 return false;
 
-            PointAffine pA; InitAffine(out pA);
+            PointAffine pA; Init(out pA);
             if (!DecodePointVar(pk, pkOff, true, ref pA))
                 return false;
 
@@ -484,14 +483,14 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             uint[] nA = new uint[ScalarUints];
             DecodeScalar(k, 0, nA);
 
-            PointAccum pR; InitAccum(out pR);
+            PointAccum pR; Init(out pR);
             ScalarMultStrausVar(nS, nA, ref pA, ref pR);
 
             byte[] check = new byte[PointBytes];
             return 0 != EncodePoint(ref pR, check, 0) && Arrays.AreEqual(check, R);
         }
 
-        private static void InitAccum(out PointAccum r)
+        private static void Init(out PointAccum r)
         {
             r.x = F.Create();
             r.y = F.Create();
@@ -500,13 +499,13 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             r.v = F.Create();
         }
 
-        private static void InitAffine(out PointAffine r)
+        private static void Init(out PointAffine r)
         {
             r.x = F.Create();
             r.y = F.Create();
         }
 
-        private static void InitExtended(out PointExtended r)
+        private static void Init(out PointExtended r)
         {
             r.x = F.Create();
             r.y = F.Create();
@@ -514,13 +513,47 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             r.t = F.Create();
         }
 
-        private static void InitPrecomp(out PointPrecomp r)
+        private static void Init(out PointPrecomp r)
         {
             r.ypx_h = F.Create();
             r.ymx_h = F.Create();
             r.xyd = F.Create();
         }
 
+        private static void InvertDoubleZs(PointExtended[] points)
+        {
+            int count = points.Length;
+            int[] cs = F.CreateTable(count);
+
+            int[] u = F.Create();
+            F.Copy(points[0].z, 0, u, 0);
+            F.Copy(u, 0, cs, 0);
+
+            int i = 0;
+            while (++i < count)
+            {
+                F.Mul(u, points[i].z, u);
+                F.Copy(u, 0, cs, i * F.Size);
+            }
+
+            F.Add(u, u, u);
+            F.InvVar(u, u);
+            --i;
+
+            int[] t = F.Create();
+
+            while (i > 0)
+            {
+                int j = i--;
+                F.Copy(cs, i * F.Size, t, 0);
+                F.Mul(t, u, t);
+                F.Mul(u, points[j].z, u);
+                F.Copy(t, 0, points[j].z, 0);
+            }
+
+            F.Copy(u, 0, points[0].z, 0);
+        }
+
         private static bool IsNeutralElementVar(int[] x, int[] y)
         {
             return F.IsZeroVar(x) && F.IsOneVar(y);
@@ -686,6 +719,38 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             F.Mul(f, g, r.z);
         }
 
+        private static void PointAddPrecompVar(int sign, ref PointPrecomp p, ref PointAccum r)
+        {
+            int[] a = F.Create();
+            int[] b = F.Create();
+            int[] c = F.Create();
+            int[] e = r.u;
+            int[] f = F.Create();
+            int[] g = F.Create();
+            int[] h = r.v;
+
+            F.Apm(r.y, r.x, b, a);
+            if (sign == 0)
+            {
+                F.Mul(a, p.ymx_h, a);
+                F.Mul(b, p.ypx_h, b);
+            }
+            else
+            {
+                F.Mul(a, p.ypx_h, a);
+                F.Mul(b, p.ymx_h, b);
+            }
+            F.Mul(r.u, r.v, c);
+            F.Mul(c, p.xyd, c);
+            F.CNegate(sign, c);
+            F.Apm(b, a, h, e);
+            F.Apm(r.z, c, g, f);
+            F.Carry(g);
+            F.Mul(e, f, r.x);
+            F.Mul(g, h, r.y);
+            F.Mul(f, g, r.z);
+        }
+
         private static void PointCopy(ref PointAccum p, ref PointExtended r)
         {
             F.Copy(p.x, 0, r.x, 0);
@@ -757,9 +822,9 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             for (int i = 0; i < PrecompPoints; ++i)
             {
                 int cond = ((i ^ index) - 1) >> 31;
-                F.CMov(cond, precompBase, off, p.ypx_h, 0);     off += F.Size;
-                F.CMov(cond, precompBase, off, p.ymx_h, 0);     off += F.Size;
-                F.CMov(cond, precompBase, off, p.xyd, 0);       off += F.Size;
+                F.CMov(cond, PrecompBaseComb, off, p.ypx_h, 0);     off += F.Size;
+                F.CMov(cond, PrecompBaseComb, off, p.ymx_h, 0);     off += F.Size;
+                F.CMov(cond, PrecompBaseComb, off, p.xyd  , 0);     off += F.Size;
             }
         }
 
@@ -792,10 +857,10 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
         {
             Debug.Assert(count > 0);
 
-            PointExtended q; InitExtended(out q);
+            PointExtended q; Init(out q);
             PointCopy(ref p, ref q);
 
-            PointExtended d; InitExtended(out d);
+            PointExtended d; Init(out d);
             PointCopy(ref q, ref d);
             PointAdd(ref q, ref d);
 
@@ -805,10 +870,10 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             int i = 0;
             for (;;)
             {
-                F.Copy(q.x, 0, table, off); off += F.Size;
-                F.Copy(q.y, 0, table, off); off += F.Size;
-                F.Copy(q.z, 0, table, off); off += F.Size;
-                F.Copy(q.t, 0, table, off); off += F.Size;
+                F.Copy(q.x, 0, table, off);     off += F.Size;
+                F.Copy(q.y, 0, table, off);     off += F.Size;
+                F.Copy(q.z, 0, table, off);     off += F.Size;
+                F.Copy(q.t, 0, table, off);     off += F.Size;
 
                 if (++i == count)
                     break;
@@ -819,22 +884,21 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             return table;
         }
 
-        private static PointExtended[] PointPrecomputeVar(ref PointExtended p, int count)
+        private static void PointPrecomputeVar(ref PointAffine p, PointExtended[] points, int count)
         {
             Debug.Assert(count > 0);
 
-            PointExtended d; InitExtended(out d);
-            PointAddVar(false, ref p, ref p, ref d);
+            Init(out points[0]);
+            PointCopy(ref p, ref points[0]);
+
+            PointExtended d; Init(out d);
+            PointAddVar(false, ref points[0], ref points[0], ref d);
 
-            PointExtended[] table = new PointExtended[count];
-            InitExtended(out table[0]);
-            PointCopy(ref p, ref table[0]);
             for (int i = 1; i < count; ++i)
             {
-                InitExtended(out table[i]);
-                PointAddVar(false, ref table[i - 1], ref d, ref table[i]);
+                Init(out points[i]);
+                PointAddVar(false, ref points[i - 1], ref d, ref points[i]);
             }
-            return table;
         }
 
         private static void PointSetNeutral(ref PointAccum p)
@@ -856,47 +920,49 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 
         public static void Precompute()
         {
-            lock (precompLock)
+            lock (PrecompLock)
             {
-                if (precompBase != null)
+                if (PrecompBaseWnaf != null && PrecompBaseComb != null)
                     return;
 
-                // Precomputed table for the base point in verification ladder
-                {
-                    PointExtended b; InitExtended(out b);
-                    F.Copy(B_x, 0, b.x, 0);
-                    F.Copy(B_y, 0, b.y, 0);
-                    PointExtendXY(ref b);
+                int wnafPoints = 1 << (WnafWidthBase - 2);
+                int combPoints = PrecompBlocks * PrecompPoints;
+                int totalPoints = wnafPoints + combPoints;
 
-                    precompBaseTable = PointPrecomputeVar(ref b, 1 << (WnafWidthBase - 2));
-                }
+                PointExtended[] points = new PointExtended[totalPoints];
 
-                PointAccum p; InitAccum(out p);
+                PointAffine b;
+                b.x = B_x;
+                b.y = B_y;
+
+                PointPrecomputeVar(ref b, points, wnafPoints);
+
+                PointAccum p; Init(out p);
                 F.Copy(B_x, 0, p.x, 0);
                 F.Copy(B_y, 0, p.y, 0);
                 PointExtendXY(ref p);
 
-                precompBase = F.CreateTable(PrecompBlocks * PrecompPoints * 3);
-
-                int off = 0;
-                for (int b = 0; b < PrecompBlocks; ++b)
+                int pointsIndex = wnafPoints;
+                PointExtended[] toothPowers = new PointExtended[PrecompTeeth];
+                for (int tooth = 0; tooth < PrecompTeeth; ++tooth)
                 {
-                    PointExtended[] ds = new PointExtended[PrecompTeeth];
-
-                    PointExtended sum; InitExtended(out sum);
+                    Init(out toothPowers[tooth]);
+                }
+                PointExtended u; Init(out u);
+                for (int block = 0; block < PrecompBlocks; ++block)
+                {
+                    ref PointExtended sum = ref points[pointsIndex++];
+                    Init(out sum);
                     PointSetNeutral(ref sum);
 
-                    for (int t = 0; t < PrecompTeeth; ++t)
+                    for (int tooth = 0; tooth < PrecompTeeth; ++tooth)
                     {
-                        PointExtended q; InitExtended(out q);
-                        PointCopy(ref p, ref q);
-                        PointAddVar(true, ref sum, ref q, ref sum);
+                        PointCopy(ref p, ref u);
+                        PointAddVar(true, ref sum, ref u, ref sum);
                         PointDouble(ref p);
+                        PointCopy(ref p, ref toothPowers[tooth]);
 
-                        InitExtended(out ds[t]);
-                        PointCopy(ref p, ref ds[t]);
-
-                        if (b + t != PrecompBlocks + PrecompTeeth - 2)
+                        if (block + tooth != PrecompBlocks + PrecompTeeth - 2)
                         {
                             for (int s = 1; s < PrecompSpacing; ++s)
                             {
@@ -905,85 +971,63 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
                         }
                     }
 
-                    PointExtended[] points = new PointExtended[PrecompPoints];
-                    int k = 0;
-                    points[k++] = sum;
-
-                    for (int t = 0; t < (PrecompTeeth - 1); ++t)
+                    for (int tooth = 0; tooth < (PrecompTeeth - 1); ++tooth)
                     {
-                        int size = 1 << t;
-                        for (int j = 0; j < size; ++j, ++k)
+                        int size = 1 << tooth;
+                        for (int j = 0; j < size; ++j, ++pointsIndex)
                         {
-                            InitExtended(out points[k]);
-                            PointAddVar(false, ref points[k - size], ref ds[t], ref points[k]);
+                            Init(out points[pointsIndex]);
+                            PointAddVar(false, ref points[pointsIndex - size], ref toothPowers[tooth],
+                                ref points[pointsIndex]);
                         }
                     }
+                }
+                Debug.Assert(pointsIndex == totalPoints);
 
-                    Debug.Assert(k == PrecompPoints);
-
-                    int[] cs = F.CreateTable(PrecompPoints);
-
-                    // TODO[ed25519] A single batch inversion across all blocks?
-                    {
-                        int[] u = F.Create();
-                        F.Copy(points[0].z, 0, u, 0);
-                        F.Copy(u, 0, cs, 0);
-
-                        int i = 0;
-                        while (++i < PrecompPoints)
-                        {
-                            F.Mul(u, points[i].z, u);
-                            F.Copy(u, 0, cs, i * F.Size);
-                        }
-
-                        F.Add(u, u, u);
-                        F.InvVar(u, u);
-                        --i;
-
-                        int[] t = F.Create();
+                InvertDoubleZs(points);
 
-                        while (i > 0)
-                        {
-                            int j = i--;
-                            F.Copy(cs, i * F.Size, t, 0);
-                            F.Mul(t, u, t);
-                            F.Copy(t, 0, cs, j * F.Size);
-                            F.Mul(u, points[j].z, u);
-                        }
+                PrecompBaseWnaf = new PointPrecomp[wnafPoints];
+                for (int i = 0; i < wnafPoints; ++i)
+                {
+                    ref PointExtended q = ref points[i];
+                    ref PointPrecomp r = ref PrecompBaseWnaf[i];
+                    Init(out r);
 
-                        F.Copy(u, 0, cs, 0);
-                    }
+                    F.Mul(q.x, q.z, q.x);
+                    F.Mul(q.y, q.z, q.y);
 
-                    for (int i = 0; i < PrecompPoints; ++i)
-                    {
-                        ref PointExtended q = ref points[i];
+                    F.Apm(q.y, q.x, r.ypx_h, r.ymx_h);
+                    F.Mul(q.x, q.y, r.xyd);
+                    F.Mul(r.xyd, C_d4, r.xyd);
 
-                        int[] x = F.Create();
-                        int[] y = F.Create();
+                    F.Normalize(r.ypx_h);
+                    F.Normalize(r.ymx_h);
+                    F.Normalize(r.xyd);
+                }
 
-                        //F.Add(q.z, q.z, x);
-                        //F.InvVar(x, y);
-                        F.Copy(cs, i * F.Size, y, 0);
+                PrecompBaseComb = F.CreateTable(combPoints * 3);
+                PointPrecomp t; Init(out t);
+                int off = 0;
+                for (int i = wnafPoints; i < totalPoints; ++i)
+                {
+                    ref PointExtended q = ref points[i];
 
-                        F.Mul(q.x, y, x);
-                        F.Mul(q.y, y, y);
+                    F.Mul(q.x, q.z, q.x);
+                    F.Mul(q.y, q.z, q.y);
 
-                        PointPrecomp r; InitPrecomp(out r);
-                        F.Apm(y, x, r.ypx_h, r.ymx_h);
-                        F.Mul(x, y, r.xyd);
-                        F.Mul(r.xyd, C_d4, r.xyd);
+                    F.Apm(q.y, q.x, t.ypx_h, t.ymx_h);
+                    F.Mul(q.x, q.y, t.xyd);
+                    F.Mul(t.xyd, C_d4, t.xyd);
 
-                        F.Normalize(r.ypx_h);
-                        F.Normalize(r.ymx_h);
-                        //F.Normalize(r.xyd);
+                    F.Normalize(t.ypx_h);
+                    F.Normalize(t.ymx_h);
+                    F.Normalize(t.xyd);
 
-                        F.Copy(r.ypx_h, 0, precompBase, off);       off += F.Size;
-                        F.Copy(r.ymx_h, 0, precompBase, off);       off += F.Size;
-                        F.Copy(r.xyd,   0, precompBase, off);       off += F.Size;
-                    }
+                    F.Copy(t.ypx_h, 0, PrecompBaseComb, off);       off += F.Size;
+                    F.Copy(t.ymx_h, 0, PrecompBaseComb, off);       off += F.Size;
+                    F.Copy(t.xyd  , 0, PrecompBaseComb, off);       off += F.Size;
                 }
-
-                Debug.Assert(off == precompBase.Length);
+                Debug.Assert(off == PrecompBaseComb.Length);
             }
         }
 
@@ -1144,7 +1188,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             }
 
             int[] table = PointPrecompute(ref p, 8);
-            PointExtended q; InitExtended(out q);
+            PointExtended q; Init(out q);
 
             PointSetNeutral(ref r);
 
@@ -1188,7 +1232,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
                 }
             }
 
-            PointPrecomp p; InitPrecomp(out p);
+            PointPrecomp p; Init(out p);
 
             PointSetNeutral(ref r);
 
@@ -1221,7 +1265,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 
         private static void ScalarMultBaseEncoded(byte[] k, byte[] r, int rOff)
         {
-            PointAccum p; InitAccum(out p);
+            PointAccum p; Init(out p);
             ScalarMultBase(k, ref p);
             if (0 == EncodePoint(ref p, r, rOff))
                 throw new InvalidOperationException();
@@ -1232,7 +1276,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             byte[] n = new byte[ScalarBytes];
             PruneScalar(k, kOff, n);
 
-            PointAccum p; InitAccum(out p);
+            PointAccum p; Init(out p);
             ScalarMultBase(n, ref p);
 
             if (0 == CheckPoint(p.x, p.y, p.z))
@@ -1248,10 +1292,9 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 
             sbyte[] ws_p = GetWnafVar(L, width);
 
-            PointExtended q; InitExtended(out q);
-            PointCopy(ref p, ref q);
-
-            PointExtended[] tp = PointPrecomputeVar(ref q, 1 << (width - 2));
+            int count = 1 << (width - 2);
+            PointExtended[] tp = new PointExtended[count];
+            PointPrecomputeVar(ref p, tp, count);
 
             PointSetNeutral(ref r);
 
@@ -1282,10 +1325,9 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             sbyte[] ws_b = GetWnafVar(nb, WnafWidthBase);
             sbyte[] ws_p = GetWnafVar(np, width);
 
-            PointExtended q; InitExtended(out q);
-            PointCopy(ref p, ref q);
-
-            PointExtended[] tp = PointPrecomputeVar(ref q, 1 << (width - 2));
+            int count = 1 << (width - 2);
+            PointExtended[] tp = new PointExtended[count];
+            PointPrecomputeVar(ref p, tp, count);
 
             PointSetNeutral(ref r);
 
@@ -1297,7 +1339,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
                     int sign = wb >> 31;
                     int index = (wb ^ sign) >> 1;
 
-                    PointAddVar(sign != 0, ref precompBaseTable[index], ref r);
+                    PointAddPrecompVar(-sign, ref PrecompBaseWnaf[index], ref r);
                 }
 
                 int wp = ws_p[bit];
@@ -1306,7 +1348,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
                     int sign = wp >> 31;
                     int index = (wp ^ sign) >> 1;
 
-                    PointAddVar((sign != 0), ref tp[index], ref r);
+                    PointAddVar(sign != 0, ref tp[index], ref r);
                 }
 
                 if (--bit < 0)
@@ -1384,7 +1426,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 
         public static bool ValidatePublicKeyFull(byte[] pk, int pkOff)
         {
-            PointAffine p; InitAffine(out p);
+            PointAffine p; Init(out p);
             if (!DecodePointVar(pk, pkOff, false, ref p))
                 return false;
 
@@ -1394,7 +1436,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             if (IsNeutralElementVar(p.x, p.y))
                 return false;
 
-            PointAccum r; InitAccum(out r);
+            PointAccum r; Init(out r);
             ScalarMultOrderVar(ref p, ref r);
 
             F.Normalize(r.x);
@@ -1406,7 +1448,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 
         public static bool ValidatePublicKeyPartial(byte[] pk, int pkOff)
         {
-            PointAffine p; InitAffine(out p);
+            PointAffine p; Init(out p);
             return DecodePointVar(pk, pkOff, false, ref p);
         }