summary refs log tree commit diff
path: root/crypto/src/math/BigInteger.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/math/BigInteger.cs')
-rw-r--r--crypto/src/math/BigInteger.cs219
1 files changed, 133 insertions, 86 deletions
diff --git a/crypto/src/math/BigInteger.cs b/crypto/src/math/BigInteger.cs
index 7da886c4f..42b5b5089 100644
--- a/crypto/src/math/BigInteger.cs
+++ b/crypto/src/math/BigInteger.cs
@@ -139,6 +139,8 @@ namespace Org.BouncyCastle.Math
         public static readonly BigInteger Two;
         public static readonly BigInteger Three;
         public static readonly BigInteger Four;
+        public static readonly BigInteger Five;
+        public static readonly BigInteger Six;
         public static readonly BigInteger Ten;
 
 #if !NETCOREAPP3_0_OR_GREATER
@@ -181,27 +183,34 @@ namespace Org.BouncyCastle.Math
         static BigInteger()
         {
             Zero = new BigInteger(0, ZeroMagnitude, false);
-            Zero.nBits = 0; Zero.nBitLength = 0;
+            Zero.nBits = 0;
+            Zero.nBitLength = 0;
 
             SMALL_CONSTANTS[0] = Zero;
             for (uint i = 1; i < SMALL_CONSTANTS.Length; ++i)
             {
-                SMALL_CONSTANTS[i] = CreateUValueOf(i);
+                var sc = CreateUValueOf(i);
+                sc.nBits = Integers.PopCount(i);
+                sc.nBitLength = BitLen(i);
+
+                SMALL_CONSTANTS[i] = sc;
             }
 
             One = SMALL_CONSTANTS[1];
             Two = SMALL_CONSTANTS[2];
             Three = SMALL_CONSTANTS[3];
             Four = SMALL_CONSTANTS[4];
+            Five = SMALL_CONSTANTS[5];
+            Six = SMALL_CONSTANTS[6];
             Ten = SMALL_CONSTANTS[10];
 
-            radix2 = ValueOf(2);
+            radix2 = Two;
             radix2E = radix2.Pow(chunk2);
 
             radix8 = ValueOf(8);
             radix8E = radix8.Pow(chunk8);
 
-            radix10 = ValueOf(10);
+            radix10 = Ten;
             radix10E = radix10.Pow(chunk10);
 
             radix16 = ValueOf(16);
@@ -1171,7 +1180,7 @@ namespace Org.BouncyCastle.Math
                 ?  1
                 :  sign == 0
                 ?  0
-                :  sign * CompareNoLeadingZeroes(0, magnitude, 0, other.magnitude);
+                :  sign * CompareNoLeadingZeros(0, magnitude, 0, other.magnitude);
         }
 
         /**
@@ -1190,10 +1199,10 @@ namespace Org.BouncyCastle.Math
                 yIndx++;
             }
 
-            return CompareNoLeadingZeroes(xIndx, x, yIndx, y);
+            return CompareNoLeadingZeros(xIndx, x, yIndx, y);
         }
 
-        private static int CompareNoLeadingZeroes(int xIndx, uint[] x, int yIndx, uint[] y)
+        private static int CompareNoLeadingZeros(int xIndx, uint[] x, int yIndx, uint[] y)
         {
             int diff = (x.Length - y.Length) - (xIndx - yIndx);
 
@@ -1234,7 +1243,7 @@ namespace Org.BouncyCastle.Math
 
             Debug.Assert(yStart < y.Length);
 
-            int xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+            int xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
             uint[] count;
 
             if (xyCmp > 0)
@@ -1271,7 +1280,7 @@ namespace Org.BouncyCastle.Math
                 for (;;)
                 {
                     if (cBitLength < xBitLength
-                        || CompareNoLeadingZeroes(xStart, x, cStart, c) >= 0)
+                        || CompareNoLeadingZeros(xStart, x, cStart, c) >= 0)
                     {
                         Subtract(xStart, x, cStart, c);
                         AddMagnitudes(count, iCount);
@@ -1289,7 +1298,7 @@ namespace Org.BouncyCastle.Math
                             if (xBitLength < yBitLength)
                                 return count;
 
-                            xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+                            xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
 
                             if (xyCmp <= 0)
                                 break;
@@ -1623,6 +1632,8 @@ namespace Org.BouncyCastle.Math
             BigInteger montRadix = One.ShiftLeft(32 * n.magnitude.Length).Remainder(n);
             BigInteger minusMontRadix = n.Subtract(montRadix);
 
+            uint[] yAccum = new uint[n.magnitude.Length + 1];
+
             do
             {
                 BigInteger a;
@@ -1633,7 +1644,7 @@ namespace Org.BouncyCastle.Math
                 while (a.sign == 0 || a.CompareTo(n) >= 0
                     || a.IsEqualMagnitude(montRadix) || a.IsEqualMagnitude(minusMontRadix));
 
-                BigInteger y = ModPowMonty(a, r, n, false);
+                BigInteger y = ModPowMonty(yAccum, a, r, n, false);
 
                 if (!y.Equals(montRadix))
                 {
@@ -1643,7 +1654,7 @@ namespace Org.BouncyCastle.Math
                         if (++j == s)
                             return false;
 
-                        y = ModPowMonty(y, Two, n, false);
+                        y = ModSquareMonty(yAccum, y, n);
 
                         if (y.Equals(montRadix))
                             return false;
@@ -1725,12 +1736,12 @@ namespace Org.BouncyCastle.Math
 //				for (;;)
 //				{
 //					// While F is even, do F=F/u, C=C*u, k=k+1.
-//					int zeroes = F.GetLowestSetBit();
-//					if (zeroes > 0)
+//					int zeros = F.GetLowestSetBit();
+//					if (zeros > 0)
 //					{
-//						F = F.ShiftRight(zeroes);
-//						C = C.ShiftLeft(zeroes);
-//						k += zeroes;
+//						F = F.ShiftRight(zeros);
+//						C = C.ShiftLeft(zeros);
+//						k += zeros;
 //					}
 //
 //					// If F = 1, then return B,k.
@@ -1891,7 +1902,8 @@ namespace Org.BouncyCastle.Math
                 }
                 else
                 {
-                    result = ModPowMonty(result, e, m, true);
+                    uint[] yAccum = new uint[m.magnitude.Length + 1];
+                    result = ModPowMonty(yAccum, result, e, m, true);
                 }
             }
 
@@ -1925,17 +1937,17 @@ namespace Org.BouncyCastle.Math
                 oddPowers[i] = ReduceBarrett(oddPowers[i - 1].Multiply(b2), m, mr, yu);
             }
 
-            int[] windowList = GetWindowList(e.magnitude, extraBits);
+            uint[] windowList = GetWindowList(e.magnitude, extraBits);
             Debug.Assert(windowList.Length > 0);
 
-            int window = windowList[0];
-            int mult = window & 0xFF, lastZeroes = window >> 8;
+            uint window = windowList[0];
+            uint mult = window & 0xFFU, lastZeros = window >> 8;
 
             BigInteger y;
             if (mult == 1)
             {
                 y = b2;
-                --lastZeroes;
+                --lastZeros;
             }
             else
             {
@@ -1943,11 +1955,11 @@ namespace Org.BouncyCastle.Math
             }
 
             int windowPos = 1;
-            while ((window = windowList[windowPos++]) != -1)
+            while ((window = windowList[windowPos++]) != uint.MaxValue)
             {
                 mult = window & 0xFF;
 
-                int bits = lastZeroes + BitLen((byte)mult);
+                int bits = (int)lastZeros + BitLen((byte)mult);
                 for (int j = 0; j < bits; ++j)
                 {
                     y = ReduceBarrett(y.Square(), m, mr, yu);
@@ -1955,10 +1967,10 @@ namespace Org.BouncyCastle.Math
 
                 y = ReduceBarrett(y.Multiply(oddPowers[mult >> 1]), m, mr, yu);
 
-                lastZeroes = window >> 8;
+                lastZeros = window >> 8;
             }
 
-            for (int i = 0; i < lastZeroes; ++i)
+            for (int i = 0; i < lastZeros; ++i)
             {
                 y = ReduceBarrett(y.Square(), m, mr, yu);
             }
@@ -1999,7 +2011,7 @@ namespace Org.BouncyCastle.Math
             return x;
         }
 
-        private static BigInteger ModPowMonty(BigInteger b, BigInteger e, BigInteger m, bool convert)
+        private static BigInteger ModPowMonty(uint[] yAccum, BigInteger b, BigInteger e, BigInteger m, bool convert)
         {
             int n = m.magnitude.Length;
             int powR = 32 * n;
@@ -2012,7 +2024,7 @@ namespace Org.BouncyCastle.Math
                 b = b.ShiftLeft(powR).Remainder(m);
             }
 
-            uint[] yAccum = new uint[n + 1];
+            Debug.Assert(yAccum.Length == n + 1);
 
             uint[] zVal = b.magnitude;
             Debug.Assert(zVal.Length <= n);
@@ -2050,17 +2062,17 @@ namespace Org.BouncyCastle.Math
                 MultiplyMonty(yAccum, oddPowers[i], zSquared, m.magnitude, mDash, smallMontyModulus);
             }
 
-            int[] windowList = GetWindowList(e.magnitude, extraBits);
+            uint[] windowList = GetWindowList(e.magnitude, extraBits);
             Debug.Assert(windowList.Length > 1);
 
-            int window = windowList[0];
-            int mult = window & 0xFF, lastZeroes = window >> 8;
+            uint window = windowList[0];
+            uint mult = window & 0xFF, lastZeros = window >> 8;
 
             uint[] yVal;
             if (mult == 1)
             {
                 yVal = zSquared;
-                --lastZeroes;
+                --lastZeros;
             }
             else
             {
@@ -2068,11 +2080,11 @@ namespace Org.BouncyCastle.Math
             }
 
             int windowPos = 1;
-            while ((window = windowList[windowPos++]) != -1)
+            while ((window = windowList[windowPos++]) != uint.MaxValue)
             {
                 mult = window & 0xFF;
 
-                int bits = lastZeroes + BitLen((byte)mult);
+                int bits = (int)lastZeros + BitLen((byte)mult);
                 for (int j = 0; j < bits; ++j)
                 {
                     SquareMonty(yAccum, yVal, m.magnitude, mDash, smallMontyModulus);
@@ -2080,10 +2092,10 @@ namespace Org.BouncyCastle.Math
 
                 MultiplyMonty(yAccum, yVal, oddPowers[mult >> 1], m.magnitude, mDash, smallMontyModulus);
 
-                lastZeroes = window >> 8;
+                lastZeros = window >> 8;
             }
 
-            for (int i = 0; i < lastZeroes; ++i)
+            for (int i = 0; i < lastZeros; ++i)
             {
                 SquareMonty(yAccum, yVal, m.magnitude, mDash, smallMontyModulus);
             }
@@ -2101,22 +2113,49 @@ namespace Org.BouncyCastle.Math
             return new BigInteger(1, yVal, true);
         }
 
-        private static int[] GetWindowList(uint[] mag, int extraBits)
+        private static BigInteger ModSquareMonty(uint[] yAccum, BigInteger b, BigInteger m)
+        {
+            int n = m.magnitude.Length;
+            int powR = 32 * n;
+            bool smallMontyModulus = m.BitLength + 2 <= powR;
+            uint mDash = m.GetMQuote();
+
+            Debug.Assert(yAccum.Length == n + 1);
+
+            uint[] zVal = b.magnitude;
+            Debug.Assert(zVal.Length <= n);
+
+            uint[] yVal = new uint[n];
+            zVal.CopyTo(yVal, n - zVal.Length);
+
+            SquareMonty(yAccum, yVal, m.magnitude, mDash, smallMontyModulus);
+
+            if (smallMontyModulus && CompareTo(0, yVal, 0, m.magnitude) >= 0)
+            {
+                Subtract(0, yVal, 0, m.magnitude);
+            }
+
+            return new BigInteger(1, yVal, true);
+        }
+
+        private static uint[] GetWindowList(uint[] mag, int extraBits)
         {
-            int v = (int)mag[0];
-            Debug.Assert(v != 0);
+            uint v = mag[0];
+            Debug.Assert(v != 0U);
 
-            int leadingBits = BitLen((uint)v);
+            int leadingBits = BitLen(v);
+            int totalBits = ((mag.Length - 1) << 5) + leadingBits;
 
-            int resultSize = (((mag.Length - 1) << 5) + leadingBits) / (1 + extraBits) + 2;
-            int[] result = new int[resultSize];
+            int resultSize = (totalBits + extraBits) / (1 + extraBits) + 1;
+            uint[] result = new uint[resultSize];
             int resultPos = 0;
 
             int bitPos = 33 - leadingBits;
             v <<= bitPos;
 
-            int mult = 1, multLimit = 1 << extraBits;
-            int zeroes = 0;
+            uint mult = 1U;
+            uint multLimit = 1U << extraBits;
+            uint zeros = 0U;
 
             int i = 0;
             for (;;)
@@ -2125,17 +2164,17 @@ namespace Org.BouncyCastle.Math
                 {
                     if (mult < multLimit)
                     {
-                        mult = (mult << 1) | (int)((uint)v >> 31);
+                        mult = (mult << 1) | (v >> 31);
                     }
-                    else if (v < 0)
+                    else if ((int)v < 0)
                     {
-                        result[resultPos++] = CreateWindowEntry(mult, zeroes);
-                        mult = 1;
-                        zeroes = 0;
+                        result[resultPos++] = CreateWindowEntry(mult, zeros);
+                        mult = 1U;
+                        zeros = 0U;
                     }
                     else
                     {
-                        ++zeroes;
+                        ++zeros;
                     }
 
                     v <<= 1;
@@ -2143,35 +2182,35 @@ namespace Org.BouncyCastle.Math
 
                 if (++i == mag.Length)
                 {
-                    result[resultPos++] = CreateWindowEntry(mult, zeroes);
+                    result[resultPos++] = CreateWindowEntry(mult, zeros);
                     break;
                 }
 
-                v = (int)mag[i];
+                v = mag[i];
                 bitPos = 0;
             }
 
-            result[resultPos] = -1;
+            result[resultPos] = uint.MaxValue; // Sentinel value
             return result;
         }
 
-        private static int CreateWindowEntry(int mult, int zeroes)
+        private static uint CreateWindowEntry(uint mult, uint zeros)
         {
             Debug.Assert(mult > 0);
 
 #if NETCOREAPP3_0_OR_GREATER
             int tz = BitOperations.TrailingZeroCount(mult);
             mult >>= tz;
-            zeroes += tz;
+            zeros += (uint)tz;
 #else
-            while ((mult & 1) == 0)
+            while ((mult & 1U) == 0U)
             {
                 mult >>= 1;
-                ++zeroes;
+                ++zeros;
             }
 #endif
 
-            return mult | (zeroes << 8);
+            return mult | (zeros << 8);
         }
 
         /**
@@ -2682,7 +2721,7 @@ namespace Org.BouncyCastle.Math
 
             Debug.Assert(yStart < y.Length);
 
-            int xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+            int xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
 
             if (xyCmp > 0)
             {
@@ -2709,7 +2748,7 @@ namespace Org.BouncyCastle.Math
                 for (;;)
                 {
                     if (cBitLength < xBitLength
-                        || CompareNoLeadingZeroes(xStart, x, cStart, c) >= 0)
+                        || CompareNoLeadingZeros(xStart, x, cStart, c) >= 0)
                     {
                         Subtract(xStart, x, cStart, c);
 
@@ -2726,7 +2765,7 @@ namespace Org.BouncyCastle.Math
                             if (xBitLength < yBitLength)
                                 return x;
 
-                            xyCmp = CompareNoLeadingZeroes(xStart, x, yStart, y);
+                            xyCmp = CompareNoLeadingZeros(xStart, x, yStart, y);
 
                             if (xyCmp <= 0)
                                 break;
@@ -2799,7 +2838,7 @@ namespace Org.BouncyCastle.Math
                 }
             }
 
-            if (CompareNoLeadingZeroes(0, magnitude, 0, n.magnitude) < 0)
+            if (CompareNoLeadingZeros(0, magnitude, 0, n.magnitude) < 0)
                 return this;
 
             uint[] result;
@@ -3094,7 +3133,7 @@ namespace Org.BouncyCastle.Math
             if (this.sign != n.sign)
                 return Add(n.Negate());
 
-            int compare = CompareNoLeadingZeroes(0, magnitude, 0, n.magnitude);
+            int compare = CompareNoLeadingZeros(0, magnitude, 0, n.magnitude);
             if (compare == 0)
                 return Zero;
 
@@ -3607,47 +3646,55 @@ namespace Org.BouncyCastle.Math
             sb.Append(s);
         }
 
+        private static BigInteger CreateUValueOf(uint value)
+        {
+            if (value == 0)
+                return Zero;
+
+            return new BigInteger(1, new uint[]{ value }, false);
+        }
+
         private static BigInteger CreateUValueOf(ulong value)
         {
             uint msw = (uint)(value >> 32);
             uint lsw = (uint)value;
 
-            if (msw != 0)
-                return new BigInteger(1, new uint[]{ msw, lsw }, false);
-
-            if (lsw != 0)
-            {
-                BigInteger n = new BigInteger(1, new uint[]{ lsw }, false);
-                // Check for a power of two
-                if ((lsw & -lsw) == lsw)
-                {
-                    n.nBits = 1;
-                }
-                return n;
-            }
+            if (msw == 0)
+                return CreateUValueOf(lsw);
 
-            return Zero;
+            return new BigInteger(1, new uint[]{ msw, lsw }, false);
         }
 
-        private static BigInteger CreateValueOf(long value)
+        public static BigInteger ValueOf(int value)
         {
-            if (value < 0)
+            if (value >= 0)
             {
-                if (value == long.MinValue)
-                    return CreateValueOf(~value).Not();
+                if (value < SMALL_CONSTANTS.Length)
+                    return SMALL_CONSTANTS[value];
 
-                return CreateValueOf(-value).Negate();
+                return CreateUValueOf((uint)value);
             }
 
-            return CreateUValueOf((ulong)value);
+            if (value == int.MinValue)
+                return CreateUValueOf((uint)~value).Not();
+
+            return ValueOf(-value).Negate();
         }
 
         public static BigInteger ValueOf(long value)
         {
-            if (value >= 0 && value < SMALL_CONSTANTS.Length)
-                return SMALL_CONSTANTS[value];
+            if (value >= 0L)
+            {
+                if (value < SMALL_CONSTANTS.Length)
+                    return SMALL_CONSTANTS[value];
+
+                return CreateUValueOf((ulong)value);
+            }
+
+            if (value == long.MinValue)
+                return CreateUValueOf((ulong)~value).Not();
 
-            return CreateValueOf(value);
+            return ValueOf(-value).Negate();
         }
 
         public int GetLowestSetBit()