diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2020-07-30 02:24:54 +0700 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2020-07-30 02:24:54 +0700 |
commit | d7b5df9df2099487c62342a9bfbc30e40711788b (patch) | |
tree | c188933e1cb1c8d3dbced0c266f5edb285debc23 /crypto | |
parent | DTLS: Exceptions properly abort handshake (diff) | |
download | BouncyCastle.NET-ed25519-d7b5df9df2099487c62342a9bfbc30e40711788b.tar.xz |
DTLS: Improved retransmission timer
Diffstat (limited to 'crypto')
-rw-r--r-- | crypto/src/crypto/tls/DtlsRecordLayer.cs | 369 | ||||
-rw-r--r-- | crypto/src/crypto/tls/DtlsReliableHandshake.cs | 53 |
2 files changed, 222 insertions, 200 deletions
diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs index 266893df0..c1a26b14f 100644 --- a/crypto/src/crypto/tls/DtlsRecordLayer.cs +++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs @@ -45,7 +45,7 @@ namespace Org.BouncyCastle.Crypto.Tls private DtlsHandshakeRetransmit mRetransmit = null; private DtlsEpoch mRetransmitEpoch = null; - private long mRetransmitExpiry = 0; + private Timeout mRetransmitTimeout = null; internal DtlsRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, byte contentType) { @@ -116,7 +116,7 @@ namespace Org.BouncyCastle.Crypto.Tls { this.mRetransmit = retransmit; this.mRetransmitEpoch = mCurrentEpoch; - this.mRetransmitExpiry = DateTimeUtilities.CurrentUnixMs() + RETRANSMIT_TIMEOUT; + this.mRetransmitTimeout = new Timeout(RETRANSMIT_TIMEOUT); } this.mInHandshake = false; @@ -150,196 +150,43 @@ namespace Org.BouncyCastle.Crypto.Tls public virtual int Receive(byte[] buf, int off, int len, int waitMillis) { - // TODO Avoid returning -1 (timeout) until 'waitMillis' has definitely elapsed + long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + + Timeout timeout = null; + if (waitMillis > 0) + { + timeout = new Timeout(waitMillis, currentTimeMillis); + } byte[] record = null; - for (;;) + while (waitMillis >= 0) { - int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH; - if (record == null || record.Length < receiveLimit) - { - record = new byte[receiveLimit]; - } - - if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry) + if (mRetransmitTimeout != null && mRetransmitTimeout.RemainingMillis(currentTimeMillis) < 1) { mRetransmit = null; mRetransmitEpoch = null; + mRetransmitTimeout = null; } - int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - if (received < 0) - { - return received; - } - if (received < RECORD_HEADER_LENGTH) - { - continue; - } - int length = TlsUtilities.ReadUint16(record, 11); - if (received != (length + RECORD_HEADER_LENGTH)) - { - continue; - } - - byte type = TlsUtilities.ReadUint8(record, 0); - - // TODO Support user-specified custom protocols? - switch (type) - { - case ContentType.alert: - case ContentType.application_data: - case ContentType.change_cipher_spec: - case ContentType.handshake: - case ContentType.heartbeat: - break; - default: - // TODO Exception? - continue; - } - - int epoch = TlsUtilities.ReadUint16(record, 3); - - DtlsEpoch recordEpoch = null; - if (epoch == mReadEpoch.Epoch) - { - recordEpoch = mReadEpoch; - } - else if (type == ContentType.handshake && mRetransmitEpoch != null - && epoch == mRetransmitEpoch.Epoch) - { - recordEpoch = mRetransmitEpoch; - } - - if (recordEpoch == null) - { - continue; - } - - long seq = TlsUtilities.ReadUint48(record, 5); - if (recordEpoch.ReplayWindow.ShouldDiscard(seq)) - { - continue; - } - - ProtocolVersion version = TlsUtilities.ReadVersion(record, 1); - if (!version.IsDtls) - { - continue; - } - - if (mReadVersion != null && !mReadVersion.Equals(version)) - { - continue; - } - - byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext( - GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH, - received - RECORD_HEADER_LENGTH); - - recordEpoch.ReplayWindow.ReportAuthenticated(seq); - - if (plaintext.Length > this.mPlaintextLimit) - { - continue; - } - - if (mReadVersion == null) - { - mReadVersion = version; - } - - switch (type) - { - case ContentType.alert: - { - if (plaintext.Length == 2) - { - byte alertLevel = plaintext[0]; - byte alertDescription = plaintext[1]; - - mPeer.NotifyAlertReceived(alertLevel, alertDescription); - - if (alertLevel == AlertLevel.fatal) - { - Failed(); - throw new TlsFatalAlert(alertDescription); - } - - // TODO Can close_notify be a fatal alert? - if (alertDescription == AlertDescription.close_notify) - { - CloseTransport(); - } - } - - continue; - } - case ContentType.application_data: - { - if (mInHandshake) - { - // TODO Consider buffering application data for new epoch that arrives - // out-of-order with the Finished message - continue; - } - break; - } - case ContentType.change_cipher_spec: - { - // Implicitly receive change_cipher_spec and change to pending cipher state - - for (int i = 0; i < plaintext.Length; ++i) - { - byte message = TlsUtilities.ReadUint8(plaintext, i); - if (message != ChangeCipherSpec.change_cipher_spec) - { - continue; - } - - if (mPendingEpoch != null) - { - mReadEpoch = mPendingEpoch; - } - } - - continue; - } - case ContentType.handshake: - { - if (!mInHandshake) - { - if (mRetransmit != null) - { - mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length); - } - - // TODO Consider support for HelloRequest - continue; - } - break; - } - case ContentType.heartbeat: + int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH; + if (record == null || record.Length < receiveLimit) { - // TODO[RFC 6520] - continue; - } + record = new byte[receiveLimit]; } - /* - * NOTE: If we receive any non-handshake data in the new epoch implies the peer has - * received our final flight. - */ - if (!mInHandshake && mRetransmit != null) + int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); + int processed = ProcessRecord(received, record, buf, off); + if (processed >= 0) { - this.mRetransmit = null; - this.mRetransmitEpoch = null; + return processed; } - Array.Copy(plaintext, 0, buf, off, plaintext.Length); - return plaintext.Length; + currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis); } + + return -1; } /// <exception cref="IOException"/> @@ -497,6 +344,176 @@ namespace Org.BouncyCastle.Crypto.Tls } } + private int ProcessRecord(int received, byte[] record, byte[] buf, int off) + { + // NOTE: received < 0 (timeout) is covered by this first case + if (received < RECORD_HEADER_LENGTH) + { + return -1; + } + int length = TlsUtilities.ReadUint16(record, 11); + if (received != (length + RECORD_HEADER_LENGTH)) + { + return -1; + } + + byte type = TlsUtilities.ReadUint8(record, 0); + + switch (type) + { + case ContentType.alert: + case ContentType.application_data: + case ContentType.change_cipher_spec: + case ContentType.handshake: + case ContentType.heartbeat: + break; + default: + return -1; + } + + int epoch = TlsUtilities.ReadUint16(record, 3); + + DtlsEpoch recordEpoch = null; + if (epoch == mReadEpoch.Epoch) + { + recordEpoch = mReadEpoch; + } + else if (type == ContentType.handshake && mRetransmitEpoch != null + && epoch == mRetransmitEpoch.Epoch) + { + recordEpoch = mRetransmitEpoch; + } + + if (recordEpoch == null) + { + return -1; + } + + long seq = TlsUtilities.ReadUint48(record, 5); + if (recordEpoch.ReplayWindow.ShouldDiscard(seq)) + { + return -1; + } + + ProtocolVersion version = TlsUtilities.ReadVersion(record, 1); + if (!version.IsDtls) + { + return -1; + } + + if (mReadVersion != null && !mReadVersion.Equals(version)) + { + return -1; + } + + byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext( + GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH, + received - RECORD_HEADER_LENGTH); + + recordEpoch.ReplayWindow.ReportAuthenticated(seq); + + if (plaintext.Length > this.mPlaintextLimit) + { + return -1; + } + + if (mReadVersion == null) + { + mReadVersion = version; + } + + switch (type) + { + case ContentType.alert: + { + if (plaintext.Length == 2) + { + byte alertLevel = plaintext[0]; + byte alertDescription = plaintext[1]; + + mPeer.NotifyAlertReceived(alertLevel, alertDescription); + + if (alertLevel == AlertLevel.fatal) + { + Failed(); + throw new TlsFatalAlert(alertDescription); + } + + // TODO Can close_notify be a fatal alert? + if (alertDescription == AlertDescription.close_notify) + { + CloseTransport(); + } + } + + return -1; + } + case ContentType.application_data: + { + if (mInHandshake) + { + // TODO Consider buffering application data for new epoch that arrives + // out-of-order with the Finished message + return -1; + } + break; + } + case ContentType.change_cipher_spec: + { + // Implicitly receive change_cipher_spec and change to pending cipher state + + for (int i = 0; i < plaintext.Length; ++i) + { + byte message = TlsUtilities.ReadUint8(plaintext, i); + if (message != ChangeCipherSpec.change_cipher_spec) + { + continue; + } + + if (mPendingEpoch != null) + { + mReadEpoch = mPendingEpoch; + } + } + + return -1; + } + case ContentType.handshake: + { + if (!mInHandshake) + { + if (mRetransmit != null) + { + mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length); + } + + // TODO Consider support for HelloRequest + return -1; + } + break; + } + case ContentType.heartbeat: + { + // TODO[RFC 6520] + return -1; + } + } + + /* + * NOTE: If we receive any non-handshake data in the new epoch implies the peer has + * received our final flight. + */ + if (!mInHandshake && mRetransmit != null) + { + this.mRetransmit = null; + this.mRetransmitEpoch = null; + this.mRetransmitTimeout = null; + } + + Array.Copy(plaintext, 0, buf, off, plaintext.Length); + return plaintext.Length; + } + private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis) { if (mRecordQueue.Available > 0) diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs index 92c222e70..3eeb8a61e 100644 --- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs +++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs @@ -3,6 +3,7 @@ using System.Collections; using System.IO; using Org.BouncyCastle.Utilities; +using Org.BouncyCastle.Utilities.Date; namespace Org.BouncyCastle.Crypto.Tls { @@ -11,6 +12,9 @@ namespace Org.BouncyCastle.Crypto.Tls private const int MaxReceiveAhead = 16; private const int MessageHeaderLength = 12; + private const int InitialResendMillis = 1000; + private const int MaxResendMillis = 60000; + private readonly DtlsRecordLayer mRecordLayer; private TlsHandshakeHash mHandshakeHash; @@ -18,7 +22,9 @@ namespace Org.BouncyCastle.Crypto.Tls private IDictionary mCurrentInboundFlight = Platform.CreateHashtable(); private IDictionary mPreviousInboundFlight = null; private IList mOutboundFlight = Platform.CreateArrayList(); - private bool mSending = true; + + private int mResendMillis = -1; + private Timeout mResendTimeout = null; private int mMessageSeq = 0, mNextReceiveSeq = 0; @@ -50,10 +56,13 @@ namespace Org.BouncyCastle.Crypto.Tls { TlsUtilities.CheckUint24(body.Length); - if (!mSending) + if (mResendTimeout != null) { CheckInboundFlight(); - mSending = true; + + mResendMillis = -1; + mResendTimeout = null; + mOutboundFlight.Clear(); } @@ -77,18 +86,18 @@ namespace Org.BouncyCastle.Crypto.Tls internal Message ReceiveMessage() { // TODO Add support for "overall" handshake timeout + long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); - if (mSending) + if (mResendTimeout == null) { - mSending = false; + mResendMillis = InitialResendMillis; + mResendTimeout = new Timeout(mResendMillis, currentTimeMillis); + PrepareInboundFlight(Platform.CreateHashtable()); } byte[] buf = null; - // TODO Check the conditions under which we should reset this - int readTimeoutMillis = 1000; - for (;;) { if (mRecordLayer.IsClosed) @@ -98,37 +107,32 @@ namespace Org.BouncyCastle.Crypto.Tls if (pending != null) return pending; + int waitMillis = System.Math.Max(1, Timeout.GetWaitMillis(mResendTimeout, currentTimeMillis)); + int receiveLimit = mRecordLayer.GetReceiveLimit(); if (buf == null || buf.Length < receiveLimit) { buf = new byte[receiveLimit]; } - int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis); - - bool resentOutbound; + int received = mRecordLayer.Receive(buf, 0, receiveLimit, waitMillis); if (received < 0) { ResendOutboundFlight(); - resentOutbound = true; } else { - resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received); + ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received); } - // TODO Review conditions for resend/backoff - if (resentOutbound) - { - readTimeoutMillis = BackOff(readTimeoutMillis); - } + currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); } } internal void Finish() { DtlsHandshakeRetransmit retransmit = null; - if (!mSending) + if (mResendTimeout != null) { CheckInboundFlight(); } @@ -162,7 +166,7 @@ namespace Org.BouncyCastle.Crypto.Tls * TODO[DTLS] implementations SHOULD back off handshake packet size during the * retransmit backoff. */ - return System.Math.Min(timeoutMillis * 2, 60000); + return System.Math.Min(timeoutMillis * 2, MaxResendMillis); } /** @@ -201,7 +205,7 @@ namespace Org.BouncyCastle.Crypto.Tls mCurrentInboundFlight = nextFlight; } - private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len) + private void ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len) { bool checkPreviousFlight = false; @@ -271,13 +275,11 @@ namespace Org.BouncyCastle.Crypto.Tls len -= message_length; } - bool result = checkPreviousFlight && CheckAll(mPreviousInboundFlight); - if (result) + if (checkPreviousFlight && CheckAll(mPreviousInboundFlight)) { ResendOutboundFlight(); ResetAll(mPreviousInboundFlight); } - return result; } private void ResendOutboundFlight() @@ -287,6 +289,9 @@ namespace Org.BouncyCastle.Crypto.Tls { WriteMessage((Message)mOutboundFlight[i]); } + + mResendMillis = BackOff(mResendMillis); + mResendTimeout = new Timeout(mResendMillis); } private Message UpdateHandshakeMessagesDigest(Message message) |