diff options
Diffstat (limited to 'crypto/src/tls/DtlsRecordLayer.cs')
-rw-r--r-- | crypto/src/tls/DtlsRecordLayer.cs | 167 |
1 files changed, 158 insertions, 9 deletions
diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index 7ec77c5da..bab6892b7 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -242,6 +242,9 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + return Receive(buf.AsSpan(off, len), waitMillis); +#else long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis); @@ -305,11 +308,85 @@ namespace Org.BouncyCastle.Tls } return -1; +#endif } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + /// <exception cref="IOException"/> + public virtual int Receive(Span<byte> buffer, int waitMillis) + { + long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + + Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis); + byte[] record = null; + + while (waitMillis >= 0) + { + if (null != m_retransmitTimeout && m_retransmitTimeout.RemainingMillis(currentTimeMillis) < 1) + { + m_retransmit = null; + m_retransmitEpoch = null; + m_retransmitTimeout = null; + } + + if (Timeout.HasExpired(m_heartbeatTimeout, currentTimeMillis)) + { + if (null != m_heartbeatInFlight) + throw new TlsTimeoutException("Heartbeat timed out"); + + this.m_heartbeatInFlight = HeartbeatMessage.Create(m_context, + HeartbeatMessageType.heartbeat_request, m_heartbeat.GeneratePayload()); + this.m_heartbeatTimeout = new Timeout(m_heartbeat.TimeoutMillis, currentTimeMillis); + + this.m_heartbeatResendMillis = DtlsReliableHandshake.INITIAL_RESEND_MILLIS; + this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis); + + SendHeartbeatMessage(m_heartbeatInFlight); + } + else if (Timeout.HasExpired(m_heartbeatResendTimeout, currentTimeMillis)) + { + this.m_heartbeatResendMillis = DtlsReliableHandshake.BackOff(m_heartbeatResendMillis); + this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis); + + SendHeartbeatMessage(m_heartbeatInFlight); + } + + waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatTimeout, currentTimeMillis); + waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatResendTimeout, currentTimeMillis); + + // NOTE: Guard against bad logic giving a negative value + if (waitMillis < 0) + { + waitMillis = 1; + } + + int receiveLimit = System.Math.Min(buffer.Length, GetReceiveLimit()) + RECORD_HEADER_LENGTH; + if (null == record || record.Length < receiveLimit) + { + record = new byte[receiveLimit]; + } + + int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); + int processed = ProcessRecord(received, record, buffer); + if (processed >= 0) + { + return processed; + } + + currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis); + } + + return -1; + } +#endif + /// <exception cref="IOException"/> public virtual void Send(byte[] buf, int off, int len) { +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + Send(buf.AsSpan(off, len)); +#else short contentType = ContentType.application_data; if (m_inHandshake || m_writeEpoch == m_retransmitEpoch) @@ -338,7 +415,7 @@ namespace Org.BouncyCastle.Tls // Implicitly send change_cipher_spec and change to pending cipher state // TODO Send change_cipher_spec and finished records in single datagram? - byte[] data = new byte[]{ 1 }; + byte[] data = new byte[1]{ 1 }; SendRecord(ContentType.change_cipher_spec, data, 0, data.Length); this.m_writeEpoch = nextEpoch; @@ -346,7 +423,51 @@ namespace Org.BouncyCastle.Tls } SendRecord(contentType, buf, off, len); +#endif + } + +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + /// <exception cref="IOException"/> + public virtual void Send(ReadOnlySpan<byte> buffer) + { + short contentType = ContentType.application_data; + + if (m_inHandshake || m_writeEpoch == m_retransmitEpoch) + { + contentType = ContentType.handshake; + + short handshakeType = TlsUtilities.ReadUint8(buffer); + if (handshakeType == HandshakeType.finished) + { + DtlsEpoch nextEpoch = null; + if (m_inHandshake) + { + nextEpoch = m_pendingEpoch; + } + else if (m_writeEpoch == m_retransmitEpoch) + { + nextEpoch = m_currentEpoch; + } + + if (nextEpoch == null) + { + // TODO + throw new InvalidOperationException(); + } + + // Implicitly send change_cipher_spec and change to pending cipher state + + // TODO Send change_cipher_spec and finished records in single datagram? + ReadOnlySpan<byte> data = stackalloc byte[1]{ 1 }; + SendRecord(ContentType.change_cipher_spec, data); + + this.m_writeEpoch = nextEpoch; + } + } + + SendRecord(contentType, buffer); } +#endif /// <exception cref="IOException"/> public virtual void Close() @@ -432,11 +553,13 @@ namespace Org.BouncyCastle.Tls { m_peer.NotifyAlertRaised(alertLevel, alertDescription, message, cause); - byte[] error = new byte[2]; - error[0] = (byte)alertLevel; - error[1] = (byte)alertDescription; - +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + ReadOnlySpan<byte> error = stackalloc byte[2]{ (byte)alertLevel, (byte)alertDescription }; + SendRecord(ContentType.alert, error); +#else + byte[] error = new byte[2]{ (byte)alertLevel, (byte)alertDescription }; SendRecord(ContentType.alert, error, 0, 2); +#endif } /// <exception cref="IOException"/> @@ -467,7 +590,11 @@ namespace Org.BouncyCastle.Tls // TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat /// <exception cref="IOException"/> +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + private int ProcessRecord(int received, byte[] record, Span<byte> buffer) +#else private int ProcessRecord(int received, byte[] record, byte[] buf, int off) +#endif { // NOTE: received < 0 (timeout) is covered by this first case if (received < RECORD_HEADER_LENGTH) @@ -700,7 +827,11 @@ namespace Org.BouncyCastle.Tls this.m_retransmitTimeout = null; } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + decoded.buf.AsSpan(decoded.off, decoded.len).CopyTo(buffer); +#else Array.Copy(decoded.buf, decoded.off, buf, off, decoded.len); +#endif return decoded.len; } @@ -712,9 +843,7 @@ namespace Org.BouncyCastle.Tls int length = 0; if (m_recordQueue.Available >= RECORD_HEADER_LENGTH) { - byte[] lengthBytes = new byte[2]; - m_recordQueue.Read(lengthBytes, 0, 2, 11); - length = TlsUtilities.ReadUint16(lengthBytes, 0); + length = m_recordQueue.ReadUint16(11); } int received = System.Math.Min(m_recordQueue.Available, RECORD_HEADER_LENGTH + length); @@ -754,9 +883,16 @@ namespace Org.BouncyCastle.Tls { MemoryStream output = new MemoryStream(); heartbeatMessage.Encode(output); - byte[] buf = output.ToArray(); +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + if (!output.TryGetBuffer(out var buffer)) + throw new InvalidOperationException(); + + SendRecord(ContentType.heartbeat, buffer); +#else + byte[] buf = output.ToArray(); SendRecord(ContentType.heartbeat, buf, 0, buf.Length); +#endif } /* @@ -766,12 +902,20 @@ namespace Org.BouncyCastle.Tls * be possible reordering of records (which might surprise a reliable transport implementation). */ /// <exception cref="IOException"/> +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + private void SendRecord(short contentType, ReadOnlySpan<byte> buffer) +#else private void SendRecord(short contentType, byte[] buf, int off, int len) +#endif { // Never send anything until a valid ClientHello has been received if (m_writeVersion == null) return; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + int len = buffer.Length; +#endif + if (len > m_plaintextLimit) throw new TlsFatalAlert(AlertDescription.internal_error); @@ -789,8 +933,13 @@ namespace Org.BouncyCastle.Tls long macSequenceNumber = GetMacSequenceNumber(recordEpoch, recordSequenceNumber); ProtocolVersion recordVersion = m_writeVersion; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType, + recordVersion, RECORD_HEADER_LENGTH, buffer); +#else TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType, recordVersion, RECORD_HEADER_LENGTH, buf, off, len); +#endif int ciphertextLength = encoded.len - RECORD_HEADER_LENGTH; TlsUtilities.CheckUint16(ciphertextLength); |