diff options
Diffstat (limited to 'crypto')
-rw-r--r-- | crypto/src/crypto/tls/RecordStream.cs | 39 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsProtocol.cs | 25 |
2 files changed, 59 insertions, 5 deletions
diff --git a/crypto/src/crypto/tls/RecordStream.cs b/crypto/src/crypto/tls/RecordStream.cs index d510ed94e..b1060fd6d 100644 --- a/crypto/src/crypto/tls/RecordStream.cs +++ b/crypto/src/crypto/tls/RecordStream.cs @@ -121,6 +121,40 @@ namespace Org.BouncyCastle.Crypto.Tls this.mPendingCipher = null; } + internal virtual void CheckRecordHeader(byte[] recordHeader) + { + byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET); + + /* + * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an + * unexpected_message alert. + */ + CheckType(type, AlertDescription.unexpected_message); + + if (!mRestrictReadVersion) + { + int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET); + if ((version & 0xffffff00) != 0x0300) + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + } + else + { + ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET); + if (mReadVersion == null) + { + // Will be set later in 'readRecord' + } + else if (!version.Equals(mReadVersion)) + { + throw new TlsFatalAlert(AlertDescription.illegal_parameter); + } + } + + int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET); + + CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow); + } + internal virtual bool ReadRecord() { byte[] recordHeader = TlsUtilities.ReadAllOrNothing(TLS_HEADER_SIZE, mInput); @@ -155,6 +189,9 @@ namespace Org.BouncyCastle.Crypto.Tls } int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET); + + CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow); + byte[] plaintext = DecodeAndVerify(type, mInput, length); mHandler.ProcessRecord(type, plaintext, 0, plaintext.Length); return true; @@ -162,8 +199,6 @@ namespace Org.BouncyCastle.Crypto.Tls internal virtual byte[] DecodeAndVerify(byte type, Stream input, int len) { - CheckLength(len, mCiphertextLimit, AlertDescription.record_overflow); - byte[] buf = TlsUtilities.ReadFully(len, input); byte[] decoded = mReadCipher.DecodeCiphertext(mReadSeqNo++, type, buf, 0, buf.Length); diff --git a/crypto/src/crypto/tls/TlsProtocol.cs b/crypto/src/crypto/tls/TlsProtocol.cs index 98c6399d3..afdaf0075 100644 --- a/crypto/src/crypto/tls/TlsProtocol.cs +++ b/crypto/src/crypto/tls/TlsProtocol.cs @@ -482,6 +482,24 @@ namespace Org.BouncyCastle.Crypto.Tls return len; } + protected virtual void SafeCheckRecordHeader(byte[] recordHeader) + { + try + { + mRecordStream.CheckRecordHeader(recordHeader); + } + catch (TlsFatalAlert e) + { + this.FailWithError(AlertLevel.fatal, e.AlertDescription, "Failed to read record", e); + throw e; + } + catch (Exception e) + { + this.FailWithError(AlertLevel.fatal, AlertDescription.internal_error, "Failed to read record", e); + throw e; + } + } + protected virtual void SafeReadRecord() { try @@ -660,13 +678,14 @@ namespace Org.BouncyCastle.Crypto.Tls // loop while there are enough bytes to read the length of the next record while (mInputBuffers.Available >= RecordStream.TLS_HEADER_SIZE) { - byte[] header = new byte[RecordStream.TLS_HEADER_SIZE]; - mInputBuffers.Peek(header); + byte[] recordHeader = new byte[RecordStream.TLS_HEADER_SIZE]; + mInputBuffers.Peek(recordHeader); - int totalLength = TlsUtilities.ReadUint16(header, RecordStream.TLS_HEADER_LENGTH_OFFSET) + RecordStream.TLS_HEADER_SIZE; + int totalLength = TlsUtilities.ReadUint16(recordHeader, RecordStream.TLS_HEADER_LENGTH_OFFSET) + RecordStream.TLS_HEADER_SIZE; if (mInputBuffers.Available < totalLength) { // not enough bytes to read a whole record + SafeCheckRecordHeader(recordHeader); break; } |