From f270f202a171903881ecd348834147a7dcd63884 Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Fri, 24 Mar 2023 13:37:24 +0700 Subject: RFC 9146: TlsAeadCipher support for connection ID --- crypto/src/tls/crypto/impl/TlsAeadCipher.cs | 105 +++++++++++++++++++-------- crypto/src/tls/crypto/impl/TlsBlockCipher.cs | 18 ++--- 2 files changed, 85 insertions(+), 38 deletions(-) diff --git a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs index f3ff44285..f5ebd7eba 100644 --- a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs +++ b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs @@ -1,6 +1,8 @@ using System; using System.IO; +using Org.BouncyCastle.Utilities; + namespace Org.BouncyCastle.Tls.Crypto.Impl { /// A generic TLS 1.2 AEAD cipher. @@ -13,6 +15,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl private const int NONCE_RFC5288 = 1; private const int NONCE_RFC7905 = 2; + private const long SequenceNumberPlaceholder = -1L; protected readonly TlsCryptoParameters m_cryptoParams; protected readonly int m_keySize; @@ -22,6 +25,8 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl protected readonly TlsAeadCipherImpl m_decryptCipher, m_encryptCipher; protected readonly byte[] m_decryptNonce, m_encryptNonce; + protected readonly byte[] m_decryptConnectionID, m_encryptConnectionID; + protected readonly bool m_decryptUseInnerPlaintext, m_encryptUseInnerPlaintext; protected readonly bool m_isTlsV13; protected readonly int m_nonceMode; @@ -39,6 +44,12 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl this.m_isTlsV13 = TlsImplUtilities.IsTlsV13(negotiatedVersion); this.m_nonceMode = GetNonceMode(m_isTlsV13, aeadType); + m_decryptConnectionID = securityParameters.ConnectionIDPeer; + m_encryptConnectionID = securityParameters.ConnectionIDLocal; + + m_decryptUseInnerPlaintext = m_isTlsV13 || !Arrays.IsNullOrEmpty(m_decryptConnectionID); + m_encryptUseInnerPlaintext = m_isTlsV13 || !Arrays.IsNullOrEmpty(m_encryptConnectionID); + switch (m_nonceMode) { case NONCE_RFC5288: @@ -136,7 +147,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl public override int GetCiphertextDecodeLimit(int plaintextLimit) { - int innerPlaintextLimit = plaintextLimit + (m_isTlsV13 ? 1 : 0); + int innerPlaintextLimit = plaintextLimit + (m_decryptUseInnerPlaintext ? 1 : 0); return innerPlaintextLimit + m_macSize + m_record_iv_length; } @@ -144,7 +155,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl public override int GetCiphertextEncodeLimit(int plaintextLength, int plaintextLimit) { int innerPlaintextLimit = plaintextLength; - if (m_isTlsV13) + if (m_encryptUseInnerPlaintext) { // TODO[tls13] Add support for padding int maxPadding = 0; @@ -155,11 +166,18 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl return innerPlaintextLimit + m_macSize + m_record_iv_length; } - public override int GetPlaintextLimit(int ciphertextLimit) + public override int GetPlaintextDecodeLimit(int ciphertextLimit) + { + int innerPlaintextLimit = ciphertextLimit - m_macSize - m_record_iv_length; + + return innerPlaintextLimit - (m_decryptUseInnerPlaintext ? 1 : 0); + } + + public override int GetPlaintextEncodeLimit(int ciphertextLimit) { int innerPlaintextLimit = ciphertextLimit - m_macSize - m_record_iv_length; - return innerPlaintextLimit - (m_isTlsV13 ? 1 : 0); + return innerPlaintextLimit - (m_encryptUseInnerPlaintext ? 1 : 0); } public override TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, @@ -189,10 +207,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl throw new TlsFatalAlert(AlertDescription.internal_error); } - int extraLength = m_isTlsV13 ? 1 : 0; + // TODO[tls13, cid] If we support adding padding to (D)TLSInnerPlaintext, this will need review + int innerPlaintextLength = plaintextLength + (m_encryptUseInnerPlaintext ? 1 : 0); - // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review - int encryptionLength = m_encryptCipher.GetOutputSize(plaintextLength + extraLength); + int encryptionLength = m_encryptCipher.GetOutputSize(innerPlaintextLength); int ciphertextLength = m_record_iv_length + encryptionLength; byte[] output = new byte[headerAllocation + ciphertextLength]; @@ -204,22 +222,25 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl outputPos += m_record_iv_length; } - short recordType = m_isTlsV13 ? ContentType.application_data : contentType; + short recordType = contentType; + if (m_encryptUseInnerPlaintext) + { + recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid; + } byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, - plaintextLength); + innerPlaintextLength, m_encryptConnectionID); try { Array.Copy(plaintext, plaintextOffset, output, outputPos, plaintextLength); - if (m_isTlsV13) + if (m_encryptUseInnerPlaintext) { output[outputPos + plaintextLength] = (byte)contentType; } m_encryptCipher.Init(nonce, m_macSize, additionalData); - outputPos += m_encryptCipher.DoFinal(output, outputPos, plaintextLength + extraLength, output, - outputPos); + outputPos += m_encryptCipher.DoFinal(output, outputPos, innerPlaintextLength, output, outputPos); } catch (IOException) { @@ -264,10 +285,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl throw new TlsFatalAlert(AlertDescription.internal_error); } - int extraLength = m_isTlsV13 ? 1 : 0; + // TODO[tls13, cid] If we support adding padding to (D)TLSInnerPlaintext, this will need review + int innerPlaintextLength = plaintext.Length + (m_encryptUseInnerPlaintext ? 1 : 0); - // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review - int encryptionLength = m_encryptCipher.GetOutputSize(plaintext.Length + extraLength); + int encryptionLength = m_encryptCipher.GetOutputSize(innerPlaintextLength); int ciphertextLength = m_record_iv_length + encryptionLength; byte[] output = new byte[headerAllocation + ciphertextLength]; @@ -279,22 +300,25 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl outputPos += m_record_iv_length; } - short recordType = m_isTlsV13 ? ContentType.application_data : contentType; + short recordType = contentType; + if (m_encryptUseInnerPlaintext) + { + recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid; + } byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, - plaintext.Length); + innerPlaintextLength, m_encryptConnectionID); try { plaintext.CopyTo(output.AsSpan(outputPos)); - if (m_isTlsV13) + if (m_encryptUseInnerPlaintext) { output[outputPos + plaintext.Length] = (byte)contentType; } m_encryptCipher.Init(nonce, m_macSize, additionalData); - outputPos += m_encryptCipher.DoFinal(output, outputPos, plaintext.Length + extraLength, output, - outputPos); + outputPos += m_encryptCipher.DoFinal(output, outputPos, innerPlaintextLength, output, outputPos); } catch (IOException) { @@ -318,7 +342,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl public override TlsDecodeResult DecodeCiphertext(long seqNo, short recordType, ProtocolVersion recordVersion, byte[] ciphertext, int ciphertextOffset, int ciphertextLength) { - if (GetPlaintextLimit(ciphertextLength) < 0) + if (GetPlaintextDecodeLimit(ciphertextLength) < 0) throw new TlsFatalAlert(AlertDescription.decode_error); byte[] nonce = new byte[m_decryptNonce.Length + m_record_iv_length]; @@ -343,10 +367,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl int encryptionOffset = ciphertextOffset + m_record_iv_length; int encryptionLength = ciphertextLength - m_record_iv_length; - int plaintextLength = m_decryptCipher.GetOutputSize(encryptionLength); + int innerPlaintextLength = m_decryptCipher.GetOutputSize(encryptionLength); byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, - plaintextLength); + innerPlaintextLength, m_decryptConnectionID); int outputPos; try @@ -373,27 +397,27 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl throw new TlsFatalAlert(AlertDescription.bad_record_mac, e); } - if (outputPos != plaintextLength) + if (outputPos != innerPlaintextLength) { // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction. throw new TlsFatalAlert(AlertDescription.internal_error); } short contentType = recordType; - if (m_isTlsV13) + int plaintextLength = innerPlaintextLength; + + if (m_decryptUseInnerPlaintext) { // Strip padding and read true content type from TLSInnerPlaintext - int pos = plaintextLength; for (;;) { - if (--pos < 0) + if (--plaintextLength < 0) throw new TlsFatalAlert(AlertDescription.unexpected_message); - byte octet = ciphertext[encryptionOffset + pos]; + byte octet = ciphertext[encryptionOffset + plaintextLength]; if (0 != octet) { contentType = (short)(octet & 0xFF); - plaintextLength = pos; break; } } @@ -445,6 +469,29 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl } } + protected virtual byte[] GetAdditionalData(long seqNo, short recordType, ProtocolVersion recordVersion, + int ciphertextLength, int plaintextLength, byte[] connectionID) + { + if (Arrays.IsNullOrEmpty(connectionID)) + return GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength); + + /* + * seq_num_placeholder + tls12_cid + cid_length + tls12_cid + DTLSCiphertext.version + epoch + * + sequence_number + cid + length_of_DTLSInnerPlaintext + */ + int cidLength = connectionID.Length; + byte[] additional_data = new byte[23 + cidLength]; + TlsUtilities.WriteUint64(SequenceNumberPlaceholder, additional_data, 0); + TlsUtilities.WriteUint8(ContentType.tls12_cid, additional_data, 8); + TlsUtilities.WriteUint8(cidLength, additional_data, 9); + TlsUtilities.WriteUint8(ContentType.tls12_cid, additional_data, 10); + TlsUtilities.WriteVersion(recordVersion, additional_data, 11); + TlsUtilities.WriteUint64(seqNo, additional_data, 13); + Array.Copy(connectionID, 0, additional_data, 21, cidLength); + TlsUtilities.WriteUint16(plaintextLength, additional_data, 21 + cidLength); + return additional_data; + } + protected virtual void RekeyCipher(SecurityParameters securityParameters, TlsAeadCipherImpl cipher, byte[] nonce, bool serverSecret) { diff --git a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs index 11eb75186..63d4826b3 100644 --- a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs +++ b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs @@ -443,43 +443,43 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl m_encryptThenMac ? 0 : macSize); bool badMac = (totalPad == 0); - int dec_output_length = blocks_length - totalPad; + int innerPlaintextLength = blocks_length - totalPad; if (!m_encryptThenMac) { - dec_output_length -= macSize; + innerPlaintextLength -= macSize; byte[] expectedMac = m_readMac.CalculateMacConstantTime(seqNo, recordType, m_decryptConnectionID, - ciphertext, offset, dec_output_length, blocks_length - macSize, m_randomData); + ciphertext, offset, innerPlaintextLength, blocks_length - macSize, m_randomData); badMac |= !TlsUtilities.ConstantTimeAreEqual(macSize, expectedMac, 0, ciphertext, - offset + dec_output_length); + offset + innerPlaintextLength); } if (badMac) throw new TlsFatalAlert(AlertDescription.bad_record_mac); short contentType = recordType; + int plaintextLength = innerPlaintextLength; + if (m_decryptUseInnerPlaintext) { // Strip padding and read true content type from DTLSInnerPlaintext - int pos = dec_output_length; for (;;) { - if (--pos < 0) + if (--plaintextLength < 0) throw new TlsFatalAlert(AlertDescription.unexpected_message); - byte octet = ciphertext[offset + pos]; + byte octet = ciphertext[offset + plaintextLength]; if (0 != octet) { contentType = (short)(octet & 0xFF); - dec_output_length = pos; break; } } } - return new TlsDecodeResult(ciphertext, offset, dec_output_length, contentType); + return new TlsDecodeResult(ciphertext, offset, plaintextLength, contentType); } public override bool UsesOpaqueRecordType -- cgit 1.4.1