summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs173
-rw-r--r--crypto/src/tls/DtlsTransport.cs99
2 files changed, 211 insertions, 61 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index e68470adb..a786da127 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -1,4 +1,5 @@
 using System;
+using System.Diagnostics;
 using System.IO;
 using System.Net.Sockets;
 
@@ -340,6 +341,31 @@ namespace Org.BouncyCastle.Tls
 #endif
         }
 
+        /// <exception cref="IOException"/>
+        internal int ReceivePending(byte[] buf, int off, int len)
+        {
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            return ReceivePending(buf.AsSpan(off, len));
+#else
+            if (m_recordQueue.Available > 0)
+            {
+                int receiveLimit = m_recordQueue.Available;
+                byte[] record = new byte[receiveLimit];
+
+                do
+                {
+                    int received = ReceivePendingRecord(record, 0, receiveLimit);
+                    int processed = ProcessRecord(received, record, buf, off, len);
+                    if (processed >= 0)
+                        return processed;
+                }
+                while (m_recordQueue.Available > 0);
+            }
+
+            return -1;
+#endif
+        }
+
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
         /// <exception cref="IOException"/>
         public virtual int Receive(Span<byte> buffer, int waitMillis)
@@ -406,6 +432,27 @@ namespace Org.BouncyCastle.Tls
 
             return -1;
         }
+
+        /// <exception cref="IOException"/>
+        internal int ReceivePending(Span<byte> buffer)
+        {
+            if (m_recordQueue.Available > 0)
+            {
+                int receiveLimit = m_recordQueue.Available;
+                byte[] record = new byte[receiveLimit];
+
+                do
+                {
+                    int received = ReceivePendingRecord(record, 0, receiveLimit);
+                    int processed = ProcessRecord(received, record, buffer);
+                    if (processed >= 0)
+                        return processed;
+                }
+                while (m_recordQueue.Available > 0);
+            }
+
+            return -1;
+        }
 #endif
 
         /// <exception cref="IOException"/>
@@ -905,84 +952,88 @@ namespace Org.BouncyCastle.Tls
         }
 
         /// <exception cref="IOException"/>
-        private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
+        private int ReceivePendingRecord(byte[] buf, int off, int len)
         {
-            if (m_recordQueue.Available > 0)
-            {
-                int recordLength = RECORD_HEADER_LENGTH;
-                if (m_recordQueue.Available >= recordLength)
-                {
-                    short recordType = m_recordQueue.ReadUint8(0);
-                    int epoch = m_recordQueue.ReadUint16(3);
+            Debug.Assert(m_recordQueue.Available > 0);
 
-                    DtlsEpoch recordEpoch = null;
-                    if (epoch == m_readEpoch.Epoch)
-                    {
-                        recordEpoch = m_readEpoch;
-                    }
-                    else if (recordType == ContentType.handshake && null != m_retransmitEpoch
-                        && epoch == m_retransmitEpoch.Epoch)
-                    {
-                        recordEpoch = m_retransmitEpoch;
-                    }
+            int recordLength = RECORD_HEADER_LENGTH;
+            if (m_recordQueue.Available >= recordLength)
+            {
+                short recordType = m_recordQueue.ReadUint8(0);
+                int epoch = m_recordQueue.ReadUint16(3);
 
-                    if (null == recordEpoch)
-                    {
-                        m_recordQueue.RemoveData(m_recordQueue.Available);
-                        return -1;
-                    }
+                DtlsEpoch recordEpoch = null;
+                if (epoch == m_readEpoch.Epoch)
+                {
+                    recordEpoch = m_readEpoch;
+                }
+                else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+                    && epoch == m_retransmitEpoch.Epoch)
+                {
+                    recordEpoch = m_retransmitEpoch;
+                }
 
-                    recordLength = recordEpoch.RecordHeaderLengthRead;
-                    if (m_recordQueue.Available >= recordLength)
-                    {
-                        int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2);
-                        recordLength += fragmentLength;
-                    }
+                if (null == recordEpoch)
+                {
+                    m_recordQueue.RemoveData(m_recordQueue.Available);
+                    return -1;
                 }
 
-                int received = System.Math.Min(m_recordQueue.Available, recordLength);
-                m_recordQueue.RemoveData(buf, off, received, 0);
-                return received;
+                recordLength = recordEpoch.RecordHeaderLengthRead;
+                if (m_recordQueue.Available >= recordLength)
+                {
+                    int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2);
+                    recordLength += fragmentLength;
+                }
             }
 
+            int received = System.Math.Min(m_recordQueue.Available, recordLength);
+            m_recordQueue.RemoveData(buf, off, received, 0);
+            return received;
+        }
+
+        /// <exception cref="IOException"/>
+        private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
+        {
+            if (m_recordQueue.Available > 0)
+                return ReceivePendingRecord(buf, off, len);
+
+            int received = ReceiveDatagram(buf, off, len, waitMillis);
+            if (received >= RECORD_HEADER_LENGTH)
             {
-                int received = ReceiveDatagram(buf, off, len, waitMillis);
-                if (received >= RECORD_HEADER_LENGTH)
-                {
-                    this.m_inConnection = true;
+                this.m_inConnection = true;
 
-                    short recordType = TlsUtilities.ReadUint8(buf, off);
-                    int epoch = TlsUtilities.ReadUint16(buf, off + 3);
+                short recordType = TlsUtilities.ReadUint8(buf, off);
+                int epoch = TlsUtilities.ReadUint16(buf, off + 3);
 
-                    DtlsEpoch recordEpoch = null;
-                    if (epoch == m_readEpoch.Epoch)
-                    {
-                        recordEpoch = m_readEpoch;
-                    }
-                    else if (recordType == ContentType.handshake && null != m_retransmitEpoch
-                        && epoch == m_retransmitEpoch.Epoch)
-                    {
-                        recordEpoch = m_retransmitEpoch;
-                    }
+                DtlsEpoch recordEpoch = null;
+                if (epoch == m_readEpoch.Epoch)
+                {
+                    recordEpoch = m_readEpoch;
+                }
+                else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+                    && epoch == m_retransmitEpoch.Epoch)
+                {
+                    recordEpoch = m_retransmitEpoch;
+                }
 
-                    if (null == recordEpoch)
-                        return -1;
+                if (null == recordEpoch)
+                    return -1;
 
-                    int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
-                    if (received >= recordHeaderLength)
+                int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
+                if (received >= recordHeaderLength)
+                {
+                    int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2);
+                    int recordLength = recordHeaderLength + fragmentLength;
+                    if (received > recordLength)
                     {
-                        int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2);
-                        int recordLength = recordHeaderLength + fragmentLength;
-                        if (received > recordLength)
-                        {
-                            m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
-                            received = recordLength;
-                        }
+                        m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
+                        received = recordLength;
                     }
                 }
-
-                return received;
             }
+
+            return received;
         }
 
         private void ResetHeartbeat()
diff --git a/crypto/src/tls/DtlsTransport.cs b/crypto/src/tls/DtlsTransport.cs
index 2d950ede0..30cd364d2 100644
--- a/crypto/src/tls/DtlsTransport.cs
+++ b/crypto/src/tls/DtlsTransport.cs
@@ -86,6 +86,61 @@ namespace Org.BouncyCastle.Tls
 #endif
         }
 
+        /// <exception cref="IOException"/>
+        public virtual int ReceivePending(byte[] buf, int off, int len)
+        {
+            if (null == buf)
+                throw new ArgumentNullException("buf");
+            if (off < 0 || off >= buf.Length)
+                throw new ArgumentException("invalid offset: " + off, "off");
+            if (len < 0 || len > buf.Length - off)
+                throw new ArgumentException("invalid length: " + len, "len");
+
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            return ReceivePending(buf.AsSpan(off, len));
+#else
+            try
+            {
+                return m_recordLayer.ReceivePending(buf, off, len);
+            }
+            catch (TlsFatalAlert fatalAlert)
+            {
+                if (m_ignoreCorruptRecords && AlertDescription.bad_record_mac == fatalAlert.AlertDescription)
+                    return -1;
+
+                m_recordLayer.Fail(fatalAlert.AlertDescription);
+                throw;
+            }
+            catch (TlsTimeoutException)
+            {
+                throw;
+            }
+            catch (SocketException e)
+            {
+                if (TlsUtilities.IsTimeout(e))
+                    throw;
+
+                m_recordLayer.Fail(AlertDescription.internal_error);
+                throw new TlsFatalAlert(AlertDescription.internal_error, e);
+            }
+            // TODO[tls-port] Can we support interrupted IO on .NET?
+            //catch (InterruptedIOException)
+            //{
+            //    throw;
+            //}
+            catch (IOException)
+            {
+                m_recordLayer.Fail(AlertDescription.internal_error);
+                throw;
+            }
+            catch (Exception e)
+            {
+                m_recordLayer.Fail(AlertDescription.internal_error);
+                throw new TlsFatalAlert(AlertDescription.internal_error, e);
+            }
+#endif
+        }
+
 #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
         /// <exception cref="IOException"/>
         public virtual int Receive(Span<byte> buffer, int waitMillis)
@@ -133,6 +188,50 @@ namespace Org.BouncyCastle.Tls
                 throw new TlsFatalAlert(AlertDescription.internal_error, e);
             }
         }
+
+        /// <exception cref="IOException"/>
+        public virtual int ReceivePending(Span<byte> buffer)
+        {
+            try
+            {
+                return m_recordLayer.ReceivePending(buffer);
+            }
+            catch (TlsFatalAlert fatalAlert)
+            {
+                if (m_ignoreCorruptRecords && AlertDescription.bad_record_mac == fatalAlert.AlertDescription)
+                    return -1;
+
+                m_recordLayer.Fail(fatalAlert.AlertDescription);
+                throw;
+            }
+            catch (TlsTimeoutException)
+            {
+                throw;
+            }
+            catch (SocketException e)
+            {
+                if (TlsUtilities.IsTimeout(e))
+                    throw;
+
+                m_recordLayer.Fail(AlertDescription.internal_error);
+                throw new TlsFatalAlert(AlertDescription.internal_error, e);
+            }
+            // TODO[tls-port] Can we support interrupted IO on .NET?
+            //catch (InterruptedIOException)
+            //{
+            //    throw;
+            //}
+            catch (IOException)
+            {
+                m_recordLayer.Fail(AlertDescription.internal_error);
+                throw;
+            }
+            catch (Exception e)
+            {
+                m_recordLayer.Fail(AlertDescription.internal_error);
+                throw new TlsFatalAlert(AlertDescription.internal_error, e);
+            }
+        }
 #endif
 
         /// <exception cref="IOException"/>