summary refs log tree commit diff
path: root/crypto/src
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src')
-rw-r--r--crypto/src/asn1/Asn1InputStream.cs11
-rw-r--r--crypto/src/asn1/Asn1RelativeOid.cs84
-rw-r--r--crypto/src/asn1/DerObjectIdentifier.cs78
-rw-r--r--crypto/src/crypto/tls/TlsRsaKeyExchange.cs224
-rw-r--r--crypto/src/math/ec/ECCurve.cs18
-rw-r--r--crypto/src/math/ec/rfc8032/Scalar25519.cs26
-rw-r--r--crypto/src/math/ec/rfc8032/Scalar448.cs26
-rw-r--r--crypto/src/math/ec/rfc8032/ScalarUtilities.cs216
-rw-r--r--crypto/src/tls/crypto/impl/bc/BcDefaultTlsCredentialedDecryptor.cs93
-rw-r--r--crypto/src/util/io/LimitedBuffer.cs1
10 files changed, 569 insertions, 208 deletions
diff --git a/crypto/src/asn1/Asn1InputStream.cs b/crypto/src/asn1/Asn1InputStream.cs
index 96b0a1c66..3b5eaaa95 100644
--- a/crypto/src/asn1/Asn1InputStream.cs
+++ b/crypto/src/asn1/Asn1InputStream.cs
@@ -377,7 +377,9 @@ namespace Org.BouncyCastle.Asn1
             switch (tagNo)
             {
             case Asn1Tags.BmpString:
+            {
                 return CreateDerBmpString(defIn);
+            }
             case Asn1Tags.Boolean:
             {
                 GetBuffer(defIn, tmpBuffers, out var contents);
@@ -390,9 +392,16 @@ namespace Org.BouncyCastle.Asn1
             }
             case Asn1Tags.ObjectIdentifier:
             {
+                DerObjectIdentifier.CheckContentsLength(defIn.Remaining);
                 bool usedBuffer = GetBuffer(defIn, tmpBuffers, out var contents);
                 return DerObjectIdentifier.CreatePrimitive(contents, clone: usedBuffer);
             }
+            case Asn1Tags.RelativeOid:
+            {
+                Asn1RelativeOid.CheckContentsLength(defIn.Remaining);
+                bool usedBuffer = GetBuffer(defIn, tmpBuffers, out var contents);
+                return Asn1RelativeOid.CreatePrimitive(contents, clone: usedBuffer);
+            }
             }
 
             byte[] bytes = defIn.ToArray();
@@ -421,8 +430,6 @@ namespace Org.BouncyCastle.Asn1
                 return Asn1OctetString.CreatePrimitive(bytes);
             case Asn1Tags.PrintableString:
                 return DerPrintableString.CreatePrimitive(bytes);
-            case Asn1Tags.RelativeOid:
-                return Asn1RelativeOid.CreatePrimitive(bytes, false);
             case Asn1Tags.T61String:
                 return DerT61String.CreatePrimitive(bytes);
             case Asn1Tags.UniversalString:
diff --git a/crypto/src/asn1/Asn1RelativeOid.cs b/crypto/src/asn1/Asn1RelativeOid.cs
index f43a85479..1fbf83d5c 100644
--- a/crypto/src/asn1/Asn1RelativeOid.cs
+++ b/crypto/src/asn1/Asn1RelativeOid.cs
@@ -22,6 +22,16 @@ namespace Org.BouncyCastle.Asn1
             }
         }
 
+        /// <summary>Implementation limit on the length of the contents octets for a Relative OID.</summary>
+        /// <remarks>
+        /// We adopt the same value used by OpenJDK for Object Identifier. In theory there is no limit on the length of
+        /// the contents, or the number of subidentifiers, or the length of individual subidentifiers. In practice,
+        /// supporting arbitrary lengths can lead to issues, e.g. denial-of-service attacks when attempting to convert a
+        /// parsed value to its (decimal) string form.
+        /// </remarks>
+        private const int MaxContentsLength = 4096;
+        private const int MaxIdentifierLength = MaxContentsLength * 4 - 1;
+
         public static Asn1RelativeOid FromContents(byte[] contents)
         {
             if (contents == null)
@@ -68,14 +78,18 @@ namespace Org.BouncyCastle.Asn1
         {
             if (identifier == null)
                 throw new ArgumentNullException(nameof(identifier));
-            if (!IsValidIdentifier(identifier, 0))
+            if (identifier.Length <= MaxIdentifierLength && IsValidIdentifier(identifier, from: 0))
             {
-                oid = default;
-                return false;
+                byte[] contents = ParseIdentifier(identifier);
+                if (contents.Length <= MaxContentsLength)
+                {
+                    oid = new Asn1RelativeOid(contents, identifier);
+                    return true;
+                }
             }
 
-            oid = new Asn1RelativeOid(ParseIdentifier(identifier), identifier);
-            return true;
+            oid = default;
+            return false;
         }
 
         private const long LongLimit = (long.MaxValue >> 7) - 0x7F;
@@ -85,31 +99,13 @@ namespace Org.BouncyCastle.Asn1
 
         public Asn1RelativeOid(string identifier)
         {
-            if (identifier == null)
-                throw new ArgumentNullException("identifier");
-            if (!IsValidIdentifier(identifier, 0))
-                throw new FormatException("string " + identifier + " not a relative OID");
-
-            m_contents = ParseIdentifier(identifier);
-            m_identifier = identifier;
-        }
-
-        private Asn1RelativeOid(Asn1RelativeOid oid, string branchID)
-        {
-            if (!IsValidIdentifier(branchID, 0))
-                throw new FormatException("string " + branchID + " not a valid relative OID branch");
-
-            m_contents = Arrays.Concatenate(oid.m_contents, ParseIdentifier(branchID));
-            m_identifier = oid.GetID() + "." + branchID;
-        }
+            CheckIdentifier(identifier);
 
-        private Asn1RelativeOid(byte[] contents, bool clone)
-        {
-            if (!IsValidContents(contents))
-                throw new ArgumentException("invalid relative OID contents", nameof(contents));
+            byte[] contents = ParseIdentifier(identifier);
+            CheckContentsLength(contents.Length);
 
-            m_contents = clone ? Arrays.Clone(contents) : contents;
-            m_identifier = null;
+            m_contents = contents;
+            m_identifier = identifier;
         }
 
         private Asn1RelativeOid(byte[] contents, string identifier)
@@ -120,7 +116,14 @@ namespace Org.BouncyCastle.Asn1
 
         public virtual Asn1RelativeOid Branch(string branchID)
         {
-            return new Asn1RelativeOid(this, branchID);
+            CheckIdentifier(branchID);
+
+            byte[] branchContents = ParseIdentifier(branchID);
+            CheckContentsLength(m_contents.Length + branchContents.Length);
+
+            return new Asn1RelativeOid(
+                contents: Arrays.Concatenate(m_contents, branchContents),
+                identifier: GetID() + "." + branchID);
         }
 
         public string GetID()
@@ -165,9 +168,30 @@ namespace Org.BouncyCastle.Asn1
             return new PrimitiveDerEncoding(tagClass, tagNo, m_contents);
         }
 
+        internal static void CheckContentsLength(int contentsLength)
+        {
+            if (contentsLength > MaxContentsLength)
+                throw new ArgumentException("exceeded relative OID contents length limit");
+        }
+
+        internal static void CheckIdentifier(string identifier)
+        {
+            if (identifier == null)
+                throw new ArgumentNullException(nameof(identifier));
+            if (identifier.Length > MaxIdentifierLength)
+                throw new ArgumentException("exceeded relative OID contents length limit");
+            if (!IsValidIdentifier(identifier, from: 0))
+                throw new FormatException("string " + identifier + " not a valid relative OID");
+        }
+
         internal static Asn1RelativeOid CreatePrimitive(byte[] contents, bool clone)
         {
-            return new Asn1RelativeOid(contents, clone);
+            CheckContentsLength(contents.Length);
+
+            if (!IsValidContents(contents))
+                throw new ArgumentException("invalid relative OID contents", nameof(contents));
+
+            return new Asn1RelativeOid(clone ? Arrays.Clone(contents) : contents, identifier: null);
         }
 
         internal static bool IsValidContents(byte[] contents)
diff --git a/crypto/src/asn1/DerObjectIdentifier.cs b/crypto/src/asn1/DerObjectIdentifier.cs
index 04792cbdd..7e1d5c2ff 100644
--- a/crypto/src/asn1/DerObjectIdentifier.cs
+++ b/crypto/src/asn1/DerObjectIdentifier.cs
@@ -23,6 +23,16 @@ namespace Org.BouncyCastle.Asn1
             }
         }
 
+        /// <summary>Implementation limit on the length of the contents octets for an Object Identifier.</summary>
+        /// <remarks>
+        /// We adopt the same value used by OpenJDK. In theory there is no limit on the length of the contents, or the
+        /// number of subidentifiers, or the length of individual subidentifiers. In practice, supporting arbitrary
+        /// lengths can lead to issues, e.g. denial-of-service attacks when attempting to convert a parsed value to its
+        /// (decimal) string form.
+        /// </remarks>
+        private const int MaxContentsLength = 4096;
+        private const int MaxIdentifierLength = MaxContentsLength * 4 + 1;
+
         public static DerObjectIdentifier FromContents(byte[] contents)
         {
             if (contents == null)
@@ -86,14 +96,18 @@ namespace Org.BouncyCastle.Asn1
         {
             if (identifier == null)
                 throw new ArgumentNullException(nameof(identifier));
-            if (!IsValidIdentifier(identifier))
+            if (identifier.Length <= MaxIdentifierLength && IsValidIdentifier(identifier))
             {
-                oid = default;
-                return false;
+                byte[] contents = ParseIdentifier(identifier);
+                if (contents.Length <= MaxContentsLength)
+                {
+                    oid = new DerObjectIdentifier(contents, identifier);
+                    return true;
+                }
             }
 
-            oid = new DerObjectIdentifier(ParseIdentifier(identifier), identifier);
-            return true;
+            oid = default;
+            return false;
         }
 
         private const long LongLimit = (long.MaxValue >> 7) - 0x7F;
@@ -105,22 +119,13 @@ namespace Org.BouncyCastle.Asn1
 
         public DerObjectIdentifier(string identifier)
         {
-            if (identifier == null)
-                throw new ArgumentNullException("identifier");
-            if (!IsValidIdentifier(identifier))
-                throw new FormatException("string " + identifier + " not an OID");
+            CheckIdentifier(identifier);
 
-            m_contents = ParseIdentifier(identifier);
-            m_identifier = identifier;
-        }
+            byte[] contents = ParseIdentifier(identifier);
+            CheckContentsLength(contents.Length);
 
-        private DerObjectIdentifier(byte[] contents, bool clone)
-        {
-            if (!Asn1RelativeOid.IsValidContents(contents))
-                throw new ArgumentException("invalid OID contents", nameof(contents));
-
-            m_contents = clone ? Arrays.Clone(contents) : contents;
-            m_identifier = null;
+            m_contents = contents;
+            m_identifier = identifier;
         }
 
         private DerObjectIdentifier(byte[] contents, string identifier)
@@ -131,11 +136,13 @@ namespace Org.BouncyCastle.Asn1
 
         public virtual DerObjectIdentifier Branch(string branchID)
         {
-            if (!Asn1RelativeOid.IsValidIdentifier(branchID, 0))
-                throw new FormatException("string " + branchID + " not a valid OID branch");
+            Asn1RelativeOid.CheckIdentifier(branchID);
+
+            byte[] branchContents = Asn1RelativeOid.ParseIdentifier(branchID);
+            CheckContentsLength(m_contents.Length + branchContents.Length);
 
             return new DerObjectIdentifier(
-                contents: Arrays.Concatenate(m_contents, Asn1RelativeOid.ParseIdentifier(branchID)),
+                contents: Arrays.Concatenate(m_contents, branchContents),
                 identifier: GetID() + "." + branchID);
         }
 
@@ -195,9 +202,27 @@ namespace Org.BouncyCastle.Asn1
             return new PrimitiveDerEncoding(tagClass, tagNo, m_contents);
         }
 
+        internal static void CheckContentsLength(int contentsLength)
+        {
+            if (contentsLength > MaxContentsLength)
+                throw new ArgumentException("exceeded OID contents length limit");
+        }
+
+        internal static void CheckIdentifier(string identifier)
+        {
+            if (identifier == null)
+                throw new ArgumentNullException(nameof(identifier));
+            if (identifier.Length > MaxIdentifierLength)
+                throw new ArgumentException("exceeded OID contents length limit");
+            if (!IsValidIdentifier(identifier))
+                throw new FormatException("string " + identifier + " not a valid OID");
+        }
+
         internal static DerObjectIdentifier CreatePrimitive(byte[] contents, bool clone)
         {
-            int index = Arrays.GetHashCode(contents);
+            CheckContentsLength(contents.Length);
+
+            uint index = (uint)Arrays.GetHashCode(contents);
 
             index ^= index >> 20;
             index ^= index >> 10;
@@ -207,7 +232,10 @@ namespace Org.BouncyCastle.Asn1
             if (originalEntry != null && Arrays.AreEqual(contents, originalEntry.m_contents))
                 return originalEntry;
 
-            var newEntry = new DerObjectIdentifier(contents, clone);
+            if (!Asn1RelativeOid.IsValidContents(contents))
+                throw new ArgumentException("invalid OID contents", nameof(contents));
+
+            var newEntry = new DerObjectIdentifier(clone ? Arrays.Clone(contents) : contents, identifier: null);
 
             var exchangedEntry = Interlocked.CompareExchange(ref Cache[index], newEntry, originalEntry);
             if (exchangedEntry != originalEntry)
@@ -228,7 +256,7 @@ namespace Org.BouncyCastle.Asn1
             if (first < '0' || first > '2')
                 return false;
 
-            if (!Asn1RelativeOid.IsValidIdentifier(identifier, 2))
+            if (!Asn1RelativeOid.IsValidIdentifier(identifier, from: 2))
                 return false;
 
             if (first == '2')
diff --git a/crypto/src/crypto/tls/TlsRsaKeyExchange.cs b/crypto/src/crypto/tls/TlsRsaKeyExchange.cs
new file mode 100644
index 000000000..20c2360ea
--- /dev/null
+++ b/crypto/src/crypto/tls/TlsRsaKeyExchange.cs
@@ -0,0 +1,224 @@
+using System;
+using System.Diagnostics;
+
+using Org.BouncyCastle.Crypto.Parameters;
+using Org.BouncyCastle.Crypto.Utilities;
+using Org.BouncyCastle.Math;
+using Org.BouncyCastle.Security;
+using Org.BouncyCastle.Utilities;
+
+namespace Org.BouncyCastle.Crypto.Tls
+{
+    public static class TlsRsaKeyExchange
+    {
+        public const int PreMasterSecretLength = 48;
+
+        public static byte[] DecryptPreMasterSecret(byte[] buf, int off, int len, RsaKeyParameters privateKey,
+            int protocolVersion, SecureRandom secureRandom)
+        {
+            if (buf == null || len < 1 || len > GetInputLimit(privateKey) || off < 0 || off > buf.Length - len)
+                throw new ArgumentException("input not a valid EncryptedPreMasterSecret");
+
+            if (!privateKey.IsPrivate)
+                throw new ArgumentException("must be an RSA private key", nameof(privateKey));
+
+            BigInteger modulus = privateKey.Modulus;
+            int bitLength = modulus.BitLength;
+            if (bitLength < 512)
+                throw new ArgumentException("must be at least 512 bits", nameof(privateKey));
+
+            if ((protocolVersion & 0xFFFF) != protocolVersion)
+                throw new ArgumentException("must be a 16 bit value", nameof(protocolVersion));
+
+            secureRandom = CryptoServicesRegistrar.GetSecureRandom(secureRandom);
+
+            /*
+             * Generate random bytes we can use as a Pre-Master-Secret if the decrypted value is invalid.
+             */
+            byte[] result = new byte[PreMasterSecretLength];
+            secureRandom.NextBytes(result);
+
+            try
+            {
+                BigInteger input = ConvertInput(modulus, buf, off, len);
+                byte[] encoding = RsaBlinded(privateKey, input, secureRandom);
+
+                int pkcs1Length = (bitLength - 1) / 8;
+                int plainTextOffset = encoding.Length - PreMasterSecretLength;
+
+                int badEncodingMask = CheckPkcs1Encoding2(encoding, pkcs1Length, PreMasterSecretLength);
+                int badVersionMask = -(Pack.BE_To_UInt16(encoding, plainTextOffset) ^ protocolVersion) >> 31;
+                int fallbackMask = badEncodingMask | badVersionMask;
+
+                for (int i = 0; i < PreMasterSecretLength; ++i)
+                {
+                    result[i] = (byte)((result[i] & fallbackMask) | (encoding[plainTextOffset + i] & ~fallbackMask));
+                }
+
+                Arrays.Fill(encoding, 0x00);
+            }
+            catch (Exception)
+            {
+                /*
+                 * Decryption should never throw an exception; return a random value instead.
+                 *
+                 * In any case, a TLS server MUST NOT generate an alert if processing an RSA-encrypted premaster
+                 * secret message fails, or the version number is not as expected. Instead, it MUST continue the
+                 * handshake with a randomly generated premaster secret.
+                 */
+            }
+
+            return result;
+        }
+
+        public static int GetInputLimit(RsaKeyParameters privateKey)
+        {
+            return (privateKey.Modulus.BitLength + 7) / 8;
+        }
+
+        private static int CAddTo(int len, int cond, byte[] x, byte[] z)
+        {
+            Debug.Assert(cond == 0 || cond == -1);
+
+            int c = 0;
+            for (int i = len - 1; i >= 0; --i)
+            {
+                c += z[i] + (x[i] & cond);
+                z[i] = (byte)c;
+                c >>= 8;
+            }
+            return c;
+        }
+
+        /**
+         * Check the argument is a valid encoding with type 2 of a plaintext with the given length. Returns 0 if
+         * valid, or -1 if invalid.
+         */
+        private static int CheckPkcs1Encoding2(byte[] buf, int pkcs1Length, int plaintextLength)
+        {
+            // The header should be at least 10 bytes
+            int errorSign = pkcs1Length - plaintextLength - 10;
+
+            int firstPadPos = buf.Length - pkcs1Length;
+            int lastPadPos = buf.Length - 1 - plaintextLength;
+
+            // Any leading bytes should be zero
+            for (int i = 0; i < firstPadPos; ++i)
+            {
+                errorSign |= -buf[i];
+            }
+
+            // The first byte should be 0x02
+            errorSign |= -(buf[firstPadPos] ^ 0x02);
+
+            // All pad bytes before the last one should be non-zero
+            for (int i = firstPadPos + 1; i < lastPadPos; ++i)
+            {
+                errorSign |= buf[i] - 1;
+            }
+
+            // Last pad byte should be zero
+            errorSign |= -buf[lastPadPos];
+
+            return errorSign >> 31;
+        }
+
+        private static BigInteger ConvertInput(BigInteger modulus, byte[] buf, int off, int len)
+        {
+            BigInteger result = BigIntegers.FromUnsignedByteArray(buf, off, len);
+            if (result.CompareTo(modulus) < 0)
+                return result;
+
+            throw new DataLengthException("input too large for RSA cipher.");
+        }
+
+        private static BigInteger Rsa(RsaKeyParameters privateKey, BigInteger input)
+        {
+            return input.ModPow(privateKey.Exponent, privateKey.Modulus);
+        }
+
+        private static byte[] RsaBlinded(RsaKeyParameters privateKey, BigInteger input, SecureRandom secureRandom)
+        {
+            BigInteger modulus = privateKey.Modulus;
+            int resultSize = (modulus.BitLength + 7) / 8;
+
+            if (!(privateKey is RsaPrivateCrtKeyParameters crtKey))
+                return BigIntegers.AsUnsignedByteArray(resultSize, Rsa(privateKey, input));
+
+            BigInteger e = crtKey.PublicExponent;
+            Debug.Assert(e != null);
+
+            BigInteger r = BigIntegers.CreateRandomInRange(BigInteger.One, modulus.Subtract(BigInteger.One),
+                secureRandom);
+            BigInteger blind = r.ModPow(e, modulus);
+            BigInteger unblind = BigIntegers.ModOddInverse(modulus, r);
+
+            BigInteger blindedInput = blind.ModMultiply(input, modulus);
+            BigInteger blindedResult = RsaCrt(crtKey, blindedInput);
+            BigInteger offsetResult = unblind.Add(BigInteger.One).ModMultiply(blindedResult, modulus);
+
+            /*
+             * BigInteger conversion time is not constant, but is only done for blinded or public values.
+             */
+            byte[] blindedResultBytes = BigIntegers.AsUnsignedByteArray(resultSize, blindedResult);
+            byte[] modulusBytes = BigIntegers.AsUnsignedByteArray(resultSize, modulus);
+            byte[] resultBytes = BigIntegers.AsUnsignedByteArray(resultSize, offsetResult);
+
+            /*
+             * A final modular subtraction is done without timing dependencies on the final result. 
+             */
+            int carry = SubFrom(resultSize, blindedResultBytes, resultBytes);
+            CAddTo(resultSize, carry, modulusBytes, resultBytes);
+
+            return resultBytes;
+        }
+
+        private static BigInteger RsaCrt(RsaPrivateCrtKeyParameters crtKey, BigInteger input)
+        {
+            //
+            // we have the extra factors, use the Chinese Remainder Theorem - the author
+            // wishes to express his thanks to Dirk Bonekaemper at rtsffm.com for
+            // advice regarding the expression of this.
+            //
+            BigInteger e = crtKey.PublicExponent;
+            Debug.Assert(e != null);
+
+            BigInteger p = crtKey.P;
+            BigInteger q = crtKey.Q;
+            BigInteger dP = crtKey.DP;
+            BigInteger dQ = crtKey.DQ;
+            BigInteger qInv = crtKey.QInv;
+
+            // mP = ((input mod p) ^ dP)) mod p
+            BigInteger mP = input.Remainder(p).ModPow(dP, p);
+
+            // mQ = ((input mod q) ^ dQ)) mod q
+            BigInteger mQ = input.Remainder(q).ModPow(dQ, q);
+
+            // h = qInv * (mP - mQ) mod p
+            BigInteger h = mP.Subtract(mQ).ModMultiply(qInv, p);
+
+            // m = h * q + mQ
+            BigInteger m = h.Multiply(q).Add(mQ);
+
+            // defence against Arjen Lenstra’s CRT attack
+            BigInteger check = m.ModPow(e, crtKey.Modulus);
+            if (!check.Equals(input))
+                throw new InvalidOperationException("RSA engine faulty decryption/signing detected");
+
+            return m;
+        }
+
+        private static int SubFrom(int len, byte[] x, byte[] z)
+        {
+            int c = 0;
+            for (int i = len - 1; i >= 0; --i)
+            {
+                c += z[i] - x[i];
+                z[i] = (byte)c;
+                c >>= 8;
+            }
+            return c;
+        }
+    }
+}
diff --git a/crypto/src/math/ec/ECCurve.cs b/crypto/src/math/ec/ECCurve.cs
index 245ca1941..ae0d5d69e 100644
--- a/crypto/src/math/ec/ECCurve.cs
+++ b/crypto/src/math/ec/ECCurve.cs
@@ -607,6 +607,13 @@ namespace Org.BouncyCastle.Math.EC
         }
 #endif
 
+        internal static int ImplGetInteger(string envVariable, int defaultValue)
+        {
+            string property = Platform.GetEnvironmentVariable(envVariable);
+
+            return int.TryParse(property, out int value) ? value : defaultValue;
+        }
+
         private class DefaultLookupTable
             : AbstractECLookupTable
         {
@@ -757,13 +764,6 @@ namespace Org.BouncyCastle.Math.EC
                 throw new ArgumentException("Fp q value not prime");
         }
 
-        private static int ImplGetInteger(string envVariable, int defaultValue)
-        {
-            string property = Platform.GetEnvironmentVariable(envVariable);
-
-            return int.TryParse(property, out int value) ? value : defaultValue;
-        }
-
         private static int ImplGetIterations(int bits, int certainty)
         {
             /*
@@ -966,6 +966,10 @@ namespace Org.BouncyCastle.Math.EC
 
         private static IFiniteField BuildField(int m, int k1, int k2, int k3)
         {
+            int maxM = ImplGetInteger("Org.BouncyCastle.EC.F2m_MaxSize", 1142); // 2 * 571
+            if (m > maxM)
+                throw new ArgumentException("F2m m value out of range");
+
             int[] exponents = (k2 | k3) == 0
                 ? new int[]{ 0, k1, m }
                 : new int[]{ 0, k1, k2, k3, m };
diff --git a/crypto/src/math/ec/rfc8032/Scalar25519.cs b/crypto/src/math/ec/rfc8032/Scalar25519.cs
index 67eee6155..08ab80607 100644
--- a/crypto/src/math/ec/rfc8032/Scalar25519.cs
+++ b/crypto/src/math/ec/rfc8032/Scalar25519.cs
@@ -595,7 +595,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 #endif
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-        internal static void ReduceBasisVar(ReadOnlySpan<uint> k, Span<uint> z0, Span<uint> z1)
+        internal static bool ReduceBasisVar(ReadOnlySpan<uint> k, Span<uint> z0, Span<uint> z1)
         {
             /*
              * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
@@ -606,28 +606,34 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             Span<uint> Nu = stackalloc uint[16];    LSq.CopyTo(Nu);
             Span<uint> Nv = stackalloc uint[16];    Nat256.Square(k, Nv); ++Nv[0];
             Span<uint> p  = stackalloc uint[16];    Nat256.Mul(L, k, p);
+            Span<uint> t  = stackalloc uint[16];
             Span<uint> u0 = stackalloc uint[4];     u0.CopyFrom(L);
             Span<uint> u1 = stackalloc uint[4];
             Span<uint> v0 = stackalloc uint[4];     v0.CopyFrom(k);
             Span<uint> v1 = stackalloc uint[4];     v1[0] = 1U;
 
+            // Conservative upper bound on the number of loop iterations needed
+            int iterations = TargetLength * 4;
             int last = 15;
             int len_Nv = ScalarUtilities.GetBitLengthPositive(last, Nv);
 
             while (len_Nv > TargetLength)
             {
+                if (--iterations < 0)
+                    return false;
+
                 int len_p = ScalarUtilities.GetBitLength(last, p);
                 int s = len_p - len_Nv;
                 s &= ~(s >> 31);
 
                 if ((int)p[last] < 0)
                 {
-                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.AddShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
                 else
                 {
-                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.SubShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
 
@@ -645,9 +651,10 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             // v1 * k == v0 mod L
             v0.CopyTo(z0);
             v1.CopyTo(z1);
+            return true;
         }
 #else
-        internal static void ReduceBasisVar(uint[] k, uint[] z0, uint[] z1)
+        internal static bool ReduceBasisVar(uint[] k, uint[] z0, uint[] z1)
         {
             /*
              * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
@@ -658,28 +665,34 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             uint[] Nu = new uint[16];       Array.Copy(LSq, Nu, 16);
             uint[] Nv = new uint[16];       Nat256.Square(k, Nv); ++Nv[0];
             uint[] p  = new uint[16];       Nat256.Mul(L, k, p);
+            uint[] t  = new uint[16];
             uint[] u0 = new uint[4];        Array.Copy(L, u0, 4);
             uint[] u1 = new uint[4];
             uint[] v0 = new uint[4];        Array.Copy(k, v0, 4);
             uint[] v1 = new uint[4];        v1[0] = 1U;
 
+            // Conservative upper bound on the number of loop iterations needed
+            int iterations = TargetLength * 4;
             int last = 15;
             int len_Nv = ScalarUtilities.GetBitLengthPositive(last, Nv);
 
             while (len_Nv > TargetLength)
             {
+                if (--iterations < 0)
+                    return false;
+
                 int len_p = ScalarUtilities.GetBitLength(last, p);
                 int s = len_p - len_Nv;
                 s &= ~(s >> 31);
 
                 if ((int)p[last] < 0)
                 {
-                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.AddShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
                 else
                 {
-                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.SubShifted_UV(last: 3, s, u0, u1, v0, v1);
                 }
 
@@ -697,6 +710,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             // v1 * k == v0 mod L
             Array.Copy(v0, z0, 4);
             Array.Copy(v1, z1, 4);
+            return true;
         }
 #endif
 
diff --git a/crypto/src/math/ec/rfc8032/Scalar448.cs b/crypto/src/math/ec/rfc8032/Scalar448.cs
index 124b91250..c3f91eef2 100644
--- a/crypto/src/math/ec/rfc8032/Scalar448.cs
+++ b/crypto/src/math/ec/rfc8032/Scalar448.cs
@@ -1114,7 +1114,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 #endif
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-        internal static void ReduceBasisVar(ReadOnlySpan<uint> k, Span<uint> z0, Span<uint> z1)
+        internal static bool ReduceBasisVar(ReadOnlySpan<uint> k, Span<uint> z0, Span<uint> z1)
         {
             /*
              * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
@@ -1125,28 +1125,34 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             Span<uint> Nu = stackalloc uint[28];    LSq.CopyTo(Nu);
             Span<uint> Nv = stackalloc uint[28];    Nat448.Square(k, Nv); ++Nv[0];
             Span<uint> p  = stackalloc uint[28];    Nat448.Mul(L, k, p);
+            Span<uint> t  = stackalloc uint[28];
             Span<uint> u0 = stackalloc uint[8];     u0.CopyFrom(L);
             Span<uint> u1 = stackalloc uint[8];
             Span<uint> v0 = stackalloc uint[8];     v0.CopyFrom(k);
             Span<uint> v1 = stackalloc uint[8];     v1[0] = 1U;
 
+            // Conservative upper bound on the number of loop iterations needed
+            int iterations = TargetLength * 4;
             int last = 27;
             int len_Nv = ScalarUtilities.GetBitLengthPositive(last, Nv);
 
             while (len_Nv > TargetLength)
             {
+                if (--iterations < 0)
+                    return false;
+
                 int len_p = ScalarUtilities.GetBitLength(last, p);
                 int s = len_p - len_Nv;
                 s &= ~(s >> 31);
 
                 if ((int)p[last] < 0)
                 {
-                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.AddShifted_UV(last: 7, s, u0, u1, v0, v1);
                 }
                 else
                 {
-                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.SubShifted_UV(last: 7, s, u0, u1, v0, v1);
                 }
 
@@ -1167,9 +1173,10 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             // v1 * k == v0 mod L
             v0.CopyTo(z0);
             v1.CopyTo(z1);
+            return true;
         }
 #else
-        internal static void ReduceBasisVar(uint[] k, uint[] z0, uint[] z1)
+        internal static bool ReduceBasisVar(uint[] k, uint[] z0, uint[] z1)
         {
             /*
              * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L.
@@ -1180,28 +1187,34 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             uint[] Nu = new uint[28];       Array.Copy(LSq, Nu, 28);
             uint[] Nv = new uint[28];       Nat448.Square(k, Nv); ++Nv[0];
             uint[] p  = new uint[28];       Nat448.Mul(L, k, p);
+            uint[] t  = new uint[28];
             uint[] u0 = new uint[8];        Array.Copy(L, u0, 8);
             uint[] u1 = new uint[8];
             uint[] v0 = new uint[8];        Array.Copy(k, v0, 8);
             uint[] v1 = new uint[8];        v1[0] = 1U;
 
+            // Conservative upper bound on the number of loop iterations needed
+            int iterations = TargetLength * 4;
             int last = 27;
             int len_Nv = ScalarUtilities.GetBitLengthPositive(last, Nv);
 
             while (len_Nv > TargetLength)
             {
+                if (--iterations < 0)
+                    return false;
+
                 int len_p = ScalarUtilities.GetBitLength(last, p);
                 int s = len_p - len_Nv;
                 s &= ~(s >> 31);
 
                 if ((int)p[last] < 0)
                 {
-                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.AddShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.AddShifted_UV(last: 7, s, u0, u1, v0, v1);
                 }
                 else
                 {
-                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p);
+                    ScalarUtilities.SubShifted_NP(last, s, Nu, Nv, p, t);
                     ScalarUtilities.SubShifted_UV(last: 7, s, u0, u1, v0, v1);
                 }
 
@@ -1222,6 +1235,7 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
             // v1 * k == v0 mod L
             Array.Copy(v0, z0, 8);
             Array.Copy(v1, z1, 8);
+            return true;
         }
 #endif
 
diff --git a/crypto/src/math/ec/rfc8032/ScalarUtilities.cs b/crypto/src/math/ec/rfc8032/ScalarUtilities.cs
index c70a4f2e8..41d7f2696 100644
--- a/crypto/src/math/ec/rfc8032/ScalarUtilities.cs
+++ b/crypto/src/math/ec/rfc8032/ScalarUtilities.cs
@@ -12,62 +12,120 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
     {
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        internal static void AddShifted_NP(int last, int s, Span<uint> Nu, ReadOnlySpan<uint> Nv, Span<uint> _p)
+        internal static void AddShifted_NP(int last, int s, Span<uint> Nu, ReadOnlySpan<uint> Nv, Span<uint> p, Span<uint> t)
 #else
-        internal static void AddShifted_NP(int last, int s, uint[] Nu, uint[] Nv, uint[] _p)
+        internal static void AddShifted_NP(int last, int s, uint[] Nu, uint[] Nv, uint[] p, uint[] t)
 #endif
         {
-            int sWords = s >> 5, sBits = s & 31;
-
-            ulong cc__p = 0UL;
+            ulong cc_p = 0UL;
             ulong cc_Nu = 0UL;
 
-            if (sBits == 0)
+            if (s == 0)
             {
-                for (int i = sWords; i <= last; ++i)
+                for (int i = 0; i <= last; ++i)
                 {
+                    uint p_i = p[i];
+
                     cc_Nu += Nu[i];
-                    cc_Nu += _p[i - sWords];
+                    cc_Nu += p_i;
 
-                    cc__p += _p[i];
-                    cc__p += Nv[i - sWords];
-                    _p[i]  = (uint)cc__p; cc__p >>= 32;
+                    cc_p += p_i;
+                    cc_p += Nv[i];
+                    p_i   = (uint)cc_p; cc_p >>= 32;
+                    p[i]  = p_i;
 
-                    cc_Nu += _p[i - sWords];
+                    cc_Nu += p_i;
                     Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
                 }
             }
-            else
+            else if (s < 32)
             {
                 uint prev_p = 0U;
                 uint prev_q = 0U;
                 uint prev_v = 0U;
 
-                for (int i = sWords; i <= last; ++i)
+                for (int i = 0; i <= last; ++i)
                 {
-                    uint next_p = _p[i - sWords];
-                    uint p_s = (next_p << sBits) | (prev_p >> -sBits);
-                    prev_p = next_p;
+                    uint p_i = p[i];
+                    uint p_s = (p_i << s) | (prev_p >> -s);
+                    prev_p = p_i;
 
                     cc_Nu += Nu[i];
                     cc_Nu += p_s;
 
-                    uint next_v = Nv[i - sWords];
-                    uint v_s = (next_v << sBits) | (prev_v >> -sBits);
+                    uint next_v = Nv[i];
+                    uint v_s = (next_v << s) | (prev_v >> -s);
                     prev_v = next_v;
 
-                    cc__p += _p[i];
-                    cc__p += v_s;
-                    _p[i]  = (uint)cc__p; cc__p >>= 32;
+                    cc_p += p_i;
+                    cc_p += v_s;
+                    p_i   = (uint)cc_p; cc_p >>= 32;
+                    p[i]  = p_i;
 
-                    uint next_q = _p[i - sWords];
-                    uint q_s = (next_q << sBits) | (prev_q >> -sBits);
-                    prev_q = next_q;
+                    uint q_s = (p_i << s) | (prev_q >> -s);
+                    prev_q = p_i;
 
                     cc_Nu += q_s;
                     Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
                 }
             }
+            else
+            {
+                // Copy the low limbs of the original p
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                t[..last].CopyFrom(p);
+#else
+                Array.Copy(p, 0, t, 0, last);
+#endif
+
+                int sWords = s >> 5, sBits = s & 31;
+                if (sBits == 0)
+                {
+                    for (int i = sWords; i <= last; ++i)
+                    {
+                        cc_Nu += Nu[i];
+                        cc_Nu += t[i - sWords];
+
+                        cc_p += p[i];
+                        cc_p += Nv[i - sWords];
+                        p[i]  = (uint)cc_p; cc_p >>= 32;
+
+                        cc_Nu += p[i - sWords];
+                        Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
+                    }
+                }
+                else
+                {
+                    uint prev_t = 0U;
+                    uint prev_q = 0U;
+                    uint prev_v = 0U;
+
+                    for (int i = sWords; i <= last; ++i)
+                    {
+                        uint next_t = t[i - sWords];
+                        uint t_s = (next_t << sBits) | (prev_t >> -sBits);
+                        prev_t = next_t;
+
+                        cc_Nu += Nu[i];
+                        cc_Nu += t_s;
+
+                        uint next_v = Nv[i - sWords];
+                        uint v_s = (next_v << sBits) | (prev_v >> -sBits);
+                        prev_v = next_v;
+
+                        cc_p += p[i];
+                        cc_p += v_s;
+                        p[i]  = (uint)cc_p; cc_p >>= 32;
+
+                        uint next_q = p[i - sWords];
+                        uint q_s = (next_q << sBits) | (prev_q >> -sBits);
+                        prev_q = next_q;
+
+                        cc_Nu += q_s;
+                        Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
+                    }
+                }
+            }
         }
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
@@ -171,62 +229,120 @@ namespace Org.BouncyCastle.Math.EC.Rfc8032
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        internal static void SubShifted_NP(int last, int s, Span<uint> Nu, ReadOnlySpan<uint> Nv, Span<uint> _p)
+        internal static void SubShifted_NP(int last, int s, Span<uint> Nu, ReadOnlySpan<uint> Nv, Span<uint> p, Span<uint> t)
 #else
-        internal static void SubShifted_NP(int last, int s, uint[] Nu, uint[] Nv, uint[] _p)
+        internal static void SubShifted_NP(int last, int s, uint[] Nu, uint[] Nv, uint[] p, uint[] t)
 #endif
         {
-            int sWords = s >> 5, sBits = s & 31;
-
-            long cc__p = 0L;
+            long cc_p = 0L;
             long cc_Nu = 0L;
 
-            if (sBits == 0)
+            if (s == 0)
             {
-                for (int i = sWords; i <= last; ++i)
+                for (int i = 0; i <= last; ++i)
                 {
+                    uint p_i = p[i];
+
                     cc_Nu += Nu[i];
-                    cc_Nu -= _p[i - sWords];
+                    cc_Nu -= p_i;
 
-                    cc__p += _p[i];
-                    cc__p -= Nv[i - sWords];
-                    _p[i]  = (uint)cc__p; cc__p >>= 32;
+                    cc_p += p_i;
+                    cc_p -= Nv[i];
+                    p_i   = (uint)cc_p; cc_p >>= 32;
+                    p[i]  = p_i;
 
-                    cc_Nu -= _p[i - sWords];
+                    cc_Nu -= p_i;
                     Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
                 }
             }
-            else
+            else if (s < 32)
             {
                 uint prev_p = 0U;
                 uint prev_q = 0U;
                 uint prev_v = 0U;
 
-                for (int i = sWords; i <= last; ++i)
+                for (int i = 0; i <= last; ++i)
                 {
-                    uint next_p = _p[i - sWords];
-                    uint p_s = (next_p << sBits) | (prev_p >> -sBits);
-                    prev_p = next_p;
+                    uint p_i = p[i];
+                    uint p_s = (p_i << s) | (prev_p >> -s);
+                    prev_p = p_i;
 
                     cc_Nu += Nu[i];
                     cc_Nu -= p_s;
 
-                    uint next_v = Nv[i - sWords];
-                    uint v_s = (next_v << sBits) | (prev_v >> -sBits);
+                    uint next_v = Nv[i];
+                    uint v_s = (next_v << s) | (prev_v >> -s);
                     prev_v = next_v;
 
-                    cc__p += _p[i];
-                    cc__p -= v_s;
-                    _p[i]  = (uint)cc__p; cc__p >>= 32;
+                    cc_p += p_i;
+                    cc_p -= v_s;
+                    p_i   = (uint)cc_p; cc_p >>= 32;
+                    p[i]  = p_i;
 
-                    uint next_q = _p[i - sWords];
-                    uint q_s = (next_q << sBits) | (prev_q >> -sBits);
-                    prev_q = next_q;
+                    uint q_s = (p_i << s) | (prev_q >> -s);
+                    prev_q = p_i;
 
                     cc_Nu -= q_s;
                     Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
                 }
             }
+            else
+            {
+                // Copy the low limbs of the original p
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                t[..last].CopyFrom(p);
+#else
+                Array.Copy(p, 0, t, 0, last);
+#endif
+
+                int sWords = s >> 5, sBits = s & 31;
+                if (sBits == 0)
+                {
+                    for (int i = sWords; i <= last; ++i)
+                    {
+                        cc_Nu += Nu[i];
+                        cc_Nu -= t[i - sWords];
+
+                        cc_p += p[i];
+                        cc_p -= Nv[i - sWords];
+                        p[i]  = (uint)cc_p; cc_p >>= 32;
+
+                        cc_Nu -= p[i - sWords];
+                        Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
+                    }
+                }
+                else
+                {
+                    uint prev_t = 0U;
+                    uint prev_q = 0U;
+                    uint prev_v = 0U;
+
+                    for (int i = sWords; i <= last; ++i)
+                    {
+                        uint next_t = t[i - sWords];
+                        uint t_s = (next_t << sBits) | (prev_t >> -sBits);
+                        prev_t = next_t;
+
+                        cc_Nu += Nu[i];
+                        cc_Nu -= t_s;
+
+                        uint next_v = Nv[i - sWords];
+                        uint v_s = (next_v << sBits) | (prev_v >> -sBits);
+                        prev_v = next_v;
+
+                        cc_p += p[i];
+                        cc_p -= v_s;
+                        p[i]  = (uint)cc_p; cc_p >>= 32;
+
+                        uint next_q = p[i - sWords];
+                        uint q_s = (next_q << sBits) | (prev_q >> -sBits);
+                        prev_q = next_q;
+
+                        cc_Nu -= q_s;
+                        Nu[i]  = (uint)cc_Nu; cc_Nu >>= 32;
+                    }
+                }
+            }
         }
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
diff --git a/crypto/src/tls/crypto/impl/bc/BcDefaultTlsCredentialedDecryptor.cs b/crypto/src/tls/crypto/impl/bc/BcDefaultTlsCredentialedDecryptor.cs
index bbe9af4e6..31f339c92 100644
--- a/crypto/src/tls/crypto/impl/bc/BcDefaultTlsCredentialedDecryptor.cs
+++ b/crypto/src/tls/crypto/impl/bc/BcDefaultTlsCredentialedDecryptor.cs
@@ -1,11 +1,7 @@
 using System;
 
 using Org.BouncyCastle.Crypto;
-using Org.BouncyCastle.Crypto.Encodings;
-using Org.BouncyCastle.Crypto.Engines;
 using Org.BouncyCastle.Crypto.Parameters;
-using Org.BouncyCastle.Security;
-using Org.BouncyCastle.Utilities;
 
 namespace Org.BouncyCastle.Tls.Crypto.Impl.BC
 {
@@ -40,100 +36,33 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl.BC
                 throw new ArgumentException("'privateKey' type not supported: " + privateKey.GetType().FullName);
             }
 
-            this.m_crypto = crypto;
-            this.m_certificate = certificate;
-            this.m_privateKey = privateKey;
+            m_crypto = crypto;
+            m_certificate = certificate;
+            m_privateKey = privateKey;
         }
 
-        public virtual Certificate Certificate
-        {
-            get { return m_certificate; }
-        }
+        public virtual Certificate Certificate => m_certificate;
 
         public virtual TlsSecret Decrypt(TlsCryptoParameters cryptoParams, byte[] ciphertext)
         {
-            // TODO Keep only the decryption itself here - move error handling outside 
             return SafeDecryptPreMasterSecret(cryptoParams, (RsaKeyParameters)m_privateKey, ciphertext);
         }
 
         /*
-         * TODO[tls-ops] Probably need to make RSA encryption/decryption into TlsCrypto functions so
-         * that users can implement "generic" encryption credentials externally
+         * TODO[tls-ops] Probably need to make RSA encryption/decryption into TlsCrypto functions so that users can
+         * implement "generic" encryption credentials externally
          */
+        // TODO[api] Just inline this into Decrypt
         protected virtual TlsSecret SafeDecryptPreMasterSecret(TlsCryptoParameters cryptoParams,
             RsaKeyParameters rsaServerPrivateKey, byte[] encryptedPreMasterSecret)
         {
-            SecureRandom secureRandom = m_crypto.SecureRandom;
-
-            /*
-             * RFC 5246 7.4.7.1.
-             */
             ProtocolVersion expectedVersion = cryptoParams.RsaPreMasterSecretVersion;
 
-            // TODO Provide as configuration option?
-            bool versionNumberCheckDisabled = false;
-
-            /*
-             * Generate 48 random bytes we can use as a Pre-Master-Secret, if the
-             * PKCS1 padding check should fail.
-             */
-            byte[] fallback = new byte[48];
-            secureRandom.NextBytes(fallback);
-
-            byte[] M = Arrays.Clone(fallback);
-            try
-            {
-                Pkcs1Encoding encoding = new Pkcs1Encoding(new RsaBlindedEngine(), fallback);
-                encoding.Init(false, new ParametersWithRandom(rsaServerPrivateKey, secureRandom));
-
-                M = encoding.ProcessBlock(encryptedPreMasterSecret, 0, encryptedPreMasterSecret.Length);
-            }
-            catch (Exception)
-            {
-                /*
-                 * This should never happen since the decryption should never throw an exception
-                 * and return a random value instead.
-                 *
-                 * In any case, a TLS server MUST NOT generate an alert if processing an
-                 * RSA-encrypted premaster secret message fails, or the version number is not as
-                 * expected. Instead, it MUST continue the handshake with a randomly generated
-                 * premaster secret.
-                 */
-            }
-
-            /*
-             * If ClientHello.legacy_version is TLS 1.1 or higher, server implementations MUST check the
-             * version number [..].
-             */
-            if (versionNumberCheckDisabled && !TlsImplUtilities.IsTlsV11(expectedVersion))
-            {
-                /*
-                 * If the version number is TLS 1.0 or earlier, server implementations SHOULD check the
-                 * version number, but MAY have a configuration option to disable the check.
-                 */
-            }
-            else
-            {
-                /*
-                 * Compare the version number in the decrypted Pre-Master-Secret with the legacy_version
-                 * field from the ClientHello. If they don't match, continue the handshake with the
-                 * randomly generated 'fallback' value.
-                 *
-                 * NOTE: The comparison and replacement must be constant-time.
-                 */
-                int mask = (expectedVersion.MajorVersion ^ (M[0] & 0xFF))
-                         | (expectedVersion.MinorVersion ^ (M[1] & 0xFF));
-
-                // 'mask' will be all 1s if the versions matched, or else all 0s.
-                mask = (mask - 1) >> 31;
-
-                for (int i = 0; i < 48; i++)
-                {
-                    M[i] = (byte)((M[i] & mask) | (fallback[i] & ~mask));
-                }
-            }
+            byte[] preMasterSecret = Org.BouncyCastle.Crypto.Tls.TlsRsaKeyExchange.DecryptPreMasterSecret(
+                encryptedPreMasterSecret, 0, encryptedPreMasterSecret.Length, rsaServerPrivateKey,
+                expectedVersion.FullVersion, m_crypto.SecureRandom);
 
-            return m_crypto.CreateSecret(M);
+            return m_crypto.CreateSecret(preMasterSecret);
         }
     }
 }
diff --git a/crypto/src/util/io/LimitedBuffer.cs b/crypto/src/util/io/LimitedBuffer.cs
index 07c9969ad..c99c49c25 100644
--- a/crypto/src/util/io/LimitedBuffer.cs
+++ b/crypto/src/util/io/LimitedBuffer.cs
@@ -47,6 +47,7 @@ namespace Org.BouncyCastle.Utilities.IO
         public override void Write(ReadOnlySpan<byte> buffer)
         {
             buffer.CopyTo(m_buf.AsSpan(m_count));
+            m_count += buffer.Length;
         }
 #endif