summary refs log tree commit diff
path: root/crypto/src/math/ec/rfc8032/Scalar25519.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/math/ec/rfc8032/Scalar25519.cs')
-rw-r--r--crypto/src/math/ec/rfc8032/Scalar25519.cs26
1 files changed, 20 insertions, 6 deletions
diff --git a/crypto/src/math/ec/rfc8032/Scalar25519.cs b/crypto/src/math/ec/rfc8032/Scalar25519.cs
index 67eee6155..08ab80607 100644
--- a/crypto/src/math/ec/rfc8032/Scalar25519.cs
+++ b/crypto/src/math/ec/rfc8032/Scalar25519.cs
@@ -595,7 +595,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 #endif
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-        internal static void ReduceBasisVar(ReadOnlySpan<uint> k, Span<uint> z0, Span<uint> z1)
+        internal static bool ReduceBasisVar(ReadOnlySpan<uint> k, Span<uint> z0, Span<uint> z1)
         {
             /*
              * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
@@ -606,28 +606,34 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             Span<uint> Nu = stackalloc uint[16];    LSq.CopyTo(Nu);
             Span<uint> Nv = stackalloc uint[16];    Nat256.Square(k, Nv); ++Nv[0];
             Span<uint> p  = stackalloc uint[16];    Nat256.Mul(L, k, p);
+            Span<uint> t  = stackalloc uint[16];
             Span<uint> u0 = stackalloc uint[4];     u0.CopyFrom(L);
             Span<uint> u1 = stackalloc uint[4];
             Span<uint> v0 = stackalloc uint[4];     v0.CopyFrom(k);
             Span<uint> v1 = stackalloc uint[4];     v1[0] = 1U;
 
+            // Conservative upper bound on the number of loop iterations needed
+            int iterations = TargetLength * 4;
             int last = 15;
             int len_Nv = ScalarUtilities.GetBitLengthPositive(last, Nv);
 
             while (len_Nv > TargetLength)
             {
+                if (--iterations < 0)
+                    return false;
+
                 int len_p = ScalarUtilities.GetBitLength(last, p);
                 int s = len_p - len_Nv;
                 s &= ~(s >> 31);
 
                 if ((int)p[last] < 0)
                 {
-                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.AddShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
                 else
                 {
-                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.SubShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
 
@@ -645,9 +651,10 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             // v1 * k == v0 mod L
             v0.CopyTo(z0);
             v1.CopyTo(z1);
+            return true;
         }
 #else
-        internal static void ReduceBasisVar(uint[] k, uint[] z0, uint[] z1)
+        internal static bool ReduceBasisVar(uint[] k, uint[] z0, uint[] z1)
         {
             /*
              * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
@@ -658,28 +665,34 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             uint[] Nu = new uint[16];       Array.Copy(LSq, Nu, 16);
             uint[] Nv = new uint[16];       Nat256.Square(k, Nv); ++Nv[0];
             uint[] p  = new uint[16];       Nat256.Mul(L, k, p);
+            uint[] t  = new uint[16];
             uint[] u0 = new uint[4];        Array.Copy(L, u0, 4);
             uint[] u1 = new uint[4];
             uint[] v0 = new uint[4];        Array.Copy(k, v0, 4);
             uint[] v1 = new uint[4];        v1[0] = 1U;
 
+            // Conservative upper bound on the number of loop iterations needed
+            int iterations = TargetLength * 4;
             int last = 15;
             int len_Nv = ScalarUtilities.GetBitLengthPositive(last, Nv);
 
             while (len_Nv > TargetLength)
             {
+                if (--iterations < 0)
+                    return false;
+
                 int len_p = ScalarUtilities.GetBitLength(last, p);
                 int s = len_p - len_Nv;
                 s &= ~(s >> 31);
 
                 if ((int)p[last] < 0)
                 {
-                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.AddShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
                 else
                 {
-                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.SubShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
 
@@ -697,6 +710,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             // v1 * k == v0 mod L
             Array.Copy(v0, z0, 4);
             Array.Copy(v1, z1, 4);
+            return true;
         }
 #endif