diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2020-07-30 01:14:52 +0700 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2020-07-30 01:14:52 +0700 |
commit | 9193844b75819ac2b14622b000c42c1f527632f2 (patch) | |
tree | 07b9a07cd324d55368c7807ccc276376f9cbce4d | |
parent | Add Timeout class for DTLS from bc-java (diff) | |
download | BouncyCastle.NET-ed25519-9193844b75819ac2b14622b000c42c1f527632f2.tar.xz |
DTLS: Exceptions properly abort handshake
- see https://github.com/bcgit/bc-csharp/issues/258
-rw-r--r-- | crypto/src/crypto/tls/DtlsRecordLayer.cs | 339 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsReliableHandshake.cs | 58 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsTransport.cs | 41 | ||||
-rw-r--r-- | crypto/src/crypto/tls/TlsUtilities.cs | 8 |
4 files changed, 266 insertions, 180 deletions
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs index 3cb0e78dd..266893df0 100644 --- a/crypto/src/crypto/tls/DtlsRecordLayer.cs +++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs @@ -1,5 +1,6 @@ using System; using System.IO; +using System.Net.Sockets; using Org.BouncyCastle.Utilities.Date; @@ -13,6 +14,21 @@ namespace Org.BouncyCastle.Crypto.Tls private const long TCP_MSL = 1000L * 60 * 2; private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2; + private static void SendDatagram(DatagramTransport sender, byte[] buf, int off, int len) + { + //try + //{ + // sender.Send(buf, off, len); + //} + //catch (InterruptedIOException e) + //{ + // e.bytesTransferred = 0; + // throw e; + //} + + sender.Send(buf, off, len); + } + private readonly DatagramTransport mTransport; private readonly TlsContext mContext; private readonly TlsPeer mPeer; @@ -134,6 +150,8 @@ namespace Org.BouncyCastle.Crypto.Tls public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { + // TODO Avoid returning -1 (timeout) until 'waitMillis' has definitely elapsed + byte[] record = null; for (;;) @@ -144,191 +162,183 @@ namespace Org.BouncyCastle.Crypto.Tls record = new byte[receiveLimit]; } - try + if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry) { - if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry) - { - mRetransmit = null; - mRetransmitEpoch = null; - } + mRetransmit = null; + mRetransmitEpoch = null; + } - int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - if (received < 0) - { - return received; - } - if (received < RECORD_HEADER_LENGTH) - { - continue; - } - int length = TlsUtilities.ReadUint16(record, 11); - if (received != (length + RECORD_HEADER_LENGTH)) - { - continue; - } + int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); + if (received < 0) + { + return received; + } + if (received < RECORD_HEADER_LENGTH) + { + continue; + } + int length = TlsUtilities.ReadUint16(record, 11); + if (received != (length + RECORD_HEADER_LENGTH)) + { + continue; + } - byte type = TlsUtilities.ReadUint8(record, 0); + byte type = TlsUtilities.ReadUint8(record, 0); - // TODO Support user-specified custom protocols? - switch (type) - { - case ContentType.alert: - case ContentType.application_data: - case ContentType.change_cipher_spec: - case ContentType.handshake: - case ContentType.heartbeat: - break; - default: - // TODO Exception? - continue; - } + // TODO Support user-specified custom protocols? + switch (type) + { + case ContentType.alert: + case ContentType.application_data: + case ContentType.change_cipher_spec: + case ContentType.handshake: + case ContentType.heartbeat: + break; + default: + // TODO Exception? + continue; + } - int epoch = TlsUtilities.ReadUint16(record, 3); + int epoch = TlsUtilities.ReadUint16(record, 3); - DtlsEpoch recordEpoch = null; - if (epoch == mReadEpoch.Epoch) - { - recordEpoch = mReadEpoch; - } - else if (type == ContentType.handshake && mRetransmitEpoch != null - && epoch == mRetransmitEpoch.Epoch) - { - recordEpoch = mRetransmitEpoch; - } + DtlsEpoch recordEpoch = null; + if (epoch == mReadEpoch.Epoch) + { + recordEpoch = mReadEpoch; + } + else if (type == ContentType.handshake && mRetransmitEpoch != null + && epoch == mRetransmitEpoch.Epoch) + { + recordEpoch = mRetransmitEpoch; + } - if (recordEpoch == null) - { - continue; - } + if (recordEpoch == null) + { + continue; + } - long seq = TlsUtilities.ReadUint48(record, 5); - if (recordEpoch.ReplayWindow.ShouldDiscard(seq)) - { - continue; - } + long seq = TlsUtilities.ReadUint48(record, 5); + if (recordEpoch.ReplayWindow.ShouldDiscard(seq)) + { + continue; + } - ProtocolVersion version = TlsUtilities.ReadVersion(record, 1); - if (!version.IsDtls) - { - continue; - } + ProtocolVersion version = TlsUtilities.ReadVersion(record, 1); + if (!version.IsDtls) + { + continue; + } - if (mReadVersion != null && !mReadVersion.Equals(version)) - { - continue; - } + if (mReadVersion != null && !mReadVersion.Equals(version)) + { + continue; + } - byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext( - GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH, - received - RECORD_HEADER_LENGTH); + byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext( + GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH, + received - RECORD_HEADER_LENGTH); - recordEpoch.ReplayWindow.ReportAuthenticated(seq); + recordEpoch.ReplayWindow.ReportAuthenticated(seq); - if (plaintext.Length > this.mPlaintextLimit) - { - continue; - } + if (plaintext.Length > this.mPlaintextLimit) + { + continue; + } - if (mReadVersion == null) - { - mReadVersion = version; - } + if (mReadVersion == null) + { + mReadVersion = version; + } - switch (type) - { - case ContentType.alert: + switch (type) + { + case ContentType.alert: + { + if (plaintext.Length == 2) { - if (plaintext.Length == 2) + byte alertLevel = plaintext[0]; + byte alertDescription = plaintext[1]; + + mPeer.NotifyAlertReceived(alertLevel, alertDescription); + + if (alertLevel == AlertLevel.fatal) { - byte alertLevel = plaintext[0]; - byte alertDescription = plaintext[1]; - - mPeer.NotifyAlertReceived(alertLevel, alertDescription); - - if (alertLevel == AlertLevel.fatal) - { - Failed(); - throw new TlsFatalAlert(alertDescription); - } - - // TODO Can close_notify be a fatal alert? - if (alertDescription == AlertDescription.close_notify) - { - CloseTransport(); - } + Failed(); + throw new TlsFatalAlert(alertDescription); } + // TODO Can close_notify be a fatal alert? + if (alertDescription == AlertDescription.close_notify) + { + CloseTransport(); + } + } + + continue; + } + case ContentType.application_data: + { + if (mInHandshake) + { + // TODO Consider buffering application data for new epoch that arrives + // out-of-order with the Finished message continue; } - case ContentType.application_data: + break; + } + case ContentType.change_cipher_spec: + { + // Implicitly receive change_cipher_spec and change to pending cipher state + + for (int i = 0; i < plaintext.Length; ++i) { - if (mInHandshake) + byte message = TlsUtilities.ReadUint8(plaintext, i); + if (message != ChangeCipherSpec.change_cipher_spec) { - // TODO Consider buffering application data for new epoch that arrives - // out-of-order with the Finished message continue; } - break; - } - case ContentType.change_cipher_spec: - { - // Implicitly receive change_cipher_spec and change to pending cipher state - for (int i = 0; i < plaintext.Length; ++i) + if (mPendingEpoch != null) { - byte message = TlsUtilities.ReadUint8(plaintext, i); - if (message != ChangeCipherSpec.change_cipher_spec) - { - continue; - } - - if (mPendingEpoch != null) - { - mReadEpoch = mPendingEpoch; - } + mReadEpoch = mPendingEpoch; } - - continue; } - case ContentType.handshake: + + continue; + } + case ContentType.handshake: + { + if (!mInHandshake) { - if (!mInHandshake) + if (mRetransmit != null) { - if (mRetransmit != null) - { - mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length); - } - - // TODO Consider support for HelloRequest - continue; + mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length); } - break; - } - case ContentType.heartbeat: - { - // TODO[RFC 6520] - continue; - } - } - /* - * NOTE: If we receive any non-handshake data in the new epoch implies the peer has - * received our final flight. - */ - if (!mInHandshake && mRetransmit != null) - { - this.mRetransmit = null; - this.mRetransmitEpoch = null; + // TODO Consider support for HelloRequest + continue; } - - Array.Copy(plaintext, 0, buf, off, plaintext.Length); - return plaintext.Length; + break; } - catch (IOException e) + case ContentType.heartbeat: { - // NOTE: Assume this is a timeout for the moment - throw e; + // TODO[RFC 6520] + continue; + } } + + /* + * NOTE: If we receive any non-handshake data in the new epoch implies the peer has + * received our final flight. + */ + if (!mInHandshake && mRetransmit != null) + { + this.mRetransmit = null; + this.mRetransmitEpoch = null; + } + + Array.Copy(plaintext, 0, buf, off, plaintext.Length); + return plaintext.Length; } } @@ -458,6 +468,35 @@ namespace Org.BouncyCastle.Crypto.Tls SendRecord(ContentType.alert, error, 0, 2); } + private int ReceiveDatagram(byte[] buf, int off, int len, int waitMillis) + { + //try + //{ + // return mTransport.Receive(buf, off, len, waitMillis); + //} + //catch (SocketTimeoutException e) + //{ + // return -1; + //} + //catch (InterruptedIOException e) + //{ + // e.bytesTransferred = 0; + // throw e; + //} + + try + { + return mTransport.Receive(buf, off, len, waitMillis); + } + catch (SocketException e) + { + if (TlsUtilities.IsTimeout(e)) + return -1; + + throw e; + } + } + private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis) { if (mRecordQueue.Available > 0) @@ -476,7 +515,7 @@ namespace Org.BouncyCastle.Crypto.Tls } { - int received = mTransport.Receive(buf, off, len, waitMillis); + int received = ReceiveDatagram(buf, off, len, waitMillis); if (received >= RECORD_HEADER_LENGTH) { int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11); @@ -524,7 +563,7 @@ namespace Org.BouncyCastle.Crypto.Tls TlsUtilities.WriteUint16(ciphertext.Length, record, 11); Array.Copy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.Length); - mTransport.Send(record, 0, record.Length); + SendDatagram(mTransport, record, 0, record.Length); } private static long GetMacSequenceNumber(int epoch, long sequence_number) diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs index 8fcc1d7c2..92c222e70 100644 --- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs +++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs @@ -76,6 +76,8 @@ namespace Org.BouncyCastle.Crypto.Tls internal Message ReceiveMessage() { + // TODO Add support for "overall" handshake timeout + if (mSending) { mSending = false; @@ -89,41 +91,37 @@ namespace Org.BouncyCastle.Crypto.Tls for (;;) { - try + if (mRecordLayer.IsClosed) + throw new TlsFatalAlert(AlertDescription.user_canceled); + + Message pending = GetPendingMessage(); + if (pending != null) + return pending; + + int receiveLimit = mRecordLayer.GetReceiveLimit(); + if (buf == null || buf.Length < receiveLimit) { - for (;;) - { - if (mRecordLayer.IsClosed) - throw new TlsFatalAlert(AlertDescription.user_canceled); - - Message pending = GetPendingMessage(); - if (pending != null) - return pending; - - int receiveLimit = mRecordLayer.GetReceiveLimit(); - if (buf == null || buf.Length < receiveLimit) - { - buf = new byte[receiveLimit]; - } - - int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis); - if (received < 0) - break; - - bool resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received); - if (resentOutbound) - { - readTimeoutMillis = BackOff(readTimeoutMillis); - } - } + buf = new byte[receiveLimit]; + } + + int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis); + + bool resentOutbound; + if (received < 0) + { + ResendOutboundFlight(); + resentOutbound = true; } - catch (IOException) + else { - // NOTE: Assume this is a timeout for the moment + resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received); } - ResendOutboundFlight(); - readTimeoutMillis = BackOff(readTimeoutMillis); + // TODO Review conditions for resend/backoff + if (resentOutbound) + { + readTimeoutMillis = BackOff(readTimeoutMillis); + } } } diff --git a/crypto/src/crypto/tls/DtlsTransport.cs b/crypto/src/crypto/tls/DtlsTransport.cs index 5c607336b..bc09707c1 100644 --- a/crypto/src/crypto/tls/DtlsTransport.cs +++ b/crypto/src/crypto/tls/DtlsTransport.cs @@ -1,5 +1,6 @@ using System; using System.IO; +using System.Net.Sockets; namespace Org.BouncyCastle.Crypto.Tls { @@ -25,6 +26,15 @@ namespace Org.BouncyCastle.Crypto.Tls public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { + if (null == buf) + throw new ArgumentNullException("buf"); + if (off < 0 || off >= buf.Length) + throw new ArgumentException("invalid offset: " + off, "off"); + if (len < 0 || len > buf.Length - off) + throw new ArgumentException("invalid length: " + len, "len"); + if (waitMillis < 0) + throw new ArgumentException("cannot be negative", "waitMillis"); + try { return mRecordLayer.Receive(buf, off, len, waitMillis); @@ -34,11 +44,23 @@ namespace Org.BouncyCastle.Crypto.Tls mRecordLayer.Fail(fatalAlert.AlertDescription); throw fatalAlert; } + //catch (InterruptedIOException e) + //{ + // throw e; + //} catch (IOException e) { mRecordLayer.Fail(AlertDescription.internal_error); throw e; } + catch (SocketException e) + { + if (TlsUtilities.IsTimeout(e)) + throw e; + + mRecordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } catch (Exception e) { mRecordLayer.Fail(AlertDescription.internal_error); @@ -48,6 +70,13 @@ namespace Org.BouncyCastle.Crypto.Tls public virtual void Send(byte[] buf, int off, int len) { + if (null == buf) + throw new ArgumentNullException("buf"); + if (off < 0 || off >= buf.Length) + throw new ArgumentException("invalid offset: " + off, "off"); + if (len < 0 || len > buf.Length - off) + throw new ArgumentException("invalid length: " + len, "len"); + try { mRecordLayer.Send(buf, off, len); @@ -57,11 +86,23 @@ namespace Org.BouncyCastle.Crypto.Tls mRecordLayer.Fail(fatalAlert.AlertDescription); throw fatalAlert; } + //catch (InterruptedIOException e) + //{ + // throw e; + //} catch (IOException e) { mRecordLayer.Fail(AlertDescription.internal_error); throw e; } + catch (SocketException e) + { + if (TlsUtilities.IsTimeout(e)) + throw e; + + mRecordLayer.Fail(AlertDescription.internal_error); + throw new TlsFatalAlert(AlertDescription.internal_error, e); + } catch (Exception e) { mRecordLayer.Fail(AlertDescription.internal_error); diff --git a/crypto/src/crypto/tls/TlsUtilities.cs b/crypto/src/crypto/tls/TlsUtilities.cs index 6ee71021f..5aad6b0a1 100644 --- a/crypto/src/crypto/tls/TlsUtilities.cs +++ b/crypto/src/crypto/tls/TlsUtilities.cs @@ -1,5 +1,6 @@ using System; using System.Collections; +using System.Net.Sockets; using System.IO; using System.Text; @@ -2345,5 +2346,12 @@ namespace Org.BouncyCastle.Crypto.Tls } return v; } + + public static bool IsTimeout(SocketException e) + { + // TODO Net 2.0+ + //return SocketError.TimedOut == e.SocketErrorCode; + return 10060 == e.ErrorCode; + } } } |