diff options
Diffstat (limited to 'crypto/src/tls/DtlsRecordLayer.cs')
-rw-r--r-- | crypto/src/tls/DtlsRecordLayer.cs | 173 |
1 files changed, 112 insertions, 61 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index e68470adb..a786da127 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.IO; using System.Net.Sockets; @@ -340,6 +341,31 @@ namespace Org.BouncyCastle.Tls #endif } + /// <exception cref="IOException"/> + internal int ReceivePending(byte[] buf, int off, int len) + { +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + return ReceivePending(buf.AsSpan(off, len)); +#else + if (m_recordQueue.Available > 0) + { + int receiveLimit = m_recordQueue.Available; + byte[] record = new byte[receiveLimit]; + + do + { + int received = ReceivePendingRecord(record, 0, receiveLimit); + int processed = ProcessRecord(received, record, buf, off, len); + if (processed >= 0) + return processed; + } + while (m_recordQueue.Available > 0); + } + + return -1; +#endif + } + #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER /// <exception cref="IOException"/> public virtual int Receive(Span<byte> buffer, int waitMillis) @@ -406,6 +432,27 @@ namespace Org.BouncyCastle.Tls return -1; } + + /// <exception cref="IOException"/> + internal int ReceivePending(Span<byte> buffer) + { + if (m_recordQueue.Available > 0) + { + int receiveLimit = m_recordQueue.Available; + byte[] record = new byte[receiveLimit]; + + do + { + int received = ReceivePendingRecord(record, 0, receiveLimit); + int processed = ProcessRecord(received, record, buffer); + if (processed >= 0) + return processed; + } + while (m_recordQueue.Available > 0); + } + + return -1; + } #endif /// <exception cref="IOException"/> @@ -905,84 +952,88 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> - private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis) + private int ReceivePendingRecord(byte[] buf, int off, int len) { - if (m_recordQueue.Available > 0) - { - int recordLength = RECORD_HEADER_LENGTH; - if (m_recordQueue.Available >= recordLength) - { - short recordType = m_recordQueue.ReadUint8(0); - int epoch = m_recordQueue.ReadUint16(3); + Debug.Assert(m_recordQueue.Available > 0); - DtlsEpoch recordEpoch = null; - if (epoch == m_readEpoch.Epoch) - { - recordEpoch = m_readEpoch; - } - else if (recordType == ContentType.handshake && null != m_retransmitEpoch - && epoch == m_retransmitEpoch.Epoch) - { - recordEpoch = m_retransmitEpoch; - } + int recordLength = RECORD_HEADER_LENGTH; + if (m_recordQueue.Available >= recordLength) + { + short recordType = m_recordQueue.ReadUint8(0); + int epoch = m_recordQueue.ReadUint16(3); - if (null == recordEpoch) - { - m_recordQueue.RemoveData(m_recordQueue.Available); - return -1; - } + DtlsEpoch recordEpoch = null; + if (epoch == m_readEpoch.Epoch) + { + recordEpoch = m_readEpoch; + } + else if (recordType == ContentType.handshake && null != m_retransmitEpoch + && epoch == m_retransmitEpoch.Epoch) + { + recordEpoch = m_retransmitEpoch; + } - recordLength = recordEpoch.RecordHeaderLengthRead; - if (m_recordQueue.Available >= recordLength) - { - int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2); - recordLength += fragmentLength; - } + if (null == recordEpoch) + { + m_recordQueue.RemoveData(m_recordQueue.Available); + return -1; } - int received = System.Math.Min(m_recordQueue.Available, recordLength); - m_recordQueue.RemoveData(buf, off, received, 0); - return received; + recordLength = recordEpoch.RecordHeaderLengthRead; + if (m_recordQueue.Available >= recordLength) + { + int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2); + recordLength += fragmentLength; + } } + int received = System.Math.Min(m_recordQueue.Available, recordLength); + m_recordQueue.RemoveData(buf, off, received, 0); + return received; + } + + /// <exception cref="IOException"/> + private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis) + { + if (m_recordQueue.Available > 0) + return ReceivePendingRecord(buf, off, len); + + int received = ReceiveDatagram(buf, off, len, waitMillis); + if (received >= RECORD_HEADER_LENGTH) { - int received = ReceiveDatagram(buf, off, len, waitMillis); - if (received >= RECORD_HEADER_LENGTH) - { - this.m_inConnection = true; + this.m_inConnection = true; - short recordType = TlsUtilities.ReadUint8(buf, off); - int epoch = TlsUtilities.ReadUint16(buf, off + 3); + short recordType = TlsUtilities.ReadUint8(buf, off); + int epoch = TlsUtilities.ReadUint16(buf, off + 3); - DtlsEpoch recordEpoch = null; - if (epoch == m_readEpoch.Epoch) - { - recordEpoch = m_readEpoch; - } - else if (recordType == ContentType.handshake && null != m_retransmitEpoch - && epoch == m_retransmitEpoch.Epoch) - { - recordEpoch = m_retransmitEpoch; - } + DtlsEpoch recordEpoch = null; + if (epoch == m_readEpoch.Epoch) + { + recordEpoch = m_readEpoch; + } + else if (recordType == ContentType.handshake && null != m_retransmitEpoch + && epoch == m_retransmitEpoch.Epoch) + { + recordEpoch = m_retransmitEpoch; + } - if (null == recordEpoch) - return -1; + if (null == recordEpoch) + return -1; - int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; - if (received >= recordHeaderLength) + int recordHeaderLength = recordEpoch.RecordHeaderLengthRead; + if (received >= recordHeaderLength) + { + int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2); + int recordLength = recordHeaderLength + fragmentLength; + if (received > recordLength) { - int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2); - int recordLength = recordHeaderLength + fragmentLength; - if (received > recordLength) - { - m_recordQueue.AddData(buf, off + recordLength, received - recordLength); - received = recordLength; - } + m_recordQueue.AddData(buf, off + recordLength, received - recordLength); + received = recordLength; } } - - return received; } + + return received; } private void ResetHeartbeat() |