summary refs log tree commit diff
path: root/crypto
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-11-15 18:21:02 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-11-15 18:21:02 +0700
commit64131c72df2dc95d4d46d32187923daf28d37f2c (patch)
tree0d838d4a1dfc74785acd70c164c66f03e109fdad /crypto
parentMark RSA key exchange cipher suites to be removed from default list (diff)
downloadBouncyCastle.NET-ed25519-64131c72df2dc95d4d46d32187923daf28d37f2c.tar.xz
Improvements to OaepEncoding
Diffstat (limited to 'crypto')
-rw-r--r--crypto/src/crypto/encodings/OaepEncoding.cs258
1 files changed, 131 insertions, 127 deletions
diff --git a/crypto/src/crypto/encodings/OaepEncoding.cs b/crypto/src/crypto/encodings/OaepEncoding.cs
index 9ddaec779..45bb61de9 100644
--- a/crypto/src/crypto/encodings/OaepEncoding.cs
+++ b/crypto/src/crypto/encodings/OaepEncoding.cs
@@ -14,42 +14,42 @@ namespace Org.BouncyCastle.Crypto.Encodings
     public class OaepEncoding
         : IAsymmetricBlockCipher
     {
-        private byte[] defHash;
-        private IDigest mgf1Hash;
+        private static int GetMgf1NoMemoLimit(IDigest d)
+        {
+            if (d is IMemoable)
+                return d.GetByteLength() - 1;
+
+            return int.MaxValue;
+        }
+
+        private readonly IAsymmetricBlockCipher engine;
+        private readonly IDigest mgf1Hash;
+        private readonly int mgf1NoMemoLimit;
+        private readonly byte[] defHash;
 
-        private IAsymmetricBlockCipher engine;
         private SecureRandom random;
         private bool forEncryption;
 
-        public OaepEncoding(
-            IAsymmetricBlockCipher cipher)
+        public OaepEncoding(IAsymmetricBlockCipher cipher)
             : this(cipher, new Sha1Digest(), null)
         {
         }
 
-        public OaepEncoding(
-            IAsymmetricBlockCipher	cipher,
-            IDigest					hash)
+        public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash)
             : this(cipher, hash, null)
         {
         }
 
-        public OaepEncoding(
-            IAsymmetricBlockCipher	cipher,
-            IDigest					hash,
-            byte[]					encodingParams)
+        public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash, byte[] encodingParams)
             : this(cipher, hash, hash, encodingParams)
         {
         }
 
-        public OaepEncoding(
-            IAsymmetricBlockCipher	cipher,
-            IDigest					hash,
-            IDigest					mgf1Hash,
-            byte[]					encodingParams)
+        public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash, IDigest mgf1Hash, byte[] encodingParams)
         {
             this.engine = cipher;
             this.mgf1Hash = mgf1Hash;
+            this.mgf1NoMemoLimit = GetMgf1NoMemoLimit(mgf1Hash);
             this.defHash = new byte[hash.GetDigestSize()];
 
             hash.Reset();
@@ -68,18 +68,16 @@ namespace Org.BouncyCastle.Crypto.Encodings
 
         public void Init(bool forEncryption, ICipherParameters parameters)
         {
+            SecureRandom initRandom = null;
             if (parameters is ParametersWithRandom withRandom)
             {
-                this.random = withRandom.Random;
+                initRandom = withRandom.Random;
             }
-            else
-            {
-                this.random = forEncryption ? CryptoServicesRegistrar.GetSecureRandom() : null;
-            }
-
-            engine.Init(forEncryption, parameters);
 
+            this.random = forEncryption ? CryptoServicesRegistrar.GetSecureRandom(initRandom) : null;
             this.forEncryption = forEncryption;
+
+            engine.Init(forEncryption, parameters);
         }
 
         public int GetInputBlockSize()
@@ -110,29 +108,19 @@ namespace Org.BouncyCastle.Crypto.Encodings
             }
         }
 
-        public byte[] ProcessBlock(
-            byte[]	inBytes,
-            int		inOff,
-            int		inLen)
+        public byte[] ProcessBlock(byte[] inBytes, int inOff, int inLen)
         {
-            if (forEncryption)
-            {
-                return EncodeBlock(inBytes, inOff, inLen);
-            }
-            else
-            {
-                return DecodeBlock(inBytes, inOff, inLen);
-            }
+            return forEncryption
+                ? EncodeBlock(inBytes, inOff, inLen)
+                : DecodeBlock(inBytes, inOff, inLen);
         }
 
-        private byte[] EncodeBlock(
-            byte[]	inBytes,
-            int		inOff,
-            int		inLen)
+        private byte[] EncodeBlock(byte[] inBytes, int inOff, int inLen)
         {
-            Check.DataLength(inLen > GetInputBlockSize(), "input data too long");
+            int inputBlockSize = GetInputBlockSize();
+            Check.DataLength(inLen > inputBlockSize, "input data too long");
 
-            byte[] block = new byte[GetInputBlockSize() + 1 + 2 * defHash.Length];
+            byte[] block = new byte[inputBlockSize + 1 + 2 * defHash.Length];
 
             //
             // copy in the message
@@ -156,33 +144,19 @@ namespace Org.BouncyCastle.Crypto.Encodings
             //
             // generate the seed.
             //
-            byte[] seed = SecureRandom.GetNextBytes(random, defHash.Length);
-
-            //
-            // mask the message block.
-            //
-            byte[] mask = MaskGeneratorFunction(seed, 0, seed.Length, block.Length - defHash.Length);
+            random.NextBytes(block, 0, defHash.Length);
 
-            for (int i = defHash.Length; i != block.Length; i++)
-            {
-                block[i] ^= mask[i - defHash.Length];
-            }
+            mgf1Hash.Reset();
 
             //
-            // add in the seed
+            // mask the message block.
             //
-            Array.Copy(seed, 0, block, 0, defHash.Length);
+            MaskGeneratorFunction(block, 0, defHash.Length, block, defHash.Length, block.Length - defHash.Length);
 
             //
             // mask the seed.
             //
-            mask = MaskGeneratorFunction(
-                block, defHash.Length, block.Length - defHash.Length, defHash.Length);
-
-            for (int i = 0; i != defHash.Length; i++)
-            {
-                block[i] ^= mask[i];
-            }
+            MaskGeneratorFunction(block, defHash.Length, block.Length - defHash.Length, block, 0, defHash.Length);
 
             return engine.ProcessBlock(block, 0, block.Length);
         }
@@ -191,52 +165,37 @@ namespace Org.BouncyCastle.Crypto.Encodings
         * @exception InvalidCipherTextException if the decrypted block turns out to
         * be badly formatted.
         */
-        private byte[] DecodeBlock(
-            byte[]	inBytes,
-            int		inOff,
-            int		inLen)
+        private byte[] DecodeBlock(byte[] inBytes, int inOff, int inLen)
         {
-            byte[] data = engine.ProcessBlock(inBytes, inOff, inLen);
-            byte[] block = new byte[engine.GetOutputBlockSize()];
+            // i.e. wrong when block.length < (2 * defHash.length) + 1
+            int wrongMask = GetOutputBlockSize() >> 31;
 
             //
             // as we may have zeros in our leading bytes for the block we produced
             // on encryption, we need to make sure our decrypted block comes back
             // the same size.
             //
-            // i.e. wrong when block.length < (2 * defHash.length) + 1
-            int wrongMask = (block.Length - ((2 * defHash.Length) + 1)) >> 31;
-
-            if (data.Length <= block.Length)
-            {
-                Array.Copy(data, 0, block, block.Length - data.Length, data.Length);
-            }
-            else
+            byte[] block = new byte[engine.GetOutputBlockSize()];
             {
-                Array.Copy(data, 0, block, 0, block.Length);
-                wrongMask |= 1;
+                byte[] data = engine.ProcessBlock(inBytes, inOff, inLen);
+                wrongMask |= (block.Length - data.Length) >> 31;
+
+                int copyLen = System.Math.Min(block.Length, data.Length);
+                Array.Copy(data, 0, block, block.Length - copyLen, copyLen);
+                Array.Clear(data, 0, data.Length);
             }
 
+            mgf1Hash.Reset();
+
             //
             // unmask the seed.
             //
-            byte[] mask = MaskGeneratorFunction(
-                block, defHash.Length, block.Length - defHash.Length, defHash.Length);
-
-            for (int i = 0; i != defHash.Length; i++)
-            {
-                block[i] ^= mask[i];
-            }
+            MaskGeneratorFunction(block, defHash.Length, block.Length - defHash.Length, block, 0, defHash.Length);
 
             //
             // unmask the message block.
             //
-            mask = MaskGeneratorFunction(block, 0, defHash.Length, block.Length - defHash.Length);
-
-            for (int i = defHash.Length; i != block.Length; i++)
-            {
-                block[i] ^= mask[i - defHash.Length];
-            }
+            MaskGeneratorFunction(block, 0, defHash.Length, block, defHash.Length, block.Length - defHash.Length);
 
             //
             // check the hash of the encoding params.
@@ -268,7 +227,7 @@ namespace Org.BouncyCastle.Crypto.Encodings
 
             if (wrongMask != 0)
             {
-                Arrays.Fill(block, 0);
+                Array.Clear(block, 0, block.Length);
                 throw new InvalidCipherTextException("data wrong");
             }
 
@@ -285,61 +244,106 @@ namespace Org.BouncyCastle.Crypto.Encodings
             return output;
         }
 
-        private byte[] MaskGeneratorFunction(byte[] Z, int zOff, int zLen, int length)
+        private void MaskGeneratorFunction(byte[] z, int zOff, int zLen, byte[] mask, int maskOff, int maskLen)
         {
             if (mgf1Hash is IXof xof)
             {
-                byte[] mask = new byte[length];
-                xof.BlockUpdate(Z, zOff, zLen);
-                xof.OutputFinal(mask, 0, length);
-                return mask;
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                Span<byte> buf = maskLen <= 512
+                    ? stackalloc byte[maskLen]
+                    : new byte[maskLen];
+                xof.BlockUpdate(z, zOff, zLen);
+                xof.OutputFinal(buf);
+                Bytes.XorTo(maskLen, buf, mask.AsSpan(maskOff));
+#else
+                byte[] buf = new byte[maskLen];
+                xof.BlockUpdate(z, zOff, zLen);
+                xof.OutputFinal(buf, 0, maskLen);
+                Bytes.XorTo(maskLen, buf, 0, mask, maskOff);
+#endif
+            }
+            else
+            {
+                MaskGeneratorFunction1(z, zOff, zLen, mask, maskOff, maskLen);
             }
-
-            return MaskGeneratorFunction1(Z, zOff, zLen, length);
         }
 
         /**
         * mask generator function, as described in PKCS1v2.
         */
-        private byte[] MaskGeneratorFunction1(
-            byte[]	Z,
-            int		zOff,
-            int		zLen,
-            int		length)
+        private void MaskGeneratorFunction1(byte[] z, int zOff, int zLen, byte[] mask, int maskOff, int maskLen)
         {
-            byte[] mask = new byte[length];
-            byte[] hashBuf = new byte[mgf1Hash.GetDigestSize()];
+            int digestSize = mgf1Hash.GetDigestSize();
+
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            Span<byte> hash = digestSize <= 128
+                ? stackalloc byte[digestSize]
+                : new byte[digestSize];
+            Span<byte> C = stackalloc byte[4];
+#else
+            byte[] hash = new byte[digestSize];
             byte[] C = new byte[4];
+#endif
             int counter = 0;
 
-            mgf1Hash.Reset();
-
-            while (counter < (length / hashBuf.Length))
-            {
-                Pack.UInt32_To_BE((uint)counter, C);
+            int maskEnd = maskOff + maskLen;
+            int maskLimit = maskEnd - digestSize;
+            int maskPos = maskOff;
 
-                mgf1Hash.BlockUpdate(Z, zOff, zLen);
-                mgf1Hash.BlockUpdate(C, 0, C.Length);
-                mgf1Hash.DoFinal(hashBuf, 0);
+            mgf1Hash.BlockUpdate(z, zOff, zLen);
 
-                Array.Copy(hashBuf, 0, mask, counter * hashBuf.Length, hashBuf.Length);
-
-                counter++;
+            if (zLen > mgf1NoMemoLimit)
+            {
+                IMemoable memoable = (IMemoable)mgf1Hash;
+                IMemoable memo = memoable.Copy();
+
+                while (maskPos < maskLimit)
+                {
+                    Pack.UInt32_To_BE((uint)counter++, C);
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                    mgf1Hash.BlockUpdate(C);
+                    mgf1Hash.DoFinal(hash);
+                    memoable.Reset(memo);
+                    Bytes.XorTo(digestSize, hash, mask.AsSpan(maskPos));
+#else
+                    mgf1Hash.BlockUpdate(C, 0, C.Length);
+                    mgf1Hash.DoFinal(hash, 0);
+                    memoable.Reset(memo);
+                    Bytes.XorTo(digestSize, hash, 0, mask, maskPos);
+#endif
+                    maskPos += digestSize;
+                }
             }
-
-            if ((counter * hashBuf.Length) < length)
+            else
             {
-                Pack.UInt32_To_BE((uint)counter, C);
-
-                mgf1Hash.BlockUpdate(Z, zOff, zLen);
-                mgf1Hash.BlockUpdate(C, 0, C.Length);
-                mgf1Hash.DoFinal(hashBuf, 0);
-
-                Array.Copy(hashBuf, 0, mask, counter * hashBuf.Length, mask.Length - (counter * hashBuf.Length));
+                while (maskPos < maskLimit)
+                {
+                    Pack.UInt32_To_BE((uint)counter++, C);
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                    mgf1Hash.BlockUpdate(C);
+                    mgf1Hash.DoFinal(hash);
+                    mgf1Hash.BlockUpdate(z, zOff, zLen);
+                    Bytes.XorTo(digestSize, hash, mask.AsSpan(maskPos));
+#else
+                    mgf1Hash.BlockUpdate(C, 0, C.Length);
+                    mgf1Hash.DoFinal(hash, 0);
+                    mgf1Hash.BlockUpdate(z, zOff, zLen);
+                    Bytes.XorTo(digestSize, hash, 0, mask, maskPos);
+#endif
+                    maskPos += digestSize;
+                }
             }
 
-            return mask;
+            Pack.UInt32_To_BE((uint)counter, C);
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            mgf1Hash.BlockUpdate(C);
+            mgf1Hash.DoFinal(hash);
+            Bytes.XorTo(maskEnd - maskPos, hash, mask.AsSpan(maskPos));
+#else
+            mgf1Hash.BlockUpdate(C, 0, C.Length);
+            mgf1Hash.DoFinal(hash, 0);
+            Bytes.XorTo(maskEnd - maskPos, hash, 0, mask, maskPos);
+#endif
         }
     }
 }
-