From d7557597d18a313c7e573b11e48ba8648d8a50a9 Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Wed, 17 May 2023 18:16:36 +0700 Subject: DTLS: Improve DtlsVerifier performance --- crypto/src/tls/DtlsClientProtocol.cs | 2 +- crypto/src/tls/DtlsRecordLayer.cs | 55 +++++++------- crypto/src/tls/DtlsReliableHandshake.cs | 50 ++++++------- crypto/src/tls/DtlsVerifier.cs | 108 ++++++++++++--------------- crypto/src/tls/TlsClientProtocol.cs | 2 +- crypto/test/src/tls/test/DtlsProtocolTest.cs | 32 +++++++- 6 files changed, 132 insertions(+), 117 deletions(-) diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs index 72484e178..c1bad2e6f 100644 --- a/crypto/src/tls/DtlsClientProtocol.cs +++ b/crypto/src/tls/DtlsClientProtocol.cs @@ -525,7 +525,7 @@ namespace Org.BouncyCastle.Tls ClientHello clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, session_id, - TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0); + cookie: TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0); MemoryStream buf = new MemoryStream(); clientHello.Encode(state.clientContext, buf); diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index efe9e7312..e3567aa46 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -4,7 +4,6 @@ using System.IO; using System.Net.Sockets; using Org.BouncyCastle.Tls.Crypto; -using Org.BouncyCastle.Tls.Crypto.Impl; using Org.BouncyCastle.Utilities; using Org.BouncyCastle.Utilities.Date; @@ -13,43 +12,45 @@ namespace Org.BouncyCastle.Tls internal class DtlsRecordLayer : DatagramTransport { - private const int RECORD_HEADER_LENGTH = 13; + internal const int RecordHeaderLength = 13; + private const int MAX_FRAGMENT_LENGTH = 1 << 14; private const long TCP_MSL = 1000L * 60 * 2; private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2; /// - internal static byte[] ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen) + internal static int ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen) { - if (dataLen < RECORD_HEADER_LENGTH) - { - return null; - } + if (dataLen < RecordHeaderLength) + return -1; short contentType = TlsUtilities.ReadUint8(data, dataOff + 0); if (ContentType.handshake != contentType) - return null; + return -1; ProtocolVersion version = TlsUtilities.ReadVersion(data, dataOff + 1); if (!ProtocolVersion.DTLSv10.IsEqualOrEarlierVersionOf(version)) - return null; + return -1; int epoch = TlsUtilities.ReadUint16(data, dataOff + 3); if (0 != epoch) - return null; + return -1; //long sequenceNumber = TlsUtilities.ReadUint48(data, dataOff + 5); int length = TlsUtilities.ReadUint16(data, dataOff + 11); - if (dataLen < RECORD_HEADER_LENGTH + length) - return null; + if (length < 1 || length > MAX_FRAGMENT_LENGTH) + return -1; - if (length > MAX_FRAGMENT_LENGTH) - return null; + if (dataLen < RecordHeaderLength + length) + return -1; + + short msgType = TlsUtilities.ReadUint8(data, dataOff + RecordHeaderLength); + if (HandshakeType.client_hello != msgType) + return -1; // NOTE: We ignore/drop any data after the first record - return TlsUtilities.CopyOfRangeExact(data, dataOff + RECORD_HEADER_LENGTH, - dataOff + RECORD_HEADER_LENGTH + length); + return length; } /// @@ -57,14 +58,14 @@ namespace Org.BouncyCastle.Tls { TlsUtilities.CheckUint16(message.Length); - byte[] record = new byte[RECORD_HEADER_LENGTH + message.Length]; + byte[] record = new byte[RecordHeaderLength + message.Length]; TlsUtilities.WriteUint8(ContentType.handshake, record, 0); TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, record, 1); TlsUtilities.WriteUint16(0, record, 3); TlsUtilities.WriteUint48(recordSeq, record, 5); TlsUtilities.WriteUint16(message.Length, record, 11); - Array.Copy(message, 0, record, RECORD_HEADER_LENGTH, message.Length); + Array.Copy(message, 0, record, RecordHeaderLength, message.Length); SendDatagram(sender, record, 0, record.Length); } @@ -124,8 +125,8 @@ namespace Org.BouncyCastle.Tls this.m_inHandshake = true; - this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RECORD_HEADER_LENGTH, - RECORD_HEADER_LENGTH); + this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RecordHeaderLength, + RecordHeaderLength); this.m_pendingEpoch = null; this.m_readEpoch = m_currentEpoch; this.m_writeEpoch = m_currentEpoch; @@ -179,8 +180,8 @@ namespace Org.BouncyCastle.Tls */ var securityParameters = m_context.SecurityParameters; - int recordHeaderLengthRead = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDPeer?.Length ?? 0); - int recordHeaderLengthWrite = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDLocal?.Length ?? 0); + int recordHeaderLengthRead = RecordHeaderLength + (securityParameters.ConnectionIDPeer?.Length ?? 0); + int recordHeaderLengthWrite = RecordHeaderLength + (securityParameters.ConnectionIDLocal?.Length ?? 0); // TODO Check for overflow this.m_pendingEpoch = new DtlsEpoch(m_writeEpoch.Epoch + 1, pendingCipher, recordHeaderLengthRead, @@ -684,7 +685,7 @@ namespace Org.BouncyCastle.Tls #endif { // NOTE: received < 0 (timeout) is covered by this first case - if (received < RECORD_HEADER_LENGTH) + if (received < RecordHeaderLength) return -1; // TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers @@ -729,7 +730,7 @@ namespace Org.BouncyCastle.Tls int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; - if (recordHeaderLength > RECORD_HEADER_LENGTH) + if (recordHeaderLength > RecordHeaderLength) { if (ContentType.tls12_cid != recordType) return -1; @@ -990,7 +991,7 @@ namespace Org.BouncyCastle.Tls { Debug.Assert(m_recordQueue.Available > 0); - int recordLength = RECORD_HEADER_LENGTH; + int recordLength = RecordHeaderLength; if (m_recordQueue.Available >= recordLength) { short recordType = m_recordQueue.ReadUint8(0); @@ -1033,7 +1034,7 @@ namespace Org.BouncyCastle.Tls return ReceivePendingRecord(buf, off, len); int received = ReceiveDatagram(buf, off, len, waitMillis); - if (received >= RECORD_HEADER_LENGTH) + if (received >= RecordHeaderLength) { this.m_inConnection = true; @@ -1151,7 +1152,7 @@ namespace Org.BouncyCastle.Tls TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3); TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5); - if (recordHeaderLength > RECORD_HEADER_LENGTH) + if (recordHeaderLength > RecordHeaderLength) { byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal; Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length); diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs index 42a98a991..b1107f7a1 100644 --- a/crypto/src/tls/DtlsReliableHandshake.cs +++ b/crypto/src/tls/DtlsReliableHandshake.cs @@ -8,47 +8,41 @@ namespace Org.BouncyCastle.Tls { internal class DtlsReliableHandshake { - private const int MAX_RECEIVE_AHEAD = 16; - private const int MESSAGE_HEADER_LENGTH = 12; + internal const int MessageHeaderLength = 12; + private const int MAX_RECEIVE_AHEAD = 16; private const int MAX_RESEND_MILLIS = 60000; /// - internal static DtlsRequest ReadClientRequest(byte[] data, int dataOff, int dataLen, Stream dtlsOutput) + internal static MemoryStream ReceiveClientHelloMessage(byte[] msg, int msgOff, int msgLen) { // TODO Support the possibility of a fragmented ClientHello datagram - byte[] message = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen); - if (null == message || message.Length < MESSAGE_HEADER_LENGTH) + if (msgLen < MessageHeaderLength) return null; - long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5); - - short msgType = TlsUtilities.ReadUint8(message, 0); + short msgType = TlsUtilities.ReadUint8(msg, msgOff); if (HandshakeType.client_hello != msgType) return null; - int length = TlsUtilities.ReadUint24(message, 1); - if (message.Length != MESSAGE_HEADER_LENGTH + length) + int length = TlsUtilities.ReadUint24(msg, msgOff + 1); + if (msgLen != MessageHeaderLength + length) return null; // TODO Consider stricter HelloVerifyRequest-related checks - //int messageSeq = TlsUtilities.ReadUint16(message, 4); + //int messageSeq = TlsUtilities.ReadUint16(msg, msgOff + 4); //if (messageSeq > 1) // return null; - int fragmentOffset = TlsUtilities.ReadUint24(message, 6); + int fragmentOffset = TlsUtilities.ReadUint24(msg, msgOff + 6); if (0 != fragmentOffset) return null; - int fragmentLength = TlsUtilities.ReadUint24(message, 9); + int fragmentLength = TlsUtilities.ReadUint24(msg, msgOff + 9); if (length != fragmentLength) return null; - ClientHello clientHello = ClientHello.Parse( - new MemoryStream(message, MESSAGE_HEADER_LENGTH, length, false), dtlsOutput); - - return new DtlsRequest(recordSeq, message, clientHello); + return new MemoryStream(msg, msgOff + MessageHeaderLength, length, false); } /// @@ -58,7 +52,7 @@ namespace Org.BouncyCastle.Tls int length = 3 + cookie.Length; - byte[] message = new byte[MESSAGE_HEADER_LENGTH + length]; + byte[] message = new byte[MessageHeaderLength + length]; TlsUtilities.WriteUint8(HandshakeType.hello_verify_request, message, 0); TlsUtilities.WriteUint24(length, message, 1); //TlsUtilities.WriteUint16(0, message, 4); @@ -66,8 +60,8 @@ namespace Org.BouncyCastle.Tls TlsUtilities.WriteUint24(length, message, 9); // HelloVerifyRequest fields - TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MESSAGE_HEADER_LENGTH + 0); - TlsUtilities.WriteOpaque8(cookie, message, MESSAGE_HEADER_LENGTH + 2); + TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MessageHeaderLength + 0); + TlsUtilities.WriteOpaque8(cookie, message, MessageHeaderLength + 2); DtlsRecordLayer.SendHelloVerifyRequestRecord(sender, recordSeq, message); } @@ -111,7 +105,7 @@ namespace Org.BouncyCastle.Tls // Simulate a previous flight consisting of the request ClientHello DtlsReassembler reassembler = new DtlsReassembler(HandshakeType.client_hello, - message.Length - MESSAGE_HEADER_LENGTH); + message.Length - MessageHeaderLength); m_currentInboundFlight[messageSeq] = reassembler; // We sent HelloVerifyRequest with (message) sequence number 0 @@ -215,7 +209,7 @@ namespace Org.BouncyCastle.Tls default: { byte[] body = message.Body; - byte[] buf = new byte[MESSAGE_HEADER_LENGTH]; + byte[] buf = new byte[MessageHeaderLength]; TlsUtilities.WriteUint8(msg_type, buf, 0); TlsUtilities.WriteUint24(body.Length, buf, 1); TlsUtilities.WriteUint16(message.Seq, buf, 4); @@ -360,10 +354,10 @@ namespace Org.BouncyCastle.Tls { bool checkPreviousFlight = false; - while (len >= MESSAGE_HEADER_LENGTH) + while (len >= MessageHeaderLength) { int fragment_length = TlsUtilities.ReadUint24(buf, off + 9); - int message_length = fragment_length + MESSAGE_HEADER_LENGTH; + int message_length = fragment_length + MessageHeaderLength; if (len < message_length) { // NOTE: Truncated message - ignore it @@ -400,7 +394,7 @@ namespace Org.BouncyCastle.Tls m_currentInboundFlight[message_seq] = reassembler; } - reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, fragment_offset, + reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset, fragment_length); } else if (m_previousInboundFlight != null) @@ -412,7 +406,7 @@ namespace Org.BouncyCastle.Tls if (m_previousInboundFlight.TryGetValue(message_seq, out var reassembler)) { - reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, + reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset, fragment_length); checkPreviousFlight = true; } @@ -446,7 +440,7 @@ namespace Org.BouncyCastle.Tls private void WriteMessage(Message message) { int sendLimit = m_recordLayer.GetSendLimit(); - int fragmentLimit = sendLimit - MESSAGE_HEADER_LENGTH; + int fragmentLimit = sendLimit - MessageHeaderLength; // TODO Support a higher minimum fragment size? if (fragmentLimit < 1) @@ -471,7 +465,7 @@ namespace Org.BouncyCastle.Tls /// private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length) { - RecordLayerBuffer fragment = new RecordLayerBuffer(MESSAGE_HEADER_LENGTH + fragment_length); + RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length); TlsUtilities.WriteUint8(message.Type, fragment); TlsUtilities.WriteUint24(message.Body.Length, fragment); TlsUtilities.WriteUint16(message.Seq, fragment); diff --git a/crypto/src/tls/DtlsVerifier.cs b/crypto/src/tls/DtlsVerifier.cs index e691685e6..01437d648 100644 --- a/crypto/src/tls/DtlsVerifier.cs +++ b/crypto/src/tls/DtlsVerifier.cs @@ -1,89 +1,79 @@ -using System; -using System.IO; +using System.IO; +using Org.BouncyCastle.Security; using Org.BouncyCastle.Tls.Crypto; using Org.BouncyCastle.Utilities; namespace Org.BouncyCastle.Tls { + /// + /// Implements cookie generation/verification for a DTLS server as described in RFC 4347, + /// 4.2.1. Denial of Service Countermeasures. + /// + /// + /// RFC 4347 4.2.1 additionally recommends changing the secret frequently. This class does not handle that + /// internally, so the instance should be replaced instead. + /// public class DtlsVerifier { - private static TlsMac CreateCookieMac(TlsCrypto crypto) - { - TlsMac mac = crypto.CreateHmac(MacAlgorithm.hmac_sha256); - - byte[] secret = new byte[mac.MacLength]; - crypto.SecureRandom.NextBytes(secret); - - mac.SetKey(secret, 0, secret.Length); - - return mac; - } - - private readonly TlsMac m_cookieMac; - private readonly TlsMacSink m_cookieMacSink; + private readonly TlsCrypto m_crypto; + private readonly byte[] m_macKey; public DtlsVerifier(TlsCrypto crypto) { - this.m_cookieMac = CreateCookieMac(crypto); - this.m_cookieMacSink = new TlsMacSink(m_cookieMac); + m_crypto = crypto; + m_macKey = SecureRandom.GetNextBytes(crypto.SecureRandom, 32); } public virtual DtlsRequest VerifyRequest(byte[] clientID, byte[] data, int dataOff, int dataLen, DatagramSender sender) { - lock (this) + try { - bool resetCookieMac = true; + int msgLen = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen); + if (msgLen < 0) + return null; - try - { - m_cookieMac.Update(clientID, 0, clientID.Length); + int bodyLength = msgLen - DtlsReliableHandshake.MessageHeaderLength; + if (bodyLength < 39) // Minimum (syntactically) valid DTLS ClientHello length + return null; - DtlsRequest request = DtlsReliableHandshake.ReadClientRequest(data, dataOff, dataLen, - m_cookieMacSink); - if (null != request) - { - byte[] expectedCookie = m_cookieMac.CalculateMac(); - resetCookieMac = false; + int msgOff = dataOff + DtlsRecordLayer.RecordHeaderLength; - // TODO Consider stricter HelloVerifyRequest protocol - //switch (request.MessageSeq) - //{ - //case 0: - //{ - // DtlsReliableHandshake.SendHelloVerifyRequest(sender, request.RecordSeq, expectedCookie); - // break; - //} - //case 1: - //{ - // if (Arrays.FixedTimeEquals(expectedCookie, request.ClientHello.Cookie)) - // return request; + var buf = DtlsReliableHandshake.ReceiveClientHelloMessage(msg: data, msgOff, msgLen); + if (buf == null) + return null; - // break; - //} - //} + var macInput = new MemoryStream(bodyLength); + ClientHello clientHello = ClientHello.Parse(buf, dtlsOutput: macInput); + if (clientHello == null) + return null; - if (Arrays.FixedTimeEquals(expectedCookie, request.ClientHello.Cookie)) - return request; + long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5); - DtlsReliableHandshake.SendHelloVerifyRequest(sender, request.RecordSeq, expectedCookie); - } - } - catch (IOException) - { - // Ignore - } - finally + byte[] cookie = clientHello.Cookie; + + TlsMac mac = m_crypto.CreateHmac(MacAlgorithm.hmac_sha256); + mac.SetKey(m_macKey, 0, m_macKey.Length); + mac.Update(clientID, 0, clientID.Length); + macInput.WriteTo(new TlsMacSink(mac)); + byte[] expectedCookie = mac.CalculateMac(); + + if (Arrays.FixedTimeEquals(expectedCookie, cookie)) { - if (resetCookieMac) - { - m_cookieMac.Reset(); - } + byte[] message = TlsUtilities.CopyOfRangeExact(data, msgOff, msgOff + msgLen); + + return new DtlsRequest(recordSeq, message, clientHello); } - return null; + DtlsReliableHandshake.SendHelloVerifyRequest(sender, recordSeq, expectedCookie); + } + catch (IOException) + { + // Ignore } + + return null; } } } diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs index 6aa1acf2f..d26f60ef1 100644 --- a/crypto/src/tls/TlsClientProtocol.cs +++ b/crypto/src/tls/TlsClientProtocol.cs @@ -1771,7 +1771,7 @@ namespace Org.BouncyCastle.Tls int bindersSize = null == m_clientBinders ? 0 : m_clientBinders.m_bindersSize; this.m_clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, legacy_session_id, - null, offeredCipherSuites, m_clientExtensions, bindersSize); + cookie: null, offeredCipherSuites, m_clientExtensions, bindersSize); SendClientHelloMessage(); } diff --git a/crypto/test/src/tls/test/DtlsProtocolTest.cs b/crypto/test/src/tls/test/DtlsProtocolTest.cs index 388003666..7fc49fb51 100644 --- a/crypto/test/src/tls/test/DtlsProtocolTest.cs +++ b/crypto/test/src/tls/test/DtlsProtocolTest.cs @@ -1,4 +1,5 @@ using System; +using System.Text; using System.Threading; using NUnit.Framework; @@ -70,7 +71,36 @@ namespace Org.BouncyCastle.Tls.Tests try { MockDtlsServer server = new MockDtlsServer(); - DtlsTransport dtlsServer = m_serverProtocol.Accept(server, m_serverTransport); + + DtlsRequest request = null; + + // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting + { + DtlsVerifier verifier = new DtlsVerifier(server.Crypto); + + // NOTE: Test value only - would typically be the client IP address + byte[] clientID = Encoding.UTF8.GetBytes("MockDtlsClient"); + + int receiveLimit = m_serverTransport.GetReceiveLimit(); + int dummyOffset = server.Crypto.SecureRandom.Next(16) + 1; + byte[] transportBuf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()]; + + do + { + if (m_isShutdown) + return; + + int length = m_serverTransport.Receive(transportBuf, dummyOffset, receiveLimit, 1000); + if (length > 0) + { + request = verifier.VerifyRequest(clientID, transportBuf, dummyOffset, length, + m_serverTransport); + } + } + while (request == null); + } + + DtlsTransport dtlsServer = m_serverProtocol.Accept(server, m_serverTransport, request); byte[] buf = new byte[dtlsServer.GetReceiveLimit()]; while (!m_isShutdown) { -- cgit 1.4.1