diff --git a/crypto/src/crypto/tls/RecordStream.cs b/crypto/src/crypto/tls/RecordStream.cs
index d510ed94e..b1060fd6d 100644
--- a/crypto/src/crypto/tls/RecordStream.cs
+++ b/crypto/src/crypto/tls/RecordStream.cs
@@ -121,6 +121,40 @@ namespace Org.BouncyCastle.Crypto.Tls
this.mPendingCipher = null;
}
+ internal virtual void CheckRecordHeader(byte[] recordHeader)
+ {
+ byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET);
+
+ /*
+ * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an
+ * unexpected_message alert.
+ */
+ CheckType(type, AlertDescription.unexpected_message);
+
+ if (!mRestrictReadVersion)
+ {
+ int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET);
+ if ((version & 0xffffff00) != 0x0300)
+ throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+ }
+ else
+ {
+ ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET);
+ if (mReadVersion == null)
+ {
+ // Will be set later in 'readRecord'
+ }
+ else if (!version.Equals(mReadVersion))
+ {
+ throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+ }
+ }
+
+ int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);
+
+ CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow);
+ }
+
internal virtual bool ReadRecord()
{
byte[] recordHeader = TlsUtilities.ReadAllOrNothing(TLS_HEADER_SIZE, mInput);
@@ -155,6 +189,9 @@ namespace Org.BouncyCastle.Crypto.Tls
}
int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);
+
+ CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow);
+
byte[] plaintext = DecodeAndVerify(type, mInput, length);
mHandler.ProcessRecord(type, plaintext, 0, plaintext.Length);
return true;
@@ -162,8 +199,6 @@ namespace Org.BouncyCastle.Crypto.Tls
internal virtual byte[] DecodeAndVerify(byte type, Stream input, int len)
{
- CheckLength(len, mCiphertextLimit, AlertDescription.record_overflow);
-
byte[] buf = TlsUtilities.ReadFully(len, input);
byte[] decoded = mReadCipher.DecodeCiphertext(mReadSeqNo++, type, buf, 0, buf.Length);
diff --git a/crypto/src/crypto/tls/TlsProtocol.cs b/crypto/src/crypto/tls/TlsProtocol.cs
index 98c6399d3..afdaf0075 100644
--- a/crypto/src/crypto/tls/TlsProtocol.cs
+++ b/crypto/src/crypto/tls/TlsProtocol.cs
@@ -482,6 +482,24 @@ namespace Org.BouncyCastle.Crypto.Tls
return len;
}
+ protected virtual void SafeCheckRecordHeader(byte[] recordHeader)
+ {
+ try
+ {
+ mRecordStream.CheckRecordHeader(recordHeader);
+ }
+ catch (TlsFatalAlert e)
+ {
+ this.FailWithError(AlertLevel.fatal, e.AlertDescription, "Failed to read record", e);
+ throw e;
+ }
+ catch (Exception e)
+ {
+ this.FailWithError(AlertLevel.fatal, AlertDescription.internal_error, "Failed to read record", e);
+ throw e;
+ }
+ }
+
protected virtual void SafeReadRecord()
{
try
@@ -660,13 +678,14 @@ namespace Org.BouncyCastle.Crypto.Tls
// loop while there are enough bytes to read the length of the next record
while (mInputBuffers.Available >= RecordStream.TLS_HEADER_SIZE)
{
- byte[] header = new byte[RecordStream.TLS_HEADER_SIZE];
- mInputBuffers.Peek(header);
+ byte[] recordHeader = new byte[RecordStream.TLS_HEADER_SIZE];
+ mInputBuffers.Peek(recordHeader);
- int totalLength = TlsUtilities.ReadUint16(header, RecordStream.TLS_HEADER_LENGTH_OFFSET) + RecordStream.TLS_HEADER_SIZE;
+ int totalLength = TlsUtilities.ReadUint16(recordHeader, RecordStream.TLS_HEADER_LENGTH_OFFSET) + RecordStream.TLS_HEADER_SIZE;
if (mInputBuffers.Available < totalLength)
{
// not enough bytes to read a whole record
+ SafeCheckRecordHeader(recordHeader);
break;
}
|