From a575400e49d34228b3fed4f365a01b1ad03c3e1c Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Sun, 26 Mar 2023 00:22:51 +0700 Subject: RFC 9146: Add simple record callback for testing purposes --- crypto/src/tls/DtlsRecordLayer.cs | 51 +++++++++++++++++++++++++++++--------- crypto/src/tls/DtlsReplayWindow.cs | 6 ++++- crypto/src/tls/DtlsTransport.cs | 28 +++++++++++++++------ 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index a786da127..82fc3db64 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -274,8 +274,14 @@ namespace Org.BouncyCastle.Tls /// public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { + return Receive(buf, off, len, waitMillis, null); + } + + /// + internal int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback) + { #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - return Receive(buf.AsSpan(off, len), waitMillis); + return Receive(buf.AsSpan(off, len), waitMillis, recordCallback); #else long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); @@ -329,7 +335,7 @@ namespace Org.BouncyCastle.Tls } int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - int processed = ProcessRecord(received, record, buf, off, len); + int processed = ProcessRecord(received, record, buf, off, len, recordCallback); if (processed >= 0) return processed; @@ -342,10 +348,10 @@ namespace Org.BouncyCastle.Tls } /// - internal int ReceivePending(byte[] buf, int off, int len) + internal int ReceivePending(byte[] buf, int off, int len, Action recordCallback) { #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - return ReceivePending(buf.AsSpan(off, len)); + return ReceivePending(buf.AsSpan(off, len), recordCallback); #else if (m_recordQueue.Available > 0) { @@ -355,7 +361,7 @@ namespace Org.BouncyCastle.Tls do { int received = ReceivePendingRecord(record, 0, receiveLimit); - int processed = ProcessRecord(received, record, buf, off, len); + int processed = ProcessRecord(received, record, buf, off, len, recordCallback); if (processed >= 0) return processed; } @@ -369,6 +375,12 @@ namespace Org.BouncyCastle.Tls #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER /// public virtual int Receive(Span buffer, int waitMillis) + { + return Receive(buffer, waitMillis, null); + } + + /// + internal int Receive(Span buffer, int waitMillis, Action recordCallback) { long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); @@ -422,7 +434,7 @@ namespace Org.BouncyCastle.Tls } int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - int processed = ProcessRecord(received, record, buffer); + int processed = ProcessRecord(received, record, buffer, recordCallback); if (processed >= 0) return processed; @@ -434,7 +446,7 @@ namespace Org.BouncyCastle.Tls } /// - internal int ReceivePending(Span buffer) + internal int ReceivePending(Span buffer, Action recordCallback) { if (m_recordQueue.Available > 0) { @@ -444,7 +456,7 @@ namespace Org.BouncyCastle.Tls do { int received = ReceivePendingRecord(record, 0, receiveLimit); - int processed = ProcessRecord(received, record, buffer); + int processed = ProcessRecord(received, record, buffer, recordCallback); if (processed >= 0) return processed; } @@ -665,9 +677,9 @@ namespace Org.BouncyCastle.Tls // TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat /// #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - private int ProcessRecord(int received, byte[] record, Span buffer) + private int ProcessRecord(int received, byte[] record, Span buffer, Action recordCallback) #else - private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len) + private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len, Action recordCallback) #endif { // NOTE: received < 0 (timeout) is covered by this first case @@ -772,8 +784,6 @@ namespace Org.BouncyCastle.Tls return -1; } - recordEpoch.ReplayWindow.ReportAuthenticated(seq); - if (decoded.len > m_plaintextLimit) return -1; @@ -805,6 +815,23 @@ namespace Org.BouncyCastle.Tls } } + recordEpoch.ReplayWindow.ReportAuthenticated(seq, out var isLatestConfirmed); + + /* + * NOTE: The record has passed record layer validation and will be dispatched according to the decoded + * content type. + */ + if (recordCallback != null) + { + // TODO Make the callback more general than just peer address update + if (ContentType.tls12_cid == recordType && + isLatestConfirmed && + recordEpoch == m_readEpoch) + { + recordCallback(); + } + } + switch (decoded.contentType) { case ContentType.alert: diff --git a/crypto/src/tls/DtlsReplayWindow.cs b/crypto/src/tls/DtlsReplayWindow.cs index a08114c2a..5267c5934 100644 --- a/crypto/src/tls/DtlsReplayWindow.cs +++ b/crypto/src/tls/DtlsReplayWindow.cs @@ -42,7 +42,9 @@ namespace Org.BouncyCastle.Tls /// Report that a received record with the given sequence number passed authentication checks. /// /// the 48-bit DTLSPlainText.sequence_number field of an authenticated record. - internal void ReportAuthenticated(long seq) + /// indicates whether is now the latest confirmed + /// sequence number. + internal void ReportAuthenticated(long seq, out bool isLatestConfirmed) { if ((seq & ValidSeqMask) != seq) throw new ArgumentException("out of range", "seq"); @@ -54,6 +56,7 @@ namespace Org.BouncyCastle.Tls { m_bitmap |= (1UL << (int)diff); } + isLatestConfirmed = false; } else { @@ -68,6 +71,7 @@ namespace Org.BouncyCastle.Tls m_bitmap |= 1UL; } m_latestConfirmedSeq = seq; + isLatestConfirmed = true; } } diff --git a/crypto/src/tls/DtlsTransport.cs b/crypto/src/tls/DtlsTransport.cs index 30cd364d2..b452b8c89 100644 --- a/crypto/src/tls/DtlsTransport.cs +++ b/crypto/src/tls/DtlsTransport.cs @@ -30,6 +30,12 @@ namespace Org.BouncyCastle.Tls /// public virtual int Receive(byte[] buf, int off, int len, int waitMillis) + { + return Receive(buf, off, len, waitMillis, null); + } + + /// + public virtual int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback) { if (null == buf) throw new ArgumentNullException("buf"); @@ -39,14 +45,14 @@ namespace Org.BouncyCastle.Tls throw new ArgumentException("invalid length: " + len, "len"); #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - return Receive(buf.AsSpan(off, len), waitMillis); + return Receive(buf.AsSpan(off, len), waitMillis, recordCallback); #else if (waitMillis < 0) throw new ArgumentException("cannot be negative", "waitMillis"); try { - return m_recordLayer.Receive(buf, off, len, waitMillis); + return m_recordLayer.Receive(buf, off, len, waitMillis, recordCallback); } catch (TlsFatalAlert fatalAlert) { @@ -87,7 +93,7 @@ namespace Org.BouncyCastle.Tls } /// - public virtual int ReceivePending(byte[] buf, int off, int len) + public virtual int ReceivePending(byte[] buf, int off, int len, Action recordCallback = null) { if (null == buf) throw new ArgumentNullException("buf"); @@ -97,11 +103,11 @@ namespace Org.BouncyCastle.Tls throw new ArgumentException("invalid length: " + len, "len"); #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - return ReceivePending(buf.AsSpan(off, len)); + return ReceivePending(buf.AsSpan(off, len), recordCallback); #else try { - return m_recordLayer.ReceivePending(buf, off, len); + return m_recordLayer.ReceivePending(buf, off, len, recordCallback); } catch (TlsFatalAlert fatalAlert) { @@ -144,13 +150,19 @@ namespace Org.BouncyCastle.Tls #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER /// public virtual int Receive(Span buffer, int waitMillis) + { + return Receive(buffer, waitMillis, null); + } + + /// + public virtual int Receive(Span buffer, int waitMillis, Action recordCallback) { if (waitMillis < 0) throw new ArgumentException("cannot be negative", nameof(waitMillis)); try { - return m_recordLayer.Receive(buffer, waitMillis); + return m_recordLayer.Receive(buffer, waitMillis, recordCallback); } catch (TlsFatalAlert fatalAlert) { @@ -190,11 +202,11 @@ namespace Org.BouncyCastle.Tls } /// - public virtual int ReceivePending(Span buffer) + public virtual int ReceivePending(Span buffer, Action recordCallback = null) { try { - return m_recordLayer.ReceivePending(buffer); + return m_recordLayer.ReceivePending(buffer, recordCallback); } catch (TlsFatalAlert fatalAlert) { -- cgit 1.4.1