diff --git a/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs b/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs
index b4edde760..93f76a161 100644
--- a/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs
+++ b/crypto/src/tls/crypto/impl/TlsSuiteHmac.cs
@@ -5,9 +5,12 @@ using Org.BouncyCastle.Utilities;
namespace Org.BouncyCastle.Tls.Crypto.Impl
{
/// <summary>A generic TLS MAC implementation, acting as an HMAC based on some underlying Digest.</summary>
+ // TODO[api] sealed
public class TlsSuiteHmac
: TlsSuiteMac
{
+ private const long SequenceNumberPlaceholder = -1L;
+
protected static int GetMacSize(TlsCryptoParameters cryptoParams, TlsMac mac)
{
int macSize = mac.MacLength;
@@ -55,22 +58,53 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
public virtual byte[] CalculateMac(long seqNo, short type, byte[] msg, int msgOff, int msgLen)
{
+ return CalculateMac(seqNo, type, null, msg, msgOff, msgLen);
+ }
+
+ // TODO[api] Replace TlsSuiteMac.CalculateMac (non-span version)
+ public virtual byte[] CalculateMac(long seqNo, short type, byte[] connectionID, byte[] msg, int msgOff,
+ int msgLen)
+ {
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
- return CalculateMac(seqNo, type, msg.AsSpan(msgOff, msgLen));
+ return CalculateMac(seqNo, type, Spans.FromNullableReadOnly(connectionID), msg.AsSpan(msgOff, msgLen));
#else
ProtocolVersion serverVersion = m_cryptoParams.ServerVersion;
- bool isSsl = serverVersion.IsSsl;
- byte[] macHeader = new byte[isSsl ? 11 : 13];
- TlsUtilities.WriteUint64(seqNo, macHeader, 0);
- TlsUtilities.WriteUint8(type, macHeader, 8);
- if (!isSsl)
+ if (serverVersion.IsSsl)
{
+ byte[] macHeader = new byte[11];
+ TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+ TlsUtilities.WriteUint8(type, macHeader, 8);
+ TlsUtilities.WriteUint16(msgLen, macHeader, 9);
+
+ m_mac.Update(macHeader, 0, macHeader.Length);
+ }
+ else if (!Arrays.IsNullOrEmpty(connectionID))
+ {
+ int cidLength = connectionID.Length;
+ byte[] macHeader = new byte[23 + cidLength];
+ TlsUtilities.WriteUint64(SequenceNumberPlaceholder, macHeader, 0);
+ TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 8);
+ TlsUtilities.WriteUint8(cidLength, macHeader, 9);
+ TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 10);
+ TlsUtilities.WriteVersion(serverVersion, macHeader, 11);
+ TlsUtilities.WriteUint64(seqNo, macHeader, 13);
+ Array.Copy(connectionID, 0, macHeader, 21, cidLength);
+ TlsUtilities.WriteUint16(msgLen, macHeader, 21 + cidLength);
+
+ m_mac.Update(macHeader, 0, macHeader.Length);
+ }
+ else
+ {
+ byte[] macHeader = new byte[13];
+ TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+ TlsUtilities.WriteUint8(type, macHeader, 8);
TlsUtilities.WriteVersion(serverVersion, macHeader, 9);
+ TlsUtilities.WriteUint16(msgLen, macHeader, 11);
+
+ m_mac.Update(macHeader, 0, macHeader.Length);
}
- TlsUtilities.WriteUint16(msgLen, macHeader, macHeader.Length - 2);
- m_mac.Update(macHeader, 0, macHeader.Length);
m_mac.Update(msg, msgOff, msgLen);
return Truncate(m_mac.CalculateMac());
@@ -80,19 +114,50 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
public virtual byte[] CalculateMac(long seqNo, short type, ReadOnlySpan<byte> message)
{
+ return CalculateMac(seqNo, type, ReadOnlySpan<byte>.Empty, message);
+ }
+
+ // TODO[api] Replace TlsSuiteMac.CalculateMac (span version)
+ public virtual byte[] CalculateMac(long seqNo, short type, ReadOnlySpan<byte> connectionID,
+ ReadOnlySpan<byte> message)
+ {
ProtocolVersion serverVersion = m_cryptoParams.ServerVersion;
- bool isSsl = serverVersion.IsSsl;
- byte[] macHeader = new byte[isSsl ? 11 : 13];
- TlsUtilities.WriteUint64(seqNo, macHeader, 0);
- TlsUtilities.WriteUint8(type, macHeader, 8);
- if (!isSsl)
+ if (serverVersion.IsSsl)
+ {
+ byte[] macHeader = new byte[11];
+ TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+ TlsUtilities.WriteUint8(type, macHeader, 8);
+ TlsUtilities.WriteUint16(message.Length, macHeader, 9);
+
+ m_mac.Update(macHeader);
+ }
+ else if (!connectionID.IsEmpty)
{
+ int cidLength = connectionID.Length;
+ byte[] macHeader = new byte[23 + cidLength];
+ TlsUtilities.WriteUint64(SequenceNumberPlaceholder, macHeader, 0);
+ TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 8);
+ TlsUtilities.WriteUint8(cidLength, macHeader, 9);
+ TlsUtilities.WriteUint8(ContentType.tls12_cid, macHeader, 10);
+ TlsUtilities.WriteVersion(serverVersion, macHeader, 11);
+ TlsUtilities.WriteUint64(seqNo, macHeader, 13);
+ connectionID.CopyTo(macHeader.AsSpan(21));
+ TlsUtilities.WriteUint16(message.Length, macHeader, 21 + cidLength);
+
+ m_mac.Update(macHeader);
+ }
+ else
+ {
+ byte[] macHeader = new byte[13];
+ TlsUtilities.WriteUint64(seqNo, macHeader, 0);
+ TlsUtilities.WriteUint8(type, macHeader, 8);
TlsUtilities.WriteVersion(serverVersion, macHeader, 9);
+ TlsUtilities.WriteUint16(message.Length, macHeader, 11);
+
+ m_mac.Update(macHeader);
}
- TlsUtilities.WriteUint16(message.Length, macHeader, macHeader.Length - 2);
- m_mac.Update(macHeader);
m_mac.Update(message);
return Truncate(m_mac.CalculateMac());
@@ -102,16 +167,23 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
public virtual byte[] CalculateMacConstantTime(long seqNo, short type, byte[] msg, int msgOff, int msgLen,
int fullLength, byte[] dummyData)
{
+ return CalculateMacConstantTime(seqNo, type, null, msg, msgOff, msgLen, fullLength, dummyData);
+ }
+
+ // TODO[api] Replace TlsSuiteMac.CalculateMacConstantTime
+ public virtual byte[] CalculateMacConstantTime(long seqNo, short type, byte[] connectionID, byte[] msg,
+ int msgOff, int msgLen, int fullLength, byte[] dummyData)
+ {
/*
* Actual MAC only calculated on 'length' bytes...
*/
- byte[] result = CalculateMac(seqNo, type, msg, msgOff, msgLen);
+ byte[] result = CalculateMac(seqNo, type, connectionID, msg, msgOff, msgLen);
/*
* ...but ensure a constant number of complete digest blocks are processed (as many as would
* be needed for 'fullLength' bytes of input).
*/
- int headerLength = TlsImplUtilities.IsSsl(m_cryptoParams) ? 11 : 13;
+ int headerLength = GetHeaderLength(connectionID);
// How many extra full blocks do we need to calculate?
int extra = GetDigestBlockCount(headerLength + fullLength) - GetDigestBlockCount(headerLength + msgLen);
@@ -136,6 +208,22 @@ namespace Org.BouncyCastle.Tls.Crypto.Impl
return (inputLength + m_digestOverhead) / m_digestBlockSize;
}
+ protected virtual int GetHeaderLength(byte[] connectionID)
+ {
+ if (m_cryptoParams.ServerVersion.IsSsl)
+ {
+ return 11;
+ }
+ else if (!Arrays.IsNullOrEmpty(connectionID))
+ {
+ return 23 + connectionID.Length;
+ }
+ else
+ {
+ return 13;
+ }
+ }
+
protected virtual byte[] Truncate(byte[] bs)
{
if (bs.Length <= m_macSize)
|