diff --git a/crypto/src/tls/DtlsClientProtocol.cs b/crypto/src/tls/DtlsClientProtocol.cs
index 72484e178..c1bad2e6f 100644
--- a/crypto/src/tls/DtlsClientProtocol.cs
+++ b/crypto/src/tls/DtlsClientProtocol.cs
@@ -525,7 +525,7 @@ namespace Org.BouncyCastle.Tls
ClientHello clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, session_id,
- TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0);
+ cookie: TlsUtilities.EmptyBytes, state.offeredCipherSuites, state.clientExtensions, 0);
MemoryStream buf = new MemoryStream();
clientHello.Encode(state.clientContext, buf);
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index efe9e7312..e3567aa46 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -4,7 +4,6 @@ using System.IO;
using System.Net.Sockets;
using Org.BouncyCastle.Tls.Crypto;
-using Org.BouncyCastle.Tls.Crypto.Impl;
using Org.BouncyCastle.Utilities;
using Org.BouncyCastle.Utilities.Date;
@@ -13,43 +12,45 @@ namespace Org.BouncyCastle.Tls
internal class DtlsRecordLayer
: DatagramTransport
{
- private const int RECORD_HEADER_LENGTH = 13;
+ internal const int RecordHeaderLength = 13;
+
private const int MAX_FRAGMENT_LENGTH = 1 << 14;
private const long TCP_MSL = 1000L * 60 * 2;
private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
/// <exception cref="IOException"/>
- internal static byte[] ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen)
+ internal static int ReceiveClientHelloRecord(byte[] data, int dataOff, int dataLen)
{
- if (dataLen < RECORD_HEADER_LENGTH)
- {
- return null;
- }
+ if (dataLen < RecordHeaderLength)
+ return -1;
short contentType = TlsUtilities.ReadUint8(data, dataOff + 0);
if (ContentType.handshake != contentType)
- return null;
+ return -1;
ProtocolVersion version = TlsUtilities.ReadVersion(data, dataOff + 1);
if (!ProtocolVersion.DTLSv10.IsEqualOrEarlierVersionOf(version))
- return null;
+ return -1;
int epoch = TlsUtilities.ReadUint16(data, dataOff + 3);
if (0 != epoch)
- return null;
+ return -1;
//long sequenceNumber = TlsUtilities.ReadUint48(data, dataOff + 5);
int length = TlsUtilities.ReadUint16(data, dataOff + 11);
- if (dataLen < RECORD_HEADER_LENGTH + length)
- return null;
+ if (length < 1 || length > MAX_FRAGMENT_LENGTH)
+ return -1;
- if (length > MAX_FRAGMENT_LENGTH)
- return null;
+ if (dataLen < RecordHeaderLength + length)
+ return -1;
+
+ short msgType = TlsUtilities.ReadUint8(data, dataOff + RecordHeaderLength);
+ if (HandshakeType.client_hello != msgType)
+ return -1;
// NOTE: We ignore/drop any data after the first record
- return TlsUtilities.CopyOfRangeExact(data, dataOff + RECORD_HEADER_LENGTH,
- dataOff + RECORD_HEADER_LENGTH + length);
+ return length;
}
/// <exception cref="IOException"/>
@@ -57,14 +58,14 @@ namespace Org.BouncyCastle.Tls
{
TlsUtilities.CheckUint16(message.Length);
- byte[] record = new byte[RECORD_HEADER_LENGTH + message.Length];
+ byte[] record = new byte[RecordHeaderLength + message.Length];
TlsUtilities.WriteUint8(ContentType.handshake, record, 0);
TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, record, 1);
TlsUtilities.WriteUint16(0, record, 3);
TlsUtilities.WriteUint48(recordSeq, record, 5);
TlsUtilities.WriteUint16(message.Length, record, 11);
- Array.Copy(message, 0, record, RECORD_HEADER_LENGTH, message.Length);
+ Array.Copy(message, 0, record, RecordHeaderLength, message.Length);
SendDatagram(sender, record, 0, record.Length);
}
@@ -124,8 +125,8 @@ namespace Org.BouncyCastle.Tls
this.m_inHandshake = true;
- this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RECORD_HEADER_LENGTH,
- RECORD_HEADER_LENGTH);
+ this.m_currentEpoch = new DtlsEpoch(0, TlsNullNullCipher.Instance, RecordHeaderLength,
+ RecordHeaderLength);
this.m_pendingEpoch = null;
this.m_readEpoch = m_currentEpoch;
this.m_writeEpoch = m_currentEpoch;
@@ -179,8 +180,8 @@ namespace Org.BouncyCastle.Tls
*/
var securityParameters = m_context.SecurityParameters;
- int recordHeaderLengthRead = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDPeer?.Length ?? 0);
- int recordHeaderLengthWrite = RECORD_HEADER_LENGTH + (securityParameters.ConnectionIDLocal?.Length ?? 0);
+ int recordHeaderLengthRead = RecordHeaderLength + (securityParameters.ConnectionIDPeer?.Length ?? 0);
+ int recordHeaderLengthWrite = RecordHeaderLength + (securityParameters.ConnectionIDLocal?.Length ?? 0);
// TODO Check for overflow
this.m_pendingEpoch = new DtlsEpoch(m_writeEpoch.Epoch + 1, pendingCipher, recordHeaderLengthRead,
@@ -684,7 +685,7 @@ namespace Org.BouncyCastle.Tls
#endif
{
// NOTE: received < 0 (timeout) is covered by this first case
- if (received < RECORD_HEADER_LENGTH)
+ if (received < RecordHeaderLength)
return -1;
// TODO[dtls13] Deal with opaque record type for 1.3 AEAD ciphers
@@ -729,7 +730,7 @@ namespace Org.BouncyCastle.Tls
int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
- if (recordHeaderLength > RECORD_HEADER_LENGTH)
+ if (recordHeaderLength > RecordHeaderLength)
{
if (ContentType.tls12_cid != recordType)
return -1;
@@ -990,7 +991,7 @@ namespace Org.BouncyCastle.Tls
{
Debug.Assert(m_recordQueue.Available > 0);
- int recordLength = RECORD_HEADER_LENGTH;
+ int recordLength = RecordHeaderLength;
if (m_recordQueue.Available >= recordLength)
{
short recordType = m_recordQueue.ReadUint8(0);
@@ -1033,7 +1034,7 @@ namespace Org.BouncyCastle.Tls
return ReceivePendingRecord(buf, off, len);
int received = ReceiveDatagram(buf, off, len, waitMillis);
- if (received >= RECORD_HEADER_LENGTH)
+ if (received >= RecordHeaderLength)
{
this.m_inConnection = true;
@@ -1151,7 +1152,7 @@ namespace Org.BouncyCastle.Tls
TlsUtilities.WriteUint16(recordEpoch, encoded.buf, encoded.off + 3);
TlsUtilities.WriteUint48(recordSequenceNumber, encoded.buf, encoded.off + 5);
- if (recordHeaderLength > RECORD_HEADER_LENGTH)
+ if (recordHeaderLength > RecordHeaderLength)
{
byte[] connectionID = m_context.SecurityParameters.ConnectionIDLocal;
Array.Copy(connectionID, 0, encoded.buf, encoded.off + 11, connectionID.Length);
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs
index 42a98a991..b1107f7a1 100644
--- a/crypto/src/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/tls/DtlsReliableHandshake.cs
@@ -8,47 +8,41 @@ namespace Org.BouncyCastle.Tls
{
internal class DtlsReliableHandshake
{
- private const int MAX_RECEIVE_AHEAD = 16;
- private const int MESSAGE_HEADER_LENGTH = 12;
+ internal const int MessageHeaderLength = 12;
+ private const int MAX_RECEIVE_AHEAD = 16;
private const int MAX_RESEND_MILLIS = 60000;
/// <exception cref="IOException"/>
- internal static DtlsRequest ReadClientRequest(byte[] data, int dataOff, int dataLen, Stream dtlsOutput)
+ internal static MemoryStream ReceiveClientHelloMessage(byte[] msg, int msgOff, int msgLen)
{
// TODO Support the possibility of a fragmented ClientHello datagram
- byte[] message = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen);
- if (null == message || message.Length < MESSAGE_HEADER_LENGTH)
+ if (msgLen < MessageHeaderLength)
return null;
- long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5);
-
- short msgType = TlsUtilities.ReadUint8(message, 0);
+ short msgType = TlsUtilities.ReadUint8(msg, msgOff);
if (HandshakeType.client_hello != msgType)
return null;
- int length = TlsUtilities.ReadUint24(message, 1);
- if (message.Length != MESSAGE_HEADER_LENGTH + length)
+ int length = TlsUtilities.ReadUint24(msg, msgOff + 1);
+ if (msgLen != MessageHeaderLength + length)
return null;
// TODO Consider stricter HelloVerifyRequest-related checks
- //int messageSeq = TlsUtilities.ReadUint16(message, 4);
+ //int messageSeq = TlsUtilities.ReadUint16(msg, msgOff + 4);
//if (messageSeq > 1)
// return null;
- int fragmentOffset = TlsUtilities.ReadUint24(message, 6);
+ int fragmentOffset = TlsUtilities.ReadUint24(msg, msgOff + 6);
if (0 != fragmentOffset)
return null;
- int fragmentLength = TlsUtilities.ReadUint24(message, 9);
+ int fragmentLength = TlsUtilities.ReadUint24(msg, msgOff + 9);
if (length != fragmentLength)
return null;
- ClientHello clientHello = ClientHello.Parse(
- new MemoryStream(message, MESSAGE_HEADER_LENGTH, length, false), dtlsOutput);
-
- return new DtlsRequest(recordSeq, message, clientHello);
+ return new MemoryStream(msg, msgOff + MessageHeaderLength, length, false);
}
/// <exception cref="IOException"/>
@@ -58,7 +52,7 @@ namespace Org.BouncyCastle.Tls
int length = 3 + cookie.Length;
- byte[] message = new byte[MESSAGE_HEADER_LENGTH + length];
+ byte[] message = new byte[MessageHeaderLength + length];
TlsUtilities.WriteUint8(HandshakeType.hello_verify_request, message, 0);
TlsUtilities.WriteUint24(length, message, 1);
//TlsUtilities.WriteUint16(0, message, 4);
@@ -66,8 +60,8 @@ namespace Org.BouncyCastle.Tls
TlsUtilities.WriteUint24(length, message, 9);
// HelloVerifyRequest fields
- TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MESSAGE_HEADER_LENGTH + 0);
- TlsUtilities.WriteOpaque8(cookie, message, MESSAGE_HEADER_LENGTH + 2);
+ TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MessageHeaderLength + 0);
+ TlsUtilities.WriteOpaque8(cookie, message, MessageHeaderLength + 2);
DtlsRecordLayer.SendHelloVerifyRequestRecord(sender, recordSeq, message);
}
@@ -111,7 +105,7 @@ namespace Org.BouncyCastle.Tls
// Simulate a previous flight consisting of the request ClientHello
DtlsReassembler reassembler = new DtlsReassembler(HandshakeType.client_hello,
- message.Length - MESSAGE_HEADER_LENGTH);
+ message.Length - MessageHeaderLength);
m_currentInboundFlight[messageSeq] = reassembler;
// We sent HelloVerifyRequest with (message) sequence number 0
@@ -215,7 +209,7 @@ namespace Org.BouncyCastle.Tls
default:
{
byte[] body = message.Body;
- byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
+ byte[] buf = new byte[MessageHeaderLength];
TlsUtilities.WriteUint8(msg_type, buf, 0);
TlsUtilities.WriteUint24(body.Length, buf, 1);
TlsUtilities.WriteUint16(message.Seq, buf, 4);
@@ -360,10 +354,10 @@ namespace Org.BouncyCastle.Tls
{
bool checkPreviousFlight = false;
- while (len >= MESSAGE_HEADER_LENGTH)
+ while (len >= MessageHeaderLength)
{
int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
- int message_length = fragment_length + MESSAGE_HEADER_LENGTH;
+ int message_length = fragment_length + MessageHeaderLength;
if (len < message_length)
{
// NOTE: Truncated message - ignore it
@@ -400,7 +394,7 @@ namespace Org.BouncyCastle.Tls
m_currentInboundFlight[message_seq] = reassembler;
}
- reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, fragment_offset,
+ reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
fragment_length);
}
else if (m_previousInboundFlight != null)
@@ -412,7 +406,7 @@ namespace Org.BouncyCastle.Tls
if (m_previousInboundFlight.TryGetValue(message_seq, out var reassembler))
{
- reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH,
+ reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength,
fragment_offset, fragment_length);
checkPreviousFlight = true;
}
@@ -446,7 +440,7 @@ namespace Org.BouncyCastle.Tls
private void WriteMessage(Message message)
{
int sendLimit = m_recordLayer.GetSendLimit();
- int fragmentLimit = sendLimit - MESSAGE_HEADER_LENGTH;
+ int fragmentLimit = sendLimit - MessageHeaderLength;
// TODO Support a higher minimum fragment size?
if (fragmentLimit < 1)
@@ -471,7 +465,7 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
{
- RecordLayerBuffer fragment = new RecordLayerBuffer(MESSAGE_HEADER_LENGTH + fragment_length);
+ RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length);
TlsUtilities.WriteUint8(message.Type, fragment);
TlsUtilities.WriteUint24(message.Body.Length, fragment);
TlsUtilities.WriteUint16(message.Seq, fragment);
diff --git a/crypto/src/tls/DtlsVerifier.cs b/crypto/src/tls/DtlsVerifier.cs
index e691685e6..01437d648 100644
--- a/crypto/src/tls/DtlsVerifier.cs
+++ b/crypto/src/tls/DtlsVerifier.cs
@@ -1,89 +1,79 @@
-using System;
-using System.IO;
+using System.IO;
+using Org.BouncyCastle.Security;
using Org.BouncyCastle.Tls.Crypto;
using Org.BouncyCastle.Utilities;
namespace Org.BouncyCastle.Tls
{
+ /// <summary>
+ /// Implements cookie generation/verification for a DTLS server as described in RFC 4347,
+ /// 4.2.1. Denial of Service Countermeasures.
+ /// </summary>
+ /// <remarks>
+ /// RFC 4347 4.2.1 additionally recommends changing the secret frequently. This class does not handle that
+ /// internally, so the instance should be replaced instead.
+ /// </remarks>
public class DtlsVerifier
{
- private static TlsMac CreateCookieMac(TlsCrypto crypto)
- {
- TlsMac mac = crypto.CreateHmac(MacAlgorithm.hmac_sha256);
-
- byte[] secret = new byte[mac.MacLength];
- crypto.SecureRandom.NextBytes(secret);
-
- mac.SetKey(secret, 0, secret.Length);
-
- return mac;
- }
-
- private readonly TlsMac m_cookieMac;
- private readonly TlsMacSink m_cookieMacSink;
+ private readonly TlsCrypto m_crypto;
+ private readonly byte[] m_macKey;
public DtlsVerifier(TlsCrypto crypto)
{
- this.m_cookieMac = CreateCookieMac(crypto);
- this.m_cookieMacSink = new TlsMacSink(m_cookieMac);
+ m_crypto = crypto;
+ m_macKey = SecureRandom.GetNextBytes(crypto.SecureRandom, 32);
}
public virtual DtlsRequest VerifyRequest(byte[] clientID, byte[] data, int dataOff, int dataLen,
DatagramSender sender)
{
- lock (this)
+ try
{
- bool resetCookieMac = true;
+ int msgLen = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen);
+ if (msgLen < 0)
+ return null;
- try
- {
- m_cookieMac.Update(clientID, 0, clientID.Length);
+ int bodyLength = msgLen - DtlsReliableHandshake.MessageHeaderLength;
+ if (bodyLength < 39) // Minimum (syntactically) valid DTLS ClientHello length
+ return null;
- DtlsRequest request = DtlsReliableHandshake.ReadClientRequest(data, dataOff, dataLen,
- m_cookieMacSink);
- if (null != request)
- {
- byte[] expectedCookie = m_cookieMac.CalculateMac();
- resetCookieMac = false;
+ int msgOff = dataOff + DtlsRecordLayer.RecordHeaderLength;
- // TODO Consider stricter HelloVerifyRequest protocol
- //switch (request.MessageSeq)
- //{
- //case 0:
- //{
- // DtlsReliableHandshake.SendHelloVerifyRequest(sender, request.RecordSeq, expectedCookie);
- // break;
- //}
- //case 1:
- //{
- // if (Arrays.FixedTimeEquals(expectedCookie, request.ClientHello.Cookie))
- // return request;
+ var buf = DtlsReliableHandshake.ReceiveClientHelloMessage(msg: data, msgOff, msgLen);
+ if (buf == null)
+ return null;
- // break;
- //}
- //}
+ var macInput = new MemoryStream(bodyLength);
+ ClientHello clientHello = ClientHello.Parse(buf, dtlsOutput: macInput);
+ if (clientHello == null)
+ return null;
- if (Arrays.FixedTimeEquals(expectedCookie, request.ClientHello.Cookie))
- return request;
+ long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5);
- DtlsReliableHandshake.SendHelloVerifyRequest(sender, request.RecordSeq, expectedCookie);
- }
- }
- catch (IOException)
- {
- // Ignore
- }
- finally
+ byte[] cookie = clientHello.Cookie;
+
+ TlsMac mac = m_crypto.CreateHmac(MacAlgorithm.hmac_sha256);
+ mac.SetKey(m_macKey, 0, m_macKey.Length);
+ mac.Update(clientID, 0, clientID.Length);
+ macInput.WriteTo(new TlsMacSink(mac));
+ byte[] expectedCookie = mac.CalculateMac();
+
+ if (Arrays.FixedTimeEquals(expectedCookie, cookie))
{
- if (resetCookieMac)
- {
- m_cookieMac.Reset();
- }
+ byte[] message = TlsUtilities.CopyOfRangeExact(data, msgOff, msgOff + msgLen);
+
+ return new DtlsRequest(recordSeq, message, clientHello);
}
- return null;
+ DtlsReliableHandshake.SendHelloVerifyRequest(sender, recordSeq, expectedCookie);
+ }
+ catch (IOException)
+ {
+ // Ignore
}
+
+ return null;
}
}
}
diff --git a/crypto/src/tls/TlsClientProtocol.cs b/crypto/src/tls/TlsClientProtocol.cs
index 6aa1acf2f..d26f60ef1 100644
--- a/crypto/src/tls/TlsClientProtocol.cs
+++ b/crypto/src/tls/TlsClientProtocol.cs
@@ -1771,7 +1771,7 @@ namespace Org.BouncyCastle.Tls
int bindersSize = null == m_clientBinders ? 0 : m_clientBinders.m_bindersSize;
this.m_clientHello = new ClientHello(legacy_version, securityParameters.ClientRandom, legacy_session_id,
- null, offeredCipherSuites, m_clientExtensions, bindersSize);
+ cookie: null, offeredCipherSuites, m_clientExtensions, bindersSize);
SendClientHelloMessage();
}
diff --git a/crypto/test/src/tls/test/DtlsProtocolTest.cs b/crypto/test/src/tls/test/DtlsProtocolTest.cs
index 388003666..7fc49fb51 100644
--- a/crypto/test/src/tls/test/DtlsProtocolTest.cs
+++ b/crypto/test/src/tls/test/DtlsProtocolTest.cs
@@ -1,4 +1,5 @@
using System;
+using System.Text;
using System.Threading;
using NUnit.Framework;
@@ -70,7 +71,36 @@ namespace Org.BouncyCastle.Tls.Tests
try
{
MockDtlsServer server = new MockDtlsServer();
- DtlsTransport dtlsServer = m_serverProtocol.Accept(server, m_serverTransport);
+
+ DtlsRequest request = null;
+
+ // Use DtlsVerifier to require a HelloVerifyRequest cookie exchange before accepting
+ {
+ DtlsVerifier verifier = new DtlsVerifier(server.Crypto);
+
+ // NOTE: Test value only - would typically be the client IP address
+ byte[] clientID = Encoding.UTF8.GetBytes("MockDtlsClient");
+
+ int receiveLimit = m_serverTransport.GetReceiveLimit();
+ int dummyOffset = server.Crypto.SecureRandom.Next(16) + 1;
+ byte[] transportBuf = new byte[dummyOffset + m_serverTransport.GetReceiveLimit()];
+
+ do
+ {
+ if (m_isShutdown)
+ return;
+
+ int length = m_serverTransport.Receive(transportBuf, dummyOffset, receiveLimit, 1000);
+ if (length > 0)
+ {
+ request = verifier.VerifyRequest(clientID, transportBuf, dummyOffset, length,
+ m_serverTransport);
+ }
+ }
+ while (request == null);
+ }
+
+ DtlsTransport dtlsServer = m_serverProtocol.Accept(server, m_serverTransport, request);
byte[] buf = new byte[dtlsServer.GetReceiveLimit()];
while (!m_isShutdown)
{
|