summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/math/BigInteger.cs148
1 files changed, 89 insertions, 59 deletions
diff --git a/crypto/src/math/BigInteger.cs b/crypto/src/math/BigInteger.cs
index d61d702c1..42b5b5089 100644
--- a/crypto/src/math/BigInteger.cs
+++ b/crypto/src/math/BigInteger.cs
@@ -1180,7 +1180,7 @@ namespace Org.BouncyCastle.Math
                 ?  1
                 :  sign == 0
                 ?  0
-                :  sign * CompareNoLeadingZeroes(0, magnitude, 0, other.magnitude);
+                :  sign * CompareNoLeadingZeros(0, magnitude, 0, other.magnitude);
         }
 
         /**
@@ -1199,10 +1199,10 @@ namespace Org.BouncyCastle.Math
                 yIndx++;
             }
 
-            return CompareNoLeadingZeroes(xIndx, x, yIndx, y);
+            return CompareNoLeadingZeros(xIndx, x, yIndx, y);
         }
 
-        private static int CompareNoLeadingZeroes(int xIndx, uint[] x, int yIndx, uint[] y)
+        private static int CompareNoLeadingZeros(int xIndx, uint[] x, int yIndx, uint[] y)
         {
             int diff = (x.Length - y.Length) - (xIndx - yIndx);
 
@@ -1243,7 +1243,7 @@ namespace Org.BouncyCastle.Math
 
             Debug.Assert(yStart < y.Length);
 
-            int xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+            int xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
             uint[] count;
 
             if (xyCmp > 0)
@@ -1280,7 +1280,7 @@ namespace Org.BouncyCastle.Math
                 for (;;)
                 {
                     if (cBitLength < xBitLength
-                        || CompareNoLeadingZeroes(xStart, x, cStart, c) >= 0)
+                        || CompareNoLeadingZeros(xStart, x, cStart, c) >= 0)
                     {
                         Subtract(xStart, x, cStart, c);
                         AddMagnitudes(count, iCount);
@@ -1298,7 +1298,7 @@ namespace Org.BouncyCastle.Math
                             if (xBitLength < yBitLength)
                                 return count;
 
-                            xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+                            xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
 
                             if (xyCmp <= 0)
                                 break;
@@ -1632,6 +1632,8 @@ namespace Org.BouncyCastle.Math
             BigInteger montRadix = One.ShiftLeft(32 * n.magnitude.Length).Remainder(n);
             BigInteger minusMontRadix = n.Subtract(montRadix);
 
+            uint[] yAccum = new uint[n.magnitude.Length + 1];
+
             do
             {
                 BigInteger a;
@@ -1642,7 +1644,7 @@ namespace Org.BouncyCastle.Math
                 while (a.sign == 0 || a.CompareTo(n) >= 0
                     || a.IsEqualMagnitude(montRadix) || a.IsEqualMagnitude(minusMontRadix));
 
-                BigInteger y = ModPowMonty(a, r, n, false);
+                BigInteger y = ModPowMonty(yAccum, a, r, n, false);
 
                 if (!y.Equals(montRadix))
                 {
@@ -1652,7 +1654,7 @@ namespace Org.BouncyCastle.Math
                         if (++j == s)
                             return false;
 
-                        y = ModPowMonty(y, Two, n, false);
+                        y = ModSquareMonty(yAccum, y, n);
 
                         if (y.Equals(montRadix))
                             return false;
@@ -1734,12 +1736,12 @@ namespace Org.BouncyCastle.Math
 //				for (;;)
 //				{
 //					// While F is even, do F=F/u, C=C*u, k=k+1.
-//					int zeroes = F.GetLowestSetBit();
-//					if (zeroes > 0)
+//					int zeros = F.GetLowestSetBit();
+//					if (zeros > 0)
 //					{
-//						F = F.ShiftRight(zeroes);
-//						C = C.ShiftLeft(zeroes);
-//						k += zeroes;
+//						F = F.ShiftRight(zeros);
+//						C = C.ShiftLeft(zeros);
+//						k += zeros;
 //					}
 //
 //					// If F = 1, then return B,k.
@@ -1900,7 +1902,8 @@ namespace Org.BouncyCastle.Math
                 }
                 else
                 {
-                    result = ModPowMonty(result, e, m, true);
+                    uint[] yAccum = new uint[m.magnitude.Length + 1];
+                    result = ModPowMonty(yAccum, result, e, m, true);
                 }
             }
 
@@ -1934,17 +1937,17 @@ namespace Org.BouncyCastle.Math
                 oddPowers[i] = ReduceBarrett(oddPowers[i - 1].Multiply(b2), m, mr, yu);
             }
 
-            int[] windowList = GetWindowList(e.magnitude, extraBits);
+            uint[] windowList = GetWindowList(e.magnitude, extraBits);
             Debug.Assert(windowList.Length > 0);
 
-            int window = windowList[0];
-            int mult = window & 0xFF, lastZeroes = window >> 8;
+            uint window = windowList[0];
+            uint mult = window & 0xFFU, lastZeros = window >> 8;
 
             BigInteger y;
             if (mult == 1)
             {
                 y = b2;
-                --lastZeroes;
+                --lastZeros;
             }
             else
             {
@@ -1952,11 +1955,11 @@ namespace Org.BouncyCastle.Math
             }
 
             int windowPos = 1;
-            while ((window = windowList[windowPos++]) != -1)
+            while ((window = windowList[windowPos++]) != uint.MaxValue)
             {
                 mult = window & 0xFF;
 
-                int bits = lastZeroes + BitLen((byte)mult);
+                int bits = (int)lastZeros + BitLen((byte)mult);
                 for (int j = 0; j < bits; ++j)
                 {
                     y = ReduceBarrett(y.Square(), m, mr, yu);
@@ -1964,10 +1967,10 @@ namespace Org.BouncyCastle.Math
 
                 y = ReduceBarrett(y.Multiply(oddPowers[mult >> 1]), m, mr, yu);
 
-                lastZeroes = window >> 8;
+                lastZeros = window >> 8;
             }
 
-            for (int i = 0; i < lastZeroes; ++i)
+            for (int i = 0; i < lastZeros; ++i)
             {
                 y = ReduceBarrett(y.Square(), m, mr, yu);
             }
@@ -2008,7 +2011,7 @@ namespace Org.BouncyCastle.Math
             return x;
         }
 
-        private static BigInteger ModPowMonty(BigInteger b, BigInteger e, BigInteger m, bool convert)
+        private static BigInteger ModPowMonty(uint[] yAccum, BigInteger b, BigInteger e, BigInteger m, bool convert)
         {
             int n = m.magnitude.Length;
             int powR = 32 * n;
@@ -2021,7 +2024,7 @@ namespace Org.BouncyCastle.Math
                 b = b.ShiftLeft(powR).Remainder(m);
             }
 
-            uint[] yAccum = new uint[n + 1];
+            Debug.Assert(yAccum.Length == n + 1);
 
             uint[] zVal = b.magnitude;
             Debug.Assert(zVal.Length <= n);
@@ -2059,17 +2062,17 @@ namespace Org.BouncyCastle.Math
                 MultiplyMonty(yAccum, oddPowers[i], zSquared, m.magnitude, mDash, smallMontyModulus);
             }
 
-            int[] windowList = GetWindowList(e.magnitude, extraBits);
+            uint[] windowList = GetWindowList(e.magnitude, extraBits);
             Debug.Assert(windowList.Length > 1);
 
-            int window = windowList[0];
-            int mult = window & 0xFF, lastZeroes = window >> 8;
+            uint window = windowList[0];
+            uint mult = window & 0xFF, lastZeros = window >> 8;
 
             uint[] yVal;
             if (mult == 1)
             {
                 yVal = zSquared;
-                --lastZeroes;
+                --lastZeros;
             }
             else
             {
@@ -2077,11 +2080,11 @@ namespace Org.BouncyCastle.Math
             }
 
             int windowPos = 1;
-            while ((window = windowList[windowPos++]) != -1)
+            while ((window = windowList[windowPos++]) != uint.MaxValue)
             {
                 mult = window & 0xFF;
 
-                int bits = lastZeroes + BitLen((byte)mult);
+                int bits = (int)lastZeros + BitLen((byte)mult);
                 for (int j = 0; j < bits; ++j)
                 {
                     SquareMonty(yAccum, yVal, m.magnitude, mDash, smallMontyModulus);
@@ -2089,10 +2092,10 @@ namespace Org.BouncyCastle.Math
 
                 MultiplyMonty(yAccum, yVal, oddPowers[mult >> 1], m.magnitude, mDash, smallMontyModulus);
 
-                lastZeroes = window >> 8;
+                lastZeros = window >> 8;
             }
 
-            for (int i = 0; i < lastZeroes; ++i)
+            for (int i = 0; i < lastZeros; ++i)
             {
                 SquareMonty(yAccum, yVal, m.magnitude, mDash, smallMontyModulus);
             }
@@ -2110,22 +2113,49 @@ namespace Org.BouncyCastle.Math
             return new BigInteger(1, yVal, true);
         }
 
-        private static int[] GetWindowList(uint[] mag, int extraBits)
+        private static BigInteger ModSquareMonty(uint[] yAccum, BigInteger b, BigInteger m)
+        {
+            int n = m.magnitude.Length;
+            int powR = 32 * n;
+            bool smallMontyModulus = m.BitLength + 2 <= powR;
+            uint mDash = m.GetMQuote();
+
+            Debug.Assert(yAccum.Length == n + 1);
+
+            uint[] zVal = b.magnitude;
+            Debug.Assert(zVal.Length <= n);
+
+            uint[] yVal = new uint[n];
+            zVal.CopyTo(yVal, n - zVal.Length);
+
+            SquareMonty(yAccum, yVal, m.magnitude, mDash, smallMontyModulus);
+
+            if (smallMontyModulus && CompareTo(0, yVal, 0, m.magnitude) >= 0)
+            {
+                Subtract(0, yVal, 0, m.magnitude);
+            }
+
+            return new BigInteger(1, yVal, true);
+        }
+
+        private static uint[] GetWindowList(uint[] mag, int extraBits)
         {
-            int v = (int)mag[0];
-            Debug.Assert(v != 0);
+            uint v = mag[0];
+            Debug.Assert(v != 0U);
 
-            int leadingBits = BitLen((uint)v);
+            int leadingBits = BitLen(v);
+            int totalBits = ((mag.Length - 1) << 5) + leadingBits;
 
-            int resultSize = (((mag.Length - 1) << 5) + leadingBits) / (1 + extraBits) + 2;
-            int[] result = new int[resultSize];
+            int resultSize = (totalBits + extraBits) / (1 + extraBits) + 1;
+            uint[] result = new uint[resultSize];
             int resultPos = 0;
 
             int bitPos = 33 - leadingBits;
             v <<= bitPos;
 
-            int mult = 1, multLimit = 1 << extraBits;
-            int zeroes = 0;
+            uint mult = 1U;
+            uint multLimit = 1U << extraBits;
+            uint zeros = 0U;
 
             int i = 0;
             for (;;)
@@ -2134,17 +2164,17 @@ namespace Org.BouncyCastle.Math
                 {
                     if (mult < multLimit)
                     {
-                        mult = (mult << 1) | (int)((uint)v >> 31);
+                        mult = (mult << 1) | (v >> 31);
                     }
-                    else if (v < 0)
+                    else if ((int)v < 0)
                     {
-                        result[resultPos++] = CreateWindowEntry(mult, zeroes);
-                        mult = 1;
-                        zeroes = 0;
+                        result[resultPos++] = CreateWindowEntry(mult, zeros);
+                        mult = 1U;
+                        zeros = 0U;
                     }
                     else
                     {
-                        ++zeroes;
+                        ++zeros;
                     }
 
                     v <<= 1;
@@ -2152,35 +2182,35 @@ namespace Org.BouncyCastle.Math
 
                 if (++i == mag.Length)
                 {
-                    result[resultPos++] = CreateWindowEntry(mult, zeroes);
+                    result[resultPos++] = CreateWindowEntry(mult, zeros);
                     break;
                 }
 
-                v = (int)mag[i];
+                v = mag[i];
                 bitPos = 0;
             }
 
-            result[resultPos] = -1;
+            result[resultPos] = uint.MaxValue; // Sentinel value
             return result;
         }
 
-        private static int CreateWindowEntry(int mult, int zeroes)
+        private static uint CreateWindowEntry(uint mult, uint zeros)
         {
             Debug.Assert(mult > 0);
 
 #if NETCOREAPP3_0_OR_GREATER
             int tz = BitOperations.TrailingZeroCount(mult);
             mult >>= tz;
-            zeroes += tz;
+            zeros += (uint)tz;
 #else
-            while ((mult & 1) == 0)
+            while ((mult & 1U) == 0U)
             {
                 mult >>= 1;
-                ++zeroes;
+                ++zeros;
             }
 #endif
 
-            return mult | (zeroes << 8);
+            return mult | (zeros << 8);
         }
 
         /**
@@ -2691,7 +2721,7 @@ namespace Org.BouncyCastle.Math
 
             Debug.Assert(yStart < y.Length);
 
-            int xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+            int xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
 
             if (xyCmp > 0)
             {
@@ -2718,7 +2748,7 @@ namespace Org.BouncyCastle.Math
                 for (;;)
                 {
                     if (cBitLength < xBitLength
-                        || CompareNoLeadingZeroes(xStart, x, cStart, c) >= 0)
+                        || CompareNoLeadingZeros(xStart, x, cStart, c) >= 0)
                     {
                         Subtract(xStart, x, cStart, c);
 
@@ -2735,7 +2765,7 @@ namespace Org.BouncyCastle.Math
                             if (xBitLength < yBitLength)
                                 return x;
 
-                            xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+                            xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
 
                             if (xyCmp <= 0)
                                 break;
@@ -2808,7 +2838,7 @@ namespace Org.BouncyCastle.Math
                 }
             }
 
-            if (CompareNoLeadingZeroes(0, magnitude, 0, n.magnitude) < 0)
+            if (CompareNoLeadingZeros(0, magnitude, 0, n.magnitude) < 0)
                 return this;
 
             uint[] result;
@@ -3103,7 +3133,7 @@ namespace Org.BouncyCastle.Math
             if (this.sign != n.sign)
                 return Add(n.Negate());
 
-            int compare = CompareNoLeadingZeroes(0, magnitude, 0, n.magnitude);
+            int compare = CompareNoLeadingZeros(0, magnitude, 0, n.magnitude);
             if (compare == 0)
                 return Zero;