summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/tls/crypto/impl/TlsSuiteHmac.cs122
1 files changed, 105 insertions, 17 deletions
diff --git a/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs b/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs
index b4edde760..93f76a161 100644
--- a/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs
+++ b/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs
@@ -5,9 +5,12 @@ using Org.BouncyCastle.Utilities;
 namespace Org.BouncyCastle.Tls.Crypto.Impl
 {
     /// <summary>A generic TLS MAC implementation, acting as an HMAC based on some underlying Digest.</summary>
+    // TODO[api] sealed
     public class TlsSuiteHmac
         : TlsSuiteMac
     {
+        private const long SequenceNumberPlaceholder = -1L;
+
         protected static int GetMacSize(TlsCryptoParameters cryptoParams, TlsMac mac)
         {
             int macSize = mac.MacLength;
@@ -55,22 +58,53 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 
         public virtual byte[] CalculateMac(long seqNo, short type, byte[] msg, int msgOff, int msgLen)
         {
+            return CalculateMac(seqNo, type, null, msg, msgOff, msgLen);
+        }
+
+        // TODO[api] Replace TlsSuiteMac.CalculateMac (non-span version)
+        public virtual byte[] CalculateMac(long seqNo, short type, byte[] connectionID, byte[] msg, int msgOff,
+            int msgLen)
+        {
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-            return CalculateMac(seqNo, type, msg.AsSpan(msgOff, msgLen));
+            return CalculateMac(seqNo, type, Spans.FromNullableReadOnly(connectionID), msg.AsSpan(msgOff, msgLen));
 #else
             ProtocolVersion serverVersion = m_cryptoParams.ServerVersion;
-            bool isSsl = serverVersion.IsSsl;
 
-            byte[] macHeader = new byte[isSsl ? 11 : 13];
-            TlsUtilities.WriteUint64(seqNo, macHeader, 0);
-            TlsUtilities.WriteUint8(type, macHeader, 8);
-            if (!isSsl)
+            if (serverVersion.IsSsl)
             {
+                byte[] macHeader = new byte[11];
+                TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+                TlsUtilities.WriteUint8(type, macHeader, 8);
+                TlsUtilities.WriteUint16(msgLen, macHeader, 9);
+
+                m_mac.Update(macHeader, 0, macHeader.Length);
+            }
+            else if (!Arrays.IsNullOrEmpty(connectionID))
+            {
+                int cidLength = connectionID.Length;
+                byte[] macHeader = new byte[23 + cidLength];
+                TlsUtilities.WriteUint64(SequenceNumberPlaceholder, macHeader, 0);
+                TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 8);
+                TlsUtilities.WriteUint8(cidLength, macHeader, 9);
+                TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 10);
+                TlsUtilities.WriteVersion(serverVersion, macHeader, 11);
+                TlsUtilities.WriteUint64(seqNo, macHeader, 13);
+                Array.Copy(connectionID, 0, macHeader, 21, cidLength);
+                TlsUtilities.WriteUint16(msgLen, macHeader, 21 + cidLength);
+
+                m_mac.Update(macHeader, 0, macHeader.Length);
+            }
+            else
+            {
+                byte[] macHeader = new byte[13];
+                TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+                TlsUtilities.WriteUint8(type, macHeader, 8);
                 TlsUtilities.WriteVersion(serverVersion, macHeader, 9);
+                TlsUtilities.WriteUint16(msgLen, macHeader, 11);
+
+                m_mac.Update(macHeader, 0, macHeader.Length);
             }
-            TlsUtilities.WriteUint16(msgLen, macHeader, macHeader.Length - 2);
 
-            m_mac.Update(macHeader, 0, macHeader.Length);
             m_mac.Update(msg, msgOff, msgLen);
 
             return Truncate(m_mac.CalculateMac());
@@ -80,19 +114,50 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
         public virtual byte[] CalculateMac(long seqNo, short type, ReadOnlySpan<byte> message)
         {
+            return CalculateMac(seqNo, type, ReadOnlySpan<byte>.Empty, message);
+        }
+
+        // TODO[api] Replace TlsSuiteMac.CalculateMac (span version)
+        public virtual byte[] CalculateMac(long seqNo, short type, ReadOnlySpan<byte> connectionID,
+            ReadOnlySpan<byte> message)
+        {
             ProtocolVersion serverVersion = m_cryptoParams.ServerVersion;
-            bool isSsl = serverVersion.IsSsl;
 
-            byte[] macHeader = new byte[isSsl ? 11 : 13];
-            TlsUtilities.WriteUint64(seqNo, macHeader, 0);
-            TlsUtilities.WriteUint8(type, macHeader, 8);
-            if (!isSsl)
+            if (serverVersion.IsSsl)
+            {
+                byte[] macHeader = new byte[11];
+                TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+                TlsUtilities.WriteUint8(type, macHeader, 8);
+                TlsUtilities.WriteUint16(message.Length, macHeader, 9);
+
+                m_mac.Update(macHeader);
+            }
+            else if (!connectionID.IsEmpty)
             {
+                int cidLength = connectionID.Length;
+                byte[] macHeader = new byte[23 + cidLength];
+                TlsUtilities.WriteUint64(SequenceNumberPlaceholder, macHeader, 0);
+                TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 8);
+                TlsUtilities.WriteUint8(cidLength, macHeader, 9);
+                TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 10);
+                TlsUtilities.WriteVersion(serverVersion, macHeader, 11);
+                TlsUtilities.WriteUint64(seqNo, macHeader, 13);
+                connectionID.CopyTo(macHeader.AsSpan(21));
+                TlsUtilities.WriteUint16(message.Length, macHeader, 21 + cidLength);
+
+                m_mac.Update(macHeader);
+            }
+            else
+            {
+                byte[] macHeader = new byte[13];
+                TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+                TlsUtilities.WriteUint8(type, macHeader, 8);
                 TlsUtilities.WriteVersion(serverVersion, macHeader, 9);
+                TlsUtilities.WriteUint16(message.Length, macHeader, 11);
+
+                m_mac.Update(macHeader);
             }
-            TlsUtilities.WriteUint16(message.Length, macHeader, macHeader.Length - 2);
 
-            m_mac.Update(macHeader);
             m_mac.Update(message);
 
             return Truncate(m_mac.CalculateMac());
@@ -102,16 +167,23 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
         public virtual byte[] CalculateMacConstantTime(long seqNo, short type, byte[] msg, int msgOff, int msgLen,
             int fullLength, byte[] dummyData)
         {
+            return CalculateMacConstantTime(seqNo, type, null, msg, msgOff, msgLen, fullLength, dummyData);
+        }
+
+        // TODO[api] Replace TlsSuiteMac.CalculateMacConstantTime
+        public virtual byte[] CalculateMacConstantTime(long seqNo, short type, byte[] connectionID, byte[] msg,
+            int msgOff, int msgLen, int fullLength, byte[] dummyData)
+        {
             /*
              * Actual MAC only calculated on 'length' bytes...
              */
-            byte[] result = CalculateMac(seqNo, type, msg, msgOff, msgLen);
+            byte[] result = CalculateMac(seqNo, type, connectionID, msg, msgOff, msgLen);
 
             /*
              * ...but ensure a constant number of complete digest blocks are processed (as many as would
              * be needed for 'fullLength' bytes of input).
              */
-            int headerLength = TlsImplUtilities.IsSsl(m_cryptoParams) ? 11 : 13;
+            int headerLength = GetHeaderLength(connectionID);
 
             // How many extra full blocks do we need to calculate?
             int extra = GetDigestBlockCount(headerLength + fullLength) - GetDigestBlockCount(headerLength + msgLen);
@@ -136,6 +208,22 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
             return (inputLength + m_digestOverhead) / m_digestBlockSize;
         }
 
+        protected virtual int GetHeaderLength(byte[] connectionID)
+        {
+            if (m_cryptoParams.ServerVersion.IsSsl)
+            {
+                return 11;
+            }
+            else if (!Arrays.IsNullOrEmpty(connectionID))
+            {
+                return 23 + connectionID.Length;
+            }
+            else
+            {
+                return 13;
+            }
+        }
+
         protected virtual byte[] Truncate(byte[] bs)
         {
             if (bs.Length <= m_macSize)