diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 3c3e1821f..39e018810 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -52,6 +52,11 @@ namespace Org.BouncyCastle.Crypto.Tls
this.mPlaintextLimit = plaintextLimit;
}
+ internal virtual int ReadEpoch
+ {
+ get { return mReadEpoch.Epoch; }
+ }
+
internal virtual ProtocolVersion ReadVersion
{
get { return mReadVersion; }
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
index 18a41769a..396ea7483 100644
--- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
@@ -8,7 +8,8 @@ namespace Org.BouncyCastle.Crypto.Tls
{
internal class DtlsReliableHandshake
{
- private const int MAX_RECEIVE_AHEAD = 10;
+ private const int MaxReceiveAhead = 16;
+ private const int MessageHeaderLength = 12;
private readonly DtlsRecordLayer mRecordLayer;
@@ -78,21 +79,7 @@ namespace Org.BouncyCastle.Crypto.Tls
if (mSending)
{
mSending = false;
- PrepareInboundFlight();
- }
-
- // Check if we already have the next message waiting
- {
- DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq];
- if (next != null)
- {
- byte[] body = next.GetBodyIfComplete();
- if (body != null)
- {
- mPreviousInboundFlight = null;
- return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body));
- }
- }
+ PrepareInboundFlight(Platform.CreateHashtable());
}
byte[] buf = null;
@@ -102,110 +89,38 @@ namespace Org.BouncyCastle.Crypto.Tls
for (;;)
{
- int receiveLimit = mRecordLayer.GetReceiveLimit();
- if (buf == null || buf.Length < receiveLimit)
- {
- buf = new byte[receiveLimit];
- }
-
- // TODO Handle records containing multiple handshake messages
-
try
{
- for (; ; )
+ for (;;)
{
+ Message pending = GetPendingMessage();
+ if (pending != null)
+ return pending;
+
+ int receiveLimit = mRecordLayer.GetReceiveLimit();
+ if (buf == null || buf.Length < receiveLimit)
+ {
+ buf = new byte[receiveLimit];
+ }
+
int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
if (received < 0)
- {
break;
- }
- if (received < 12)
- {
- continue;
- }
- int fragment_length = TlsUtilities.ReadUint24(buf, 9);
- if (received != (fragment_length + 12))
- {
- continue;
- }
- int seq = TlsUtilities.ReadUint16(buf, 4);
- if (seq > (mNextReceiveSeq + MAX_RECEIVE_AHEAD))
- {
- continue;
- }
- byte msg_type = TlsUtilities.ReadUint8(buf, 0);
- int length = TlsUtilities.ReadUint24(buf, 1);
- int fragment_offset = TlsUtilities.ReadUint24(buf, 6);
- if (fragment_offset + fragment_length > length)
- {
- continue;
- }
- if (seq < mNextReceiveSeq)
- {
- /*
- * NOTE: If we Receive the previous flight of incoming messages in full
- * again, retransmit our last flight
- */
- if (mPreviousInboundFlight != null)
- {
- DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[seq];
- if (reassembler != null)
- {
- reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset,
- fragment_length);
-
- if (CheckAll(mPreviousInboundFlight))
- {
- ResendOutboundFlight();
-
- /*
- * TODO[DTLS] implementations SHOULD back off handshake packet
- * size during the retransmit backoff.
- */
- readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000);
-
- ResetAll(mPreviousInboundFlight);
- }
- }
- }
- }
- else
+ bool resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
+ if (resentOutbound)
{
- DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq];
- if (reassembler == null)
- {
- reassembler = new DtlsReassembler(msg_type, length);
- mCurrentInboundFlight[seq] = reassembler;
- }
-
- reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset, fragment_length);
-
- if (seq == mNextReceiveSeq)
- {
- byte[] body = reassembler.GetBodyIfComplete();
- if (body != null)
- {
- mPreviousInboundFlight = null;
- return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++,
- reassembler.MsgType, body));
- }
- }
+ readTimeoutMillis = BackOff(readTimeoutMillis);
}
}
}
- catch (IOException)
+ catch (IOException e)
{
// NOTE: Assume this is a timeout for the moment
}
ResendOutboundFlight();
-
- /*
- * TODO[DTLS] implementations SHOULD back off handshake packet size during the
- * retransmit backoff.
- */
- readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000);
+ readTimeoutMillis = BackOff(readTimeoutMillis);
}
}
@@ -216,15 +131,20 @@ namespace Org.BouncyCastle.Crypto.Tls
{
CheckInboundFlight();
}
- else if (mCurrentInboundFlight != null)
+ else
{
- /*
- * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
- * when in the FINISHED state, the node that transmits the last flight (the server in an
- * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
- * of the peer's last flight with a retransmit of the last flight.
- */
- retransmit = new Retransmit(this);
+ PrepareInboundFlight(null);
+
+ if (mPreviousInboundFlight != null)
+ {
+ /*
+ * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
+ * when in the FINISHED state, the node that transmits the last flight (the server in an
+ * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
+ * of the peer's last flight with a retransmit of the last flight.
+ */
+ retransmit = new Retransmit(this);
+ }
}
mRecordLayer.HandshakeSuccessful(retransmit);
@@ -235,44 +155,13 @@ namespace Org.BouncyCastle.Crypto.Tls
mHandshakeHash.Reset();
}
- private void HandleRetransmittedHandshakeRecord(int epoch, byte[] buf, int off, int len)
+ private int BackOff(int timeoutMillis)
{
/*
- * TODO Need to handle the case where the previous inbound flight contains
- * messages from two epochs.
+ * TODO[DTLS] implementations SHOULD back off handshake packet size during the
+ * retransmit backoff.
*/
- if (len < 12)
- return;
- int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
- if (len != (fragment_length + 12))
- return;
- int seq = TlsUtilities.ReadUint16(buf, off + 4);
- if (seq >= mNextReceiveSeq)
- return;
-
- byte msg_type = TlsUtilities.ReadUint8(buf, off);
-
- // TODO This is a hack that only works until we try to support renegotiation
- int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
- if (epoch != expectedEpoch)
- return;
-
- int length = TlsUtilities.ReadUint24(buf, off + 1);
- int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
- if (fragment_offset + fragment_length > length)
- return;
-
- DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq];
- if (reassembler != null)
- {
- reassembler.ContributeFragment(msg_type, length, buf, off + 12, fragment_offset,
- fragment_length);
- if (CheckAll(mCurrentInboundFlight))
- {
- ResendOutboundFlight();
- ResetAll(mCurrentInboundFlight);
- }
- }
+ return System.Math.Min(timeoutMillis * 2, 60000);
}
/**
@@ -289,11 +178,105 @@ namespace Org.BouncyCastle.Crypto.Tls
}
}
- private void PrepareInboundFlight()
+ private Message GetPendingMessage()
+ {
+ DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq];
+ if (next != null)
+ {
+ byte[] body = next.GetBodyIfComplete();
+ if (body != null)
+ {
+ mPreviousInboundFlight = null;
+ return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body));
+ }
+ }
+ return null;
+ }
+
+ private void PrepareInboundFlight(IDictionary nextFlight)
{
ResetAll(mCurrentInboundFlight);
mPreviousInboundFlight = mCurrentInboundFlight;
- mCurrentInboundFlight = Platform.CreateHashtable();
+ mCurrentInboundFlight = nextFlight;
+ }
+
+ private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
+ {
+ bool checkPreviousFlight = false;
+
+ while (len >= MessageHeaderLength)
+ {
+ int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
+ int message_length = fragment_length + MessageHeaderLength;
+ if (len < message_length)
+ {
+ // NOTE: Truncated message - ignore it
+ break;
+ }
+
+ int length = TlsUtilities.ReadUint24(buf, off + 1);
+ int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
+ if (fragment_offset + fragment_length > length)
+ {
+ // NOTE: Malformed fragment - ignore it and the rest of the record
+ break;
+ }
+
+ /*
+ * NOTE: This very simple epoch check will only work until we want to support
+ * renegotiation (and we're not likely to do that anyway).
+ */
+ byte msg_type = TlsUtilities.ReadUint8(buf, off + 0);
+ int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
+ if (epoch != expectedEpoch)
+ {
+ break;
+ }
+
+ int message_seq = TlsUtilities.ReadUint16(buf, off + 4);
+ if (message_seq >= (mNextReceiveSeq + windowSize))
+ {
+ // NOTE: Too far ahead - ignore
+ }
+ else if (message_seq >= mNextReceiveSeq)
+ {
+ DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[message_seq];
+ if (reassembler == null)
+ {
+ reassembler = new DtlsReassembler(msg_type, length);
+ mCurrentInboundFlight[message_seq] = reassembler;
+ }
+
+ reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
+ fragment_length);
+ }
+ else if (mPreviousInboundFlight != null)
+ {
+ /*
+ * NOTE: If we receive the previous flight of incoming messages in full again,
+ * retransmit our last flight
+ */
+
+ DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[message_seq];
+ if (reassembler != null)
+ {
+ reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
+ fragment_length);
+ checkPreviousFlight = true;
+ }
+ }
+
+ off += message_length;
+ len -= message_length;
+ }
+
+ bool result = checkPreviousFlight && CheckAll(mPreviousInboundFlight);
+ if (result)
+ {
+ ResendOutboundFlight();
+ ResetAll(mPreviousInboundFlight);
+ }
+ return result;
}
private void ResendOutboundFlight()
@@ -310,7 +293,7 @@ namespace Org.BouncyCastle.Crypto.Tls
if (message.Type != HandshakeType.hello_request)
{
byte[] body = message.Body;
- byte[] buf = new byte[12];
+ byte[] buf = new byte[MessageHeaderLength];
TlsUtilities.WriteUint8(message.Type, buf, 0);
TlsUtilities.WriteUint24(body.Length, buf, 1);
TlsUtilities.WriteUint16(message.Seq, buf, 4);
@@ -325,7 +308,7 @@ namespace Org.BouncyCastle.Crypto.Tls
private void WriteMessage(Message message)
{
int sendLimit = mRecordLayer.GetSendLimit();
- int fragmentLimit = sendLimit - 12;
+ int fragmentLimit = sendLimit - MessageHeaderLength;
// TODO Support a higher minimum fragment size?
if (fragmentLimit < 1)
@@ -349,7 +332,7 @@ namespace Org.BouncyCastle.Crypto.Tls
private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
{
- RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length);
+ RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length);
TlsUtilities.WriteUint8(message.Type, fragment);
TlsUtilities.WriteUint24(message.Body.Length, fragment);
TlsUtilities.WriteUint16(message.Seq, fragment);
@@ -444,7 +427,7 @@ namespace Org.BouncyCastle.Crypto.Tls
public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len)
{
- mOuter.HandleRetransmittedHandshakeRecord(epoch, buf, off, len);
+ mOuter.ProcessRecord(0, epoch, buf, off, len);
}
}
}
diff --git a/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs b/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs
index a1ba62dde..f191ef005 100644
--- a/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs
+++ b/crypto/test/src/crypto/tls/test/DtlsTestSuite.cs
@@ -215,5 +215,14 @@ namespace Org.BouncyCastle.Crypto.Tls.Tests
c.serverMinimumVersion = ProtocolVersion.DTLSv10;
return c;
}
+
+ public static void RunTests()
+ {
+ foreach (TestCaseData data in Suite())
+ {
+ Console.WriteLine(data.TestName);
+ new DtlsTestCase().RunTest((TlsTestConfig)data.Arguments[0]);
+ }
+ }
}
}
diff --git a/crypto/test/src/crypto/tls/test/TlsTestSuite.cs b/crypto/test/src/crypto/tls/test/TlsTestSuite.cs
index 5dd9cf0f5..849e738af 100644
--- a/crypto/test/src/crypto/tls/test/TlsTestSuite.cs
+++ b/crypto/test/src/crypto/tls/test/TlsTestSuite.cs
@@ -201,5 +201,14 @@ namespace Org.BouncyCastle.Crypto.Tls.Tests
c.serverMinimumVersion = ProtocolVersion.SSLv3;
return c;
}
+
+ public static void RunTests()
+ {
+ foreach (TestCaseData data in Suite())
+ {
+ Console.WriteLine(data.TestName);
+ new TlsTestCase().RunTest((TlsTestConfig)data.Arguments[0]);
+ }
+ }
}
}
|