diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index e68470adb..a786da127 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -1,4 +1,5 @@
using System;
+using System.Diagnostics;
using System.IO;
using System.Net.Sockets;
@@ -340,6 +341,31 @@ namespace Org.BouncyCastle.Tls
#endif
}
+ /// <exception cref="IOException"/>
+ internal int ReceivePending(byte[] buf, int off, int len)
+ {
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ return ReceivePending(buf.AsSpan(off, len));
+#else
+ if (m_recordQueue.Available > 0)
+ {
+ int receiveLimit = m_recordQueue.Available;
+ byte[] record = new byte[receiveLimit];
+
+ do
+ {
+ int received = ReceivePendingRecord(record, 0, receiveLimit);
+ int processed = ProcessRecord(received, record, buf, off, len);
+ if (processed >= 0)
+ return processed;
+ }
+ while (m_recordQueue.Available > 0);
+ }
+
+ return -1;
+#endif
+ }
+
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
/// <exception cref="IOException"/>
public virtual int Receive(Span<byte> buffer, int waitMillis)
@@ -406,6 +432,27 @@ namespace Org.BouncyCastle.Tls
return -1;
}
+
+ /// <exception cref="IOException"/>
+ internal int ReceivePending(Span<byte> buffer)
+ {
+ if (m_recordQueue.Available > 0)
+ {
+ int receiveLimit = m_recordQueue.Available;
+ byte[] record = new byte[receiveLimit];
+
+ do
+ {
+ int received = ReceivePendingRecord(record, 0, receiveLimit);
+ int processed = ProcessRecord(received, record, buffer);
+ if (processed >= 0)
+ return processed;
+ }
+ while (m_recordQueue.Available > 0);
+ }
+
+ return -1;
+ }
#endif
/// <exception cref="IOException"/>
@@ -905,84 +952,88 @@ namespace Org.BouncyCastle.Tls
}
/// <exception cref="IOException"/>
- private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
+ private int ReceivePendingRecord(byte[] buf, int off, int len)
{
- if (m_recordQueue.Available > 0)
- {
- int recordLength = RECORD_HEADER_LENGTH;
- if (m_recordQueue.Available >= recordLength)
- {
- short recordType = m_recordQueue.ReadUint8(0);
- int epoch = m_recordQueue.ReadUint16(3);
+ Debug.Assert(m_recordQueue.Available > 0);
- DtlsEpoch recordEpoch = null;
- if (epoch == m_readEpoch.Epoch)
- {
- recordEpoch = m_readEpoch;
- }
- else if (recordType == ContentType.handshake && null != m_retransmitEpoch
- && epoch == m_retransmitEpoch.Epoch)
- {
- recordEpoch = m_retransmitEpoch;
- }
+ int recordLength = RECORD_HEADER_LENGTH;
+ if (m_recordQueue.Available >= recordLength)
+ {
+ short recordType = m_recordQueue.ReadUint8(0);
+ int epoch = m_recordQueue.ReadUint16(3);
- if (null == recordEpoch)
- {
- m_recordQueue.RemoveData(m_recordQueue.Available);
- return -1;
- }
+ DtlsEpoch recordEpoch = null;
+ if (epoch == m_readEpoch.Epoch)
+ {
+ recordEpoch = m_readEpoch;
+ }
+ else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+ && epoch == m_retransmitEpoch.Epoch)
+ {
+ recordEpoch = m_retransmitEpoch;
+ }
- recordLength = recordEpoch.RecordHeaderLengthRead;
- if (m_recordQueue.Available >= recordLength)
- {
- int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2);
- recordLength += fragmentLength;
- }
+ if (null == recordEpoch)
+ {
+ m_recordQueue.RemoveData(m_recordQueue.Available);
+ return -1;
}
- int received = System.Math.Min(m_recordQueue.Available, recordLength);
- m_recordQueue.RemoveData(buf, off, received, 0);
- return received;
+ recordLength = recordEpoch.RecordHeaderLengthRead;
+ if (m_recordQueue.Available >= recordLength)
+ {
+ int fragmentLength = m_recordQueue.ReadUint16(recordLength - 2);
+ recordLength += fragmentLength;
+ }
}
+ int received = System.Math.Min(m_recordQueue.Available, recordLength);
+ m_recordQueue.RemoveData(buf, off, received, 0);
+ return received;
+ }
+
+ /// <exception cref="IOException"/>
+ private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
+ {
+ if (m_recordQueue.Available > 0)
+ return ReceivePendingRecord(buf, off, len);
+
+ int received = ReceiveDatagram(buf, off, len, waitMillis);
+ if (received >= RECORD_HEADER_LENGTH)
{
- int received = ReceiveDatagram(buf, off, len, waitMillis);
- if (received >= RECORD_HEADER_LENGTH)
- {
- this.m_inConnection = true;
+ this.m_inConnection = true;
- short recordType = TlsUtilities.ReadUint8(buf, off);
- int epoch = TlsUtilities.ReadUint16(buf, off + 3);
+ short recordType = TlsUtilities.ReadUint8(buf, off);
+ int epoch = TlsUtilities.ReadUint16(buf, off + 3);
- DtlsEpoch recordEpoch = null;
- if (epoch == m_readEpoch.Epoch)
- {
- recordEpoch = m_readEpoch;
- }
- else if (recordType == ContentType.handshake && null != m_retransmitEpoch
- && epoch == m_retransmitEpoch.Epoch)
- {
- recordEpoch = m_retransmitEpoch;
- }
+ DtlsEpoch recordEpoch = null;
+ if (epoch == m_readEpoch.Epoch)
+ {
+ recordEpoch = m_readEpoch;
+ }
+ else if (recordType == ContentType.handshake && null != m_retransmitEpoch
+ && epoch == m_retransmitEpoch.Epoch)
+ {
+ recordEpoch = m_retransmitEpoch;
+ }
- if (null == recordEpoch)
- return -1;
+ if (null == recordEpoch)
+ return -1;
- int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
- if (received >= recordHeaderLength)
+ int recordHeaderLength = recordEpoch.RecordHeaderLengthRead;
+ if (received >= recordHeaderLength)
+ {
+ int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2);
+ int recordLength = recordHeaderLength + fragmentLength;
+ if (received > recordLength)
{
- int fragmentLength = TlsUtilities.ReadUint16(buf, off + recordHeaderLength - 2);
- int recordLength = recordHeaderLength + fragmentLength;
- if (received > recordLength)
- {
- m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
- received = recordLength;
- }
+ m_recordQueue.AddData(buf, off + recordLength, received - recordLength);
+ received = recordLength;
}
}
-
- return received;
}
+
+ return received;
}
private void ResetHeartbeat()
diff --git a/crypto/src/tls/DtlsTransport.cs b/crypto/src/tls/DtlsTransport.cs
index 2d950ede0..30cd364d2 100644
--- a/crypto/src/tls/DtlsTransport.cs
+++ b/crypto/src/tls/DtlsTransport.cs
@@ -86,6 +86,61 @@ namespace Org.BouncyCastle.Tls
#endif
}
+ /// <exception cref="IOException"/>
+ public virtual int ReceivePending(byte[] buf, int off, int len)
+ {
+ if (null == buf)
+ throw new ArgumentNullException("buf");
+ if (off < 0 || off >= buf.Length)
+ throw new ArgumentException("invalid offset: " + off, "off");
+ if (len < 0 || len > buf.Length - off)
+ throw new ArgumentException("invalid length: " + len, "len");
+
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+ return ReceivePending(buf.AsSpan(off, len));
+#else
+ try
+ {
+ return m_recordLayer.ReceivePending(buf, off, len);
+ }
+ catch (TlsFatalAlert fatalAlert)
+ {
+ if (m_ignoreCorruptRecords && AlertDescription.bad_record_mac == fatalAlert.AlertDescription)
+ return -1;
+
+ m_recordLayer.Fail(fatalAlert.AlertDescription);
+ throw;
+ }
+ catch (TlsTimeoutException)
+ {
+ throw;
+ }
+ catch (SocketException e)
+ {
+ if (TlsUtilities.IsTimeout(e))
+ throw;
+
+ m_recordLayer.Fail(AlertDescription.internal_error);
+ throw new TlsFatalAlert(AlertDescription.internal_error, e);
+ }
+ // TODO[tls-port] Can we support interrupted IO on .NET?
+ //catch (InterruptedIOException)
+ //{
+ // throw;
+ //}
+ catch (IOException)
+ {
+ m_recordLayer.Fail(AlertDescription.internal_error);
+ throw;
+ }
+ catch (Exception e)
+ {
+ m_recordLayer.Fail(AlertDescription.internal_error);
+ throw new TlsFatalAlert(AlertDescription.internal_error, e);
+ }
+#endif
+ }
+
#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
/// <exception cref="IOException"/>
public virtual int Receive(Span<byte> buffer, int waitMillis)
@@ -133,6 +188,50 @@ namespace Org.BouncyCastle.Tls
throw new TlsFatalAlert(AlertDescription.internal_error, e);
}
}
+
+ /// <exception cref="IOException"/>
+ public virtual int ReceivePending(Span<byte> buffer)
+ {
+ try
+ {
+ return m_recordLayer.ReceivePending(buffer);
+ }
+ catch (TlsFatalAlert fatalAlert)
+ {
+ if (m_ignoreCorruptRecords && AlertDescription.bad_record_mac == fatalAlert.AlertDescription)
+ return -1;
+
+ m_recordLayer.Fail(fatalAlert.AlertDescription);
+ throw;
+ }
+ catch (TlsTimeoutException)
+ {
+ throw;
+ }
+ catch (SocketException e)
+ {
+ if (TlsUtilities.IsTimeout(e))
+ throw;
+
+ m_recordLayer.Fail(AlertDescription.internal_error);
+ throw new TlsFatalAlert(AlertDescription.internal_error, e);
+ }
+ // TODO[tls-port] Can we support interrupted IO on .NET?
+ //catch (InterruptedIOException)
+ //{
+ // throw;
+ //}
+ catch (IOException)
+ {
+ m_recordLayer.Fail(AlertDescription.internal_error);
+ throw;
+ }
+ catch (Exception e)
+ {
+ m_recordLayer.Fail(AlertDescription.internal_error);
+ throw new TlsFatalAlert(AlertDescription.internal_error, e);
+ }
+ }
#endif
/// <exception cref="IOException"/>
|