summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-03-24 13:37:24 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-04-13 17:16:20 +0700
commitf270f202a171903881ecd348834147a7dcd63884 (patch)
tree263a25e687e4163d47246baf92844295dc6619dc
parentRFC 9146: DtlsEpoch tracks record header lengths (diff)
downloadBouncyCastle.NET-ed25519-f270f202a171903881ecd348834147a7dcd63884.tar.xz
RFC 9146: TlsAeadCipher support for connection ID
-rw-r--r--crypto/src/tls/crypto/impl/TlsAeadCipher.cs105
-rw-r--r--crypto/src/tls/crypto/impl/TlsBlockCipher.cs18
2 files changed, 85 insertions, 38 deletions
diff --git a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
index f3ff44285..f5ebd7eba 100644
--- a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
+++ b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
@@ -1,6 +1,8 @@
 using System;
 using System.IO;
 
+using Org.BouncyCastle.Utilities;
+
 namespace Org.BouncyCastle.Tls.Crypto.Impl
 {
     /// <summary>A generic TLS 1.2 AEAD cipher.</summary>
@@ -13,6 +15,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
         private const int NONCE_RFC5288 = 1;
         private const int NONCE_RFC7905 = 2;
+        private const long SequenceNumberPlaceholder = -1L;
 
         protected readonly TlsCryptoParameters m_cryptoParams;
         protected readonly int m_keySize;
@@ -22,6 +25,8 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
         protected readonly TlsAeadCipherImpl m_decryptCipher, m_encryptCipher;
         protected readonly byte[] m_decryptNonce, m_encryptNonce;
+        protected readonly byte[] m_decryptConnectionID, m_encryptConnectionID;
+        protected readonly bool m_decryptUseInnerPlaintext, m_encryptUseInnerPlaintext;
 
         protected readonly bool m_isTlsV13;
         protected readonly int m_nonceMode;
@@ -39,6 +44,12 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             this.m_isTlsV13 = TlsImplUtilities.IsTlsV13(negotiatedVersion);
             this.m_nonceMode = GetNonceMode(m_isTlsV13, aeadType);
 
+            m_decryptConnectionID = securityParameters.ConnectionIDPeer;
+            m_encryptConnectionID = securityParameters.ConnectionIDLocal;
+
+            m_decryptUseInnerPlaintext = m_isTlsV13 || !Arrays.IsNullOrEmpty(m_decryptConnectionID);
+            m_encryptUseInnerPlaintext = m_isTlsV13 || !Arrays.IsNullOrEmpty(m_encryptConnectionID);
+
             switch (m_nonceMode)
             {
             case NONCE_RFC5288:
@@ -136,7 +147,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
         public override int GetCiphertextDecodeLimit(int plaintextLimit)
         {
-            int innerPlaintextLimit = plaintextLimit + (m_isTlsV13 ? 1 : 0);
+            int innerPlaintextLimit = plaintextLimit + (m_decryptUseInnerPlaintext ? 1 : 0);
 
             return innerPlaintextLimit + m_macSize + m_record_iv_length;
         }
@@ -144,7 +155,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
         public override int GetCiphertextEncodeLimit(int plaintextLength, int plaintextLimit)
         {
             int innerPlaintextLimit = plaintextLength;
-            if (m_isTlsV13)
+            if (m_encryptUseInnerPlaintext)
             {
                 // TODO[tls13] Add support for padding
                 int maxPadding = 0;
@@ -155,11 +166,18 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             return innerPlaintextLimit + m_macSize + m_record_iv_length;
         }
 
-        public override int GetPlaintextLimit(int ciphertextLimit)
+        public override int GetPlaintextDecodeLimit(int ciphertextLimit)
+        {
+            int innerPlaintextLimit = ciphertextLimit - m_macSize - m_record_iv_length;
+
+            return innerPlaintextLimit - (m_decryptUseInnerPlaintext ? 1 : 0);
+        }
+
+        public override int GetPlaintextEncodeLimit(int ciphertextLimit)
         {
             int innerPlaintextLimit = ciphertextLimit - m_macSize - m_record_iv_length;
 
-            return  innerPlaintextLimit - (m_isTlsV13 ? 1 : 0);
+            return innerPlaintextLimit - (m_encryptUseInnerPlaintext ? 1 : 0);
         }
 
         public override TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion,
@@ -189,10 +207,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 throw new TlsFatalAlert(AlertDescription.internal_error);
             }
 
-            int extraLength = m_isTlsV13 ? 1 : 0;
+            // TODO[tls13, cid] If we support adding padding to (D)TLSInnerPlaintext, this will need review
+            int innerPlaintextLength = plaintextLength + (m_encryptUseInnerPlaintext ? 1 : 0);
 
-            // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review
-            int encryptionLength = m_encryptCipher.GetOutputSize(plaintextLength + extraLength);
+            int encryptionLength = m_encryptCipher.GetOutputSize(innerPlaintextLength);
             int ciphertextLength = m_record_iv_length + encryptionLength;
 
             byte[] output = new byte[headerAllocation + ciphertextLength];
@@ -204,22 +222,25 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 outputPos += m_record_iv_length;
             }
 
-            short recordType = m_isTlsV13 ? ContentType.application_data : contentType;
+            short recordType = contentType;
+            if (m_encryptUseInnerPlaintext)
+            {
+                recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid;
+            }
 
             byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength,
-                plaintextLength);
+                innerPlaintextLength, m_encryptConnectionID);
 
             try
             {
                 Array.Copy(plaintext, plaintextOffset, output, outputPos, plaintextLength);
-                if (m_isTlsV13)
+                if (m_encryptUseInnerPlaintext)
                 {
                     output[outputPos + plaintextLength] = (byte)contentType;
                 }
 
                 m_encryptCipher.Init(nonce, m_macSize, additionalData);
-                outputPos += m_encryptCipher.DoFinal(output, outputPos, plaintextLength + extraLength, output,
-                    outputPos);
+                outputPos += m_encryptCipher.DoFinal(output, outputPos, innerPlaintextLength, output, outputPos);
             }
             catch (IOException)
             {
@@ -264,10 +285,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 throw new TlsFatalAlert(AlertDescription.internal_error);
             }
 
-            int extraLength = m_isTlsV13 ? 1 : 0;
+            // TODO[tls13, cid] If we support adding padding to (D)TLSInnerPlaintext, this will need review
+            int innerPlaintextLength = plaintext.Length + (m_encryptUseInnerPlaintext ? 1 : 0);
 
-            // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review
-            int encryptionLength = m_encryptCipher.GetOutputSize(plaintext.Length + extraLength);
+            int encryptionLength = m_encryptCipher.GetOutputSize(innerPlaintextLength);
             int ciphertextLength = m_record_iv_length + encryptionLength;
 
             byte[] output = new byte[headerAllocation + ciphertextLength];
@@ -279,22 +300,25 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 outputPos += m_record_iv_length;
             }
 
-            short recordType = m_isTlsV13 ? ContentType.application_data : contentType;
+            short recordType = contentType;
+            if (m_encryptUseInnerPlaintext)
+            {
+                recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid;
+            }
 
             byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength,
-                plaintext.Length);
+                innerPlaintextLength, m_encryptConnectionID);
 
             try
             {
                 plaintext.CopyTo(output.AsSpan(outputPos));
-                if (m_isTlsV13)
+                if (m_encryptUseInnerPlaintext)
                 {
                     output[outputPos + plaintext.Length] = (byte)contentType;
                 }
 
                 m_encryptCipher.Init(nonce, m_macSize, additionalData);
-                outputPos += m_encryptCipher.DoFinal(output, outputPos, plaintext.Length + extraLength, output,
-                    outputPos);
+                outputPos += m_encryptCipher.DoFinal(output, outputPos, innerPlaintextLength, output, outputPos);
             }
             catch (IOException)
             {
@@ -318,7 +342,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
         public override TlsDecodeResult DecodeCiphertext(long seqNo, short recordType, ProtocolVersion recordVersion,
             byte[] ciphertext, int ciphertextOffset, int ciphertextLength)
         {
-            if (GetPlaintextLimit(ciphertextLength) < 0)
+            if (GetPlaintextDecodeLimit(ciphertextLength) < 0)
                 throw new TlsFatalAlert(AlertDescription.decode_error);
 
             byte[] nonce = new byte[m_decryptNonce.Length + m_record_iv_length];
@@ -343,10 +367,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
             int encryptionOffset = ciphertextOffset + m_record_iv_length;
             int encryptionLength = ciphertextLength - m_record_iv_length;
-            int plaintextLength = m_decryptCipher.GetOutputSize(encryptionLength);
+            int innerPlaintextLength = m_decryptCipher.GetOutputSize(encryptionLength);
 
             byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength,
-                plaintextLength);
+                innerPlaintextLength, m_decryptConnectionID);
 
             int outputPos;
             try
@@ -373,27 +397,27 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 throw new TlsFatalAlert(AlertDescription.bad_record_mac, e);
             }
 
-            if (outputPos != plaintextLength)
+            if (outputPos != innerPlaintextLength)
             {
                 // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction.
                 throw new TlsFatalAlert(AlertDescription.internal_error);
             }
 
             short contentType = recordType;
-            if (m_isTlsV13)
+            int plaintextLength = innerPlaintextLength;
+
+            if (m_decryptUseInnerPlaintext)
             {
                 // Strip padding and read true content type from TLSInnerPlaintext
-                int pos = plaintextLength;
                 for (;;)
                 {
-                    if (--pos < 0)
+                    if (--plaintextLength < 0)
                         throw new TlsFatalAlert(AlertDescription.unexpected_message);
 
-                    byte octet = ciphertext[encryptionOffset + pos];
+                    byte octet = ciphertext[encryptionOffset + plaintextLength];
                     if (0 != octet)
                     {
                         contentType = (short)(octet & 0xFF);
-                        plaintextLength = pos;
                         break;
                     }
                 }
@@ -445,6 +469,29 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             }
         }
 
+        protected virtual byte[] GetAdditionalData(long seqNo, short recordType, ProtocolVersion recordVersion,
+            int ciphertextLength, int plaintextLength, byte[] connectionID)
+        {
+            if (Arrays.IsNullOrEmpty(connectionID))
+                return GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);
+
+            /*
+             * seq_num_placeholder + tls12_cid + cid_length + tls12_cid + DTLSCiphertext.version + epoch
+             *     + sequence_number + cid + length_of_DTLSInnerPlaintext
+             */
+            int cidLength = connectionID.Length;
+            byte[] additional_data = new byte[23 + cidLength];
+            TlsUtilities.WriteUint64(SequenceNumberPlaceholder, additional_data, 0);
+            TlsUtilities.WriteUint8(ContentType.tls12_cid, additional_data, 8);
+            TlsUtilities.WriteUint8(cidLength, additional_data, 9);
+            TlsUtilities.WriteUint8(ContentType.tls12_cid, additional_data, 10);
+            TlsUtilities.WriteVersion(recordVersion, additional_data, 11);
+            TlsUtilities.WriteUint64(seqNo, additional_data, 13);
+            Array.Copy(connectionID, 0, additional_data, 21, cidLength);
+            TlsUtilities.WriteUint16(plaintextLength, additional_data, 21 + cidLength);
+            return additional_data;
+        }
+
         protected virtual void RekeyCipher(SecurityParameters securityParameters, TlsAeadCipherImpl cipher,
             byte[] nonce, bool serverSecret)
         {
diff --git a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
index 11eb75186..63d4826b3 100644
--- a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
+++ b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
@@ -443,43 +443,43 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 m_encryptThenMac ? 0 : macSize);
             bool badMac = (totalPad == 0);
 
-            int dec_output_length = blocks_length - totalPad;
+            int innerPlaintextLength = blocks_length - totalPad;
 
             if (!m_encryptThenMac)
             {
-                dec_output_length -= macSize;
+                innerPlaintextLength -= macSize;
 
                 byte[] expectedMac = m_readMac.CalculateMacConstantTime(seqNo, recordType, m_decryptConnectionID,
-                    ciphertext, offset, dec_output_length, blocks_length - macSize, m_randomData);
+                    ciphertext, offset, innerPlaintextLength, blocks_length - macSize, m_randomData);
 
                 badMac |= !TlsUtilities.ConstantTimeAreEqual(macSize, expectedMac, 0, ciphertext,
-                    offset + dec_output_length);
+                    offset + innerPlaintextLength);
             }
 
             if (badMac)
                 throw new TlsFatalAlert(AlertDescription.bad_record_mac);
 
             short contentType = recordType;
+            int plaintextLength = innerPlaintextLength;
+
             if (m_decryptUseInnerPlaintext)
             {
                 // Strip padding and read true content type from DTLSInnerPlaintext
-                int pos = dec_output_length;
                 for (;;)
                 {
-                    if (--pos < 0)
+                    if (--plaintextLength < 0)
                         throw new TlsFatalAlert(AlertDescription.unexpected_message);
 
-                    byte octet = ciphertext[offset + pos];
+                    byte octet = ciphertext[offset + plaintextLength];
                     if (0 != octet)
                     {
                         contentType = (short)(octet & 0xFF);
-                        dec_output_length = pos;
                         break;
                     }
                 }
             }
 
-            return new TlsDecodeResult(ciphertext, offset, dec_output_length, contentType);
+            return new TlsDecodeResult(ciphertext, offset, plaintextLength, contentType);
         }
 
         public override bool UsesOpaqueRecordType