diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index a786da127..82fc3db64 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -274,8 +274,14 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
{
+ return Receive(buf, off, len, waitMillis, null);
+ }
+
+ /// <exception cref="IOException"/>
+ internal int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback)
+ {
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
- return Receive(buf.AsSpan(off, len), waitMillis);
+ return Receive(buf.AsSpan(off, len), waitMillis, recordCallback);
#else
long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
@@ -329,7 +335,7 @@ namespace Org.BouncyCastle.Tls
}
int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
- int processed = ProcessRecord(received, record, buf, off, len);
+ int processed = ProcessRecord(received, record, buf, off, len, recordCallback);
if (processed >= 0)
return processed;
@@ -342,10 +348,10 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- internal int ReceivePending(byte[] buf, int off, int len)
+ internal int ReceivePending(byte[] buf, int off, int len, Action recordCallback)
{
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
- return ReceivePending(buf.AsSpan(off, len));
+ return ReceivePending(buf.AsSpan(off, len), recordCallback);
#else
if (m_recordQueue.Available > 0)
{
@@ -355,7 +361,7 @@ namespace Org.BouncyCastle.Tls
do
{
int received = ReceivePendingRecord(record, 0, receiveLimit);
- int processed = ProcessRecord(received, record, buf, off, len);
+ int processed = ProcessRecord(received, record, buf, off, len, recordCallback);
if (processed >= 0)
return processed;
}
@@ -370,6 +376,12 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
public virtual int Receive(Span<byte> buffer, int waitMillis)
{
+ return Receive(buffer, waitMillis, null);
+ }
+
+ /// <exception cref="IOException"/>
+ internal int Receive(Span<byte> buffer, int waitMillis, Action recordCallback)
+ {
long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis);
@@ -422,7 +434,7 @@ namespace Org.BouncyCastle.Tls
}
int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
- int processed = ProcessRecord(received, record, buffer);
+ int processed = ProcessRecord(received, record, buffer, recordCallback);
if (processed >= 0)
return processed;
@@ -434,7 +446,7 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- internal int ReceivePending(Span<byte> buffer)
+ internal int ReceivePending(Span<byte> buffer, Action recordCallback)
{
if (m_recordQueue.Available > 0)
{
@@ -444,7 +456,7 @@ namespace Org.BouncyCastle.Tls
do
{
int received = ReceivePendingRecord(record, 0, receiveLimit);
- int processed = ProcessRecord(received, record, buffer);
+ int processed = ProcessRecord(received, record, buffer, recordCallback);
if (processed >= 0)
return processed;
}
@@ -665,9 +677,9 @@ namespace Org.BouncyCastle.Tls
// TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat
/// <exception cref="IOException"/>
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
- private int ProcessRecord(int received, byte[] record, Span<byte> buffer)
+ private int ProcessRecord(int received, byte[] record, Span<byte> buffer, Action recordCallback)
#else
- private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len)
+ private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len, Action recordCallback)
#endif
{
// NOTE: received < 0 (timeout) is covered by this first case
@@ -772,8 +784,6 @@ namespace Org.BouncyCastle.Tls
return -1;
}
- recordEpoch.ReplayWindow.ReportAuthenticated(seq);
-
if (decoded.len > m_plaintextLimit)
return -1;
@@ -805,6 +815,23 @@ namespace Org.BouncyCastle.Tls
}
}
+ recordEpoch.ReplayWindow.ReportAuthenticated(seq, out var isLatestConfirmed);
+
+ /*
+ * NOTE: The record has passed record layer validation and will be dispatched according to the decoded
+ * content type.
+ */
+ if (recordCallback != null)
+ {
+ // TODO Make the callback more general than just peer address update
+ if (ContentType.tls12_cid == recordType &&
+ isLatestConfirmed &&
+ recordEpoch == m_readEpoch)
+ {
+ recordCallback();
+ }
+ }
+
switch (decoded.contentType)
{
case ContentType.alert:
diff --git a/crypto/src/tls/DtlsReplayWindow.cs b/crypto/src/tls/DtlsReplayWindow.cs
index a08114c2a..5267c5934 100644
--- a/crypto/src/tls/DtlsReplayWindow.cs
+++ b/crypto/src/tls/DtlsReplayWindow.cs
@@ -42,7 +42,9 @@ namespace Org.BouncyCastle.Tls
/// <summary>Report that a received record with the given sequence number passed authentication checks.
/// </summary>
/// <param name="seq">the 48-bit DTLSPlainText.sequence_number field of an authenticated record.</param>
- internal void ReportAuthenticated(long seq)
+ /// <param name="isLatestConfirmed">indicates whether <paramref name="seq"/> is now the latest confirmed
+ /// sequence number.</param>
+ internal void ReportAuthenticated(long seq, out bool isLatestConfirmed)
{
if ((seq & ValidSeqMask) != seq)
throw new ArgumentException("out of range", "seq");
@@ -54,6 +56,7 @@ namespace Org.BouncyCastle.Tls
{
m_bitmap |= (1UL << (int)diff);
}
+ isLatestConfirmed = false;
}
else
{
@@ -68,6 +71,7 @@ namespace Org.BouncyCastle.Tls
m_bitmap |= 1UL;
}
m_latestConfirmedSeq = seq;
+ isLatestConfirmed = true;
}
}
diff --git a/crypto/src/tls/DtlsTransport.cs b/crypto/src/tls/DtlsTransport.cs
index 30cd364d2..b452b8c89 100644
--- a/crypto/src/tls/DtlsTransport.cs
+++ b/crypto/src/tls/DtlsTransport.cs
@@ -31,6 +31,12 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
{
+ return Receive(buf, off, len, waitMillis, null);
+ }
+
+ /// <exception cref="IOException"/>
+ public virtual int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback)
+ {
if (null == buf)
throw new ArgumentNullException("buf");
if (off < 0 || off >= buf.Length)
@@ -39,14 +45,14 @@ namespace Org.BouncyCastle.Tls
throw new ArgumentException("invalid length: " + len, "len");
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
- return Receive(buf.AsSpan(off, len), waitMillis);
+ return Receive(buf.AsSpan(off, len), waitMillis, recordCallback);
#else
if (waitMillis < 0)
throw new ArgumentException("cannot be negative", "waitMillis");
try
{
- return m_recordLayer.Receive(buf, off, len, waitMillis);
+ return m_recordLayer.Receive(buf, off, len, waitMillis, recordCallback);
}
catch (TlsFatalAlert fatalAlert)
{
@@ -87,7 +93,7 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- public virtual int ReceivePending(byte[] buf, int off, int len)
+ public virtual int ReceivePending(byte[] buf, int off, int len, Action recordCallback = null)
{
if (null == buf)
throw new ArgumentNullException("buf");
@@ -97,11 +103,11 @@ namespace Org.BouncyCastle.Tls
throw new ArgumentException("invalid length: " + len, "len");
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
- return ReceivePending(buf.AsSpan(off, len));
+ return ReceivePending(buf.AsSpan(off, len), recordCallback);
#else
try
{
- return m_recordLayer.ReceivePending(buf, off, len);
+ return m_recordLayer.ReceivePending(buf, off, len, recordCallback);
}
catch (TlsFatalAlert fatalAlert)
{
@@ -145,12 +151,18 @@ namespace Org.BouncyCastle.Tls
/// <exception cref="IOException"/>
public virtual int Receive(Span<byte> buffer, int waitMillis)
{
+ return Receive(buffer, waitMillis, null);
+ }
+
+ /// <exception cref="IOException"/>
+ public virtual int Receive(Span<byte> buffer, int waitMillis, Action recordCallback)
+ {
if (waitMillis < 0)
throw new ArgumentException("cannot be negative", nameof(waitMillis));
try
{
- return m_recordLayer.Receive(buffer, waitMillis);
+ return m_recordLayer.Receive(buffer, waitMillis, recordCallback);
}
catch (TlsFatalAlert fatalAlert)
{
@@ -190,11 +202,11 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- public virtual int ReceivePending(Span<byte> buffer)
+ public virtual int ReceivePending(Span<byte> buffer, Action recordCallback = null)
{
try
{
- return m_recordLayer.ReceivePending(buffer);
+ return m_recordLayer.ReceivePending(buffer, recordCallback);
}
catch (TlsFatalAlert fatalAlert)
{
|