summary refs log tree commit diff
path: root/crypto/src/tls/DtlsRecordLayer.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/DtlsRecordLayer.cs')
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs51
1 files changed, 39 insertions, 12 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index a786da127..82fc3db64 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -274,8 +274,14 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
+            return Receive(buf, off, len, waitMillis, null);
+        }
+
+        /// <exception cref="IOException"/>
+        internal int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback)
+        {
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-            return Receive(buf.AsSpan(off, len), waitMillis);
+            return Receive(buf.AsSpan(off, len), waitMillis, recordCallback);
 #else
             long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
@@ -329,7 +335,7 @@ namespace Org.BouncyCastle.Tls
                 }
 
                 int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
-                int processed = ProcessRecord(received, record, buf, off, len);
+                int processed = ProcessRecord(received, record, buf, off, len, recordCallback);
                 if (processed >= 0)
                     return processed;
 
@@ -342,10 +348,10 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        internal int ReceivePending(byte[] buf, int off, int len)
+        internal int ReceivePending(byte[] buf, int off, int len, Action recordCallback)
         {
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-            return ReceivePending(buf.AsSpan(off, len));
+            return ReceivePending(buf.AsSpan(off, len), recordCallback);
 #else
             if (m_recordQueue.Available > 0)
             {
@@ -355,7 +361,7 @@ namespace Org.BouncyCastle.Tls
                 do
                 {
                     int received = ReceivePendingRecord(record, 0, receiveLimit);
-                    int processed = ProcessRecord(received, record, buf, off, len);
+                    int processed = ProcessRecord(received, record, buf, off, len, recordCallback);
                     if (processed >= 0)
                         return processed;
                 }
@@ -370,6 +376,12 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int Receive(Span<byte> buffer, int waitMillis)
         {
+            return Receive(buffer, waitMillis, null);
+        }
+
+        /// <exception cref="IOException"/>
+        internal int Receive(Span<byte> buffer, int waitMillis, Action recordCallback)
+        {
             long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
             Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis);
@@ -422,7 +434,7 @@ namespace Org.BouncyCastle.Tls
                 }
 
                 int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
-                int processed = ProcessRecord(received, record, buffer);
+                int processed = ProcessRecord(received, record, buffer, recordCallback);
                 if (processed >= 0)
                     return processed;
 
@@ -434,7 +446,7 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        internal int ReceivePending(Span<byte> buffer)
+        internal int ReceivePending(Span<byte> buffer, Action recordCallback)
         {
             if (m_recordQueue.Available > 0)
             {
@@ -444,7 +456,7 @@ namespace Org.BouncyCastle.Tls
                 do
                 {
                     int received = ReceivePendingRecord(record, 0, receiveLimit);
-                    int processed = ProcessRecord(received, record, buffer);
+                    int processed = ProcessRecord(received, record, buffer, recordCallback);
                     if (processed >= 0)
                         return processed;
                 }
@@ -665,9 +677,9 @@ namespace Org.BouncyCastle.Tls
         // TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat
         /// <exception cref="IOException"/>
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-        private int ProcessRecord(int received, byte[] record, Span<byte> buffer)
+        private int ProcessRecord(int received, byte[] record, Span<byte> buffer, Action recordCallback)
 #else
-        private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len)
+        private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len, Action recordCallback)
 #endif
         {
             // NOTE: received < 0 (timeout) is covered by this first case
@@ -772,8 +784,6 @@ namespace Org.BouncyCastle.Tls
                 return -1;
             }
 
-            recordEpoch.ReplayWindow.ReportAuthenticated(seq);
-
             if (decoded.len > m_plaintextLimit)
                 return -1;
 
@@ -805,6 +815,23 @@ namespace Org.BouncyCastle.Tls
                 }
             }
 
+            recordEpoch.ReplayWindow.ReportAuthenticated(seq, out var isLatestConfirmed);
+
+            /*
+             * NOTE: The record has passed record layer validation and will be dispatched according to the decoded
+             * content type.
+             */
+            if (recordCallback != null)
+            {
+                // TODO Make the callback more general than just peer address update
+                if (ContentType.tls12_cid == recordType &&
+                    isLatestConfirmed &&
+                    recordEpoch == m_readEpoch)
+                {
+                    recordCallback();
+                }
+            }
+
             switch (decoded.contentType)
             {
             case ContentType.alert: