summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2017-03-22 23:09:57 +1030
committerPeter Dettman <peter.dettman@bouncycastle.org>2017-03-22 23:09:57 +1030
commitd53ac55ba52d6913e093377e64ec9cc7b0bdbe37 (patch)
tree418f61a731742225bce265d0159da65cb1ae00ec
parentUse new TlsNoCloseNotifyException instead of generic EndOfStreamException (diff)
downloadBouncyCastle.NET-ed25519-d53ac55ba52d6913e093377e64ec9cc7b0bdbe37.tar.xz
Non-blocking TLS validates header of partially-received records
- https://github.com/bcgit/bc-java/issues/133
-rw-r--r--crypto/src/crypto/tls/RecordStream.cs39
-rw-r--r--crypto/src/crypto/tls/TlsProtocol.cs25
2 files changed, 59 insertions, 5 deletions
diff --git a/crypto/src/crypto/tls/RecordStream.cs b/crypto/src/crypto/tls/RecordStream.cs
index d510ed94e..b1060fd6d 100644
--- a/crypto/src/crypto/tls/RecordStream.cs
+++ b/crypto/src/crypto/tls/RecordStream.cs
@@ -121,6 +121,40 @@ namespace Org.BouncyCastle.Crypto.Tls
             this.mPendingCipher = null;
         }
 
+        internal virtual void CheckRecordHeader(byte[] recordHeader)
+        {
+            byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET);
+
+            /*
+             * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an
+             * unexpected_message alert.
+             */
+            CheckType(type, AlertDescription.unexpected_message);
+
+            if (!mRestrictReadVersion)
+            {
+                int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET);
+                if ((version & 0xffffff00) != 0x0300)
+                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+            }
+            else
+            {
+                ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET);
+                if (mReadVersion == null)
+                {
+                    // Will be set later in 'readRecord'
+                }
+                else if (!version.Equals(mReadVersion))
+                {
+                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
+                }
+            }
+
+            int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);
+
+            CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow);
+        }
+
         internal virtual bool ReadRecord()
         {
             byte[] recordHeader = TlsUtilities.ReadAllOrNothing(TLS_HEADER_SIZE, mInput);
@@ -155,6 +189,9 @@ namespace Org.BouncyCastle.Crypto.Tls
             }
 
             int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);
+
+            CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow);
+
             byte[] plaintext = DecodeAndVerify(type, mInput, length);
             mHandler.ProcessRecord(type, plaintext, 0, plaintext.Length);
             return true;
@@ -162,8 +199,6 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         internal virtual byte[] DecodeAndVerify(byte type, Stream input, int len)
         {
-            CheckLength(len, mCiphertextLimit, AlertDescription.record_overflow);
-
             byte[] buf = TlsUtilities.ReadFully(len, input);
             byte[] decoded = mReadCipher.DecodeCiphertext(mReadSeqNo++, type, buf, 0, buf.Length);
 
diff --git a/crypto/src/crypto/tls/TlsProtocol.cs b/crypto/src/crypto/tls/TlsProtocol.cs
index 98c6399d3..afdaf0075 100644
--- a/crypto/src/crypto/tls/TlsProtocol.cs
+++ b/crypto/src/crypto/tls/TlsProtocol.cs
@@ -482,6 +482,24 @@ namespace Org.BouncyCastle.Crypto.Tls
             return len;
         }
 
+        protected virtual void SafeCheckRecordHeader(byte[] recordHeader)
+        {
+            try
+            {
+                mRecordStream.CheckRecordHeader(recordHeader);
+            }
+            catch (TlsFatalAlert e)
+            {
+                this.FailWithError(AlertLevel.fatal, e.AlertDescription, "Failed to read record", e);
+                throw e;
+            }
+            catch (Exception e)
+            {
+                this.FailWithError(AlertLevel.fatal, AlertDescription.internal_error, "Failed to read record", e);
+                throw e;
+            }
+        }
+
         protected virtual void SafeReadRecord()
         {
             try
@@ -660,13 +678,14 @@ namespace Org.BouncyCastle.Crypto.Tls
             // loop while there are enough bytes to read the length of the next record
             while (mInputBuffers.Available >= RecordStream.TLS_HEADER_SIZE)
             {
-                byte[] header = new byte[RecordStream.TLS_HEADER_SIZE];
-                mInputBuffers.Peek(header);
+                byte[] recordHeader = new byte[RecordStream.TLS_HEADER_SIZE];
+                mInputBuffers.Peek(recordHeader);
 
-                int totalLength = TlsUtilities.ReadUint16(header, RecordStream.TLS_HEADER_LENGTH_OFFSET) + RecordStream.TLS_HEADER_SIZE;
+                int totalLength = TlsUtilities.ReadUint16(recordHeader, RecordStream.TLS_HEADER_LENGTH_OFFSET) + RecordStream.TLS_HEADER_SIZE;
                 if (mInputBuffers.Available < totalLength)
                 {
                     // not enough bytes to read a whole record
+                    SafeCheckRecordHeader(recordHeader);
                     break;
                 }