diff options
-rw-r--r-- | crypto/src/crypto/encodings/OaepEncoding.cs | 258 |
1 files changed, 131 insertions, 127 deletions
diff --git a/crypto/src/crypto/encodings/OaepEncoding.cs b/crypto/src/crypto/encodings/OaepEncoding.cs index 9ddaec779..45bb61de9 100644 --- a/crypto/src/crypto/encodings/OaepEncoding.cs +++ b/crypto/src/crypto/encodings/OaepEncoding.cs @@ -14,42 +14,42 @@ namespace Org.BouncyCastle.Crypto.Encodings public class OaepEncoding : IAsymmetricBlockCipher { - private byte[] defHash; - private IDigest mgf1Hash; + private static int GetMgf1NoMemoLimit(IDigest d) + { + if (d is IMemoable) + return d.GetByteLength() - 1; + + return int.MaxValue; + } + + private readonly IAsymmetricBlockCipher engine; + private readonly IDigest mgf1Hash; + private readonly int mgf1NoMemoLimit; + private readonly byte[] defHash; - private IAsymmetricBlockCipher engine; private SecureRandom random; private bool forEncryption; - public OaepEncoding( - IAsymmetricBlockCipher cipher) + public OaepEncoding(IAsymmetricBlockCipher cipher) : this(cipher, new Sha1Digest(), null) { } - public OaepEncoding( - IAsymmetricBlockCipher cipher, - IDigest hash) + public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash) : this(cipher, hash, null) { } - public OaepEncoding( - IAsymmetricBlockCipher cipher, - IDigest hash, - byte[] encodingParams) + public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash, byte[] encodingParams) : this(cipher, hash, hash, encodingParams) { } - public OaepEncoding( - IAsymmetricBlockCipher cipher, - IDigest hash, - IDigest mgf1Hash, - byte[] encodingParams) + public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash, IDigest mgf1Hash, byte[] encodingParams) { this.engine = cipher; this.mgf1Hash = mgf1Hash; + this.mgf1NoMemoLimit = GetMgf1NoMemoLimit(mgf1Hash); this.defHash = new byte[hash.GetDigestSize()]; hash.Reset(); @@ -68,18 +68,16 @@ namespace Org.BouncyCastle.Crypto.Encodings public void Init(bool forEncryption, ICipherParameters parameters) { + SecureRandom initRandom = null; if (parameters is ParametersWithRandom withRandom) { - this.random = withRandom.Random; + initRandom = withRandom.Random; } - else - { - this.random = forEncryption ? CryptoServicesRegistrar.GetSecureRandom() : null; - } - - engine.Init(forEncryption, parameters); + this.random = forEncryption ? CryptoServicesRegistrar.GetSecureRandom(initRandom) : null; this.forEncryption = forEncryption; + + engine.Init(forEncryption, parameters); } public int GetInputBlockSize() @@ -110,29 +108,19 @@ namespace Org.BouncyCastle.Crypto.Encodings } } - public byte[] ProcessBlock( - byte[] inBytes, - int inOff, - int inLen) + public byte[] ProcessBlock(byte[] inBytes, int inOff, int inLen) { - if (forEncryption) - { - return EncodeBlock(inBytes, inOff, inLen); - } - else - { - return DecodeBlock(inBytes, inOff, inLen); - } + return forEncryption + ? EncodeBlock(inBytes, inOff, inLen) + : DecodeBlock(inBytes, inOff, inLen); } - private byte[] EncodeBlock( - byte[] inBytes, - int inOff, - int inLen) + private byte[] EncodeBlock(byte[] inBytes, int inOff, int inLen) { - Check.DataLength(inLen > GetInputBlockSize(), "input data too long"); + int inputBlockSize = GetInputBlockSize(); + Check.DataLength(inLen > inputBlockSize, "input data too long"); - byte[] block = new byte[GetInputBlockSize() + 1 + 2 * defHash.Length]; + byte[] block = new byte[inputBlockSize + 1 + 2 * defHash.Length]; // // copy in the message @@ -156,33 +144,19 @@ namespace Org.BouncyCastle.Crypto.Encodings // // generate the seed. // - byte[] seed = SecureRandom.GetNextBytes(random, defHash.Length); - - // - // mask the message block. - // - byte[] mask = MaskGeneratorFunction(seed, 0, seed.Length, block.Length - defHash.Length); + random.NextBytes(block, 0, defHash.Length); - for (int i = defHash.Length; i != block.Length; i++) - { - block[i] ^= mask[i - defHash.Length]; - } + mgf1Hash.Reset(); // - // add in the seed + // mask the message block. // - Array.Copy(seed, 0, block, 0, defHash.Length); + MaskGeneratorFunction(block, 0, defHash.Length, block, defHash.Length, block.Length - defHash.Length); // // mask the seed. // - mask = MaskGeneratorFunction( - block, defHash.Length, block.Length - defHash.Length, defHash.Length); - - for (int i = 0; i != defHash.Length; i++) - { - block[i] ^= mask[i]; - } + MaskGeneratorFunction(block, defHash.Length, block.Length - defHash.Length, block, 0, defHash.Length); return engine.ProcessBlock(block, 0, block.Length); } @@ -191,52 +165,37 @@ namespace Org.BouncyCastle.Crypto.Encodings * @exception InvalidCipherTextException if the decrypted block turns out to * be badly formatted. */ - private byte[] DecodeBlock( - byte[] inBytes, - int inOff, - int inLen) + private byte[] DecodeBlock(byte[] inBytes, int inOff, int inLen) { - byte[] data = engine.ProcessBlock(inBytes, inOff, inLen); - byte[] block = new byte[engine.GetOutputBlockSize()]; + // i.e. wrong when block.length < (2 * defHash.length) + 1 + int wrongMask = GetOutputBlockSize() >> 31; // // as we may have zeros in our leading bytes for the block we produced // on encryption, we need to make sure our decrypted block comes back // the same size. // - // i.e. wrong when block.length < (2 * defHash.length) + 1 - int wrongMask = (block.Length - ((2 * defHash.Length) + 1)) >> 31; - - if (data.Length <= block.Length) - { - Array.Copy(data, 0, block, block.Length - data.Length, data.Length); - } - else + byte[] block = new byte[engine.GetOutputBlockSize()]; { - Array.Copy(data, 0, block, 0, block.Length); - wrongMask |= 1; + byte[] data = engine.ProcessBlock(inBytes, inOff, inLen); + wrongMask |= (block.Length - data.Length) >> 31; + + int copyLen = System.Math.Min(block.Length, data.Length); + Array.Copy(data, 0, block, block.Length - copyLen, copyLen); + Array.Clear(data, 0, data.Length); } + mgf1Hash.Reset(); + // // unmask the seed. // - byte[] mask = MaskGeneratorFunction( - block, defHash.Length, block.Length - defHash.Length, defHash.Length); - - for (int i = 0; i != defHash.Length; i++) - { - block[i] ^= mask[i]; - } + MaskGeneratorFunction(block, defHash.Length, block.Length - defHash.Length, block, 0, defHash.Length); // // unmask the message block. // - mask = MaskGeneratorFunction(block, 0, defHash.Length, block.Length - defHash.Length); - - for (int i = defHash.Length; i != block.Length; i++) - { - block[i] ^= mask[i - defHash.Length]; - } + MaskGeneratorFunction(block, 0, defHash.Length, block, defHash.Length, block.Length - defHash.Length); // // check the hash of the encoding params. @@ -268,7 +227,7 @@ namespace Org.BouncyCastle.Crypto.Encodings if (wrongMask != 0) { - Arrays.Fill(block, 0); + Array.Clear(block, 0, block.Length); throw new InvalidCipherTextException("data wrong"); } @@ -285,61 +244,106 @@ namespace Org.BouncyCastle.Crypto.Encodings return output; } - private byte[] MaskGeneratorFunction(byte[] Z, int zOff, int zLen, int length) + private void MaskGeneratorFunction(byte[] z, int zOff, int zLen, byte[] mask, int maskOff, int maskLen) { if (mgf1Hash is IXof xof) { - byte[] mask = new byte[length]; - xof.BlockUpdate(Z, zOff, zLen); - xof.OutputFinal(mask, 0, length); - return mask; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + Span<byte> buf = maskLen <= 512 + ? stackalloc byte[maskLen] + : new byte[maskLen]; + xof.BlockUpdate(z, zOff, zLen); + xof.OutputFinal(buf); + Bytes.XorTo(maskLen, buf, mask.AsSpan(maskOff)); +#else + byte[] buf = new byte[maskLen]; + xof.BlockUpdate(z, zOff, zLen); + xof.OutputFinal(buf, 0, maskLen); + Bytes.XorTo(maskLen, buf, 0, mask, maskOff); +#endif + } + else + { + MaskGeneratorFunction1(z, zOff, zLen, mask, maskOff, maskLen); } - - return MaskGeneratorFunction1(Z, zOff, zLen, length); } /** * mask generator function, as described in PKCS1v2. */ - private byte[] MaskGeneratorFunction1( - byte[] Z, - int zOff, - int zLen, - int length) + private void MaskGeneratorFunction1(byte[] z, int zOff, int zLen, byte[] mask, int maskOff, int maskLen) { - byte[] mask = new byte[length]; - byte[] hashBuf = new byte[mgf1Hash.GetDigestSize()]; + int digestSize = mgf1Hash.GetDigestSize(); + +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + Span<byte> hash = digestSize <= 128 + ? stackalloc byte[digestSize] + : new byte[digestSize]; + Span<byte> C = stackalloc byte[4]; +#else + byte[] hash = new byte[digestSize]; byte[] C = new byte[4]; +#endif int counter = 0; - mgf1Hash.Reset(); - - while (counter < (length / hashBuf.Length)) - { - Pack.UInt32_To_BE((uint)counter, C); + int maskEnd = maskOff + maskLen; + int maskLimit = maskEnd - digestSize; + int maskPos = maskOff; - mgf1Hash.BlockUpdate(Z, zOff, zLen); - mgf1Hash.BlockUpdate(C, 0, C.Length); - mgf1Hash.DoFinal(hashBuf, 0); + mgf1Hash.BlockUpdate(z, zOff, zLen); - Array.Copy(hashBuf, 0, mask, counter * hashBuf.Length, hashBuf.Length); - - counter++; + if (zLen > mgf1NoMemoLimit) + { + IMemoable memoable = (IMemoable)mgf1Hash; + IMemoable memo = memoable.Copy(); + + while (maskPos < maskLimit) + { + Pack.UInt32_To_BE((uint)counter++, C); +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + mgf1Hash.BlockUpdate(C); + mgf1Hash.DoFinal(hash); + memoable.Reset(memo); + Bytes.XorTo(digestSize, hash, mask.AsSpan(maskPos)); +#else + mgf1Hash.BlockUpdate(C, 0, C.Length); + mgf1Hash.DoFinal(hash, 0); + memoable.Reset(memo); + Bytes.XorTo(digestSize, hash, 0, mask, maskPos); +#endif + maskPos += digestSize; + } } - - if ((counter * hashBuf.Length) < length) + else { - Pack.UInt32_To_BE((uint)counter, C); - - mgf1Hash.BlockUpdate(Z, zOff, zLen); - mgf1Hash.BlockUpdate(C, 0, C.Length); - mgf1Hash.DoFinal(hashBuf, 0); - - Array.Copy(hashBuf, 0, mask, counter * hashBuf.Length, mask.Length - (counter * hashBuf.Length)); + while (maskPos < maskLimit) + { + Pack.UInt32_To_BE((uint)counter++, C); +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + mgf1Hash.BlockUpdate(C); + mgf1Hash.DoFinal(hash); + mgf1Hash.BlockUpdate(z, zOff, zLen); + Bytes.XorTo(digestSize, hash, mask.AsSpan(maskPos)); +#else + mgf1Hash.BlockUpdate(C, 0, C.Length); + mgf1Hash.DoFinal(hash, 0); + mgf1Hash.BlockUpdate(z, zOff, zLen); + Bytes.XorTo(digestSize, hash, 0, mask, maskPos); +#endif + maskPos += digestSize; + } } - return mask; + Pack.UInt32_To_BE((uint)counter, C); +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + mgf1Hash.BlockUpdate(C); + mgf1Hash.DoFinal(hash); + Bytes.XorTo(maskEnd - maskPos, hash, mask.AsSpan(maskPos)); +#else + mgf1Hash.BlockUpdate(C, 0, C.Length); + mgf1Hash.DoFinal(hash, 0); + Bytes.XorTo(maskEnd - maskPos, hash, 0, mask, maskPos); +#endif } } } - |