diff options
-rw-r--r-- | crypto/src/tls/crypto/impl/TlsSuiteHmac.cs | 122 |
1 files changed, 105 insertions, 17 deletions
diff --git a/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs b/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs index b4edde760..93f76a161 100644 --- a/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs +++ b/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs @@ -5,9 +5,12 @@ using Org.BouncyCastle.Utilities; namespace Org.BouncyCastle.Tls.Crypto.Impl { /// <summary>A generic TLS MAC implementation, acting as an HMAC based on some underlying Digest.</summary> + // TODO[api] sealed public class TlsSuiteHmac : TlsSuiteMac { + private const long SequenceNumberPlaceholder = -1L; + protected static int GetMacSize(TlsCryptoParameters cryptoParams, TlsMac mac) { int macSize = mac.MacLength; @@ -55,22 +58,53 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl public virtual byte[] CalculateMac(long seqNo, short type, byte[] msg, int msgOff, int msgLen) { + return CalculateMac(seqNo, type, null, msg, msgOff, msgLen); + } + + // TODO[api] Replace TlsSuiteMac.CalculateMac (non-span version) + public virtual byte[] CalculateMac(long seqNo, short type, byte[] connectionID, byte[] msg, int msgOff, + int msgLen) + { #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - return CalculateMac(seqNo, type, msg.AsSpan(msgOff, msgLen)); + return CalculateMac(seqNo, type, Spans.FromNullableReadOnly(connectionID), msg.AsSpan(msgOff, msgLen)); #else ProtocolVersion serverVersion = m_cryptoParams.ServerVersion; - bool isSsl = serverVersion.IsSsl; - byte[] macHeader = new byte[isSsl ? 11 : 13]; - TlsUtilities.WriteUint64(seqNo, macHeader, 0); - TlsUtilities.WriteUint8(type, macHeader, 8); - if (!isSsl) + if (serverVersion.IsSsl) { + byte[] macHeader = new byte[11]; + TlsUtilities.WriteUint64(seqNo, macHeader, 0); + TlsUtilities.WriteUint8(type, macHeader, 8); + TlsUtilities.WriteUint16(msgLen, macHeader, 9); + + m_mac.Update(macHeader, 0, macHeader.Length); + } + else if (!Arrays.IsNullOrEmpty(connectionID)) + { + int cidLength = connectionID.Length; + byte[] macHeader = new byte[23 + cidLength]; + TlsUtilities.WriteUint64(SequenceNumberPlaceholder, macHeader, 0); + TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 8); + TlsUtilities.WriteUint8(cidLength, macHeader, 9); + TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 10); + TlsUtilities.WriteVersion(serverVersion, macHeader, 11); + TlsUtilities.WriteUint64(seqNo, macHeader, 13); + Array.Copy(connectionID, 0, macHeader, 21, cidLength); + TlsUtilities.WriteUint16(msgLen, macHeader, 21 + cidLength); + + m_mac.Update(macHeader, 0, macHeader.Length); + } + else + { + byte[] macHeader = new byte[13]; + TlsUtilities.WriteUint64(seqNo, macHeader, 0); + TlsUtilities.WriteUint8(type, macHeader, 8); TlsUtilities.WriteVersion(serverVersion, macHeader, 9); + TlsUtilities.WriteUint16(msgLen, macHeader, 11); + + m_mac.Update(macHeader, 0, macHeader.Length); } - TlsUtilities.WriteUint16(msgLen, macHeader, macHeader.Length - 2); - m_mac.Update(macHeader, 0, macHeader.Length); m_mac.Update(msg, msgOff, msgLen); return Truncate(m_mac.CalculateMac()); @@ -80,19 +114,50 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER public virtual byte[] CalculateMac(long seqNo, short type, ReadOnlySpan<byte> message) { + return CalculateMac(seqNo, type, ReadOnlySpan<byte>.Empty, message); + } + + // TODO[api] Replace TlsSuiteMac.CalculateMac (span version) + public virtual byte[] CalculateMac(long seqNo, short type, ReadOnlySpan<byte> connectionID, + ReadOnlySpan<byte> message) + { ProtocolVersion serverVersion = m_cryptoParams.ServerVersion; - bool isSsl = serverVersion.IsSsl; - byte[] macHeader = new byte[isSsl ? 11 : 13]; - TlsUtilities.WriteUint64(seqNo, macHeader, 0); - TlsUtilities.WriteUint8(type, macHeader, 8); - if (!isSsl) + if (serverVersion.IsSsl) + { + byte[] macHeader = new byte[11]; + TlsUtilities.WriteUint64(seqNo, macHeader, 0); + TlsUtilities.WriteUint8(type, macHeader, 8); + TlsUtilities.WriteUint16(message.Length, macHeader, 9); + + m_mac.Update(macHeader); + } + else if (!connectionID.IsEmpty) { + int cidLength = connectionID.Length; + byte[] macHeader = new byte[23 + cidLength]; + TlsUtilities.WriteUint64(SequenceNumberPlaceholder, macHeader, 0); + TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 8); + TlsUtilities.WriteUint8(cidLength, macHeader, 9); + TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 10); + TlsUtilities.WriteVersion(serverVersion, macHeader, 11); + TlsUtilities.WriteUint64(seqNo, macHeader, 13); + connectionID.CopyTo(macHeader.AsSpan(21)); + TlsUtilities.WriteUint16(message.Length, macHeader, 21 + cidLength); + + m_mac.Update(macHeader); + } + else + { + byte[] macHeader = new byte[13]; + TlsUtilities.WriteUint64(seqNo, macHeader, 0); + TlsUtilities.WriteUint8(type, macHeader, 8); TlsUtilities.WriteVersion(serverVersion, macHeader, 9); + TlsUtilities.WriteUint16(message.Length, macHeader, 11); + + m_mac.Update(macHeader); } - TlsUtilities.WriteUint16(message.Length, macHeader, macHeader.Length - 2); - m_mac.Update(macHeader); m_mac.Update(message); return Truncate(m_mac.CalculateMac()); @@ -102,16 +167,23 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl public virtual byte[] CalculateMacConstantTime(long seqNo, short type, byte[] msg, int msgOff, int msgLen, int fullLength, byte[] dummyData) { + return CalculateMacConstantTime(seqNo, type, null, msg, msgOff, msgLen, fullLength, dummyData); + } + + // TODO[api] Replace TlsSuiteMac.CalculateMacConstantTime + public virtual byte[] CalculateMacConstantTime(long seqNo, short type, byte[] connectionID, byte[] msg, + int msgOff, int msgLen, int fullLength, byte[] dummyData) + { /* * Actual MAC only calculated on 'length' bytes... */ - byte[] result = CalculateMac(seqNo, type, msg, msgOff, msgLen); + byte[] result = CalculateMac(seqNo, type, connectionID, msg, msgOff, msgLen); /* * ...but ensure a constant number of complete digest blocks are processed (as many as would * be needed for 'fullLength' bytes of input). */ - int headerLength = TlsImplUtilities.IsSsl(m_cryptoParams) ? 11 : 13; + int headerLength = GetHeaderLength(connectionID); // How many extra full blocks do we need to calculate? int extra = GetDigestBlockCount(headerLength + fullLength) - GetDigestBlockCount(headerLength + msgLen); @@ -136,6 +208,22 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl return (inputLength + m_digestOverhead) / m_digestBlockSize; } + protected virtual int GetHeaderLength(byte[] connectionID) + { + if (m_cryptoParams.ServerVersion.IsSsl) + { + return 11; + } + else if (!Arrays.IsNullOrEmpty(connectionID)) + { + return 23 + connectionID.Length; + } + else + { + return 13; + } + } + protected virtual byte[] Truncate(byte[] bs) { if (bs.Length <= m_macSize) |