summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-11-13 15:15:57 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-11-13 15:15:57 +0700
commit2d7d20cfdb15c5ee05d4d8135600419848cdee9d (patch)
treed8e27e4ebf9923581d48dc4786123bf57fac0888
parentMove CRT fault countermeasure into RsaCoreEngine (diff)
downloadBouncyCastle.NET-ed25519-2d7d20cfdb15c5ee05d4d8135600419848cdee9d.tar.xz
Improvements to PKCS1Encoding
-rw-r--r--crypto/src/crypto/encodings/Pkcs1Encoding.cs276
1 files changed, 138 insertions, 138 deletions
diff --git a/crypto/src/crypto/encodings/Pkcs1Encoding.cs b/crypto/src/crypto/encodings/Pkcs1Encoding.cs
index 299d0ddb0..dac560d76 100644
--- a/crypto/src/crypto/encodings/Pkcs1Encoding.cs
+++ b/crypto/src/crypto/encodings/Pkcs1Encoding.cs
@@ -1,7 +1,7 @@
 using System;
+using System.Threading;
 
 using Org.BouncyCastle.Crypto.Parameters;
-using Org.BouncyCastle.Crypto.Digests;
 using Org.BouncyCastle.Security;
 using Org.BouncyCastle.Utilities;
 
@@ -31,17 +31,18 @@ namespace Org.BouncyCastle.Crypto.Encodings
          */
         public static bool StrictLengthEnabled
         {
-            get { return strictLengthEnabled[0]; }
-            set { strictLengthEnabled[0] = value; }
+            get { return Convert.ToBoolean(Interlocked.Read(ref m_strictLengthEnabled)); }
+            set { Interlocked.Exchange(ref m_strictLengthEnabled, Convert.ToInt64(value)); }
         }
 
-        private static readonly bool[] strictLengthEnabled;
+        private static long m_strictLengthEnabled = 0;
 
         static Pkcs1Encoding()
         {
             string strictProperty = Platform.GetEnvironmentVariable(StrictLengthEnabledProperty);
+            bool strictLengthEnabled = strictProperty == null || Platform.EqualsIgnoreCase("true", strictProperty);
 
-            strictLengthEnabled = new bool[]{ strictProperty == null || Platform.EqualsIgnoreCase("true", strictProperty) };
+            m_strictLengthEnabled = Convert.ToInt64(strictLengthEnabled);
         }
 
 
@@ -59,8 +60,7 @@ namespace Org.BouncyCastle.Crypto.Encodings
          *
          * @param cipher
          */
-        public Pkcs1Encoding(
-            IAsymmetricBlockCipher cipher)
+        public Pkcs1Encoding(IAsymmetricBlockCipher cipher)
         {
             this.engine = cipher;
             this.useStrictLength = StrictLengthEnabled;
@@ -139,46 +139,38 @@ namespace Org.BouncyCastle.Crypto.Encodings
                 :	baseBlockSize - HeaderLength;
         }
 
-        public byte[] ProcessBlock(
-            byte[]	input,
-            int		inOff,
-            int		length)
+        public byte[] ProcessBlock(byte[] input, int inOff, int length)
         {
             return forEncryption
                 ?	EncodeBlock(input, inOff, length)
                 :	DecodeBlock(input, inOff, length);
         }
 
-        private byte[] EncodeBlock(
-            byte[]	input,
-            int		inOff,
-            int		inLen)
+        private byte[] EncodeBlock(byte[] input, int inOff, int inLen)
         {
             if (inLen > GetInputBlockSize())
                 throw new ArgumentException("input data too large", "inLen");
 
             byte[] block = new byte[engine.GetInputBlockSize()];
 
+            int lastPadPos = block.Length - 1 - inLen;
             if (forPrivateKey)
             {
-                block[0] = 0x01;                        // type code 1
+                block[0] = 0x01;                                // type code 1
 
-                for (int i = 1; i != block.Length - inLen - 1; i++)
+                for (int i = 1; i < lastPadPos; ++i)
                 {
-                    block[i] = (byte)0xFF;
+                    block[i] = 0xFF;
                 }
             }
             else
             {
-                random.NextBytes(block);                // random fill
+                random.NextBytes(block);                        // random fill
 
-                block[0] = 0x02;                        // type code 2
+                block[0] = 0x02;                                // type code 2
 
-                //
-                // a zero byte marks the end of the padding, so all
-                // the pad bytes must be non-zero.
-                //
-                for (int i = 1; i != block.Length - inLen - 1; i++)
+                // a zero byte marks the end of the padding, so all the pad bytes must be non-zero.
+                for (int i = 1; i < lastPadPos; ++i)
                 {
                     while (block[i] == 0)
                     {
@@ -187,57 +179,92 @@ namespace Org.BouncyCastle.Crypto.Encodings
                 }
             }
 
-            block[block.Length - inLen - 1] = 0x00;       // mark the end of the padding
+            block[lastPadPos] = 0x00;                           // mark the end of the padding
             Array.Copy(input, inOff, block, block.Length - inLen, inLen);
 
             return engine.ProcessBlock(block, 0, block.Length);
         }
 
         /**
-         * Checks if the argument is a correctly PKCS#1.5 encoded Plaintext
-         * for encryption.
-         * 
-         * @param encoded The Plaintext.
-         * @param pLen Expected length of the plaintext.
-         * @return Either 0, if the encoding is correct, or -1, if it is incorrect.
+         * Check the argument is a valid encoding with type 1. Returns the plaintext length if valid, or -1 if invalid.
          */
-        private static int CheckPkcs1Encoding(byte[] encoded, int pLen)
+        private static int CheckPkcs1Encoding1(byte[] buf)
         {
-            int correct = 0;
-            /*
-             * Check if the first two bytes are 0 2
-             */
-            correct |= (encoded[0] ^ 2);
+            int foundZeroMask = 0;
+            int lastPadPos = 0;
 
-            /*
-             * Now the padding check, check for no 0 byte in the padding
-             */
-            int plen = encoded.Length - (
-                      pLen /* Length of the PMS */
-                    +  1 /* Final 0-byte before PMS */
-            );
+            // The first byte should be 0x01
+            int badPadSign = -(buf[0] ^ 0x01);
 
-            for (int i = 1; i < plen; i++)
+            // There must be a zero terminator for the padding somewhere
+            for (int i = 1; i < buf.Length; ++i)
             {
-                int tmp = encoded[i];
-                tmp |= tmp >> 1;
-                tmp |= tmp >> 2;
-                tmp |= tmp >> 4;
-                correct |= (tmp & 1) - 1;
+                int padByte = buf[i];
+                int is0x00Mask = ((padByte ^ 0x00) - 1) >> 31;
+                int is0xFFMask = ((padByte ^ 0xFF) - 1) >> 31;
+                lastPadPos ^= i & ~foundZeroMask & is0x00Mask;
+                foundZeroMask |= is0x00Mask;
+                badPadSign |= ~(foundZeroMask | is0xFFMask);
             }
 
-            /*
-             * Make sure the padding ends with a 0 byte.
-             */
-            correct |= encoded[encoded.Length - (pLen + 1)];
+            // The header should be at least 10 bytes
+            badPadSign |= lastPadPos - 9;
 
-            /*
-             * Return 0 or 1, depending on the result.
-             */
-            correct |= correct >> 1;
-            correct |= correct >> 2;
-            correct |= correct >> 4;
-            return ~((correct & 1) - 1);
+            int plaintextLength = buf.Length - 1 - lastPadPos;
+            return plaintextLength | badPadSign >> 31;
+        }
+
+        /**
+         * Check the argument is a valid encoding with type 2. Returns the plaintext length if valid, or -1 if invalid.
+         */
+        private static int CheckPkcs1Encoding2(byte[] buf)
+        {
+            int foundZeroMask = 0;
+            int lastPadPos = 0;
+
+            // The first byte should be 0x02
+            int badPadSign = -(buf[0] ^ 0x02);
+
+            // There must be a zero terminator for the padding somewhere
+            for (int i = 1; i < buf.Length; ++i)
+            {
+                int padByte = buf[i];
+                int is0x00Mask = ((padByte ^ 0x00) - 1) >> 31;
+                lastPadPos ^= i & ~foundZeroMask & is0x00Mask;
+                foundZeroMask |= is0x00Mask;
+            }
+
+            // The header should be at least 10 bytes
+            badPadSign |= lastPadPos - 9;
+
+            int plaintextLength = buf.Length - 1 - lastPadPos;
+            return plaintextLength | badPadSign >> 31;
+        }
+
+        /**
+         * Check the argument is a valid encoding with type 2 of a plaintext with the given length. Returns 0 if
+         * valid, or -1 if invalid.
+         */
+        private static int CheckPkcs1Encoding2(byte[] buf, int plaintextLength)
+        {
+            // The first byte should be 0x02
+            int badPadSign = -(buf[0] ^ 0x02);
+
+            int lastPadPos = buf.Length - 1 - plaintextLength;
+
+            // The header should be at least 10 bytes
+            badPadSign |= lastPadPos - 9;
+
+            // All pad bytes before the last one should be non-zero
+            for (int i = 1; i < lastPadPos; ++i)
+            {
+                badPadSign |= buf[i] - 1;
+            }
+
+            // Last pad byte should be zero
+            badPadSign |= -buf[lastPadPos];
+
+            return badPadSign >> 31;
         }
 
         /**
@@ -255,31 +282,42 @@ namespace Org.BouncyCastle.Crypto.Encodings
             if (!forPrivateKey)
                 throw new InvalidCipherTextException("sorry, this method is only for decryption, not for signing");
 
-            byte[] block = engine.ProcessBlock(input, inOff, inLen);
-            byte[] fallbackResult = fallback;
-            if (fallbackResult == null)
+            int plaintextLength = this.pLen;
+
+            byte[] random = fallback;
+            if (fallback == null)
             {
-                fallbackResult = SecureRandom.GetNextBytes(SecureRandom.ArbitraryRandom, pLen);
+                random = SecureRandom.GetNextBytes(this.random, plaintextLength);
             }
 
-            byte[] data = (useStrictLength & (block.Length != engine.GetOutputBlockSize())) ? blockBuffer : block;
+            int badPadMask = 0;
+            int strictBlockSize = engine.GetOutputBlockSize();
+            byte[] block = engine.ProcessBlock(input, inOff, inLen);
 
-		    /*
-		     * Check the padding.
-		     */
-            int correct = CheckPkcs1Encoding(data, this.pLen);
+            byte[] data = block;
+            if (block.Length != strictBlockSize)
+            {
+                if (useStrictLength || block.Length < strictBlockSize)
+                {
+                    data = blockBuffer;
+                }
+            }
+
+            badPadMask |= CheckPkcs1Encoding2(data, plaintextLength);
 
-		    /*
-		     * Now, to a constant time constant memory copy of the decrypted value
-		     * or the random value, depending on the validity of the padding.
-		     */
-            byte[] result = new byte[this.pLen];
-            for (int i = 0; i < this.pLen; i++)
+            /*
+             * Now, to a constant time constant memory copy of the decrypted value
+             * or the random value, depending on the validity of the padding.
+             */
+            int dataOff = data.Length - plaintextLength; 
+            byte[] result = new byte[plaintextLength];
+            for (int i = 0; i < plaintextLength; ++i)
             {
-                result[i] = (byte)((data[i + (data.Length - pLen)] & (~correct)) | (fallbackResult[i] & correct));
+                result[i] = (byte)((data[dataOff + i] & ~badPadMask) | (random[i] & badPadMask));
             }
 
-            Arrays.Fill(data, 0);
+            Arrays.Fill(block, 0);
+            Arrays.Fill(blockBuffer, 0, System.Math.Max(0, blockBuffer.Length - block.Length), 0);
 
             return result;
         }
@@ -287,82 +325,44 @@ namespace Org.BouncyCastle.Crypto.Encodings
         /**
         * @exception InvalidCipherTextException if the decrypted block is not in Pkcs1 format.
         */
-        private byte[] DecodeBlock(
-            byte[]	input,
-            int		inOff,
-            int		inLen)
+        private byte[] DecodeBlock(byte[] input, int inOff, int inLen)
         {
             /*
              * If the length of the expected plaintext is known, we use a constant-time decryption.
              * If the decryption fails, we return a random value.
              */
-            if (this.pLen != -1)
-            {
-                return this.DecodeBlockOrRandom(input, inOff, inLen);
-            }
+            if (forPrivateKey && this.pLen != -1)
+                return DecodeBlockOrRandom(input, inOff, inLen);
 
+            int strictBlockSize = engine.GetOutputBlockSize();
             byte[] block = engine.ProcessBlock(input, inOff, inLen);
-            bool incorrectLength = (useStrictLength & (block.Length != engine.GetOutputBlockSize()));
 
-            byte[] data;
-            if (block.Length < GetOutputBlockSize())
+            bool incorrectLength = useStrictLength & (block.Length != strictBlockSize);
+
+            byte[] data = block;
+            if (block.Length < strictBlockSize)
             {
                 data = blockBuffer;
             }
-            else
-            {
-                data = block;
-            }
-
-            byte expectedType = (byte)(forPrivateKey ? 2 : 1);
-            byte type = data[0];
-
-            bool badType = (type != expectedType);
 
-            //
-            // find and extract the message block.
-            //
-            int start = FindStart(type, data);
+            int plaintextLength = forPrivateKey ? CheckPkcs1Encoding2(data) : CheckPkcs1Encoding1(data);
 
-            start++;           // data should start at the next byte
-
-            if (badType | (start < HeaderLength))
-            {
-                Arrays.Fill(data, 0);
-                throw new InvalidCipherTextException("block incorrect");
-            }
-
-            // if we get this far, it's likely to be a genuine encoding error
-            if (incorrectLength)
+            try
             {
-                Arrays.Fill(data, 0);
-                throw new InvalidCipherTextException("block incorrect size");
+                if (plaintextLength < 0)
+                    throw new InvalidCipherTextException("block incorrect");
+                if (incorrectLength)
+                    throw new InvalidCipherTextException("block incorrect size");
+
+                byte[] result = new byte[plaintextLength];
+                Array.Copy(data, data.Length - plaintextLength, result, 0, plaintextLength);
+                return result;
             }
-
-            byte[] result = new byte[data.Length - start];
-
-            Array.Copy(data, start, result, 0, result.Length);
-
-            return result;
-        }
-
-        private int FindStart(byte type, byte[] block)
-        {
-            int start = -1;
-            bool padErr = false;
-
-            for (int i = 1; i != block.Length; i++)
+            finally
             {
-                byte pad = block[i];
-
-                if (pad == 0 & start < 0)
-                {
-                    start = i;
-                }
-                padErr |= ((type == 1) & (start < 0) & (pad != (byte)0xff));
+                Arrays.Fill(block, 0);
+                Arrays.Fill(blockBuffer, 0, System.Math.Max(0, blockBuffer.Length - block.Length), 0);
             }
-
-            return padErr ? -1 : start;
         }
     }
 }