summary refs log tree commit diff
path: root/crypto/src/tls/DtlsRecordLayer.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/tls/DtlsRecordLayer.cs')
-rw-r--r--crypto/src/tls/DtlsRecordLayer.cs167
1 files changed, 158 insertions, 9 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs
index 7ec77c5da..bab6892b7 100644
--- a/crypto/src/tls/DtlsRecordLayer.cs
+++ b/crypto/src/tls/DtlsRecordLayer.cs
@@ -242,6 +242,9 @@ namespace Org.BouncyCastle.Tls
         /// <exception cref="IOException"/>
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            return Receive(buf.AsSpan(off, len), waitMillis);
+#else
             long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
 
             Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis);
@@ -305,11 +308,85 @@ namespace Org.BouncyCastle.Tls
             }
 
             return -1;
+#endif
         }
 
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+        /// <exception cref="IOException"/>
+        public virtual int Receive(Span<byte> buffer, int waitMillis)
+        {
+            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+
+            Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis);
+            byte[] record = null;
+
+            while (waitMillis >= 0)
+            {
+                if (null != m_retransmitTimeout && m_retransmitTimeout.RemainingMillis(currentTimeMillis) < 1)
+                {
+                    m_retransmit = null;
+                    m_retransmitEpoch = null;
+                    m_retransmitTimeout = null;
+                }
+
+                if (Timeout.HasExpired(m_heartbeatTimeout, currentTimeMillis))
+                {
+                    if (null != m_heartbeatInFlight)
+                        throw new TlsTimeoutException("Heartbeat timed out");
+
+                    this.m_heartbeatInFlight = HeartbeatMessage.Create(m_context,
+                        HeartbeatMessageType.heartbeat_request, m_heartbeat.GeneratePayload());
+                    this.m_heartbeatTimeout = new Timeout(m_heartbeat.TimeoutMillis, currentTimeMillis);
+
+                    this.m_heartbeatResendMillis = DtlsReliableHandshake.INITIAL_RESEND_MILLIS;
+                    this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis);
+
+                    SendHeartbeatMessage(m_heartbeatInFlight);
+                }
+                else if (Timeout.HasExpired(m_heartbeatResendTimeout, currentTimeMillis))
+                {
+                    this.m_heartbeatResendMillis = DtlsReliableHandshake.BackOff(m_heartbeatResendMillis);
+                    this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis);
+
+                    SendHeartbeatMessage(m_heartbeatInFlight);
+                }
+
+                waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatTimeout, currentTimeMillis);
+                waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatResendTimeout, currentTimeMillis);
+
+                // NOTE: Guard against bad logic giving a negative value 
+                if (waitMillis < 0)
+                {
+                    waitMillis = 1;
+                }
+
+                int receiveLimit = System.Math.Min(buffer.Length, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
+                if (null == record || record.Length < receiveLimit)
+                {
+                    record = new byte[receiveLimit];
+                }
+
+                int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
+                int processed = ProcessRecord(received, record, buffer);
+                if (processed >= 0)
+                {
+                    return processed;
+                }
+
+                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+                waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis);
+            }
+
+            return -1;
+        }
+#endif
+
         /// <exception cref="IOException"/>
         public virtual void Send(byte[] buf, int off, int len)
         {
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            Send(buf.AsSpan(off, len));
+#else
             short contentType = ContentType.application_data;
 
             if (m_inHandshake || m_writeEpoch == m_retransmitEpoch)
@@ -338,7 +415,7 @@ namespace Org.BouncyCastle.Tls
                     // Implicitly send change_cipher_spec and change to pending cipher state
 
                     // TODO Send change_cipher_spec and finished records in single datagram?
-                    byte[] data = new byte[]{ 1 };
+                    byte[] data = new byte[1]{ 1 };
                     SendRecord(ContentType.change_cipher_spec, data, 0, data.Length);
 
                     this.m_writeEpoch = nextEpoch;
@@ -346,7 +423,51 @@ namespace Org.BouncyCastle.Tls
             }
 
             SendRecord(contentType, buf, off, len);
+#endif
+        }
+
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+        /// <exception cref="IOException"/>
+        public virtual void Send(ReadOnlySpan<byte> buffer)
+        {
+            short contentType = ContentType.application_data;
+
+            if (m_inHandshake || m_writeEpoch == m_retransmitEpoch)
+            {
+                contentType = ContentType.handshake;
+
+                short handshakeType = TlsUtilities.ReadUint8(buffer);
+                if (handshakeType == HandshakeType.finished)
+                {
+                    DtlsEpoch nextEpoch = null;
+                    if (m_inHandshake)
+                    {
+                        nextEpoch = m_pendingEpoch;
+                    }
+                    else if (m_writeEpoch == m_retransmitEpoch)
+                    {
+                        nextEpoch = m_currentEpoch;
+                    }
+
+                    if (nextEpoch == null)
+                    {
+                        // TODO
+                        throw new InvalidOperationException();
+                    }
+
+                    // Implicitly send change_cipher_spec and change to pending cipher state
+
+                    // TODO Send change_cipher_spec and finished records in single datagram?
+                    ReadOnlySpan<byte> data = stackalloc byte[1]{ 1 };
+                    SendRecord(ContentType.change_cipher_spec, data);
+
+                    this.m_writeEpoch = nextEpoch;
+                }
+            }
+
+            SendRecord(contentType, buffer);
         }
+#endif
 
         /// <exception cref="IOException"/>
         public virtual void Close()
@@ -432,11 +553,13 @@ namespace Org.BouncyCastle.Tls
         {
             m_peer.NotifyAlertRaised(alertLevel, alertDescription, message, cause);
 
-            byte[] error = new byte[2];
-            error[0] = (byte)alertLevel;
-            error[1] = (byte)alertDescription;
-
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            ReadOnlySpan<byte> error = stackalloc byte[2]{ (byte)alertLevel, (byte)alertDescription };
+            SendRecord(ContentType.alert, error);
+#else
+            byte[] error = new byte[2]{ (byte)alertLevel, (byte)alertDescription };
             SendRecord(ContentType.alert, error, 0, 2);
+#endif
         }
 
         /// <exception cref="IOException"/>
@@ -467,7 +590,11 @@ namespace Org.BouncyCastle.Tls
 
         // TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat
         /// <exception cref="IOException"/>
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+        private int ProcessRecord(int received, byte[] record, Span<byte> buffer)
+#else
         private int ProcessRecord(int received, byte[] record, byte[] buf, int off)
+#endif
         {
             // NOTE: received < 0 (timeout) is covered by this first case
             if (received < RECORD_HEADER_LENGTH)
@@ -700,7 +827,11 @@ namespace Org.BouncyCastle.Tls
                 this.m_retransmitTimeout = null;
             }
 
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            decoded.buf.AsSpan(decoded.off, decoded.len).CopyTo(buffer);
+#else
             Array.Copy(decoded.buf, decoded.off, buf, off, decoded.len);
+#endif
             return decoded.len;
         }
 
@@ -712,9 +843,7 @@ namespace Org.BouncyCastle.Tls
                 int length = 0;
                 if (m_recordQueue.Available >= RECORD_HEADER_LENGTH)
                 {
-                    byte[] lengthBytes = new byte[2];
-                    m_recordQueue.Read(lengthBytes, 0, 2, 11);
-                    length = TlsUtilities.ReadUint16(lengthBytes, 0);
+                    length = m_recordQueue.ReadUint16(11);
                 }
 
                 int received = System.Math.Min(m_recordQueue.Available, RECORD_HEADER_LENGTH + length);
@@ -754,9 +883,16 @@ namespace Org.BouncyCastle.Tls
         {
             MemoryStream output = new MemoryStream();
             heartbeatMessage.Encode(output);
-            byte[] buf = output.ToArray();
 
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            if (!output.TryGetBuffer(out var buffer))
+                throw new InvalidOperationException();
+
+            SendRecord(ContentType.heartbeat, buffer);
+#else
+            byte[] buf = output.ToArray();
             SendRecord(ContentType.heartbeat, buf, 0, buf.Length);
+#endif
         }
 
         /*
@@ -766,12 +902,20 @@ namespace Org.BouncyCastle.Tls
          * be possible reordering of records (which might surprise a reliable transport implementation).
          */
         /// <exception cref="IOException"/>
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+        private void SendRecord(short contentType, ReadOnlySpan<byte> buffer)
+#else
         private void SendRecord(short contentType, byte[] buf, int off, int len)
+#endif
         {
             // Never send anything until a valid ClientHello has been received
             if (m_writeVersion == null)
                 return;
 
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+            int len = buffer.Length;
+#endif
+
             if (len > m_plaintextLimit)
                 throw new TlsFatalAlert(AlertDescription.internal_error);
 
@@ -789,8 +933,13 @@ namespace Org.BouncyCastle.Tls
                 long macSequenceNumber = GetMacSequenceNumber(recordEpoch, recordSequenceNumber);
                 ProtocolVersion recordVersion = m_writeVersion;
 
+#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
+                TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType,
+                    recordVersion, RECORD_HEADER_LENGTH, buffer);
+#else
                 TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType,
                     recordVersion, RECORD_HEADER_LENGTH, buf, off, len);
+#endif
 
                 int ciphertextLength = encoded.len - RECORD_HEADER_LENGTH;
                 TlsUtilities.CheckUint16(ciphertextLength);