From fcb56e8c3cbeb83ef6af6ab4f9681d7d8318e299 Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Sat, 25 Mar 2023 16:11:42 +0700 Subject: RFC 9146: Add ReceivePending methods --- crypto/src/tls/DtlsRecordLayer.cs | 173 ++++++++++++++++++++++++-------------- crypto/src/tls/DtlsTransport.cs | 99 ++++++++++++++++++++++ 2 files changed, 211 insertions(+), 61 deletions(-) diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index e68470adb..a786da127 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.IO; using System.Net.Sockets; @@ -340,6 +341,31 @@ namespace Org.BouncyCastle.Tls #endif } + /// + internal int ReceivePending(byte[] buf, int off, int len) + { +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + return ReceivePending(buf.AsSpan(off, len)); +#else + if (m_recordQueue.Available > 0) + { + int receiveLimit = m_recordQueue.Available; + byte[] record = new byte[receiveLimit]; + + do + { + int received = ReceivePendingRecord(record, 0, receiveLimit); + int processed = ProcessRecord(received, record, buf, off, len); + if (processed >= 0) + return processed; + } + while (m_recordQueue.Available > 0); + } + + return -1; +#endif + } + #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER /// public virtual int Receive(Span buffer, int waitMillis) @@ -406,6 +432,27 @@ namespace Org.BouncyCastle.Tls return -1; } + + /// + internal int ReceivePending(Span buffer) + { + if (m_recordQueue.Available > 0) + { + int receiveLimit = m_recordQueue.Available; + byte[] record = new byte[receiveLimit]; + + do + { + int received = ReceivePendingRecord(record, 0, receiveLimit); + int processed = ProcessRecord(received, record, buffer); + if (processed >= 0) + return processed; + } + while (m_recordQueue.Available > 0); + } + + return -1; + } #endif /// @@ -905,84 +952,88 @@ namespace Org.BouncyCastle.Tls } /// - private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis) + private int ReceivePendingRecord(byte[] buf, int off, int len) { - if (m_recordQueue.Available > 0) - { - int recordLength = RECORD_HEADER_LENGTH; - if (m_recordQueue.Available >= recordLength) - { - short recordType = m_recordQueue.ReadUint8(0); - int epoch = m_recordQueue.ReadUint16(3); + Debug.Assert(m_recordQueue.Available > 0); - DtlsEpoch recordEpoch = null; - if (epoch == m_readEpoch.Epoch) - { - recordEpoch = m_readEpoch; - } - else if (recordType == ContentType.handshake && null != m_retransmitEpoch - && epoch == m_retransmitEpoch.Epoch) - { - recordEpoch = m_retransmitEpoch; - } + int recordLength = RECORD_HEADER_LENGTH; + if (m_recordQueue.Available >= recordLength) + { + short recordType = m_recordQueue.ReadUint8(0); + int epoch = m_recordQueue.ReadUint16(3); - if (null == recordEpoch) - { - m_recordQueue.RemoveData(m_recordQueue.Available); - return -1; - } + DtlsEpoch recordEpoch = null; + if (epoch == m_readEpoch.Epoch) + { + recordEpoch = m_readEpoch; + } + else if (recordType == ContentType.handshake && null != m_retransmitEpoch + && epoch == m_retransmitEpoch.Epoch) + { + recordEpoch = m_retransmitEpoch; + } - recordLength = recordEpoch.RecordHeaderLengthRead; - if (m_recordQueue.Available >= recordLength) - { - int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2); - recordLength += fragmentLength; - } + if (null == recordEpoch) + { + m_recordQueue.RemoveData(m_recordQueue.Available); + return -1; } - int received = System.Math.Min(m_recordQueue.Available, recordLength); - m_recordQueue.RemoveData(buf, off, received, 0); - return received; + recordLength = recordEpoch.RecordHeaderLengthRead; + if (m_recordQueue.Available >= recordLength) + { + int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2); + recordLength += fragmentLength; + } } + int received = System.Math.Min(m_recordQueue.Available, recordLength); + m_recordQueue.RemoveData(buf, off, received, 0); + return received; + } + + /// + private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis) + { + if (m_recordQueue.Available > 0) + return ReceivePendingRecord(buf, off, len); + + int received = ReceiveDatagram(buf, off, len, waitMillis); + if (received >= RECORD_HEADER_LENGTH) { - int received = ReceiveDatagram(buf, off, len, waitMillis); - if (received >= RECORD_HEADER_LENGTH) - { - this.m_inConnection = true; + this.m_inConnection = true; - short recordType = TlsUtilities.ReadUint8(buf, off); - int epoch = TlsUtilities.ReadUint16(buf, off + 3); + short recordType = TlsUtilities.ReadUint8(buf, off); + int epoch = TlsUtilities.ReadUint16(buf, off + 3); - DtlsEpoch recordEpoch = null; - if (epoch == m_readEpoch.Epoch) - { - recordEpoch = m_readEpoch; - } - else if (recordType == ContentType.handshake && null != m_retransmitEpoch - && epoch == m_retransmitEpoch.Epoch) - { - recordEpoch = m_retransmitEpoch; - } + DtlsEpoch recordEpoch = null; + if (epoch == m_readEpoch.Epoch) + { + recordEpoch = m_readEpoch; + } + else if (recordType == ContentType.handshake && null != m_retransmitEpoch + && epoch == m_retransmitEpoch.Epoch) + { + recordEpoch = m_retransmitEpoch; + } - if (null == recordEpoch) - return -1; + if (null == recordEpoch) + return -1; - int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; - if (received >= recordHeaderLength) + int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; + if (received >= recordHeaderLength) + { + int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2); + int recordLength = recordHeaderLength + fragmentLength; + if (received > recordLength) { - int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2); - int recordLength = recordHeaderLength + fragmentLength; - if (received > recordLength) - { - m_recordQueue.AddData(buf, off + recordLength, received - recordLength); - received = recordLength; - } + m_recordQueue.AddData(buf, off + recordLength, received - recordLength); + received = recordLength; } } - - return received; } + + return received; } private void ResetHeartbeat() diff --git a/crypto/src/tls/DtlsTransport.cs b/crypto/src/tls/DtlsTransport.cs index 2d950ede0..30cd364d2 100644 --- a/crypto/src/tls/DtlsTransport.cs +++ b/crypto/src/tls/DtlsTransport.cs @@ -86,6 +86,61 @@ namespace Org.BouncyCastle.Tls #endif } + /// + public virtual int ReceivePending(byte[] buf, int off, int len) + { + if (null == buf) + throw new ArgumentNullException("buf"); + if (off < 0 || off >= buf.Length) + throw new ArgumentException("invalid offset: " + off, "off"); + if (len < 0 || len > buf.Length - off) + throw new ArgumentException("invalid length: " + len, "len"); + +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + return ReceivePending(buf.AsSpan(off, len)); +#else + try + { + return m_recordLayer.ReceivePending(buf, off, len); + } + catch (TlsFatalAlert fatalAlert) + { + if (m_ignoreCorruptRecords && AlertDescription.bad_record_mac == fatalAlert.AlertDescription) + return -1; + + m_recordLayer.Fail(fatalAlert.AlertDescription); + throw; + } + catch (TlsTimeoutException) + { + throw; + } + catch (SocketException e) + { + if (TlsUtilities.IsTimeout(e)) + throw; + + m_recordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } + // TODO[tls-port] Can we support interrupted IO on .NET? + //catch (InterruptedIOException) + //{ + // throw; + //} + catch (IOException) + { + m_recordLayer.Fail(AlertDescription.internal_error); + throw; + } + catch (Exception e) + { + m_recordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } +#endif + } + #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER /// public virtual int Receive(Span buffer, int waitMillis) @@ -133,6 +188,50 @@ namespace Org.BouncyCastle.Tls throw new TlsFatalAlert(AlertDescription.internal_error, e); } } + + /// + public virtual int ReceivePending(Span buffer) + { + try + { + return m_recordLayer.ReceivePending(buffer); + } + catch (TlsFatalAlert fatalAlert) + { + if (m_ignoreCorruptRecords && AlertDescription.bad_record_mac == fatalAlert.AlertDescription) + return -1; + + m_recordLayer.Fail(fatalAlert.AlertDescription); + throw; + } + catch (TlsTimeoutException) + { + throw; + } + catch (SocketException e) + { + if (TlsUtilities.IsTimeout(e)) + throw; + + m_recordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } + // TODO[tls-port] Can we support interrupted IO on .NET? + //catch (InterruptedIOException) + //{ + // throw; + //} + catch (IOException) + { + m_recordLayer.Fail(AlertDescription.internal_error); + throw; + } + catch (Exception e) + { + m_recordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } + } #endif /// -- cgit 1.4.1