summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-04-15 16:57:11 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-04-15 16:57:11 +0700
commitf7699727e8b9cc1dcf28cddabdae69195fafabe0 (patch)
treed03f7ebbbbd5007e159aec0e6d28c0705b8ecdff
parentFix warning (diff)
downloadBouncyCastle.NET-ed25519-f7699727e8b9cc1dcf28cddabdae69195fafabe0.tar.xz
Refactor GCM code
-rw-r--r--crypto/src/crypto/modes/GCMBlockCipher.cs136
-rw-r--r--crypto/src/crypto/modes/gcm/GcmUtilities.cs84
-rw-r--r--crypto/test/src/crypto/test/GCMTest.cs2
3 files changed, 130 insertions, 92 deletions
diff --git a/crypto/src/crypto/modes/GCMBlockCipher.cs b/crypto/src/crypto/modes/GCMBlockCipher.cs
index c592c3af3..16d9f3654 100644
--- a/crypto/src/crypto/modes/GCMBlockCipher.cs
+++ b/crypto/src/crypto/modes/GCMBlockCipher.cs
@@ -402,6 +402,11 @@ namespace Org.BouncyCastle.Crypto.Modes
             {
                 Check.OutputLength(output, outOff, BlockSize, "output buffer too short");
 
+                if (blocksRemaining == 0)
+                    throw new InvalidOperationException("Attempt to process too many blocks");
+
+                --blocksRemaining;
+
                 if (totalLength == 0)
                 {
                     InitCipher();
@@ -443,6 +448,11 @@ namespace Org.BouncyCastle.Crypto.Modes
             {
                 Check.OutputLength(output, BlockSize, "output buffer too short");
 
+                if (blocksRemaining == 0)
+                    throw new InvalidOperationException("Attempt to process too many blocks");
+
+                --blocksRemaining;
+
                 if (totalLength == 0)
                 {
                     InitCipher();
@@ -485,6 +495,12 @@ namespace Org.BouncyCastle.Crypto.Modes
                 {
                     Check.OutputLength(output, outOff, resultLen, "output buffer too short");
 
+                    uint blocksNeeded = (uint)resultLen >> 4;
+                    if (blocksRemaining < blocksNeeded)
+                        throw new InvalidOperationException("Attempt to process too many blocks");
+
+                    blocksRemaining -= blocksNeeded;
+
                     if (totalLength == 0)
                     {
                         InitCipher();
@@ -539,6 +555,12 @@ namespace Org.BouncyCastle.Crypto.Modes
                 {
                     Check.OutputLength(output, outOff, resultLen, "output buffer too short");
 
+                    uint blocksNeeded = (uint)resultLen >> 4;
+                    if (blocksRemaining < blocksNeeded)
+                        throw new InvalidOperationException("Attempt to process too many blocks");
+
+                    blocksRemaining -= blocksNeeded;
+
                     if (totalLength == 0)
                     {
                         InitCipher();
@@ -620,6 +642,12 @@ namespace Org.BouncyCastle.Crypto.Modes
                 {
                     Check.OutputLength(output, resultLen, "output buffer too short");
 
+                    uint blocksNeeded = (uint)resultLen >> 4;
+                    if (blocksRemaining < blocksNeeded)
+                        throw new InvalidOperationException("Attempt to process too many blocks");
+
+                    blocksRemaining -= blocksNeeded;
+
                     if (totalLength == 0)
                     {
                         InitCipher();
@@ -686,6 +714,12 @@ namespace Org.BouncyCastle.Crypto.Modes
                 {
                     Check.OutputLength(output, resultLen, "output buffer too short");
 
+                    uint blocksNeeded = (uint)resultLen >> 4;
+                    if (blocksRemaining < blocksNeeded)
+                        throw new InvalidOperationException("Attempt to process too many blocks");
+
+                    blocksRemaining -= blocksNeeded;
+
                     if (totalLength == 0)
                     {
                         InitCipher();
@@ -800,6 +834,11 @@ namespace Org.BouncyCastle.Crypto.Modes
 
             if (extra > 0)
             {
+                if (blocksRemaining == 0)
+                    throw new InvalidOperationException("Attempt to process too many blocks");
+
+                --blocksRemaining;
+
                 ProcessPartial(bufBlock, 0, extra, output, outOff);
             }
 
@@ -912,6 +951,11 @@ namespace Org.BouncyCastle.Crypto.Modes
 
             if (extra > 0)
             {
+                if (blocksRemaining == 0)
+                    throw new InvalidOperationException("Attempt to process too many blocks");
+
+                --blocksRemaining;
+
                 ProcessPartial(bufBlock.AsSpan(0, extra), output);
             }
 
@@ -1175,6 +1219,8 @@ namespace Org.BouncyCastle.Crypto.Modes
             if (limit < BlockSize * 4)
                 throw new ArgumentOutOfRangeException(nameof(limit));
 
+            var HPowBound = HPow[3];
+
             Span<Vector128<byte>> counters = stackalloc Vector128<byte>[4];
             var ctrBlocks = MemoryMarshal.AsBytes(counters);
 
@@ -1183,6 +1229,9 @@ namespace Org.BouncyCastle.Crypto.Modes
 
             while (input.Length >= limit)
             {
+                var inputBound = input[BlockSize * 4 - 1];
+                var outputBound = output[BlockSize * 4 - 1];
+
                 GetNextCtrBlocks4(ctrBlocks);
 
                 var c0 = MemoryMarshal.Read<Vector128<byte>>(input);
@@ -1200,17 +1249,20 @@ namespace Org.BouncyCastle.Crypto.Modes
                 MemoryMarshal.Write(output[(BlockSize * 2)..], ref p2);
                 MemoryMarshal.Write(output[(BlockSize * 3)..], ref p3);
 
-                c0 = Ssse3.Shuffle(c0, ReverseBytesMask);
-                c1 = Ssse3.Shuffle(c1, ReverseBytesMask);
-                c2 = Ssse3.Shuffle(c2, ReverseBytesMask);
-                c3 = Ssse3.Shuffle(c3, ReverseBytesMask);
+                input = input[(BlockSize * 4)..];
+                output = output[(BlockSize * 4)..];
+
+                var d0 = Ssse3.Shuffle(c0, ReverseBytesMask);
+                var d1 = Ssse3.Shuffle(c1, ReverseBytesMask);
+                var d2 = Ssse3.Shuffle(c2, ReverseBytesMask);
+                var d3 = Ssse3.Shuffle(c3, ReverseBytesMask);
 
-                c0 = Sse2.Xor(c0, S128);
+                d0 = Sse2.Xor(d0, S128);
 
-                GcmUtilities.MultiplyExt(c0.AsUInt64(), HPow[0], out var U0, out var U1, out var U2);
-                GcmUtilities.MultiplyExt(c1.AsUInt64(), HPow[1], out var V0, out var V1, out var V2);
-                GcmUtilities.MultiplyExt(c2.AsUInt64(), HPow[2], out var W0, out var W1, out var W2);
-                GcmUtilities.MultiplyExt(c3.AsUInt64(), HPow[3], out var X0, out var X1, out var X2);
+                GcmUtilities.MultiplyExt(d0.AsUInt64(), HPow[0], out var U0, out var U1, out var U2);
+                GcmUtilities.MultiplyExt(d1.AsUInt64(), HPow[1], out var V0, out var V1, out var V2);
+                GcmUtilities.MultiplyExt(d2.AsUInt64(), HPow[2], out var W0, out var W1, out var W2);
+                GcmUtilities.MultiplyExt(d3.AsUInt64(), HPow[3], out var X0, out var X1, out var X2);
 
                 U0 = Sse2.Xor(U0, V0);
                 U1 = Sse2.Xor(U1, V1);
@@ -1225,9 +1277,6 @@ namespace Org.BouncyCastle.Crypto.Modes
                 U2 = Sse2.Xor(U2, X2);
 
                 S128 = GcmUtilities.Reduce3(U0, U1, U2).AsByte();
-
-                input = input[(BlockSize * 4)..];
-                output = output[(BlockSize * 4)..];
             }
 
             S128 = Ssse3.Shuffle(S128, ReverseBytesMask);
@@ -1365,6 +1414,8 @@ namespace Org.BouncyCastle.Crypto.Modes
             if (!IsFourWaySupported)
                 throw new PlatformNotSupportedException(nameof(EncryptBlocks4));
 
+            var HPowBound = HPow[3];
+
             Span<Vector128<byte>> counters = stackalloc Vector128<byte>[4];
             var ctrBlocks = MemoryMarshal.AsBytes(counters);
 
@@ -1373,6 +1424,8 @@ namespace Org.BouncyCastle.Crypto.Modes
 
             while (input.Length >= BlockSize * 4)
             {
+                var outputBound = output[BlockSize * 4 - 1];
+
                 GetNextCtrBlocks4(ctrBlocks);
 
                 var p0 = MemoryMarshal.Read<Vector128<byte>>(input);
@@ -1390,17 +1443,20 @@ namespace Org.BouncyCastle.Crypto.Modes
                 MemoryMarshal.Write(output[(BlockSize * 2)..], ref c2);
                 MemoryMarshal.Write(output[(BlockSize * 3)..], ref c3);
 
-                c0 = Ssse3.Shuffle(c0, ReverseBytesMask);
-                c1 = Ssse3.Shuffle(c1, ReverseBytesMask);
-                c2 = Ssse3.Shuffle(c2, ReverseBytesMask);
-                c3 = Ssse3.Shuffle(c3, ReverseBytesMask);
+                input = input[(BlockSize * 4)..];
+                output = output[(BlockSize * 4)..];
+
+                var d0 = Ssse3.Shuffle(c0, ReverseBytesMask);
+                var d1 = Ssse3.Shuffle(c1, ReverseBytesMask);
+                var d2 = Ssse3.Shuffle(c2, ReverseBytesMask);
+                var d3 = Ssse3.Shuffle(c3, ReverseBytesMask);
 
-                c0 = Sse2.Xor(c0, S128);
+                d0 = Sse2.Xor(d0, S128);
 
-                GcmUtilities.MultiplyExt(c0.AsUInt64(), HPow[0], out var U0, out var U1, out var U2);
-                GcmUtilities.MultiplyExt(c1.AsUInt64(), HPow[1], out var V0, out var V1, out var V2);
-                GcmUtilities.MultiplyExt(c2.AsUInt64(), HPow[2], out var W0, out var W1, out var W2);
-                GcmUtilities.MultiplyExt(c3.AsUInt64(), HPow[3], out var X0, out var X1, out var X2);
+                GcmUtilities.MultiplyExt(d0.AsUInt64(), HPow[0], out var U0, out var U1, out var U2);
+                GcmUtilities.MultiplyExt(d1.AsUInt64(), HPow[1], out var V0, out var V1, out var V2);
+                GcmUtilities.MultiplyExt(d2.AsUInt64(), HPow[2], out var W0, out var W1, out var W2);
+                GcmUtilities.MultiplyExt(d3.AsUInt64(), HPow[3], out var X0, out var X1, out var X2);
 
                 U0 = Sse2.Xor(U0, V0);
                 U1 = Sse2.Xor(U1, V1);
@@ -1415,9 +1471,6 @@ namespace Org.BouncyCastle.Crypto.Modes
                 U2 = Sse2.Xor(U2, X2);
 
                 S128 = GcmUtilities.Reduce3(U0, U1, U2).AsByte();
-
-                input = input[(BlockSize * 4)..];
-                output = output[(BlockSize * 4)..];
             }
 
             S128 = Ssse3.Shuffle(S128, ReverseBytesMask);
@@ -1428,11 +1481,6 @@ namespace Org.BouncyCastle.Crypto.Modes
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private void GetNextCtrBlock(Span<byte> block)
         {
-            if (blocksRemaining == 0)
-                throw new InvalidOperationException("Attempt to process too many blocks");
-
-            blocksRemaining--;
-
             Pack.UInt32_To_BE(++counter32, counter, 12);
 
             cipher.ProcessBlock(counter, block);
@@ -1441,11 +1489,6 @@ namespace Org.BouncyCastle.Crypto.Modes
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private void GetNextCtrBlocks2(Span<byte> blocks)
         {
-            if (blocksRemaining < 2)
-                throw new InvalidOperationException("Attempt to process too many blocks");
-
-            blocksRemaining -= 2;
-
             Pack.UInt32_To_BE(++counter32, counter, 12);
             cipher.ProcessBlock(counter, blocks);
 
@@ -1456,21 +1499,16 @@ namespace Org.BouncyCastle.Crypto.Modes
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private void GetNextCtrBlocks4(Span<byte> blocks)
         {
-            if (blocksRemaining < 4)
-                throw new InvalidOperationException("Attempt to process too many blocks");
-
-            blocksRemaining -= 4;
+            uint counter0 = counter32;
+            uint counter1 = counter0 + 1U;
+            uint counter2 = counter0 + 2U;
+            uint counter3 = counter0 + 3U;
+            uint counter4 = counter0 + 4U;
+            counter32 = counter4;
 
 #if NETCOREAPP3_0_OR_GREATER
             if (AesEngine_X86.IsSupported && cipher is AesEngine_X86 x86)
             {
-                uint counter0 = counter32;
-                uint counter1 = counter0 + 1U;
-                uint counter2 = counter0 + 2U;
-                uint counter3 = counter0 + 3U;
-                uint counter4 = counter0 + 4U;
-                counter32 = counter4;
-
                 counter.CopyTo(blocks);
                 counter.CopyTo(blocks[BlockSize..]);
                 counter.CopyTo(blocks[(BlockSize * 2)..]);
@@ -1485,16 +1523,16 @@ namespace Org.BouncyCastle.Crypto.Modes
             }
 #endif
 
-            Pack.UInt32_To_BE(++counter32, counter, 12);
+            Pack.UInt32_To_BE(counter1, counter, 12);
             cipher.ProcessBlock(counter, blocks);
 
-            Pack.UInt32_To_BE(++counter32, counter, 12);
+            Pack.UInt32_To_BE(counter2, counter, 12);
             cipher.ProcessBlock(counter, blocks[BlockSize..]);
 
-            Pack.UInt32_To_BE(++counter32, counter, 12);
+            Pack.UInt32_To_BE(counter3, counter, 12);
             cipher.ProcessBlock(counter, blocks[(BlockSize * 2)..]);
 
-            Pack.UInt32_To_BE(++counter32, counter, 12);
+            Pack.UInt32_To_BE(counter4, counter, 12);
             cipher.ProcessBlock(counter, blocks[(BlockSize * 3)..]);
         }
 
diff --git a/crypto/src/crypto/modes/gcm/GcmUtilities.cs b/crypto/src/crypto/modes/gcm/GcmUtilities.cs
index b2c74d7d0..1cc4d262d 100644
--- a/crypto/src/crypto/modes/gcm/GcmUtilities.cs
+++ b/crypto/src/crypto/modes/gcm/GcmUtilities.cs
@@ -75,7 +75,7 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm
 
         internal static void Multiply(ref FieldElement x, ref FieldElement y)
         {
-            ulong z0, z1, z2, z3;
+            ulong z0, z1, z2;
 
 #if NETCOREAPP3_0_OR_GREATER
             if (Pclmulqdq.IsSupported)
@@ -94,10 +94,18 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm
                 ulong t1 = Z2.GetElement(0) ^ Z1.GetElement(1);
                 ulong t0 = Z2.GetElement(1);
 
+                Debug.Assert(t0 >> 63 == 0);
+
+                t1 ^= t3 ^ (t3 >>  1) ^ (t3 >>  2) ^ (t3 >>  7);
+                t2 ^=      (t3 << 63) ^ (t3 << 62) ^ (t3 << 57);
+
                 z0 = (t0 << 1) | (t1 >> 63);
                 z1 = (t1 << 1) | (t2 >> 63);
-                z2 = (t2 << 1) | (t3 >> 63);
-                z3 = (t3 << 1);
+                z2 = (t2 << 1);
+
+                z0 ^= z2 ^ (z2 >>  1) ^ (z2 >>  2) ^ (z2 >>  7);
+//              z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
+                z1 ^=                   (t2 << 63) ^ (t2 << 58);
             }
             else
 #endif
@@ -113,6 +121,7 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm
                 ulong y0 = y.n0, y1 = y.n1;
                 ulong x0r = Longs.Reverse(x0), x1r = Longs.Reverse(x1);
                 ulong y0r = Longs.Reverse(y0), y1r = Longs.Reverse(y1);
+                ulong z3;
 
                 ulong h0 = Longs.Reverse(ImplMul64(x0r, y0r));
                 ulong h1 = ImplMul64(x0, y0) << 1;
@@ -125,16 +134,16 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm
                 z1 = h1 ^ h0 ^ h2 ^ h4;
                 z2 = h2 ^ h1 ^ h3 ^ h5;
                 z3 = h3;
-            }
 
-            Debug.Assert(z3 << 63 == 0);
+                Debug.Assert(z3 << 63 == 0);
 
-            z1 ^= z3 ^ (z3 >>  1) ^ (z3 >>  2) ^ (z3 >>  7);
-//          z2 ^=      (z3 << 63) ^ (z3 << 62) ^ (z3 << 57);
-            z2 ^=                   (z3 << 62) ^ (z3 << 57);
+                z1 ^= z3 ^ (z3 >>  1) ^ (z3 >>  2) ^ (z3 >>  7);
+//              z2 ^=      (z3 << 63) ^ (z3 << 62) ^ (z3 << 57);
+                z2 ^=                   (z3 << 62) ^ (z3 << 57);
 
-            z0 ^= z2 ^ (z2 >>  1) ^ (z2 >>  2) ^ (z2 >>  7);
-            z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
+                z0 ^= z2 ^ (z2 >>  1) ^ (z2 >>  2) ^ (z2 >>  7);
+                z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
+            }
 
             x.n0 = z0;
             x.n1 = z1;
@@ -176,19 +185,18 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm
             ulong t1 = Z2.GetElement(0);
             ulong t0 = Z2.GetElement(1);
 
-            ulong z0 = (t0 << 1) | (t1 >> 63);
-            ulong z1 = (t1 << 1) | (t2 >> 63);
-            ulong z2 = (t2 << 1) | (t3 >> 63);
-            ulong z3 = (t3 << 1);
+            Debug.Assert(t0 >> 63 == 0);
 
-            Debug.Assert(z3 << 63 == 0);
+            t1 ^= t3 ^ (t3 >>  1) ^ (t3 >>  2) ^ (t3 >>  7);
+            t2 ^=      (t3 << 63) ^ (t3 << 62) ^ (t3 << 57);
 
-            z1 ^= z3 ^ (z3 >>  1) ^ (z3 >>  2) ^ (z3 >>  7);
-//          z2 ^=      (z3 << 63) ^ (z3 << 62) ^ (z3 << 57);
-            z2 ^=                   (z3 << 62) ^ (z3 << 57);
+            ulong z0 = (t0 << 1) | (t1 >> 63);
+            ulong z1 = (t1 << 1) | (t2 >> 63);
+            ulong z2 = (t2 << 1);
 
             z0 ^= z2 ^ (z2 >>  1) ^ (z2 >>  2) ^ (z2 >>  7);
-            z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
+//          z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
+            z1 ^=                   (t2 << 63) ^ (t2 << 58);
 
             return Vector128.Create(z1, z0);
         }
@@ -201,19 +209,18 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm
             ulong t1 = Z2.GetElement(0) ^ Z1.GetElement(1);
             ulong t0 = Z2.GetElement(1);
 
-            ulong z0 = (t0 << 1) | (t1 >> 63);
-            ulong z1 = (t1 << 1) | (t2 >> 63);
-            ulong z2 = (t2 << 1) | (t3 >> 63);
-            ulong z3 = (t3 << 1);
+            Debug.Assert(t0 >> 63 == 0);
 
-            Debug.Assert(z3 << 63 == 0);
+            t1 ^= t3 ^ (t3 >>  1) ^ (t3 >>  2) ^ (t3 >>  7);
+            t2 ^=      (t3 << 63) ^ (t3 << 62) ^ (t3 << 57);
 
-            z1 ^= z3 ^ (z3 >>  1) ^ (z3 >>  2) ^ (z3 >>  7);
-//          z2 ^=      (z3 << 63) ^ (z3 << 62) ^ (z3 << 57);
-            z2 ^=                   (z3 << 62) ^ (z3 << 57);
+            ulong z0 = (t0 << 1) | (t1 >> 63);
+            ulong z1 = (t1 << 1) | (t2 >> 63);
+            ulong z2 = (t2 << 1);
 
             z0 ^= z2 ^ (z2 >>  1) ^ (z2 >>  2) ^ (z2 >>  7);
-            z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
+//          z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
+            z1 ^=                   (t2 << 63) ^ (t2 << 58);
 
             return Vector128.Create(z1, z0);
         }
@@ -276,23 +283,16 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm
 
         internal static void Square(ref FieldElement x)
         {
-            ulong z1 = Interleave.Expand64To128Rev(x.n0, out ulong z0);
-            ulong z3 = Interleave.Expand64To128Rev(x.n1, out ulong z2);
-
-            Debug.Assert(z3 << 63 == 0UL);
+            ulong t1 = Interleave.Expand64To128Rev(x.n0, out ulong t0);
+            ulong t3 = Interleave.Expand64To128Rev(x.n1, out ulong t2);
 
-            z1 ^= z3 ^ (z3 >>  1) ^ (z3 >>  2) ^ (z3 >>  7);
-//          z2 ^=      (z3 << 63) ^ (z3 << 62) ^ (z3 << 57);
-            z2 ^=                   (z3 << 62) ^ (z3 << 57);
+            Debug.Assert((t0 | t1 | t2 | t3) << 63 == 0UL);
 
-            Debug.Assert(z2 << 63 == 0UL);
+            var z1 = t1 ^ t3 ^ (t3 >>  1) ^ (t3 >>  2) ^ (t3 >>  7);
+            var z2 = t2 ^                   (t3 << 62) ^ (t3 << 57);
 
-            z0 ^= z2 ^ (z2 >>  1) ^ (z2 >>  2) ^ (z2 >>  7);
-//          z1 ^=      (z2 << 63) ^ (z2 << 62) ^ (z2 << 57);
-            z1 ^=                   (z2 << 62) ^ (z2 << 57);
-
-            x.n0 = z0;
-            x.n1 = z1;
+            x.n0   = t0 ^ z2 ^ (z2 >>  1) ^ (z2 >>  2) ^ (z2 >>  7);
+            x.n1   = z1 ^                   (t2 << 62) ^ (t2 << 57);
         }
 
         internal static void Xor(byte[] x, byte[] y)
diff --git a/crypto/test/src/crypto/test/GCMTest.cs b/crypto/test/src/crypto/test/GCMTest.cs
index aaca6f1ee..952c6ca98 100644
--- a/crypto/test/src/crypto/test/GCMTest.cs
+++ b/crypto/test/src/crypto/test/GCMTest.cs
@@ -561,7 +561,7 @@ namespace Org.BouncyCastle.Crypto.Tests
 
         private void RandomTests(SecureRandom srng, IGcmMultiplier m)
         {
-            for (int i = 0; i < 10; ++i)
+            for (int i = 0; i < 100; ++i)
             {
                 RandomTest(srng, m);
             }