summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2023-03-26 00:22:51 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2023-04-13 17:16:20 +0700
commita575400e49d34228b3fed4f365a01b1ad03c3e1c (patch)
treee321e65398e4eb219b048b0e301ac8631e096108
parentRFC 9146: TODOs for API changes when possible (diff)
downloadBouncyCastle.NET-ed25519-a575400e49d34228b3fed4f365a01b1ad03c3e1c.tar.xz
RFC 9146: Add simple record callback for testing purposes
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs51
-rw-r--r--crypto/src/tls/DtlsReplayWindow.cs6
-rw-r--r--crypto/src/tls/DtlsTransport.cs28
3 files changed, 64 insertions, 21 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index a786da127..82fc3db64 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -274,8 +274,14 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
+            return Receive(buf, off, len, waitMillis, null);
+        }
+
+        /// <exception cref="IOException"/>
+        internal int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback)
+        {
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-            return Receive(buf.AsSpan(off, len), waitMillis);
+            return Receive(buf.AsSpan(off, len), waitMillis, recordCallback);
 #else
             long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
@@ -329,7 +335,7 @@ namespace Org.BouncyCastle.Tls
                 }
 
                 int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
-                int processed = ProcessRecord(received, record, buf, off, len);
+                int processed = ProcessRecord(received, record, buf, off, len, recordCallback);
                 if (processed >= 0)
                     return processed;
 
@@ -342,10 +348,10 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        internal int ReceivePending(byte[] buf, int off, int len)
+        internal int ReceivePending(byte[] buf, int off, int len, Action recordCallback)
         {
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-            return ReceivePending(buf.AsSpan(off, len));
+            return ReceivePending(buf.AsSpan(off, len), recordCallback);
 #else
             if (m_recordQueue.Available > 0)
             {
@@ -355,7 +361,7 @@ namespace Org.BouncyCastle.Tls
                 do
                 {
                     int received = ReceivePendingRecord(record, 0, receiveLimit);
-                    int processed = ProcessRecord(received, record, buf, off, len);
+                    int processed = ProcessRecord(received, record, buf, off, len, recordCallback);
                     if (processed >= 0)
                         return processed;
                 }
@@ -370,6 +376,12 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int Receive(Span<byte> buffer, int waitMillis)
         {
+            return Receive(buffer, waitMillis, null);
+        }
+
+        /// <exception cref="IOException"/>
+        internal int Receive(Span<byte> buffer, int waitMillis, Action recordCallback)
+        {
             long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
             Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis);
@@ -422,7 +434,7 @@ namespace Org.BouncyCastle.Tls
                 }
 
                 int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
-                int processed = ProcessRecord(received, record, buffer);
+                int processed = ProcessRecord(received, record, buffer, recordCallback);
                 if (processed >= 0)
                     return processed;
 
@@ -434,7 +446,7 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        internal int ReceivePending(Span<byte> buffer)
+        internal int ReceivePending(Span<byte> buffer, Action recordCallback)
         {
             if (m_recordQueue.Available > 0)
             {
@@ -444,7 +456,7 @@ namespace Org.BouncyCastle.Tls
                 do
                 {
                     int received = ReceivePendingRecord(record, 0, receiveLimit);
-                    int processed = ProcessRecord(received, record, buffer);
+                    int processed = ProcessRecord(received, record, buffer, recordCallback);
                     if (processed >= 0)
                         return processed;
                 }
@@ -665,9 +677,9 @@ namespace Org.BouncyCastle.Tls
         // TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat
         /// <exception cref="IOException"/>
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-        private int ProcessRecord(int received, byte[] record, Span<byte> buffer)
+        private int ProcessRecord(int received, byte[] record, Span<byte> buffer, Action recordCallback)
 #else
-        private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len)
+        private int ProcessRecord(int received, byte[] record, byte[] buf, int off, int len, Action recordCallback)
 #endif
         {
             // NOTE: received < 0 (timeout) is covered by this first case
@@ -772,8 +784,6 @@ namespace Org.BouncyCastle.Tls
                 return -1;
             }
 
-            recordEpoch.ReplayWindow.ReportAuthenticated(seq);
-
             if (decoded.len > m_plaintextLimit)
                 return -1;
 
@@ -805,6 +815,23 @@ namespace Org.BouncyCastle.Tls
                 }
             }
 
+            recordEpoch.ReplayWindow.ReportAuthenticated(seq, out var isLatestConfirmed);
+
+            /*
+             * NOTE: The record has passed record layer validation and will be dispatched according to the decoded
+             * content type.
+             */
+            if (recordCallback != null)
+            {
+                // TODO Make the callback more general than just peer address update
+                if (ContentType.tls12_cid == recordType &&
+                    isLatestConfirmed &&
+                    recordEpoch == m_readEpoch)
+                {
+                    recordCallback();
+                }
+            }
+
             switch (decoded.contentType)
             {
             case ContentType.alert:
diff --git a/crypto/src/tls/DtlsReplayWindow.cs b/crypto/src/tls/DtlsReplayWindow.cs
index a08114c2a..5267c5934 100644
--- a/crypto/src/tls/DtlsReplayWindow.cs
+++ b/crypto/src/tls/DtlsReplayWindow.cs
@@ -42,7 +42,9 @@ namespace Org.BouncyCastle.Tls
         /// <summary>Report that a received record with the given sequence number passed authentication checks.
         /// </summary>
         /// <param name="seq">the 48-bit DTLSPlainText.sequence_number field of an authenticated record.</param>
-        internal void ReportAuthenticated(long seq)
+        /// <param name="isLatestConfirmed">indicates whether <paramref name="seq"/> is now the latest confirmed
+        /// sequence number.</param>
+        internal void ReportAuthenticated(long seq, out bool isLatestConfirmed)
         {
             if ((seq & ValidSeqMask) != seq)
                 throw new ArgumentException("out of range", "seq");
@@ -54,6 +56,7 @@ namespace Org.BouncyCastle.Tls
                 {
                     m_bitmap |= (1UL << (int)diff);
                 }
+                isLatestConfirmed = false;
             }
             else
             {
@@ -68,6 +71,7 @@ namespace Org.BouncyCastle.Tls
                     m_bitmap |= 1UL;
                 }
                 m_latestConfirmedSeq = seq;
+                isLatestConfirmed = true;
             }
         }
 
diff --git a/crypto/src/tls/DtlsTransport.cs b/crypto/src/tls/DtlsTransport.cs
index 30cd364d2..b452b8c89 100644
--- a/crypto/src/tls/DtlsTransport.cs
+++ b/crypto/src/tls/DtlsTransport.cs
@@ -31,6 +31,12 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
+            return Receive(buf, off, len, waitMillis, null);
+        }
+
+        /// <exception cref="IOException"/>
+        public virtual int Receive(byte[] buf, int off, int len, int waitMillis, Action recordCallback)
+        {
             if (null == buf)
                 throw new ArgumentNullException("buf");
             if (off < 0 || off >= buf.Length)
@@ -39,14 +45,14 @@ namespace Org.BouncyCastle.Tls
                 throw new ArgumentException("invalid length: " + len, "len");
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-            return Receive(buf.AsSpan(off, len), waitMillis);
+            return Receive(buf.AsSpan(off, len), waitMillis, recordCallback);
 #else
             if (waitMillis < 0)
                 throw new ArgumentException("cannot be negative", "waitMillis");
 
             try
             {
-                return m_recordLayer.Receive(buf, off, len, waitMillis);
+                return m_recordLayer.Receive(buf, off, len, waitMillis, recordCallback);
             }
             catch (TlsFatalAlert fatalAlert)
             {
@@ -87,7 +93,7 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        public virtual int ReceivePending(byte[] buf, int off, int len)
+        public virtual int ReceivePending(byte[] buf, int off, int len, Action recordCallback = null)
         {
             if (null == buf)
                 throw new ArgumentNullException("buf");
@@ -97,11 +103,11 @@ namespace Org.BouncyCastle.Tls
                 throw new ArgumentException("invalid length: " + len, "len");
 
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
-            return ReceivePending(buf.AsSpan(off, len));
+            return ReceivePending(buf.AsSpan(off, len), recordCallback);
 #else
             try
             {
-                return m_recordLayer.ReceivePending(buf, off, len);
+                return m_recordLayer.ReceivePending(buf, off, len, recordCallback);
             }
             catch (TlsFatalAlert fatalAlert)
             {
@@ -145,12 +151,18 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int Receive(Span<byte> buffer, int waitMillis)
         {
+            return Receive(buffer, waitMillis, null);
+        }
+
+        /// <exception cref="IOException"/>
+        public virtual int Receive(Span<byte> buffer, int waitMillis, Action recordCallback)
+        {
             if (waitMillis < 0)
                 throw new ArgumentException("cannot be negative", nameof(waitMillis));
 
             try
             {
-                return m_recordLayer.Receive(buffer, waitMillis);
+                return m_recordLayer.Receive(buffer, waitMillis, recordCallback);
             }
             catch (TlsFatalAlert fatalAlert)
             {
@@ -190,11 +202,11 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        public virtual int ReceivePending(Span<byte> buffer)
+        public virtual int ReceivePending(Span<byte> buffer, Action recordCallback = null)
         {
             try
             {
-                return m_recordLayer.ReceivePending(buffer);
+                return m_recordLayer.ReceivePending(buffer, recordCallback);
             }
             catch (TlsFatalAlert fatalAlert)
             {