summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2022-09-30 00:13:26 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2022-09-30 00:13:26 +0700
commit9f0beca05117e428efd8c180ba12e779c0f1a7e5 (patch)
treedfee78050f5d9749be5735a8c041b26a2f12f1be
parentPreserve mac after DoFinal (diff)
downloadBouncyCastle.NET-ed25519-9f0beca05117e428efd8c180ba12e779c0f1a7e5.tar.xz
Grain128Aead performance, constant-time
-rw-r--r--crypto/src/crypto/engines/Grain128AEADEngine.cs225
1 files changed, 84 insertions, 141 deletions
diff --git a/crypto/src/crypto/engines/Grain128AEADEngine.cs b/crypto/src/crypto/engines/Grain128AEADEngine.cs
index a571cb124..174d010f3 100644
--- a/crypto/src/crypto/engines/Grain128AEADEngine.cs
+++ b/crypto/src/crypto/engines/Grain128AEADEngine.cs
@@ -3,6 +3,7 @@ using System.IO;
 
 using Org.BouncyCastle.Crypto.Modes;
 using Org.BouncyCastle.Crypto.Parameters;
+using Org.BouncyCastle.Crypto.Utilities;
 
 namespace Org.BouncyCastle.Crypto.Engines
 {
@@ -24,10 +25,8 @@ namespace Org.BouncyCastle.Crypto.Engines
         private uint[] nfsr;
         private uint[] authAcc;
         private uint[] authSr;
-        private uint outputZ;
 
         private bool initialised = false;
-        private bool isEven = true; // zero treated as even
         private bool aadFinished = false;
         private MemoryStream aadData = new MemoryStream();
 
@@ -85,7 +84,7 @@ namespace Org.BouncyCastle.Crypto.Engines
         {
             for (int i = 0; i < 320; ++i)
             {
-                outputZ = GetOutput();
+                uint outputZ = GetOutput();
                 nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0] ^ outputZ) & 1);
                 lfsr = Shift(lfsr, (GetOutputLFSR() ^ outputZ) & 1);
             }
@@ -93,7 +92,7 @@ namespace Org.BouncyCastle.Crypto.Engines
             {
                 for (int remainder = 0; remainder < 8; ++remainder)
                 {
-                    outputZ = GetOutput();
+                    uint outputZ = GetOutput();
                     nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0] ^ outputZ ^ (uint)((workingKey[quotient]) >> remainder)) & 1);
                     lfsr = Shift(lfsr, (GetOutputLFSR() ^ outputZ ^ (uint)((workingKey[quotient + 8]) >> remainder)) & 1);
                 }
@@ -102,7 +101,7 @@ namespace Org.BouncyCastle.Crypto.Engines
             {
                 for (int remainder = 0; remainder < 32; ++remainder)
                 {
-                    outputZ = GetOutput();
+                    uint outputZ = GetOutput();
                     nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
                     lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
                     authAcc[quotient] |= outputZ << remainder;
@@ -112,7 +111,7 @@ namespace Org.BouncyCastle.Crypto.Engines
             {
                 for (int remainder = 0; remainder < 32; ++remainder)
                 {
-                    outputZ = GetOutput();
+                    uint outputZ = GetOutput();
                     nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
                     lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
                     authSr[quotient] |= outputZ << remainder;
@@ -203,6 +202,7 @@ namespace Org.BouncyCastle.Crypto.Engines
             uint s79 = lfsr[2] >> 15;
             uint s93 = lfsr[2] >> 29;
             uint s94 = lfsr[2] >> 30;
+
             return ((b12 & s8) ^ (s13 & s20) ^ (b95 & s42) ^ (s60 & s79) ^ (b12 & b95 & s94) ^ s93
                 ^ b2 ^ b15 ^ b36 ^ b45 ^ b64 ^ b73 ^ b89) & 1;
         }
@@ -231,28 +231,18 @@ namespace Org.BouncyCastle.Crypto.Engines
          */
         private void SetKey(byte[] keyBytes, byte[] ivBytes)
         {
-            ivBytes[12] = (byte)0xFF;
-            ivBytes[13] = (byte)0xFF;
-            ivBytes[14] = (byte)0xFF;
-            ivBytes[15] = (byte)0x7F;//(byte) 0xFE;
+            ivBytes[12] = 0xFF;
+            ivBytes[13] = 0xFF;
+            ivBytes[14] = 0xFF;
+            ivBytes[15] = 0x7F;
             workingKey = keyBytes;
             workingIV = ivBytes;
 
             /**
              * Load NFSR and LFSR
              */
-            int j = 0;
-            for (int i = 0; i < nfsr.Length; i++)
-            {
-                nfsr[i] = (uint)(((workingKey[j + 3]) << 24) | ((workingKey[j + 2]) << 16)
-                    & 0x00FF0000 | ((workingKey[j + 1]) << 8) & 0x0000FF00
-                    | ((workingKey[j]) & 0x000000FF));
-
-                lfsr[i] = (uint)(((workingIV[j + 3]) << 24) | ((workingIV[j + 2]) << 16)
-                    & 0x00FF0000 | ((workingIV[j + 1]) << 8) & 0x0000FF00
-                    | ((workingIV[j]) & 0x000000FF));
-                j += 4;
-            }
+            Pack.LE_To_UInt32(workingKey, 0, nfsr);
+            Pack.LE_To_UInt32(workingIV, 0, lfsr);
         }
 
         public int ProcessBytes(byte[] input, int inOff, int len, byte[] output, int outOff)
@@ -303,7 +293,6 @@ namespace Org.BouncyCastle.Crypto.Engines
 
         private void Reset(bool clearMac)
         {
-            this.isEven = true;
             if (clearMac)
             {
                 this.mac = null;
@@ -319,78 +308,61 @@ namespace Org.BouncyCastle.Crypto.Engines
         private void GetKeyStream(ReadOnlySpan<byte> input, Span<byte> output)
         {
             int len = input.Length;
-            int mCnt = 0, acCnt = 0, cCnt = 0;
-            byte[] plaintext = new byte[len];
-            for (int i = 0; i < len; ++i)
-            {
-                plaintext[i] = (byte)ReverseByte(input[i]);
-            }
             for (int i = 0; i < len; ++i)
             {
-                byte cc = 0;
-                for (int j = 0; j < 16; ++j)
+                uint cc = 0, input_i = input[i];
+                for (int j = 0; j < 8; ++j)
                 {
-                    outputZ = GetOutput();
+                    uint outputZ = GetOutput();
+                    nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
+                    lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
+
+                    uint input_i_j = (input_i >> j) & 1U;
+                    cc |= (input_i_j ^ outputZ) << j;
+
+                    //if (input_i_j != 0)
+                    //{
+                    //    Accumulate();
+                    //}
+                    uint mask = 0U - input_i_j;
+                    authAcc[0] ^= authSr[0] & mask;
+                    authAcc[1] ^= authSr[1] & mask;
+
+                    AuthShift(GetOutput());
                     nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
                     lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
-                    if (isEven)
-                    {
-                        cc |= (byte)(((((plaintext[mCnt >> 3]) >> (7 - (mCnt & 7))) & 1) ^ outputZ) << (cCnt & 7));
-                        mCnt++;
-                        cCnt++;
-                        isEven = false;
-                    }
-                    else
-                    {
-                        if ((plaintext[acCnt >> 3] & (1 << (7 - (acCnt & 7)))) != 0)
-                        {
-                            Accumulate();
-                        }
-                        AuthShift(outputZ);
-                        acCnt++;
-                        isEven = true;
-                    }
                 }
-                output[i] = cc;
+                output[i] = (byte)cc;
             }
         }
 #else
         private void GetKeyStream(byte[] input, int inOff, int len, byte[] ciphertext, int outOff)
         {
-            int mCnt = 0, acCnt = 0, cCnt = 0;
-            byte[] plaintext = new byte[len];
             for (int i = 0; i < len; ++i)
             {
-                plaintext[i] = (byte)ReverseByte(input[inOff + i]);
-            }
-            for (int i = 0; i < len; ++i)
-            {
-                byte cc = 0;
-                for (int j = 0; j < 16; ++j)
+                uint cc = 0, input_i = input[inOff + i];
+                for (int j = 0; j < 8; ++j)
                 {
-                    outputZ = GetOutput();
+                    uint outputZ = GetOutput();
+                    nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
+                    lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
+
+                    uint input_i_j = (input_i >> j) & 1U;
+                    cc |= (input_i_j ^ outputZ) << j;
+
+                    //if (input_i_j != 0)
+                    //{
+                    //    Accumulate();
+                    //}
+                    uint mask = 0U - input_i_j;
+                    authAcc[0] ^= authSr[0] & mask;
+                    authAcc[1] ^= authSr[1] & mask;
+
+                    AuthShift(GetOutput());
                     nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
                     lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
-                    if (isEven)
-                    {
-                        cc |= (byte)(((((plaintext[mCnt >> 3]) >> (7 - (mCnt & 7))) & 1) ^ outputZ) << (cCnt & 7));
-                        mCnt++;
-                        cCnt++;
-                        isEven = false;
-                    }
-                    else
-                    {
-
-                        if ((plaintext[acCnt >> 3] & (1 << (7 - (acCnt & 7)))) != 0)
-                        {
-                            Accumulate();
-                        }
-                        AuthShift(outputZ);
-                        acCnt++;
-                        isEven = true;
-                    }
                 }
-                ciphertext[outOff + i] = cc;
+                ciphertext[outOff + i] = (byte)cc;
             }
         }
 #endif
@@ -467,63 +439,53 @@ namespace Org.BouncyCastle.Crypto.Engines
         {
             byte[] ader;
             int aderlen;
+            //encodeDer
             if (len < 128)
             {
                 ader = new byte[1 + len];
-                ader[0] = (byte)ReverseByte((uint)len);
+                ader[0] = (byte)len;
                 aderlen = 0;
             }
             else
             {
+                // aderlen is the highest bit position divided by 8
                 aderlen = LenLength(len);
                 ader = new byte[aderlen + 1 + len];
-                ader[0] = (byte)ReverseByte(0x80 | (uint)aderlen);
+                ader[0] = (byte)(0x80 | (uint)aderlen);
                 uint tmp = (uint)len;
                 for (int i = 0; i < aderlen; ++i)
                 {
-                    ader[1 + i] = (byte)ReverseByte(tmp & 0xff);
+                    ader[1 + i] = (byte)tmp;
                     tmp >>= 8;
                 }
             }
             for (int i = 0; i < len; ++i)
             {
-                ader[1 + aderlen + i] = (byte)ReverseByte(input[inOff + i]);
+                ader[1 + aderlen + i] = input[inOff + i];
             }
 
-            int adCnt = 0;
             for (int i = 0; i < ader.Length; ++i)
             {
-                for (int j = 0; j < 16; ++j)
+                uint ader_i = ader[i];
+                for (int j = 0; j < 8; ++j)
                 {
-                    outputZ = GetOutput();
                     nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
                     lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
-                    if ((j & 1) == 1)
-                    {
-                        byte adval = (byte)(ader[adCnt >> 3] & (1 << (7 - (adCnt & 7))));
-                        if (adval != 0)
-                        {
-                            Accumulate();
-                        }
-                        AuthShift(outputZ);
-                        adCnt++;
-                    }
-                }
-            }
-        }
 
-        private int LenLength(int v)
-        {
-            if ((v & 0xff) == v)
-                return 1;
+                    uint ader_i_j = (ader_i >> j) & 1U;
+                    //if (ader_i_j != 0)
+                    //{
+                    //    Accumulate();
+                    //}
+                    uint mask = 0U - ader_i_j;
+                    authAcc[0] ^= authSr[0] & mask;
+                    authAcc[1] ^= authSr[1] & mask;
 
-            if ((v & 0xffff) == v)
-                return 2;
-
-            if ((v & 0xffffff) == v)
-                return 3;
-
-            return 4;
+                    AuthShift(GetOutput());
+                    nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
+                    lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
+                }
+            }
         }
 
         public int DoFinal(byte[] output, int outOff)
@@ -537,21 +499,9 @@ namespace Org.BouncyCastle.Crypto.Engines
                 aadFinished = true;
             }
 
-            this.mac = new byte[8];
-
-            outputZ = GetOutput();
-            nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
-            lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
             Accumulate();
 
-            int cCnt = 0;
-            for (int i = 0; i < 2; ++i)
-            {
-                for (int j = 0; j < 4; ++j)
-                {
-                    mac[cCnt++] = (byte)((authAcc[i] >> (j << 3)) & 0xff);
-                }
-            }
+            this.mac = Pack.UInt32_To_LE(authAcc);
 
             Array.Copy(mac, 0, output, outOff, mac.Length);
 
@@ -570,21 +520,9 @@ namespace Org.BouncyCastle.Crypto.Engines
                 aadFinished = true;
             }
 
-            this.mac = new byte[8];
-
-            outputZ = GetOutput();
-            nfsr = Shift(nfsr, (GetOutputNFSR() ^ lfsr[0]) & 1);
-            lfsr = Shift(lfsr, (GetOutputLFSR()) & 1);
             Accumulate();
 
-            int cCnt = 0;
-            for (int i = 0; i < 2; ++i)
-            {
-                for (int j = 0; j < 4; ++j)
-                {
-                    mac[cCnt++] = (byte)((authAcc[i] >> (j << 3)) & 0xff);
-                }
-            }
+            this.mac = Pack.UInt32_To_LE(authAcc);
 
             mac.CopyTo(output);
 
@@ -609,13 +547,18 @@ namespace Org.BouncyCastle.Crypto.Engines
             return len + 8;
         }
 
-        private uint ReverseByte(uint x)
+        private static int LenLength(int v)
         {
-            x = (uint)(((x & 0x55) << 1) | ((x & (~0x55)) >> 1)) & 0xFF;
-            x = (uint)(((x & 0x33) << 2) | ((x & (~0x33)) >> 2)) & 0xFF;
-            x = (uint)(((x & 0x0f) << 4) | ((x & (~0x0f)) >> 4)) & 0xFF;
-            return x;
+            if ((v & 0xff) == v)
+                return 1;
+
+            if ((v & 0xffff) == v)
+                return 2;
+
+            if ((v & 0xffffff) == v)
+                return 3;
+
+            return 4;
         }
     }
 }
-