summary refs log tree commit diff
diff options
context:
space:
mode:
authorPeter Dettman <peter.dettman@bouncycastle.org>2020-07-30 01:14:52 +0700
committerPeter Dettman <peter.dettman@bouncycastle.org>2020-07-30 01:14:52 +0700
commit9193844b75819ac2b14622b000c42c1f527632f2 (patch)
tree07b9a07cd324d55368c7807ccc276376f9cbce4d
parentAdd Timeout class for DTLS from bc-java (diff)
downloadBouncyCastle.NET-ed25519-9193844b75819ac2b14622b000c42c1f527632f2.tar.xz
DTLS: Exceptions properly abort handshake
- see https://github.com/bcgit/bc-csharp/issues/258
-rw-r--r--crypto/src/crypto/tls/DtlsRecordLayer.cs339
-rw-r--r--crypto/src/crypto/tls/DtlsReliableHandshake.cs58
-rw-r--r--crypto/src/crypto/tls/DtlsTransport.cs41
-rw-r--r--crypto/src/crypto/tls/TlsUtilities.cs8
4 files changed, 266 insertions, 180 deletions
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 3cb0e78dd..266893df0 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -1,5 +1,6 @@
 using System;
 using System.IO;
+using System.Net.Sockets;
 
 using Org.BouncyCastle.Utilities.Date;
 
@@ -13,6 +14,21 @@ namespace Org.BouncyCastle.Crypto.Tls
         private const long TCP_MSL = 1000L * 60 * 2;
         private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
 
+        private static void SendDatagram(DatagramTransport sender, byte[] buf, int off, int len)
+        {
+            //try
+            //{
+            //    sender.Send(buf, off, len);
+            //}
+            //catch (InterruptedIOException e)
+            //{
+            //    e.bytesTransferred = 0;
+            //    throw e;
+            //}
+
+            sender.Send(buf, off, len);
+        }
+
         private readonly DatagramTransport mTransport;
         private readonly TlsContext mContext;
         private readonly TlsPeer mPeer;
@@ -134,6 +150,8 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
+            // TODO Avoid returning -1 (timeout) until 'waitMillis' has definitely elapsed  
+
             byte[] record = null;
 
             for (;;)
@@ -144,191 +162,183 @@ namespace Org.BouncyCastle.Crypto.Tls
                     record = new byte[receiveLimit];
                 }
 
-                try
+                if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
                 {
-                    if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
-                    {
-                        mRetransmit = null;
-                        mRetransmitEpoch = null;
-                    }
+                    mRetransmit = null;
+                    mRetransmitEpoch = null;
+                }
 
-                    int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
-                    if (received < 0)
-                    {
-                        return received;
-                    }
-                    if (received < RECORD_HEADER_LENGTH)
-                    {
-                        continue;
-                    }
-                    int length = TlsUtilities.ReadUint16(record, 11);
-                    if (received != (length + RECORD_HEADER_LENGTH))
-                    {
-                        continue;
-                    }
+                int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
+                if (received < 0)
+                {
+                    return received;
+                }
+                if (received < RECORD_HEADER_LENGTH)
+                {
+                    continue;
+                }
+                int length = TlsUtilities.ReadUint16(record, 11);
+                if (received != (length + RECORD_HEADER_LENGTH))
+                {
+                    continue;
+                }
 
-                    byte type = TlsUtilities.ReadUint8(record, 0);
+                byte type = TlsUtilities.ReadUint8(record, 0);
 
-                    // TODO Support user-specified custom protocols?
-                    switch (type)
-                    {
-                    case ContentType.alert:
-                    case ContentType.application_data:
-                    case ContentType.change_cipher_spec:
-                    case ContentType.handshake:
-                    case ContentType.heartbeat:
-                        break;
-                    default:
-                        // TODO Exception?
-                        continue;
-                    }
+                // TODO Support user-specified custom protocols?
+                switch (type)
+                {
+                case ContentType.alert:
+                case ContentType.application_data:
+                case ContentType.change_cipher_spec:
+                case ContentType.handshake:
+                case ContentType.heartbeat:
+                    break;
+                default:
+                    // TODO Exception?
+                    continue;
+                }
 
-                    int epoch = TlsUtilities.ReadUint16(record, 3);
+                int epoch = TlsUtilities.ReadUint16(record, 3);
 
-                    DtlsEpoch recordEpoch = null;
-                    if (epoch == mReadEpoch.Epoch)
-                    {
-                        recordEpoch = mReadEpoch;
-                    }
-                    else if (type == ContentType.handshake && mRetransmitEpoch != null
-                        && epoch == mRetransmitEpoch.Epoch)
-                    {
-                        recordEpoch = mRetransmitEpoch;
-                    }
+                DtlsEpoch recordEpoch = null;
+                if (epoch == mReadEpoch.Epoch)
+                {
+                    recordEpoch = mReadEpoch;
+                }
+                else if (type == ContentType.handshake && mRetransmitEpoch != null
+                    && epoch == mRetransmitEpoch.Epoch)
+                {
+                    recordEpoch = mRetransmitEpoch;
+                }
 
-                    if (recordEpoch == null)
-                    {
-                        continue;
-                    }
+                if (recordEpoch == null)
+                {
+                    continue;
+                }
 
-                    long seq = TlsUtilities.ReadUint48(record, 5);
-                    if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
-                    {
-                        continue;
-                    }
+                long seq = TlsUtilities.ReadUint48(record, 5);
+                if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
+                {
+                    continue;
+                }
 
-                    ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
-                    if (!version.IsDtls)
-                    {
-                        continue;
-                    }
+                ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
+                if (!version.IsDtls)
+                {
+                    continue;
+                }
 
-                    if (mReadVersion != null && !mReadVersion.Equals(version))
-                    {
-                        continue;
-                    }
+                if (mReadVersion != null && !mReadVersion.Equals(version))
+                {
+                    continue;
+                }
 
-                    byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
-                        GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
-                        received - RECORD_HEADER_LENGTH);
+                byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
+                    GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
+                    received - RECORD_HEADER_LENGTH);
 
-                    recordEpoch.ReplayWindow.ReportAuthenticated(seq);
+                recordEpoch.ReplayWindow.ReportAuthenticated(seq);
 
-                    if (plaintext.Length > this.mPlaintextLimit)
-                    {
-                        continue;
-                    }
+                if (plaintext.Length > this.mPlaintextLimit)
+                {
+                    continue;
+                }
 
-                    if (mReadVersion == null)
-                    {
-                        mReadVersion = version;
-                    }
+                if (mReadVersion == null)
+                {
+                    mReadVersion = version;
+                }
 
-                    switch (type)
-                    {
-                    case ContentType.alert:
+                switch (type)
+                {
+                case ContentType.alert:
+                {
+                    if (plaintext.Length == 2)
                     {
-                        if (plaintext.Length == 2)
+                        byte alertLevel = plaintext[0];
+                        byte alertDescription = plaintext[1];
+
+                        mPeer.NotifyAlertReceived(alertLevel, alertDescription);
+
+                        if (alertLevel == AlertLevel.fatal)
                         {
-                            byte alertLevel = plaintext[0];
-                            byte alertDescription = plaintext[1];
-
-                            mPeer.NotifyAlertReceived(alertLevel, alertDescription);
-
-                            if (alertLevel == AlertLevel.fatal)
-                            {
-                                Failed();
-                                throw new TlsFatalAlert(alertDescription);
-                            }
-
-                            // TODO Can close_notify be a fatal alert?
-                            if (alertDescription == AlertDescription.close_notify)
-                            {
-                                CloseTransport();
-                            }
+                            Failed();
+                            throw new TlsFatalAlert(alertDescription);
                         }
 
+                        // TODO Can close_notify be a fatal alert?
+                        if (alertDescription == AlertDescription.close_notify)
+                        {
+                            CloseTransport();
+                        }
+                    }
+
+                    continue;
+                }
+                case ContentType.application_data:
+                {
+                    if (mInHandshake)
+                    {
+                        // TODO Consider buffering application data for new epoch that arrives
+                        // out-of-order with the Finished message
                         continue;
                     }
-                    case ContentType.application_data:
+                    break;
+                }
+                case ContentType.change_cipher_spec:
+                {
+                    // Implicitly receive change_cipher_spec and change to pending cipher state
+
+                    for (int i = 0; i < plaintext.Length; ++i)
                     {
-                        if (mInHandshake)
+                        byte message = TlsUtilities.ReadUint8(plaintext, i);
+                        if (message != ChangeCipherSpec.change_cipher_spec)
                         {
-                            // TODO Consider buffering application data for new epoch that arrives
-                            // out-of-order with the Finished message
                             continue;
                         }
-                        break;
-                    }
-                    case ContentType.change_cipher_spec:
-                    {
-                        // Implicitly receive change_cipher_spec and change to pending cipher state
 
-                        for (int i = 0; i < plaintext.Length; ++i)
+                        if (mPendingEpoch != null)
                         {
-                            byte message = TlsUtilities.ReadUint8(plaintext, i);
-                            if (message != ChangeCipherSpec.change_cipher_spec)
-                            {
-                                continue;
-                            }
-
-                            if (mPendingEpoch != null)
-                            {
-                                mReadEpoch = mPendingEpoch;
-                            }
+                            mReadEpoch = mPendingEpoch;
                         }
-
-                        continue;
                     }
-                    case ContentType.handshake:
+
+                    continue;
+                }
+                case ContentType.handshake:
+                {
+                    if (!mInHandshake)
                     {
-                        if (!mInHandshake)
+                        if (mRetransmit != null)
                         {
-                            if (mRetransmit != null)
-                            {
-                                mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
-                            }
-
-                            // TODO Consider support for HelloRequest
-                            continue;
+                            mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
                         }
-                        break;
-                    }
-                    case ContentType.heartbeat:
-                    {
-                        // TODO[RFC 6520]
-                        continue;
-                    }
-                    }
 
-                    /*
-                     * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
-                     * received our final flight.
-                     */
-                    if (!mInHandshake && mRetransmit != null)
-                    {
-                        this.mRetransmit = null;
-                        this.mRetransmitEpoch = null;
+                        // TODO Consider support for HelloRequest
+                        continue;
                     }
-
-                    Array.Copy(plaintext, 0, buf, off, plaintext.Length);
-                    return plaintext.Length;
+                    break;
                 }
-                catch (IOException e)
+                case ContentType.heartbeat:
                 {
-                    // NOTE: Assume this is a timeout for the moment
-                    throw e;
+                    // TODO[RFC 6520]
+                    continue;
+                }
                 }
+
+                /*
+                    * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
+                    * received our final flight.
+                    */
+                if (!mInHandshake && mRetransmit != null)
+                {
+                    this.mRetransmit = null;
+                    this.mRetransmitEpoch = null;
+                }
+
+                Array.Copy(plaintext, 0, buf, off, plaintext.Length);
+                return plaintext.Length;
             }
         }
 
@@ -458,6 +468,35 @@ namespace Org.BouncyCastle.Crypto.Tls
             SendRecord(ContentType.alert, error, 0, 2);
         }
 
+        private int ReceiveDatagram(byte[] buf, int off, int len, int waitMillis)
+        {
+            //try
+            //{
+            //    return mTransport.Receive(buf, off, len, waitMillis);
+            //}
+            //catch (SocketTimeoutException e)
+            //{
+            //    return -1;
+            //}
+            //catch (InterruptedIOException e)
+            //{
+            //    e.bytesTransferred = 0;
+            //    throw e;
+            //}
+
+            try
+            {
+                return mTransport.Receive(buf, off, len, waitMillis);
+            }
+            catch (SocketException e)
+            {
+                if (TlsUtilities.IsTimeout(e))
+                    return -1;
+
+                throw e;
+            }
+        }
+
         private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
         {
             if (mRecordQueue.Available > 0)
@@ -476,7 +515,7 @@ namespace Org.BouncyCastle.Crypto.Tls
             }
 
             {
-                int received = mTransport.Receive(buf, off, len, waitMillis);
+                int received = ReceiveDatagram(buf, off, len, waitMillis);
                 if (received >= RECORD_HEADER_LENGTH)
                 {
                     int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11);
@@ -524,7 +563,7 @@ namespace Org.BouncyCastle.Crypto.Tls
             TlsUtilities.WriteUint16(ciphertext.Length, record, 11);
             Array.Copy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.Length);
 
-            mTransport.Send(record, 0, record.Length);
+            SendDatagram(mTransport, record, 0, record.Length);
         }
 
         private static long GetMacSequenceNumber(int epoch, long sequence_number)
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
index 8fcc1d7c2..92c222e70 100644
--- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
@@ -76,6 +76,8 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         internal Message ReceiveMessage()
         {
+            // TODO Add support for "overall" handshake timeout
+
             if (mSending)
             {
                 mSending = false;
@@ -89,41 +91,37 @@ namespace Org.BouncyCastle.Crypto.Tls
 
             for (;;)
             {
-                try
+                if (mRecordLayer.IsClosed)
+                    throw new TlsFatalAlert(AlertDescription.user_canceled);
+
+                Message pending = GetPendingMessage();
+                if (pending != null)
+                    return pending;
+
+                int receiveLimit = mRecordLayer.GetReceiveLimit();
+                if (buf == null || buf.Length < receiveLimit)
                 {
-                    for (;;)
-                    {
-                        if (mRecordLayer.IsClosed)
-                            throw new TlsFatalAlert(AlertDescription.user_canceled);
-
-                        Message pending = GetPendingMessage();
-                        if (pending != null)
-                            return pending;
-
-                        int receiveLimit = mRecordLayer.GetReceiveLimit();
-                        if (buf == null || buf.Length < receiveLimit)
-                        {
-                            buf = new byte[receiveLimit];
-                        }
-
-                        int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
-                        if (received < 0)
-                            break;
-
-                        bool resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
-                        if (resentOutbound)
-                        {
-                            readTimeoutMillis = BackOff(readTimeoutMillis);
-                        }
-                    }
+                    buf = new byte[receiveLimit];
+                }
+
+                int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
+
+                bool resentOutbound;
+                if (received < 0)
+                {
+                    ResendOutboundFlight();
+                    resentOutbound = true;
                 }
-                catch (IOException)
+                else
                 {
-                    // NOTE: Assume this is a timeout for the moment
+                    resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
                 }
 
-                ResendOutboundFlight();
-                readTimeoutMillis = BackOff(readTimeoutMillis);
+                // TODO Review conditions for resend/backoff  
+                if (resentOutbound)
+                {
+                    readTimeoutMillis = BackOff(readTimeoutMillis);
+                }
             }
         }
 
diff --git a/crypto/src/crypto/tls/DtlsTransport.cs b/crypto/src/crypto/tls/DtlsTransport.cs
index 5c607336b..bc09707c1 100644
--- a/crypto/src/crypto/tls/DtlsTransport.cs
+++ b/crypto/src/crypto/tls/DtlsTransport.cs
@@ -1,5 +1,6 @@
 using System;
 using System.IO;
+using System.Net.Sockets;
 
 namespace Org.BouncyCastle.Crypto.Tls
 {
@@ -25,6 +26,15 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
         {
+            if (null == buf)
+                throw new ArgumentNullException("buf");
+            if (off < 0 || off >= buf.Length)
+                throw new ArgumentException("invalid offset: " + off, "off");
+            if (len < 0 || len > buf.Length - off)
+                throw new ArgumentException("invalid length: " + len, "len");
+            if (waitMillis < 0)
+                throw new ArgumentException("cannot be negative", "waitMillis");
+
             try
             {
                 return mRecordLayer.Receive(buf, off, len, waitMillis);
@@ -34,11 +44,23 @@ namespace Org.BouncyCastle.Crypto.Tls
                 mRecordLayer.Fail(fatalAlert.AlertDescription);
                 throw fatalAlert;
             }
+            //catch (InterruptedIOException e)
+            //{
+            //    throw e;
+            //}
             catch (IOException e)
             {
                 mRecordLayer.Fail(AlertDescription.internal_error);
                 throw e;
             }
+            catch (SocketException e)
+            {
+                if (TlsUtilities.IsTimeout(e))
+                    throw e;
+
+                mRecordLayer.Fail(AlertDescription.internal_error);
+                throw new TlsFatalAlert(AlertDescription.internal_error, e);
+            }
             catch (Exception e)
             {
                 mRecordLayer.Fail(AlertDescription.internal_error);
@@ -48,6 +70,13 @@ namespace Org.BouncyCastle.Crypto.Tls
 
         public virtual void Send(byte[] buf, int off, int len)
         {
+            if (null == buf)
+                throw new ArgumentNullException("buf");
+            if (off < 0 || off >= buf.Length)
+                throw new ArgumentException("invalid offset: " + off, "off");
+            if (len < 0 || len > buf.Length - off)
+                throw new ArgumentException("invalid length: " + len, "len");
+
             try
             {
                 mRecordLayer.Send(buf, off, len);
@@ -57,11 +86,23 @@ namespace Org.BouncyCastle.Crypto.Tls
                 mRecordLayer.Fail(fatalAlert.AlertDescription);
                 throw fatalAlert;
             }
+            //catch (InterruptedIOException e)
+            //{
+            //    throw e;
+            //}
             catch (IOException e)
             {
                 mRecordLayer.Fail(AlertDescription.internal_error);
                 throw e;
             }
+            catch (SocketException e)
+            {
+                if (TlsUtilities.IsTimeout(e))
+                    throw e;
+
+                mRecordLayer.Fail(AlertDescription.internal_error);
+                throw new TlsFatalAlert(AlertDescription.internal_error, e);
+            }
             catch (Exception e)
             {
                 mRecordLayer.Fail(AlertDescription.internal_error);
diff --git a/crypto/src/crypto/tls/TlsUtilities.cs b/crypto/src/crypto/tls/TlsUtilities.cs
index 6ee71021f..5aad6b0a1 100644
--- a/crypto/src/crypto/tls/TlsUtilities.cs
+++ b/crypto/src/crypto/tls/TlsUtilities.cs
@@ -1,5 +1,6 @@
 using System;
 using System.Collections;
+using System.Net.Sockets;
 using System.IO;
 using System.Text;
 
@@ -2345,5 +2346,12 @@ namespace Org.BouncyCastle.Crypto.Tls
             }
             return v;
         }
+
+        public static bool IsTimeout(SocketException e)
+        {
+            // TODO Net 2.0+
+            //return SocketError.TimedOut == e.SocketErrorCode;
+            return 10060 == e.ErrorCode;
+        }
     }
 }