diff --git a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
index f3ff44285..f5ebd7eba 100644
--- a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
+++ b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
@@ -1,6 +1,8 @@
using System;
using System.IO;
+using Org.BouncyCastle.Utilities;
+
namespace Org.BouncyCastle.Tls.Crypto.Impl
{
/// <summary>A generic TLS 1.2 AEAD cipher.</summary>
@@ -13,6 +15,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
private const int NONCE_RFC5288 = 1;
private const int NONCE_RFC7905 = 2;
+ private const long SequenceNumberPlaceholder = -1L;
protected readonly TlsCryptoParameters m_cryptoParams;
protected readonly int m_keySize;
@@ -22,6 +25,8 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
protected readonly TlsAeadCipherImpl m_decryptCipher, m_encryptCipher;
protected readonly byte[] m_decryptNonce, m_encryptNonce;
+ protected readonly byte[] m_decryptConnectionID, m_encryptConnectionID;
+ protected readonly bool m_decryptUseInnerPlaintext, m_encryptUseInnerPlaintext;
protected readonly bool m_isTlsV13;
protected readonly int m_nonceMode;
@@ -39,6 +44,12 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
this.m_isTlsV13 = TlsImplUtilities.IsTlsV13(negotiatedVersion);
this.m_nonceMode = GetNonceMode(m_isTlsV13, aeadType);
+ m_decryptConnectionID = securityParameters.ConnectionIDPeer;
+ m_encryptConnectionID = securityParameters.ConnectionIDLocal;
+
+ m_decryptUseInnerPlaintext = m_isTlsV13 || !Arrays.IsNullOrEmpty(m_decryptConnectionID);
+ m_encryptUseInnerPlaintext = m_isTlsV13 || !Arrays.IsNullOrEmpty(m_encryptConnectionID);
+
switch (m_nonceMode)
{
case NONCE_RFC5288:
@@ -136,7 +147,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
public override int GetCiphertextDecodeLimit(int plaintextLimit)
{
- int innerPlaintextLimit = plaintextLimit + (m_isTlsV13 ? 1 : 0);
+ int innerPlaintextLimit = plaintextLimit + (m_decryptUseInnerPlaintext ? 1 : 0);
return innerPlaintextLimit + m_macSize + m_record_iv_length;
}
@@ -144,7 +155,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
public override int GetCiphertextEncodeLimit(int plaintextLength, int plaintextLimit)
{
int innerPlaintextLimit = plaintextLength;
- if (m_isTlsV13)
+ if (m_encryptUseInnerPlaintext)
{
// TODO[tls13] Add support for padding
int maxPadding = 0;
@@ -155,11 +166,18 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
return innerPlaintextLimit + m_macSize + m_record_iv_length;
}
- public override int GetPlaintextLimit(int ciphertextLimit)
+ public override int GetPlaintextDecodeLimit(int ciphertextLimit)
+ {
+ int innerPlaintextLimit = ciphertextLimit - m_macSize - m_record_iv_length;
+
+ return innerPlaintextLimit - (m_decryptUseInnerPlaintext ? 1 : 0);
+ }
+
+ public override int GetPlaintextEncodeLimit(int ciphertextLimit)
{
int innerPlaintextLimit = ciphertextLimit - m_macSize - m_record_iv_length;
- return innerPlaintextLimit - (m_isTlsV13 ? 1 : 0);
+ return innerPlaintextLimit - (m_encryptUseInnerPlaintext ? 1 : 0);
}
public override TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion,
@@ -189,10 +207,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
throw new TlsFatalAlert(AlertDescription.internal_error);
}
- int extraLength = m_isTlsV13 ? 1 : 0;
+ // TODO[tls13, cid] If we support adding padding to (D)TLSInnerPlaintext, this will need review
+ int innerPlaintextLength = plaintextLength + (m_encryptUseInnerPlaintext ? 1 : 0);
- // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review
- int encryptionLength = m_encryptCipher.GetOutputSize(plaintextLength + extraLength);
+ int encryptionLength = m_encryptCipher.GetOutputSize(innerPlaintextLength);
int ciphertextLength = m_record_iv_length + encryptionLength;
byte[] output = new byte[headerAllocation + ciphertextLength];
@@ -204,22 +222,25 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
outputPos += m_record_iv_length;
}
- short recordType = m_isTlsV13 ? ContentType.application_data : contentType;
+ short recordType = contentType;
+ if (m_encryptUseInnerPlaintext)
+ {
+ recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid;
+ }
byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength,
- plaintextLength);
+ innerPlaintextLength, m_encryptConnectionID);
try
{
Array.Copy(plaintext, plaintextOffset, output, outputPos, plaintextLength);
- if (m_isTlsV13)
+ if (m_encryptUseInnerPlaintext)
{
output[outputPos + plaintextLength] = (byte)contentType;
}
m_encryptCipher.Init(nonce, m_macSize, additionalData);
- outputPos += m_encryptCipher.DoFinal(output, outputPos, plaintextLength + extraLength, output,
- outputPos);
+ outputPos += m_encryptCipher.DoFinal(output, outputPos, innerPlaintextLength, output, outputPos);
}
catch (IOException)
{
@@ -264,10 +285,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
throw new TlsFatalAlert(AlertDescription.internal_error);
}
- int extraLength = m_isTlsV13 ? 1 : 0;
+ // TODO[tls13, cid] If we support adding padding to (D)TLSInnerPlaintext, this will need review
+ int innerPlaintextLength = plaintext.Length + (m_encryptUseInnerPlaintext ? 1 : 0);
- // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review
- int encryptionLength = m_encryptCipher.GetOutputSize(plaintext.Length + extraLength);
+ int encryptionLength = m_encryptCipher.GetOutputSize(innerPlaintextLength);
int ciphertextLength = m_record_iv_length + encryptionLength;
byte[] output = new byte[headerAllocation + ciphertextLength];
@@ -279,22 +300,25 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
outputPos += m_record_iv_length;
}
- short recordType = m_isTlsV13 ? ContentType.application_data : contentType;
+ short recordType = contentType;
+ if (m_encryptUseInnerPlaintext)
+ {
+ recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid;
+ }
byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength,
- plaintext.Length);
+ innerPlaintextLength, m_encryptConnectionID);
try
{
plaintext.CopyTo(output.AsSpan(outputPos));
- if (m_isTlsV13)
+ if (m_encryptUseInnerPlaintext)
{
output[outputPos + plaintext.Length] = (byte)contentType;
}
m_encryptCipher.Init(nonce, m_macSize, additionalData);
- outputPos += m_encryptCipher.DoFinal(output, outputPos, plaintext.Length + extraLength, output,
- outputPos);
+ outputPos += m_encryptCipher.DoFinal(output, outputPos, innerPlaintextLength, output, outputPos);
}
catch (IOException)
{
@@ -318,7 +342,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
public override TlsDecodeResult DecodeCiphertext(long seqNo, short recordType, ProtocolVersion recordVersion,
byte[] ciphertext, int ciphertextOffset, int ciphertextLength)
{
- if (GetPlaintextLimit(ciphertextLength) < 0)
+ if (GetPlaintextDecodeLimit(ciphertextLength) < 0)
throw new TlsFatalAlert(AlertDescription.decode_error);
byte[] nonce = new byte[m_decryptNonce.Length + m_record_iv_length];
@@ -343,10 +367,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
int encryptionOffset = ciphertextOffset + m_record_iv_length;
int encryptionLength = ciphertextLength - m_record_iv_length;
- int plaintextLength = m_decryptCipher.GetOutputSize(encryptionLength);
+ int innerPlaintextLength = m_decryptCipher.GetOutputSize(encryptionLength);
byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength,
- plaintextLength);
+ innerPlaintextLength, m_decryptConnectionID);
int outputPos;
try
@@ -373,27 +397,27 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
throw new TlsFatalAlert(AlertDescription.bad_record_mac, e);
}
- if (outputPos != plaintextLength)
+ if (outputPos != innerPlaintextLength)
{
// NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction.
throw new TlsFatalAlert(AlertDescription.internal_error);
}
short contentType = recordType;
- if (m_isTlsV13)
+ int plaintextLength = innerPlaintextLength;
+
+ if (m_decryptUseInnerPlaintext)
{
// Strip padding and read true content type from TLSInnerPlaintext
- int pos = plaintextLength;
for (;;)
{
- if (--pos < 0)
+ if (--plaintextLength < 0)
throw new TlsFatalAlert(AlertDescription.unexpected_message);
- byte octet = ciphertext[encryptionOffset + pos];
+ byte octet = ciphertext[encryptionOffset + plaintextLength];
if (0 != octet)
{
contentType = (short)(octet & 0xFF);
- plaintextLength = pos;
break;
}
}
@@ -445,6 +469,29 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
}
}
+ protected virtual byte[] GetAdditionalData(long seqNo, short recordType, ProtocolVersion recordVersion,
+ int ciphertextLength, int plaintextLength, byte[] connectionID)
+ {
+ if (Arrays.IsNullOrEmpty(connectionID))
+ return GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);
+
+ /*
+ * seq_num_placeholder + tls12_cid + cid_length + tls12_cid + DTLSCiphertext.version + epoch
+ * + sequence_number + cid + length_of_DTLSInnerPlaintext
+ */
+ int cidLength = connectionID.Length;
+ byte[] additional_data = new byte[23 + cidLength];
+ TlsUtilities.WriteUint64(SequenceNumberPlaceholder, additional_data, 0);
+ TlsUtilities.WriteUint8(ContentType.tls12_cid, additional_data, 8);
+ TlsUtilities.WriteUint8(cidLength, additional_data, 9);
+ TlsUtilities.WriteUint8(ContentType.tls12_cid, additional_data, 10);
+ TlsUtilities.WriteVersion(recordVersion, additional_data, 11);
+ TlsUtilities.WriteUint64(seqNo, additional_data, 13);
+ Array.Copy(connectionID, 0, additional_data, 21, cidLength);
+ TlsUtilities.WriteUint16(plaintextLength, additional_data, 21 + cidLength);
+ return additional_data;
+ }
+
protected virtual void RekeyCipher(SecurityParameters securityParameters, TlsAeadCipherImpl cipher,
byte[] nonce, bool serverSecret)
{
diff --git a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
index 11eb75186..63d4826b3 100644
--- a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
+++ b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
@@ -443,43 +443,43 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
m_encryptThenMac ? 0 : macSize);
bool badMac = (totalPad == 0);
- int dec_output_length = blocks_length - totalPad;
+ int innerPlaintextLength = blocks_length - totalPad;
if (!m_encryptThenMac)
{
- dec_output_length -= macSize;
+ innerPlaintextLength -= macSize;
byte[] expectedMac = m_readMac.CalculateMacConstantTime(seqNo, recordType, m_decryptConnectionID,
- ciphertext, offset, dec_output_length, blocks_length - macSize, m_randomData);
+ ciphertext, offset, innerPlaintextLength, blocks_length - macSize, m_randomData);
badMac |= !TlsUtilities.ConstantTimeAreEqual(macSize, expectedMac, 0, ciphertext,
- offset + dec_output_length);
+ offset + innerPlaintextLength);
}
if (badMac)
throw new TlsFatalAlert(AlertDescription.bad_record_mac);
short contentType = recordType;
+ int plaintextLength = innerPlaintextLength;
+
if (m_decryptUseInnerPlaintext)
{
// Strip padding and read true content type from DTLSInnerPlaintext
- int pos = dec_output_length;
for (;;)
{
- if (--pos < 0)
+ if (--plaintextLength < 0)
throw new TlsFatalAlert(AlertDescription.unexpected_message);
- byte octet = ciphertext[offset + pos];
+ byte octet = ciphertext[offset + plaintextLength];
if (0 != octet)
{
contentType = (short)(octet & 0xFF);
- dec_output_length = pos;
break;
}
}
}
- return new TlsDecodeResult(ciphertext, offset, dec_output_length, contentType);
+ return new TlsDecodeResult(ciphertext, offset, plaintextLength, contentType);
}
public override bool UsesOpaqueRecordType
|