summary refs log tree commit diff
path: root/crypto
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2017-06-08 19:00:24 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2017-06-08 19:00:24 +0700
commit08629c3971a95d19e431c5cdb38cccbb8e5f79c0 (patch)
tree9a6fee9da4f67b1ba7a2d36b78656af23ea3112b /crypto
parentAdd latest extension type values from IANA registry (diff)
downloadBouncyCastle.NET-ed25519-08629c3971a95d19e431c5cdb38cccbb8e5f79c0.tar.xz
Add explicit limit for sequence numbers
Diffstat (limited to 'crypto')
-rw-r--r--crypto/src/crypto/tls/RecordStream.cs36
1 files changed, 30 insertions, 6 deletions
diff --git a/crypto/src/crypto/tls/RecordStream.cs b/crypto/src/crypto/tls/RecordStream.cs
index 46673cf7e..5d556ad06 100644
--- a/crypto/src/crypto/tls/RecordStream.cs
+++ b/crypto/src/crypto/tls/RecordStream.cs
@@ -21,7 +21,7 @@ namespace Org.BouncyCastle.Crypto.Tls
         private Stream mOutput;
         private TlsCompression mPendingCompression = null, mReadCompression = null, mWriteCompression = null;
         private TlsCipher mPendingCipher = null, mReadCipher = null, mWriteCipher = null;
-        private long mReadSeqNo = 0, mWriteSeqNo = 0;
+        private SequenceNumber mReadSeqNo = new SequenceNumber(), mWriteSeqNo = new SequenceNumber();
         private MemoryStream mBuffer = new MemoryStream();
 
         private TlsHandshakeHash mHandshakeHash = null;
@@ -100,7 +100,7 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             this.mWriteCompression = this.mPendingCompression;
             this.mWriteCipher = this.mPendingCipher;
-            this.mWriteSeqNo = 0;
+            this.mWriteSeqNo = new SequenceNumber();
         }
 
         internal virtual void ReceivedReadCipherSpec()
@@ -110,7 +110,7 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             this.mReadCompression = this.mPendingCompression;
             this.mReadCipher = this.mPendingCipher;
-            this.mReadSeqNo = 0;
+            this.mReadSeqNo = new SequenceNumber();
         }
 
         internal virtual void FinaliseHandshake()
@@ -203,7 +203,9 @@ namespace Org.BouncyCastle.Crypto.Tls
         internal virtual byte[] DecodeAndVerify(byte type, Stream input, int len)
         {
             byte[] buf = TlsUtilities.ReadFully(len, input);
-            byte[] decoded = mReadCipher.DecodeCiphertext(mReadSeqNo++, type, buf, 0, buf.Length);
+
+            long seqNo = mReadSeqNo.NextValue(AlertDescription.unexpected_message);
+            byte[] decoded = mReadCipher.DecodeCiphertext(seqNo, type, buf, 0, buf.Length);
 
             CheckLength(decoded.Length, mCompressedLimit, AlertDescription.record_overflow);
 
@@ -262,10 +264,12 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             Stream cOut = mWriteCompression.Compress(mBuffer);
 
+            long seqNo = mWriteSeqNo.NextValue(AlertDescription.internal_error);
+
             byte[] ciphertext;
             if (cOut == mBuffer)
             {
-                ciphertext = mWriteCipher.EncodePlaintext(mWriteSeqNo++, type, plaintext, plaintextOffset, plaintextLength);
+                ciphertext = mWriteCipher.EncodePlaintext(seqNo, type, plaintext, plaintextOffset, plaintextLength);
             }
             else
             {
@@ -279,7 +283,7 @@ namespace Org.BouncyCastle.Crypto.Tls
                  */
                 CheckLength(compressed.Length, plaintextLength + 1024, AlertDescription.internal_error);
 
-                ciphertext = mWriteCipher.EncodePlaintext(mWriteSeqNo++, type, compressed, 0, compressed.Length);
+                ciphertext = mWriteCipher.EncodePlaintext(seqNo, type, compressed, 0, compressed.Length);
             }
 
             /*
@@ -384,5 +388,25 @@ namespace Org.BouncyCastle.Crypto.Tls
                 mOuter.mHandshakeHash.BlockUpdate(buf, off, len);
             }
         }
+
+        private class SequenceNumber
+        {
+            private long value = 0L;
+            private bool exhausted = false;
+
+            internal long NextValue(byte alertDescription)
+            {
+                if (exhausted)
+                {
+                    throw new TlsFatalAlert(alertDescription);
+                }
+                long result = value;
+                if (++value == 0)
+                {
+                    exhausted = true;
+                }
+                return result;
+            }
+        }
     }
 }