diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 266893df0..c1a26b14f 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -45,7 +45,7 @@ namespace Org.BouncyCastle.Crypto.Tls
private DtlsHandshakeRetransmit mRetransmit = null;
private DtlsEpoch mRetransmitEpoch = null;
- private long mRetransmitExpiry = 0;
+ private Timeout mRetransmitTimeout = null;
internal DtlsRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, byte contentType)
{
@@ -116,7 +116,7 @@ namespace Org.BouncyCastle.Crypto.Tls
{
this.mRetransmit = retransmit;
this.mRetransmitEpoch = mCurrentEpoch;
- this.mRetransmitExpiry = DateTimeUtilities.CurrentUnixMs() + RETRANSMIT_TIMEOUT;
+ this.mRetransmitTimeout = new Timeout(RETRANSMIT_TIMEOUT);
}
this.mInHandshake = false;
@@ -150,196 +150,43 @@ namespace Org.BouncyCastle.Crypto.Tls
public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
{
- // TODO Avoid returning -1 (timeout) until 'waitMillis' has definitely elapsed
+ long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+
+ Timeout timeout = null;
+ if (waitMillis > 0)
+ {
+ timeout = new Timeout(waitMillis, currentTimeMillis);
+ }
byte[] record = null;
- for (;;)
+ while (waitMillis >= 0)
{
- int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
- if (record == null || record.Length < receiveLimit)
- {
- record = new byte[receiveLimit];
- }
-
- if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
+ if (mRetransmitTimeout != null && mRetransmitTimeout.RemainingMillis(currentTimeMillis) < 1)
{
mRetransmit = null;
mRetransmitEpoch = null;
+ mRetransmitTimeout = null;
}
- int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
- if (received < 0)
- {
- return received;
- }
- if (received < RECORD_HEADER_LENGTH)
- {
- continue;
- }
- int length = TlsUtilities.ReadUint16(record, 11);
- if (received != (length + RECORD_HEADER_LENGTH))
- {
- continue;
- }
-
- byte type = TlsUtilities.ReadUint8(record, 0);
-
- // TODO Support user-specified custom protocols?
- switch (type)
- {
- case ContentType.alert:
- case ContentType.application_data:
- case ContentType.change_cipher_spec:
- case ContentType.handshake:
- case ContentType.heartbeat:
- break;
- default:
- // TODO Exception?
- continue;
- }
-
- int epoch = TlsUtilities.ReadUint16(record, 3);
-
- DtlsEpoch recordEpoch = null;
- if (epoch == mReadEpoch.Epoch)
- {
- recordEpoch = mReadEpoch;
- }
- else if (type == ContentType.handshake && mRetransmitEpoch != null
- && epoch == mRetransmitEpoch.Epoch)
- {
- recordEpoch = mRetransmitEpoch;
- }
-
- if (recordEpoch == null)
- {
- continue;
- }
-
- long seq = TlsUtilities.ReadUint48(record, 5);
- if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
- {
- continue;
- }
-
- ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
- if (!version.IsDtls)
- {
- continue;
- }
-
- if (mReadVersion != null && !mReadVersion.Equals(version))
- {
- continue;
- }
-
- byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
- GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
- received - RECORD_HEADER_LENGTH);
-
- recordEpoch.ReplayWindow.ReportAuthenticated(seq);
-
- if (plaintext.Length > this.mPlaintextLimit)
- {
- continue;
- }
-
- if (mReadVersion == null)
- {
- mReadVersion = version;
- }
-
- switch (type)
- {
- case ContentType.alert:
- {
- if (plaintext.Length == 2)
- {
- byte alertLevel = plaintext[0];
- byte alertDescription = plaintext[1];
-
- mPeer.NotifyAlertReceived(alertLevel, alertDescription);
-
- if (alertLevel == AlertLevel.fatal)
- {
- Failed();
- throw new TlsFatalAlert(alertDescription);
- }
-
- // TODO Can close_notify be a fatal alert?
- if (alertDescription == AlertDescription.close_notify)
- {
- CloseTransport();
- }
- }
-
- continue;
- }
- case ContentType.application_data:
- {
- if (mInHandshake)
- {
- // TODO Consider buffering application data for new epoch that arrives
- // out-of-order with the Finished message
- continue;
- }
- break;
- }
- case ContentType.change_cipher_spec:
- {
- // Implicitly receive change_cipher_spec and change to pending cipher state
-
- for (int i = 0; i < plaintext.Length; ++i)
- {
- byte message = TlsUtilities.ReadUint8(plaintext, i);
- if (message != ChangeCipherSpec.change_cipher_spec)
- {
- continue;
- }
-
- if (mPendingEpoch != null)
- {
- mReadEpoch = mPendingEpoch;
- }
- }
-
- continue;
- }
- case ContentType.handshake:
- {
- if (!mInHandshake)
- {
- if (mRetransmit != null)
- {
- mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
- }
-
- // TODO Consider support for HelloRequest
- continue;
- }
- break;
- }
- case ContentType.heartbeat:
+ int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
+ if (record == null || record.Length < receiveLimit)
{
- // TODO[RFC 6520]
- continue;
- }
+ record = new byte[receiveLimit];
}
- /*
- * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
- * received our final flight.
- */
- if (!mInHandshake && mRetransmit != null)
+ int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
+ int processed = ProcessRecord(received, record, buf, off);
+ if (processed >= 0)
{
- this.mRetransmit = null;
- this.mRetransmitEpoch = null;
+ return processed;
}
- Array.Copy(plaintext, 0, buf, off, plaintext.Length);
- return plaintext.Length;
+ currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
+ waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis);
}
+
+ return -1;
}
/// <exception cref="IOException"/>
@@ -497,6 +344,176 @@ namespace Org.BouncyCastle.Crypto.Tls
}
}
+ private int ProcessRecord(int received, byte[] record, byte[] buf, int off)
+ {
+ // NOTE: received < 0 (timeout) is covered by this first case
+ if (received < RECORD_HEADER_LENGTH)
+ {
+ return -1;
+ }
+ int length = TlsUtilities.ReadUint16(record, 11);
+ if (received != (length + RECORD_HEADER_LENGTH))
+ {
+ return -1;
+ }
+
+ byte type = TlsUtilities.ReadUint8(record, 0);
+
+ switch (type)
+ {
+ case ContentType.alert:
+ case ContentType.application_data:
+ case ContentType.change_cipher_spec:
+ case ContentType.handshake:
+ case ContentType.heartbeat:
+ break;
+ default:
+ return -1;
+ }
+
+ int epoch = TlsUtilities.ReadUint16(record, 3);
+
+ DtlsEpoch recordEpoch = null;
+ if (epoch == mReadEpoch.Epoch)
+ {
+ recordEpoch = mReadEpoch;
+ }
+ else if (type == ContentType.handshake && mRetransmitEpoch != null
+ && epoch == mRetransmitEpoch.Epoch)
+ {
+ recordEpoch = mRetransmitEpoch;
+ }
+
+ if (recordEpoch == null)
+ {
+ return -1;
+ }
+
+ long seq = TlsUtilities.ReadUint48(record, 5);
+ if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
+ {
+ return -1;
+ }
+
+ ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
+ if (!version.IsDtls)
+ {
+ return -1;
+ }
+
+ if (mReadVersion != null && !mReadVersion.Equals(version))
+ {
+ return -1;
+ }
+
+ byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
+ GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
+ received - RECORD_HEADER_LENGTH);
+
+ recordEpoch.ReplayWindow.ReportAuthenticated(seq);
+
+ if (plaintext.Length > this.mPlaintextLimit)
+ {
+ return -1;
+ }
+
+ if (mReadVersion == null)
+ {
+ mReadVersion = version;
+ }
+
+ switch (type)
+ {
+ case ContentType.alert:
+ {
+ if (plaintext.Length == 2)
+ {
+ byte alertLevel = plaintext[0];
+ byte alertDescription = plaintext[1];
+
+ mPeer.NotifyAlertReceived(alertLevel, alertDescription);
+
+ if (alertLevel == AlertLevel.fatal)
+ {
+ Failed();
+ throw new TlsFatalAlert(alertDescription);
+ }
+
+ // TODO Can close_notify be a fatal alert?
+ if (alertDescription == AlertDescription.close_notify)
+ {
+ CloseTransport();
+ }
+ }
+
+ return -1;
+ }
+ case ContentType.application_data:
+ {
+ if (mInHandshake)
+ {
+ // TODO Consider buffering application data for new epoch that arrives
+ // out-of-order with the Finished message
+ return -1;
+ }
+ break;
+ }
+ case ContentType.change_cipher_spec:
+ {
+ // Implicitly receive change_cipher_spec and change to pending cipher state
+
+ for (int i = 0; i < plaintext.Length; ++i)
+ {
+ byte message = TlsUtilities.ReadUint8(plaintext, i);
+ if (message != ChangeCipherSpec.change_cipher_spec)
+ {
+ continue;
+ }
+
+ if (mPendingEpoch != null)
+ {
+ mReadEpoch = mPendingEpoch;
+ }
+ }
+
+ return -1;
+ }
+ case ContentType.handshake:
+ {
+ if (!mInHandshake)
+ {
+ if (mRetransmit != null)
+ {
+ mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
+ }
+
+ // TODO Consider support for HelloRequest
+ return -1;
+ }
+ break;
+ }
+ case ContentType.heartbeat:
+ {
+ // TODO[RFC 6520]
+ return -1;
+ }
+ }
+
+ /*
+ * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
+ * received our final flight.
+ */
+ if (!mInHandshake && mRetransmit != null)
+ {
+ this.mRetransmit = null;
+ this.mRetransmitEpoch = null;
+ this.mRetransmitTimeout = null;
+ }
+
+ Array.Copy(plaintext, 0, buf, off, plaintext.Length);
+ return plaintext.Length;
+ }
+
private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
{
if (mRecordQueue.Available > 0)
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
index 92c222e70..3eeb8a61e 100644
--- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
@@ -3,6 +3,7 @@ using System.Collections;
using System.IO;
using Org.BouncyCastle.Utilities;
+using Org.BouncyCastle.Utilities.Date;
namespace Org.BouncyCastle.Crypto.Tls
{
@@ -11,6 +12,9 @@ namespace Org.BouncyCastle.Crypto.Tls
private const int MaxReceiveAhead = 16;
private const int MessageHeaderLength = 12;
+ private const int InitialResendMillis = 1000;
+ private const int MaxResendMillis = 60000;
+
private readonly DtlsRecordLayer mRecordLayer;
private TlsHandshakeHash mHandshakeHash;
@@ -18,7 +22,9 @@ namespace Org.BouncyCastle.Crypto.Tls
private IDictionary mCurrentInboundFlight = Platform.CreateHashtable();
private IDictionary mPreviousInboundFlight = null;
private IList mOutboundFlight = Platform.CreateArrayList();
- private bool mSending = true;
+
+ private int mResendMillis = -1;
+ private Timeout mResendTimeout = null;
private int mMessageSeq = 0, mNextReceiveSeq = 0;
@@ -50,10 +56,13 @@ namespace Org.BouncyCastle.Crypto.Tls
{
TlsUtilities.CheckUint24(body.Length);
- if (!mSending)
+ if (mResendTimeout != null)
{
CheckInboundFlight();
- mSending = true;
+
+ mResendMillis = -1;
+ mResendTimeout = null;
+
mOutboundFlight.Clear();
}
@@ -77,18 +86,18 @@ namespace Org.BouncyCastle.Crypto.Tls
internal Message ReceiveMessage()
{
// TODO Add support for "overall" handshake timeout
+ long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
- if (mSending)
+ if (mResendTimeout == null)
{
- mSending = false;
+ mResendMillis = InitialResendMillis;
+ mResendTimeout = new Timeout(mResendMillis, currentTimeMillis);
+
PrepareInboundFlight(Platform.CreateHashtable());
}
byte[] buf = null;
- // TODO Check the conditions under which we should reset this
- int readTimeoutMillis = 1000;
-
for (;;)
{
if (mRecordLayer.IsClosed)
@@ -98,37 +107,32 @@ namespace Org.BouncyCastle.Crypto.Tls
if (pending != null)
return pending;
+ int waitMillis = System.Math.Max(1, Timeout.GetWaitMillis(mResendTimeout, currentTimeMillis));
+
int receiveLimit = mRecordLayer.GetReceiveLimit();
if (buf == null || buf.Length < receiveLimit)
{
buf = new byte[receiveLimit];
}
- int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
-
- bool resentOutbound;
+ int received = mRecordLayer.Receive(buf, 0, receiveLimit, waitMillis);
if (received < 0)
{
ResendOutboundFlight();
- resentOutbound = true;
}
else
{
- resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
+ ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
}
- // TODO Review conditions for resend/backoff
- if (resentOutbound)
- {
- readTimeoutMillis = BackOff(readTimeoutMillis);
- }
+ currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
}
}
internal void Finish()
{
DtlsHandshakeRetransmit retransmit = null;
- if (!mSending)
+ if (mResendTimeout != null)
{
CheckInboundFlight();
}
@@ -162,7 +166,7 @@ namespace Org.BouncyCastle.Crypto.Tls
* TODO[DTLS] implementations SHOULD back off handshake packet size during the
* retransmit backoff.
*/
- return System.Math.Min(timeoutMillis * 2, 60000);
+ return System.Math.Min(timeoutMillis * 2, MaxResendMillis);
}
/**
@@ -201,7 +205,7 @@ namespace Org.BouncyCastle.Crypto.Tls
mCurrentInboundFlight = nextFlight;
}
- private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
+ private void ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
{
bool checkPreviousFlight = false;
@@ -271,13 +275,11 @@ namespace Org.BouncyCastle.Crypto.Tls
len -= message_length;
}
- bool result = checkPreviousFlight && CheckAll(mPreviousInboundFlight);
- if (result)
+ if (checkPreviousFlight && CheckAll(mPreviousInboundFlight))
{
ResendOutboundFlight();
ResetAll(mPreviousInboundFlight);
}
- return result;
}
private void ResendOutboundFlight()
@@ -287,6 +289,9 @@ namespace Org.BouncyCastle.Crypto.Tls
{
WriteMessage((Message)mOutboundFlight[i]);
}
+
+ mResendMillis = BackOff(mResendMillis);
+ mResendTimeout = new Timeout(mResendMillis);
}
private Message UpdateHandshakeMessagesDigest(Message message)
|