summary refs log tree commit diff
path: root/crypto/src/tls/DtlsReliableHandshake.cs
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2021-07-12 15:15:36 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2021-07-12 15:15:36 +0700
commit68c795fe81277f73aeb90d8ad4c6f4305f32c906 (patch)
tree59643344aafef91bbd4c4a3a7973deba3d837a00 /crypto/src/tls/DtlsReliableHandshake.cs
parentTLS test tweaks (diff)
downloadBouncyCastle.NET-ed25519-68c795fe81277f73aeb90d8ad4c6f4305f32c906.tar.xz
Port of new TLS API from bc-java
Diffstat (limited to 'crypto/src/tls/DtlsReliableHandshake.cs')
-rw-r--r--crypto/src/tls/DtlsReliableHandshake.cs558
1 files changed, 558 insertions, 0 deletions
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs
new file mode 100644
index 000000000..b2f8f130a
--- /dev/null
+++ b/crypto/src/tls/DtlsReliableHandshake.cs
@@ -0,0 +1,558 @@
+using System;
+using System.Collections;
+using System.IO;
+
+using Org.BouncyCastle.Utilities;
+using Org.BouncyCastle.Utilities.Date;
+
+namespace Org.BouncyCastle.Tls
+{
+    internal class DtlsReliableHandshake
+    {
+        private const int MAX_RECEIVE_AHEAD = 16;
+        private const int MESSAGE_HEADER_LENGTH = 12;
+
+        internal const int INITIAL_RESEND_MILLIS = 1000;
+        private const int MAX_RESEND_MILLIS = 60000;
+
+        /// <exception cref="IOException"/>
+        internal static DtlsRequest ReadClientRequest(byte[] data, int dataOff, int dataLen, Stream dtlsOutput)
+        {
+            // TODO Support the possibility of a fragmented ClientHello datagram
+
+            byte[] message = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen);
+            if (null == message || message.Length < MESSAGE_HEADER_LENGTH)
+                return null;
+
+            long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5);
+
+            short msgType = TlsUtilities.ReadUint8(message, 0);
+            if (HandshakeType.client_hello != msgType)
+                return null;
+
+            int length = TlsUtilities.ReadUint24(message, 1);
+            if (message.Length != MESSAGE_HEADER_LENGTH + length)
+                return null;
+
+            // TODO Consider stricter HelloVerifyRequest-related checks
+            //int messageSeq = TlsUtilities.ReadUint16(message, 4);
+            //if (messageSeq > 1)
+            //    return null;
+
+            int fragmentOffset = TlsUtilities.ReadUint24(message, 6);
+            if (0 != fragmentOffset)
+                return null;
+
+            int fragmentLength = TlsUtilities.ReadUint24(message, 9);
+            if (length != fragmentLength)
+                return null;
+
+            ClientHello clientHello = ClientHello.Parse(
+                new MemoryStream(message, MESSAGE_HEADER_LENGTH, length, false), dtlsOutput);
+
+            return new DtlsRequest(recordSeq, message, clientHello);
+        }
+
+        /// <exception cref="IOException"/>
+        internal static void SendHelloVerifyRequest(DatagramSender sender, long recordSeq, byte[] cookie)
+        {
+            TlsUtilities.CheckUint8(cookie.Length);
+
+            int length = 3 + cookie.Length;
+
+            byte[] message = new byte[MESSAGE_HEADER_LENGTH + length];
+            TlsUtilities.WriteUint8(HandshakeType.hello_verify_request, message, 0);
+            TlsUtilities.WriteUint24(length, message, 1);
+            //TlsUtilities.WriteUint16(0, message, 4);
+            //TlsUtilities.WriteUint24(0, message, 6);
+            TlsUtilities.WriteUint24(length, message, 9);
+
+            // HelloVerifyRequest fields
+            TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MESSAGE_HEADER_LENGTH + 0);
+            TlsUtilities.WriteOpaque8(cookie, message, MESSAGE_HEADER_LENGTH + 2);
+
+            DtlsRecordLayer.SendHelloVerifyRequestRecord(sender, recordSeq, message);
+        }
+
+        /*
+         * No 'final' modifiers so that it works in earlier JDKs
+         */
+        private DtlsRecordLayer m_recordLayer;
+        private Timeout m_handshakeTimeout;
+
+        private TlsHandshakeHash m_handshakeHash;
+
+        private IDictionary m_currentInboundFlight = Platform.CreateHashtable();
+        private IDictionary m_previousInboundFlight = null;
+        private IList m_outboundFlight = Platform.CreateArrayList();
+
+        private int m_resendMillis = -1;
+        private Timeout m_resendTimeout = null;
+
+        private int m_next_send_seq = 0, m_next_receive_seq = 0;
+
+        internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport, int timeoutMillis,
+            DtlsRequest request)
+        {
+            this.m_recordLayer = transport;
+            this.m_handshakeHash = new DeferredHash(context);
+            this.m_handshakeTimeout = Timeout.ForWaitMillis(timeoutMillis);
+
+            if (null != request)
+            {
+                this.m_resendMillis = INITIAL_RESEND_MILLIS;
+                this.m_resendTimeout = new Timeout(m_resendMillis);
+
+                long recordSeq = request.RecordSeq;
+                int messageSeq = request.MessageSeq;
+                byte[] message = request.Message;
+
+                m_recordLayer.ResetAfterHelloVerifyRequestServer(recordSeq);
+
+                // Simulate a previous flight consisting of the request ClientHello
+                DtlsReassembler reassembler = new DtlsReassembler(HandshakeType.client_hello,
+                    message.Length - MESSAGE_HEADER_LENGTH);
+                m_currentInboundFlight[messageSeq] = reassembler;
+
+                // We sent HelloVerifyRequest with (message) sequence number 0
+                this.m_next_send_seq = 1;
+                this.m_next_receive_seq = messageSeq + 1;
+
+                m_handshakeHash.Update(message, 0, message.Length);
+            }
+        }
+
+        internal void ResetAfterHelloVerifyRequestClient()
+        {
+            this.m_currentInboundFlight = Platform.CreateHashtable();
+            this.m_previousInboundFlight = null;
+            this.m_outboundFlight = Platform.CreateArrayList();
+
+            this.m_resendMillis = -1;
+            this.m_resendTimeout = null;
+
+            // We're waiting for ServerHello, always with (message) sequence number 1
+            this.m_next_receive_seq = 1;
+
+            m_handshakeHash.Reset();
+        }
+
+        internal TlsHandshakeHash HandshakeHash
+        {
+            get { return m_handshakeHash; }
+        }
+
+        internal TlsHandshakeHash PrepareToFinish()
+        {
+            TlsHandshakeHash result = m_handshakeHash;
+            this.m_handshakeHash = m_handshakeHash.StopTracking();
+            return result;
+        }
+
+        /// <exception cref="IOException"/>
+        internal void SendMessage(short msg_type, byte[] body)
+        {
+            TlsUtilities.CheckUint24(body.Length);
+
+            if (null != m_resendTimeout)
+            {
+                CheckInboundFlight();
+
+                this.m_resendMillis = -1;
+                this.m_resendTimeout = null;
+
+                m_outboundFlight.Clear();
+            }
+
+            Message message = new Message(m_next_send_seq++, msg_type, body);
+
+            m_outboundFlight.Add(message);
+
+            WriteMessage(message);
+            UpdateHandshakeMessagesDigest(message);
+        }
+
+        /// <exception cref="IOException"/>
+        internal byte[] ReceiveMessageBody(short msg_type)
+        {
+            Message message = ReceiveMessage();
+            if (message.Type != msg_type)
+                throw new TlsFatalAlert(AlertDescription.unexpected_message);
+
+            return message.Body;
+        }
+
+        /// <exception cref="IOException"/>
+        internal Message ReceiveMessage()
+        {
+            long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+
+            if (null == m_resendTimeout)
+            {
+                m_resendMillis = INITIAL_RESEND_MILLIS;
+                m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
+
+                PrepareInboundFlight(Platform.CreateHashtable());
+            }
+
+            byte[] buf = null;
+
+            for (;;)
+            {
+                if (m_recordLayer.IsClosed)
+                    throw new TlsFatalAlert(AlertDescription.user_canceled);
+
+                Message pending = GetPendingMessage();
+                if (pending != null)
+                    return pending;
+
+                if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
+                    throw new TlsTimeoutException("Handshake timed out");
+
+                int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
+                waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
+
+                // NOTE: Ensure a finite wait, of at least 1ms
+                if (waitMillis < 1)
+                {
+                    waitMillis = 1;
+                }
+
+                int receiveLimit = m_recordLayer.GetReceiveLimit();
+                if (buf == null || buf.Length < receiveLimit)
+                {
+                    buf = new byte[receiveLimit];
+                }
+
+                int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
+                if (received < 0)
+                {
+                    ResendOutboundFlight();
+                }
+                else
+                {
+                    ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
+                }
+
+                currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+            }
+        }
+
+        internal void Finish()
+        {
+            DtlsHandshakeRetransmit retransmit = null;
+            if (null != m_resendTimeout)
+            {
+                CheckInboundFlight();
+            }
+            else
+            {
+                PrepareInboundFlight(null);
+
+                if (m_previousInboundFlight != null)
+                {
+                    /*
+                     * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
+                     * when in the FINISHED state, the node that transmits the last flight (the server in an
+                     * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
+                     * of the peer's last flight with a retransmit of the last flight.
+                     */
+                    retransmit = new Retransmit(this);
+                }
+            }
+
+            m_recordLayer.HandshakeSuccessful(retransmit);
+        }
+
+        internal static int BackOff(int timeoutMillis)
+        {
+            /*
+             * TODO[DTLS] implementations SHOULD back off handshake packet size during the
+             * retransmit backoff.
+             */
+            return System.Math.Min(timeoutMillis * 2, MAX_RESEND_MILLIS);
+        }
+
+        /**
+         * Check that there are no "extra" messages left in the current inbound flight
+         */
+        private void CheckInboundFlight()
+        {
+            foreach (int key in m_currentInboundFlight.Keys)
+            {
+                if (key >= m_next_receive_seq)
+                {
+                    // TODO Should this be considered an error?
+                }
+            }
+        }
+
+        /// <exception cref="IOException"/>
+        private Message GetPendingMessage()
+        {
+            DtlsReassembler next = (DtlsReassembler)m_currentInboundFlight[m_next_receive_seq];
+            if (next != null)
+            {
+                byte[] body = next.GetBodyIfComplete();
+                if (body != null)
+                {
+                    m_previousInboundFlight = null;
+                    return UpdateHandshakeMessagesDigest(new Message(m_next_receive_seq++, next.MsgType, body));
+                }
+            }
+            return null;
+        }
+
+        private void PrepareInboundFlight(IDictionary nextFlight)
+        {
+            ResetAll(m_currentInboundFlight);
+            m_previousInboundFlight = m_currentInboundFlight;
+            m_currentInboundFlight = nextFlight;
+        }
+
+        /// <exception cref="IOException"/>
+        private void ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
+        {
+            bool checkPreviousFlight = false;
+
+            while (len >= MESSAGE_HEADER_LENGTH)
+            {
+                int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
+                int message_length = fragment_length + MESSAGE_HEADER_LENGTH;
+                if (len < message_length)
+                {
+                    // NOTE: Truncated message - ignore it
+                    break;
+                }
+
+                int length = TlsUtilities.ReadUint24(buf, off + 1);
+                int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
+                if (fragment_offset + fragment_length > length)
+                {
+                    // NOTE: Malformed fragment - ignore it and the rest of the record
+                    break;
+                }
+
+                /*
+                 * NOTE: This very simple epoch check will only work until we want to support
+                 * renegotiation (and we're not likely to do that anyway).
+                 */
+                short msg_type = TlsUtilities.ReadUint8(buf, off + 0);
+                int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
+                if (epoch != expectedEpoch)
+                    break;
+
+                int message_seq = TlsUtilities.ReadUint16(buf, off + 4);
+                if (message_seq >= (m_next_receive_seq + windowSize))
+                {
+                    // NOTE: Too far ahead - ignore
+                }
+                else if (message_seq >= m_next_receive_seq)
+                {
+                    DtlsReassembler reassembler = (DtlsReassembler)m_currentInboundFlight[message_seq];
+                    if (reassembler == null)
+                    {
+                        reassembler = new DtlsReassembler(msg_type, length);
+                        m_currentInboundFlight[message_seq] = reassembler;
+                    }
+
+                    reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, fragment_offset,
+                        fragment_length);
+                }
+                else if (m_previousInboundFlight != null)
+                {
+                    /*
+                     * NOTE: If we receive the previous flight of incoming messages in full again,
+                     * retransmit our last flight
+                     */
+
+                    DtlsReassembler reassembler = (DtlsReassembler)m_previousInboundFlight[message_seq];
+                    if (reassembler != null)
+                    {
+                        reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH,
+                            fragment_offset, fragment_length);
+                        checkPreviousFlight = true;
+                    }
+                }
+
+                off += message_length;
+                len -= message_length;
+            }
+
+            if (checkPreviousFlight && CheckAll(m_previousInboundFlight))
+            {
+                ResendOutboundFlight();
+                ResetAll(m_previousInboundFlight);
+            }
+        }
+
+        /// <exception cref="IOException"/>
+        private void ResendOutboundFlight()
+        {
+            m_recordLayer.ResetWriteEpoch();
+            foreach (Message message in m_outboundFlight)
+            {
+                WriteMessage(message);
+            }
+
+            m_resendMillis = BackOff(m_resendMillis);
+            m_resendTimeout = new Timeout(m_resendMillis);
+        }
+
+        /// <exception cref="IOException"/>
+        private Message UpdateHandshakeMessagesDigest(Message message)
+        {
+            short msg_type = message.Type;
+            switch (msg_type)
+            {
+            case HandshakeType.hello_request:
+            case HandshakeType.hello_verify_request:
+            case HandshakeType.key_update:
+            case HandshakeType.new_session_ticket:
+                break;
+
+            default:
+            {
+                byte[] body = message.Body;
+                byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
+                TlsUtilities.WriteUint8(msg_type, buf, 0);
+                TlsUtilities.WriteUint24(body.Length, buf, 1);
+                TlsUtilities.WriteUint16(message.Seq, buf, 4);
+                TlsUtilities.WriteUint24(0, buf, 6);
+                TlsUtilities.WriteUint24(body.Length, buf, 9);
+                m_handshakeHash.Update(buf, 0, buf.Length);
+                m_handshakeHash.Update(body, 0, body.Length);
+                break;
+            }
+            }
+
+            return message;
+        }
+
+        /// <exception cref="IOException"/>
+        private void WriteMessage(Message message)
+        {
+            int sendLimit = m_recordLayer.GetSendLimit();
+            int fragmentLimit = sendLimit - MESSAGE_HEADER_LENGTH;
+
+            // TODO Support a higher minimum fragment size?
+            if (fragmentLimit < 1)
+            {
+                // TODO Should we be throwing an exception here?
+                throw new TlsFatalAlert(AlertDescription.internal_error);
+            }
+
+            int length = message.Body.Length;
+
+            // NOTE: Must still send a fragment if body is empty
+            int fragment_offset = 0;
+            do
+            {
+                int fragment_length = System.Math.Min(length - fragment_offset, fragmentLimit);
+                WriteHandshakeFragment(message, fragment_offset, fragment_length);
+                fragment_offset += fragment_length;
+            }
+            while (fragment_offset < length);
+        }
+
+        /// <exception cref="IOException"/>
+        private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
+        {
+            RecordLayerBuffer fragment = new RecordLayerBuffer(MESSAGE_HEADER_LENGTH + fragment_length);
+            TlsUtilities.WriteUint8(message.Type, fragment);
+            TlsUtilities.WriteUint24(message.Body.Length, fragment);
+            TlsUtilities.WriteUint16(message.Seq, fragment);
+            TlsUtilities.WriteUint24(fragment_offset, fragment);
+            TlsUtilities.WriteUint24(fragment_length, fragment);
+            fragment.Write(message.Body, fragment_offset, fragment_length);
+
+            fragment.SendToRecordLayer(m_recordLayer);
+        }
+
+        private static bool CheckAll(IDictionary inboundFlight)
+        {
+            foreach (DtlsReassembler r in inboundFlight.Values)
+            {
+                if (r.GetBodyIfComplete() == null)
+                    return false;
+            }
+            return true;
+        }
+
+        private static void ResetAll(IDictionary inboundFlight)
+        {
+            foreach (DtlsReassembler r in inboundFlight.Values)
+            {
+                r.Reset();
+            }
+        }
+
+        internal class Message
+        {
+            private readonly int m_message_seq;
+            private readonly short m_msg_type;
+            private readonly byte[] m_body;
+
+            internal Message(int message_seq, short msg_type, byte[] body)
+            {
+                this.m_message_seq = message_seq;
+                this.m_msg_type = msg_type;
+                this.m_body = body;
+            }
+
+            public int Seq
+            {
+                get { return m_message_seq; }
+            }
+
+            public short Type
+            {
+                get { return m_msg_type; }
+            }
+
+            public byte[] Body
+            {
+                get { return m_body; }
+            }
+        }
+
+        internal class RecordLayerBuffer
+            : MemoryStream
+        {
+            internal RecordLayerBuffer(int size)
+                : base(size)
+            {
+            }
+
+            internal void SendToRecordLayer(DtlsRecordLayer recordLayer)
+            {
+#if PORTABLE
+                byte[] buf = ToArray();
+                int bufLen = buf.Length;
+#else
+                byte[] buf = GetBuffer();
+                int bufLen = (int)Length;
+#endif
+
+                recordLayer.Send(buf, 0, bufLen);
+                Platform.Dispose(this);
+            }
+        }
+
+        internal class Retransmit
+            : DtlsHandshakeRetransmit
+        {
+            private readonly DtlsReliableHandshake m_outer;
+
+            internal Retransmit(DtlsReliableHandshake outer)
+            {
+                this.m_outer = outer;
+            }
+
+            public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len)
+            {
+                m_outer.ProcessRecord(0, epoch, buf, off, len);
+            }
+        }
+    }
+}