summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/tls/crypto/impl/AbstractTlsCipher.cs5
-rw-r--r--crypto/src/tls/crypto/impl/TlsAeadCipher.cs8
-rw-r--r--crypto/src/tls/crypto/impl/TlsBlockCipher.cs169
3 files changed, 136 insertions, 46 deletions
diff --git a/crypto/src/tls/crypto/impl/AbstractTlsCipher.cs b/crypto/src/tls/crypto/impl/AbstractTlsCipher.cs
index 03d6ddba2..5315eb965 100644
--- a/crypto/src/tls/crypto/impl/AbstractTlsCipher.cs
+++ b/crypto/src/tls/crypto/impl/AbstractTlsCipher.cs
@@ -12,7 +12,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
         public abstract int GetCiphertextEncodeLimit(int plaintextLength, int plaintextLimit);
 
         // TODO[api] Remove this method from TlsCipher
-        public abstract int GetPlaintextLimit(int ciphertextLimit);
+        public virtual int GetPlaintextLimit(int ciphertextLimit)
+        {
+            return GetPlaintextEncodeLimit(ciphertextLimit);
+        }
 
         // TODO[api] Add to TlsCipher
         public virtual int GetPlaintextDecodeLimit(int ciphertextLimit)
diff --git a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
index f238a3afb..f3ff44285 100644
--- a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
+++ b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs
@@ -136,7 +136,9 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
         public override int GetCiphertextDecodeLimit(int plaintextLimit)
         {
-            return plaintextLimit + m_macSize + m_record_iv_length + (m_isTlsV13 ? 1 : 0);
+            int innerPlaintextLimit = plaintextLimit + (m_isTlsV13 ? 1 : 0);
+
+            return innerPlaintextLimit + m_macSize + m_record_iv_length;
         }
 
         public override int GetCiphertextEncodeLimit(int plaintextLength, int plaintextLimit)
@@ -155,7 +157,9 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
         public override int GetPlaintextLimit(int ciphertextLimit)
         {
-            return ciphertextLimit - m_macSize - m_record_iv_length - (m_isTlsV13 ? 1 : 0);
+            int innerPlaintextLimit = ciphertextLimit - m_macSize - m_record_iv_length;
+
+            return  innerPlaintextLimit - (m_isTlsV13 ? 1 : 0);
         }
 
         public override TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion,
diff --git a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
index 1e6889982..11eb75186 100644
--- a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
+++ b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs
@@ -18,7 +18,9 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
         protected readonly bool m_useExtraPadding;
 
         protected readonly TlsBlockCipherImpl m_decryptCipher, m_encryptCipher;
-        protected readonly TlsSuiteMac m_readMac, m_writeMac;
+        protected readonly TlsSuiteHmac m_readMac, m_writeMac;
+        protected readonly byte[] m_decryptConnectionID, m_encryptConnectionID;
+        protected readonly bool m_decryptUseInnerPlaintext, m_encryptUseInnerPlaintext;
 
         /// <exception cref="IOException"/>
         public TlsBlockCipher(TlsCryptoParameters cryptoParams, TlsBlockCipherImpl encryptCipher,
@@ -30,6 +32,12 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             if (TlsImplUtilities.IsTlsV13(negotiatedVersion))
                 throw new TlsFatalAlert(AlertDescription.internal_error);
 
+            m_decryptConnectionID = securityParameters.ConnectionIDPeer;
+            m_encryptConnectionID = securityParameters.ConnectionIDLocal;
+
+            m_decryptUseInnerPlaintext = !Arrays.IsNullOrEmpty(m_decryptConnectionID);
+            m_encryptUseInnerPlaintext = !Arrays.IsNullOrEmpty(m_encryptConnectionID);
+
             this.m_cryptoParams = cryptoParams;
             this.m_randomData = cryptoParams.NonceGenerator.GenerateNonce(256);
 
@@ -151,49 +159,49 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
         {
             int blockSize = m_decryptCipher.GetBlockSize();
             int macSize = m_readMac.Size;
-            int maxPadding = 256;
+            int maxExtraPadding = 256;
 
-            return GetCiphertextLength(blockSize, macSize, maxPadding, plaintextLimit);
+            int innerPlaintextLimit = plaintextLimit + (m_decryptUseInnerPlaintext ? 1 : 0);
+
+            return GetCiphertextLength(blockSize, macSize, maxExtraPadding, innerPlaintextLimit);
         }
 
         public override int GetCiphertextEncodeLimit(int plaintextLength, int plaintextLimit)
         {
             int blockSize = m_encryptCipher.GetBlockSize();
             int macSize = m_writeMac.Size;
-            int maxPadding = m_useExtraPadding ? 256 : blockSize;
+            int maxExtraPadding = m_useExtraPadding ? 256 : blockSize;
+
+            int innerPlaintextLimit = plaintextLength;
+            if (m_encryptUseInnerPlaintext)
+            {
+                // TODO[cid] Add support for padding
+                int maxPadding = 0;
+
+                innerPlaintextLimit = 1 + System.Math.Min(plaintextLimit, plaintextLength + maxPadding);
+            }
 
-            return GetCiphertextLength(blockSize, macSize, maxPadding, plaintextLength);
+            return GetCiphertextLength(blockSize, macSize, maxExtraPadding, innerPlaintextLimit);
         }
 
-        public override int GetPlaintextLimit(int ciphertextLimit)
+        public override int GetPlaintextDecodeLimit(int ciphertextLimit)
         {
-            int blockSize = m_encryptCipher.GetBlockSize();
-            int macSize = m_writeMac.Size;
+            int blockSize = m_decryptCipher.GetBlockSize();
+            int macSize = m_readMac.Size;
 
-            int plaintextLimit = ciphertextLimit;
+            int innerPlaintextLimit = GetPlaintextLength(blockSize, macSize, ciphertextLimit);
 
-            // Leave room for the MAC, and require block-alignment
-            if (m_encryptThenMac)
-            {
-                plaintextLimit -= macSize;
-                plaintextLimit -= plaintextLimit % blockSize;
-            }
-            else
-            {
-                plaintextLimit -= plaintextLimit % blockSize;
-                plaintextLimit -= macSize;
-            }
+            return innerPlaintextLimit - (m_decryptUseInnerPlaintext ? 1 : 0);
+        }
 
-            // Minimum 1 byte of padding
-            --plaintextLimit;
+        public override int GetPlaintextEncodeLimit(int ciphertextLimit)
+        {
+            int blockSize = m_encryptCipher.GetBlockSize();
+            int macSize = m_writeMac.Size;
 
-            // An explicit IV consumes 1 block
-            if (m_useExplicitIV)
-            {
-                plaintextLimit -= blockSize;
-            }
+            int innerPlaintextLimit = GetPlaintextLength(blockSize, macSize, ciphertextLimit);
 
-            return plaintextLimit;
+            return innerPlaintextLimit - (m_encryptUseInnerPlaintext ? 1 : 0);
         }
 
         public override TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion,
@@ -205,7 +213,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             int blockSize = m_encryptCipher.GetBlockSize();
             int macSize = m_writeMac.Size;
 
-            int enc_input_length = len;
+            // TODO[cid] If we support adding padding to DTLSInnerPlaintext, this will need review
+            int innerPlaintextLength = len + (m_encryptUseInnerPlaintext ? 1 : 0);
+
+            int enc_input_length = innerPlaintextLength;
             if (!m_encryptThenMac)
             {
                 enc_input_length += macSize;
@@ -220,7 +231,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 padding_length += actualExtraPadBlocks * blockSize;
             }
 
-            int totalSize = len + macSize + padding_length;
+            int totalSize = innerPlaintextLength + macSize + padding_length;
             if (m_useExplicitIV)
             {
                 totalSize += blockSize;
@@ -237,12 +248,22 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 outOff += blockSize;
             }
 
+            int innerPlaintextOffset = outOff;
+
             Array.Copy(plaintext, offset, outBuf, outOff, len);
             outOff += len;
 
+            short recordType = contentType;
+            if (m_encryptUseInnerPlaintext)
+            {
+                outBuf[outOff++] = (byte)contentType;
+                recordType = ContentType.tls12_cid;
+            }
+
             if (!m_encryptThenMac)
             {
-                byte[] mac = m_writeMac.CalculateMac(seqNo, contentType, plaintext, offset, len);
+                byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID, outBuf,
+                    innerPlaintextOffset, innerPlaintextLength);
                 Array.Copy(mac, 0, outBuf, outOff, mac.Length);
                 outOff += mac.Length;
             }
@@ -257,7 +278,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
             if (m_encryptThenMac)
             {
-                byte[] mac = m_writeMac.CalculateMac(seqNo, contentType, outBuf, headerAllocation,
+                byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, outBuf, headerAllocation,
                     outOff - headerAllocation);
                 Array.Copy(mac, 0, outBuf, outOff, mac.Length);
                 outOff += mac.Length;
@@ -266,7 +287,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             if (outOff != outBuf.Length)
                 throw new TlsFatalAlert(AlertDescription.internal_error);
 
-            return new TlsEncodeResult(outBuf, 0, outBuf.Length, contentType);
+            return new TlsEncodeResult(outBuf, 0, outBuf.Length, recordType);
 #endif
         }
 
@@ -277,7 +298,10 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             int blockSize = m_encryptCipher.GetBlockSize();
             int macSize = m_writeMac.Size;
 
-            int enc_input_length = plaintext.Length;
+            // TODO[cid] If we support adding padding to DTLSInnerPlaintext, this will need review
+            int innerPlaintextLength = plaintext.Length + (m_encryptUseInnerPlaintext ? 1 : 0);
+
+            int enc_input_length = innerPlaintextLength;
             if (!m_encryptThenMac)
             {
                 enc_input_length += macSize;
@@ -292,7 +316,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 padding_length += actualExtraPadBlocks * blockSize;
             }
 
-            int totalSize = plaintext.Length + macSize + padding_length;
+            int totalSize = innerPlaintextLength + macSize + padding_length;
             if (m_useExplicitIV)
             {
                 totalSize += blockSize;
@@ -309,12 +333,22 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
                 outOff += blockSize;
             }
 
+            int innerPlaintextOffset = outOff;
+
             plaintext.CopyTo(outBuf.AsSpan(outOff));
             outOff += plaintext.Length;
 
+            short recordType = contentType;
+            if (m_encryptUseInnerPlaintext)
+            {
+                outBuf[outOff++] = (byte)contentType;
+                recordType = ContentType.tls12_cid;
+            }
+
             if (!m_encryptThenMac)
             {
-                byte[] mac = m_writeMac.CalculateMac(seqNo, contentType, plaintext);
+                byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID,
+                    outBuf.AsSpan(innerPlaintextOffset, innerPlaintextLength));
                 mac.CopyTo(outBuf.AsSpan(outOff));
                 outOff += mac.Length;
             }
@@ -329,8 +363,8 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
             if (m_encryptThenMac)
             {
-                byte[] mac = m_writeMac.CalculateMac(seqNo, contentType, outBuf, headerAllocation,
-                    outOff - headerAllocation);
+                byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID,
+                    outBuf.AsSpan(headerAllocation, outOff - headerAllocation));
                 Array.Copy(mac, 0, outBuf, outOff, mac.Length);
                 outOff += mac.Length;
             }
@@ -338,7 +372,7 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             if (outOff != outBuf.Length)
                 throw new TlsFatalAlert(AlertDescription.internal_error);
 
-            return new TlsEncodeResult(outBuf, 0, outBuf.Length, contentType);
+            return new TlsEncodeResult(outBuf, 0, outBuf.Length, recordType);
         }
 #endif
 
@@ -377,7 +411,8 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
             if (m_encryptThenMac)
             {
-                byte[] expectedMac = m_readMac.CalculateMac(seqNo, recordType, ciphertext, offset, len - macSize);
+                byte[] expectedMac = m_readMac.CalculateMac(seqNo, recordType, m_decryptConnectionID, ciphertext,
+                    offset, len - macSize);
 
                 bool checkMac = TlsUtilities.ConstantTimeAreEqual(macSize, expectedMac, 0, ciphertext,
                     offset + len - macSize);
@@ -414,8 +449,8 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             {
                 dec_output_length -= macSize;
 
-                byte[] expectedMac = m_readMac.CalculateMacConstantTime(seqNo, recordType, ciphertext, offset,
-                    dec_output_length, blocks_length - macSize, m_randomData);
+                byte[] expectedMac = m_readMac.CalculateMacConstantTime(seqNo, recordType, m_decryptConnectionID,
+                    ciphertext, offset, dec_output_length, blocks_length - macSize, m_randomData);
 
                 badMac |= !TlsUtilities.ConstantTimeAreEqual(macSize, expectedMac, 0, ciphertext,
                     offset + dec_output_length);
@@ -424,7 +459,27 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             if (badMac)
                 throw new TlsFatalAlert(AlertDescription.bad_record_mac);
 
-            return new TlsDecodeResult(ciphertext, offset, dec_output_length, recordType);
+            short contentType = recordType;
+            if (m_decryptUseInnerPlaintext)
+            {
+                // Strip padding and read true content type from DTLSInnerPlaintext
+                int pos = dec_output_length;
+                for (;;)
+                {
+                    if (--pos < 0)
+                        throw new TlsFatalAlert(AlertDescription.unexpected_message);
+
+                    byte octet = ciphertext[offset + pos];
+                    if (0 != octet)
+                    {
+                        contentType = (short)(octet & 0xFF);
+                        dec_output_length = pos;
+                        break;
+                    }
+                }
+            }
+
+            return new TlsDecodeResult(ciphertext, offset, dec_output_length, contentType);
         }
 
         public override bool UsesOpaqueRecordType
@@ -514,5 +569,33 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
             return ciphertextLength;
         }
+
+        protected virtual int GetPlaintextLength(int blockSize, int macSize, int ciphertextLength)
+        {
+            int plaintextLength = ciphertextLength;
+
+            // Leave room for the MAC, and require block-alignment
+            if (m_encryptThenMac)
+            {
+                plaintextLength -= macSize;
+                plaintextLength -= plaintextLength % blockSize;
+            }
+            else
+            {
+                plaintextLength -= plaintextLength % blockSize;
+                plaintextLength -= macSize;
+            }
+
+            // Minimum 1 byte of padding
+            --plaintextLength;
+
+            // An explicit IV consumes 1 block
+            if (m_useExplicitIV)
+            {
+                plaintextLength -= blockSize;
+            }
+
+            return plaintextLength;
+        }
     }
 }