diff --git a/crypto/src/crypto/encodings/OaepEncoding.cs b/crypto/src/crypto/encodings/OaepEncoding.cs
index 9ddaec779..45bb61de9 100644
--- a/crypto/src/crypto/encodings/OaepEncoding.cs
+++ b/crypto/src/crypto/encodings/OaepEncoding.cs
@@ -14,42 +14,42 @@ namespace Org.BouncyCastle.Crypto.Encodings
public class OaepEncoding
: IAsymmetricBlockCipher
{
- private byte[] defHash;
- private IDigest mgf1Hash;
+ private static int GetMgf1NoMemoLimit(IDigest d)
+ {
+ if (d is IMemoable)
+ return d.GetByteLength() - 1;
+
+ return int.MaxValue;
+ }
+
+ private readonly IAsymmetricBlockCipher engine;
+ private readonly IDigest mgf1Hash;
+ private readonly int mgf1NoMemoLimit;
+ private readonly byte[] defHash;
- private IAsymmetricBlockCipher engine;
private SecureRandom random;
private bool forEncryption;
- public OaepEncoding(
- IAsymmetricBlockCipher cipher)
+ public OaepEncoding(IAsymmetricBlockCipher cipher)
: this(cipher, new Sha1Digest(), null)
{
}
- public OaepEncoding(
- IAsymmetricBlockCipher cipher,
- IDigest hash)
+ public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash)
: this(cipher, hash, null)
{
}
- public OaepEncoding(
- IAsymmetricBlockCipher cipher,
- IDigest hash,
- byte[] encodingParams)
+ public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash, byte[] encodingParams)
: this(cipher, hash, hash, encodingParams)
{
}
- public OaepEncoding(
- IAsymmetricBlockCipher cipher,
- IDigest hash,
- IDigest mgf1Hash,
- byte[] encodingParams)
+ public OaepEncoding(IAsymmetricBlockCipher cipher, IDigest hash, IDigest mgf1Hash, byte[] encodingParams)
{
this.engine = cipher;
this.mgf1Hash = mgf1Hash;
+ this.mgf1NoMemoLimit = GetMgf1NoMemoLimit(mgf1Hash);
this.defHash = new byte[hash.GetDigestSize()];
hash.Reset();
@@ -68,18 +68,16 @@ namespace Org.BouncyCastle.Crypto.Encodings
public void Init(bool forEncryption, ICipherParameters parameters)
{
+ SecureRandom initRandom = null;
if (parameters is ParametersWithRandom withRandom)
{
- this.random = withRandom.Random;
+ initRandom = withRandom.Random;
}
- else
- {
- this.random = forEncryption ? CryptoServicesRegistrar.GetSecureRandom() : null;
- }
-
- engine.Init(forEncryption, parameters);
+ this.random = forEncryption ? CryptoServicesRegistrar.GetSecureRandom(initRandom) : null;
this.forEncryption = forEncryption;
+
+ engine.Init(forEncryption, parameters);
}
public int GetInputBlockSize()
@@ -110,29 +108,19 @@ namespace Org.BouncyCastle.Crypto.Encodings
}
}
- public byte[] ProcessBlock(
- byte[] inBytes,
- int inOff,
- int inLen)
+ public byte[] ProcessBlock(byte[] inBytes, int inOff, int inLen)
{
- if (forEncryption)
- {
- return EncodeBlock(inBytes, inOff, inLen);
- }
- else
- {
- return DecodeBlock(inBytes, inOff, inLen);
- }
+ return forEncryption
+ ? EncodeBlock(inBytes, inOff, inLen)
+ : DecodeBlock(inBytes, inOff, inLen);
}
- private byte[] EncodeBlock(
- byte[] inBytes,
- int inOff,
- int inLen)
+ private byte[] EncodeBlock(byte[] inBytes, int inOff, int inLen)
{
- Check.DataLength(inLen > GetInputBlockSize(), "input data too long");
+ int inputBlockSize = GetInputBlockSize();
+ Check.DataLength(inLen > inputBlockSize, "input data too long");
- byte[] block = new byte[GetInputBlockSize() + 1 + 2 * defHash.Length];
+ byte[] block = new byte[inputBlockSize + 1 + 2 * defHash.Length];
//
// copy in the message
@@ -156,33 +144,19 @@ namespace Org.BouncyCastle.Crypto.Encodings
//
// generate the seed.
//
- byte[] seed = SecureRandom.GetNextBytes(random, defHash.Length);
-
- //
- // mask the message block.
- //
- byte[] mask = MaskGeneratorFunction(seed, 0, seed.Length, block.Length - defHash.Length);
+ random.NextBytes(block, 0, defHash.Length);
- for (int i = defHash.Length; i != block.Length; i++)
- {
- block[i] ^= mask[i - defHash.Length];
- }
+ mgf1Hash.Reset();
//
- // add in the seed
+ // mask the message block.
//
- Array.Copy(seed, 0, block, 0, defHash.Length);
+ MaskGeneratorFunction(block, 0, defHash.Length, block, defHash.Length, block.Length - defHash.Length);
//
// mask the seed.
//
- mask = MaskGeneratorFunction(
- block, defHash.Length, block.Length - defHash.Length, defHash.Length);
-
- for (int i = 0; i != defHash.Length; i++)
- {
- block[i] ^= mask[i];
- }
+ MaskGeneratorFunction(block, defHash.Length, block.Length - defHash.Length, block, 0, defHash.Length);
return engine.ProcessBlock(block, 0, block.Length);
}
@@ -191,52 +165,37 @@ namespace Org.BouncyCastle.Crypto.Encodings
* @exception InvalidCipherTextException if the decrypted block turns out to
* be badly formatted.
*/
- private byte[] DecodeBlock(
- byte[] inBytes,
- int inOff,
- int inLen)
+ private byte[] DecodeBlock(byte[] inBytes, int inOff, int inLen)
{
- byte[] data = engine.ProcessBlock(inBytes, inOff, inLen);
- byte[] block = new byte[engine.GetOutputBlockSize()];
+ // i.e. wrong when block.length < (2 * defHash.length) + 1
+ int wrongMask = GetOutputBlockSize() >> 31;
//
// as we may have zeros in our leading bytes for the block we produced
// on encryption, we need to make sure our decrypted block comes back
// the same size.
//
- // i.e. wrong when block.length < (2 * defHash.length) + 1
- int wrongMask = (block.Length - ((2 * defHash.Length) + 1)) >> 31;
-
- if (data.Length <= block.Length)
- {
- Array.Copy(data, 0, block, block.Length - data.Length, data.Length);
- }
- else
+ byte[] block = new byte[engine.GetOutputBlockSize()];
{
- Array.Copy(data, 0, block, 0, block.Length);
- wrongMask |= 1;
+ byte[] data = engine.ProcessBlock(inBytes, inOff, inLen);
+ wrongMask |= (block.Length - data.Length) >> 31;
+
+ int copyLen = System.Math.Min(block.Length, data.Length);
+ Array.Copy(data, 0, block, block.Length - copyLen, copyLen);
+ Array.Clear(data, 0, data.Length);
}
+ mgf1Hash.Reset();
+
//
// unmask the seed.
//
- byte[] mask = MaskGeneratorFunction(
- block, defHash.Length, block.Length - defHash.Length, defHash.Length);
-
- for (int i = 0; i != defHash.Length; i++)
- {
- block[i] ^= mask[i];
- }
+ MaskGeneratorFunction(block, defHash.Length, block.Length - defHash.Length, block, 0, defHash.Length);
//
// unmask the message block.
//
- mask = MaskGeneratorFunction(block, 0, defHash.Length, block.Length - defHash.Length);
-
- for (int i = defHash.Length; i != block.Length; i++)
- {
- block[i] ^= mask[i - defHash.Length];
- }
+ MaskGeneratorFunction(block, 0, defHash.Length, block, defHash.Length, block.Length - defHash.Length);
//
// check the hash of the encoding params.
@@ -268,7 +227,7 @@ namespace Org.BouncyCastle.Crypto.Encodings
if (wrongMask != 0)
{
- Arrays.Fill(block, 0);
+ Array.Clear(block, 0, block.Length);
throw new InvalidCipherTextException("data wrong");
}
@@ -285,61 +244,106 @@ namespace Org.BouncyCastle.Crypto.Encodings
return output;
}
- private byte[] MaskGeneratorFunction(byte[] Z, int zOff, int zLen, int length)
+ private void MaskGeneratorFunction(byte[] z, int zOff, int zLen, byte[] mask, int maskOff, int maskLen)
{
if (mgf1Hash is IXof xof)
{
- byte[] mask = new byte[length];
- xof.BlockUpdate(Z, zOff, zLen);
- xof.OutputFinal(mask, 0, length);
- return mask;
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ Span<byte> buf = maskLen <= 512
+ ? stackalloc byte[maskLen]
+ : new byte[maskLen];
+ xof.BlockUpdate(z, zOff, zLen);
+ xof.OutputFinal(buf);
+ Bytes.XorTo(maskLen, buf, mask.AsSpan(maskOff));
+#else
+ byte[] buf = new byte[maskLen];
+ xof.BlockUpdate(z, zOff, zLen);
+ xof.OutputFinal(buf, 0, maskLen);
+ Bytes.XorTo(maskLen, buf, 0, mask, maskOff);
+#endif
+ }
+ else
+ {
+ MaskGeneratorFunction1(z, zOff, zLen, mask, maskOff, maskLen);
}
-
- return MaskGeneratorFunction1(Z, zOff, zLen, length);
}
/**
* mask generator function, as described in PKCS1v2.
*/
- private byte[] MaskGeneratorFunction1(
- byte[] Z,
- int zOff,
- int zLen,
- int length)
+ private void MaskGeneratorFunction1(byte[] z, int zOff, int zLen, byte[] mask, int maskOff, int maskLen)
{
- byte[] mask = new byte[length];
- byte[] hashBuf = new byte[mgf1Hash.GetDigestSize()];
+ int digestSize = mgf1Hash.GetDigestSize();
+
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ Span<byte> hash = digestSize <= 128
+ ? stackalloc byte[digestSize]
+ : new byte[digestSize];
+ Span<byte> C = stackalloc byte[4];
+#else
+ byte[] hash = new byte[digestSize];
byte[] C = new byte[4];
+#endif
int counter = 0;
- mgf1Hash.Reset();
-
- while (counter < (length / hashBuf.Length))
- {
- Pack.UInt32_To_BE((uint)counter, C);
+ int maskEnd = maskOff + maskLen;
+ int maskLimit = maskEnd - digestSize;
+ int maskPos = maskOff;
- mgf1Hash.BlockUpdate(Z, zOff, zLen);
- mgf1Hash.BlockUpdate(C, 0, C.Length);
- mgf1Hash.DoFinal(hashBuf, 0);
+ mgf1Hash.BlockUpdate(z, zOff, zLen);
- Array.Copy(hashBuf, 0, mask, counter * hashBuf.Length, hashBuf.Length);
-
- counter++;
+ if (zLen > mgf1NoMemoLimit)
+ {
+ IMemoable memoable = (IMemoable)mgf1Hash;
+ IMemoable memo = memoable.Copy();
+
+ while (maskPos < maskLimit)
+ {
+ Pack.UInt32_To_BE((uint)counter++, C);
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ mgf1Hash.BlockUpdate(C);
+ mgf1Hash.DoFinal(hash);
+ memoable.Reset(memo);
+ Bytes.XorTo(digestSize, hash, mask.AsSpan(maskPos));
+#else
+ mgf1Hash.BlockUpdate(C, 0, C.Length);
+ mgf1Hash.DoFinal(hash, 0);
+ memoable.Reset(memo);
+ Bytes.XorTo(digestSize, hash, 0, mask, maskPos);
+#endif
+ maskPos += digestSize;
+ }
}
-
- if ((counter * hashBuf.Length) < length)
+ else
{
- Pack.UInt32_To_BE((uint)counter, C);
-
- mgf1Hash.BlockUpdate(Z, zOff, zLen);
- mgf1Hash.BlockUpdate(C, 0, C.Length);
- mgf1Hash.DoFinal(hashBuf, 0);
-
- Array.Copy(hashBuf, 0, mask, counter * hashBuf.Length, mask.Length - (counter * hashBuf.Length));
+ while (maskPos < maskLimit)
+ {
+ Pack.UInt32_To_BE((uint)counter++, C);
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ mgf1Hash.BlockUpdate(C);
+ mgf1Hash.DoFinal(hash);
+ mgf1Hash.BlockUpdate(z, zOff, zLen);
+ Bytes.XorTo(digestSize, hash, mask.AsSpan(maskPos));
+#else
+ mgf1Hash.BlockUpdate(C, 0, C.Length);
+ mgf1Hash.DoFinal(hash, 0);
+ mgf1Hash.BlockUpdate(z, zOff, zLen);
+ Bytes.XorTo(digestSize, hash, 0, mask, maskPos);
+#endif
+ maskPos += digestSize;
+ }
}
- return mask;
+ Pack.UInt32_To_BE((uint)counter, C);
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ mgf1Hash.BlockUpdate(C);
+ mgf1Hash.DoFinal(hash);
+ Bytes.XorTo(maskEnd - maskPos, hash, mask.AsSpan(maskPos));
+#else
+ mgf1Hash.BlockUpdate(C, 0, C.Length);
+ mgf1Hash.DoFinal(hash, 0);
+ Bytes.XorTo(maskEnd - maskPos, hash, 0, mask, maskPos);
+#endif
}
}
}
-
|