diff options
Diffstat (limited to 'crypto/src/crypto/tls/DtlsReliableHandshake.cs')
-rw-r--r-- | crypto/src/crypto/tls/DtlsReliableHandshake.cs | 443 |
1 files changed, 443 insertions, 0 deletions
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs new file mode 100644 index 000000000..8e4439e67 --- /dev/null +++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs @@ -0,0 +1,443 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Crypto.Tls +{ + internal class DtlsReliableHandshake + { + private const int MAX_RECEIVE_AHEAD = 10; + + private readonly DtlsRecordLayer mRecordLayer; + + private TlsHandshakeHash mHandshakeHash; + + private IDictionary mCurrentInboundFlight = Platform.CreateHashtable(); + private IDictionary mPreviousInboundFlight = null; + private IList mOutboundFlight = Platform.CreateArrayList(); + private bool mSending = true; + + private int mMessageSeq = 0, mNextReceiveSeq = 0; + + internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport) + { + this.mRecordLayer = transport; + this.mHandshakeHash = new DeferredHash(); + this.mHandshakeHash.Init(context); + } + + internal void NotifyHelloComplete() + { + this.mHandshakeHash = mHandshakeHash.NotifyPrfDetermined(); + } + + internal TlsHandshakeHash HandshakeHash + { + get { return mHandshakeHash; } + } + + internal TlsHandshakeHash PrepareToFinish() + { + TlsHandshakeHash result = mHandshakeHash; + this.mHandshakeHash = mHandshakeHash.StopTracking(); + return result; + } + + internal void SendMessage(byte msg_type, byte[] body) + { + TlsUtilities.CheckUint24(body.Length); + + if (!mSending) + { + CheckInboundFlight(); + mSending = true; + mOutboundFlight.Clear(); + } + + Message message = new Message(mMessageSeq++, msg_type, body); + + mOutboundFlight.Add(message); + + WriteMessage(message); + UpdateHandshakeMessagesDigest(message); + } + + internal byte[] ReceiveMessageBody(byte msg_type) + { + Message message = ReceiveMessage(); + if (message.Type != msg_type) + throw new TlsFatalAlert(AlertDescription.unexpected_message); + + return message.Body; + } + + internal Message ReceiveMessage() + { + if (mSending) + { + mSending = false; + PrepareInboundFlight(); + } + + // Check if we already have the next message waiting + { + DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq]; + if (next != null) + { + byte[] body = next.GetBodyIfComplete(); + if (body != null) + { + mPreviousInboundFlight = null; + return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body)); + } + } + } + + byte[] buf = null; + + // TODO Check the conditions under which we should reset this + int readTimeoutMillis = 1000; + + for (;;) + { + int receiveLimit = mRecordLayer.GetReceiveLimit(); + if (buf == null || buf.Length < receiveLimit) + { + buf = new byte[receiveLimit]; + } + + // TODO Handle records containing multiple handshake messages + + try + { + for (; ; ) + { + int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis); + if (received < 0) + { + break; + } + if (received < 12) + { + continue; + } + int fragment_length = TlsUtilities.ReadUint24(buf, 9); + if (received != (fragment_length + 12)) + { + continue; + } + int seq = TlsUtilities.ReadUint16(buf, 4); + if (seq > (mNextReceiveSeq + MAX_RECEIVE_AHEAD)) + { + continue; + } + byte msg_type = TlsUtilities.ReadUint8(buf, 0); + int length = TlsUtilities.ReadUint24(buf, 1); + int fragment_offset = TlsUtilities.ReadUint24(buf, 6); + if (fragment_offset + fragment_length > length) + { + continue; + } + + if (seq < mNextReceiveSeq) + { + /* + * NOTE: If we Receive the previous flight of incoming messages in full + * again, retransmit our last flight + */ + if (mPreviousInboundFlight != null) + { + DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[seq]; + if (reassembler != null) + { + reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset, + fragment_length); + + if (CheckAll(mPreviousInboundFlight)) + { + ResendOutboundFlight(); + + /* + * TODO[DTLS] implementations SHOULD back off handshake packet + * size during the retransmit backoff. + */ + readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000); + + ResetAll(mPreviousInboundFlight); + } + } + } + } + else + { + DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq]; + if (reassembler == null) + { + reassembler = new DtlsReassembler(msg_type, length); + mCurrentInboundFlight[seq] = reassembler; + } + + reassembler.ContributeFragment(msg_type, length, buf, 12, fragment_offset, fragment_length); + + if (seq == mNextReceiveSeq) + { + byte[] body = reassembler.GetBodyIfComplete(); + if (body != null) + { + mPreviousInboundFlight = null; + return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, + reassembler.MsgType, body)); + } + } + } + } + } + catch (IOException) + { + // NOTE: Assume this is a timeout for the moment + } + + ResendOutboundFlight(); + + /* + * TODO[DTLS] implementations SHOULD back off handshake packet size during the + * retransmit backoff. + */ + readTimeoutMillis = System.Math.Min(readTimeoutMillis * 2, 60000); + } + } + + internal void Finish() + { + DtlsHandshakeRetransmit retransmit = null; + if (!mSending) + { + CheckInboundFlight(); + } + else if (mCurrentInboundFlight != null) + { + /* + * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP], + * when in the FINISHED state, the node that transmits the last flight (the server in an + * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit + * of the peer's last flight with a retransmit of the last flight. + */ + retransmit = new Retransmit(this); + } + + mRecordLayer.HandshakeSuccessful(retransmit); + } + + internal void ResetHandshakeMessagesDigest() + { + mHandshakeHash.Reset(); + } + + private void HandleRetransmittedHandshakeRecord(int epoch, byte[] buf, int off, int len) + { + /* + * TODO Need to handle the case where the previous inbound flight contains + * messages from two epochs. + */ + if (len < 12) + return; + int fragment_length = TlsUtilities.ReadUint24(buf, off + 9); + if (len != (fragment_length + 12)) + return; + int seq = TlsUtilities.ReadUint16(buf, off + 4); + if (seq >= mNextReceiveSeq) + return; + + byte msg_type = TlsUtilities.ReadUint8(buf, off); + + // TODO This is a hack that only works until we try to support renegotiation + int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0; + if (epoch != expectedEpoch) + return; + + int length = TlsUtilities.ReadUint24(buf, off + 1); + int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6); + if (fragment_offset + fragment_length > length) + return; + + DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[seq]; + if (reassembler != null) + { + reassembler.ContributeFragment(msg_type, length, buf, off + 12, fragment_offset, + fragment_length); + if (CheckAll(mCurrentInboundFlight)) + { + ResendOutboundFlight(); + ResetAll(mCurrentInboundFlight); + } + } + } + + /** + * Check that there are no "extra" messages left in the current inbound flight + */ + private void CheckInboundFlight() + { + foreach (int key in mCurrentInboundFlight.Keys) + { + if (key >= mNextReceiveSeq) + { + // TODO Should this be considered an error? + } + } + } + + private void PrepareInboundFlight() + { + ResetAll(mCurrentInboundFlight); + mPreviousInboundFlight = mCurrentInboundFlight; + mCurrentInboundFlight = Platform.CreateHashtable(); + } + + private void ResendOutboundFlight() + { + mRecordLayer.ResetWriteEpoch(); + for (int i = 0; i < mOutboundFlight.Count; ++i) + { + WriteMessage((Message)mOutboundFlight[i]); + } + } + + private Message UpdateHandshakeMessagesDigest(Message message) + { + if (message.Type != HandshakeType.hello_request) + { + byte[] body = message.Body; + byte[] buf = new byte[12]; + TlsUtilities.WriteUint8(message.Type, buf, 0); + TlsUtilities.WriteUint24(body.Length, buf, 1); + TlsUtilities.WriteUint16(message.Seq, buf, 4); + TlsUtilities.WriteUint24(0, buf, 6); + TlsUtilities.WriteUint24(body.Length, buf, 9); + mHandshakeHash.BlockUpdate(buf, 0, buf.Length); + mHandshakeHash.BlockUpdate(body, 0, body.Length); + } + return message; + } + + private void WriteMessage(Message message) + { + int sendLimit = mRecordLayer.GetSendLimit(); + int fragmentLimit = sendLimit - 12; + + // TODO Support a higher minimum fragment size? + if (fragmentLimit < 1) + { + // TODO Should we be throwing an exception here? + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + int length = message.Body.Length; + + // NOTE: Must still send a fragment if body is empty + int fragment_offset = 0; + do + { + int fragment_length = System.Math.Min(length - fragment_offset, fragmentLimit); + WriteHandshakeFragment(message, fragment_offset, fragment_length); + fragment_offset += fragment_length; + } + while (fragment_offset < length); + } + + private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length) + { + RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length); + TlsUtilities.WriteUint8(message.Type, fragment); + TlsUtilities.WriteUint24(message.Body.Length, fragment); + TlsUtilities.WriteUint16(message.Seq, fragment); + TlsUtilities.WriteUint24(fragment_offset, fragment); + TlsUtilities.WriteUint24(fragment_length, fragment); + fragment.Write(message.Body, fragment_offset, fragment_length); + + fragment.SendToRecordLayer(mRecordLayer); + } + + private static bool CheckAll(IDictionary inboundFlight) + { + foreach (DtlsReassembler r in inboundFlight.Values) + { + if (r.GetBodyIfComplete() == null) + { + return false; + } + } + return true; + } + + private static void ResetAll(IDictionary inboundFlight) + { + foreach (DtlsReassembler r in inboundFlight.Values) + { + r.Reset(); + } + } + + internal class Message + { + private readonly int mMessageSeq; + private readonly byte mMsgType; + private readonly byte[] mBody; + + internal Message(int message_seq, byte msg_type, byte[] body) + { + this.mMessageSeq = message_seq; + this.mMsgType = msg_type; + this.mBody = body; + } + + public int Seq + { + get { return mMessageSeq; } + } + + public byte Type + { + get { return mMsgType; } + } + + public byte[] Body + { + get { return mBody; } + } + } + + internal class RecordLayerBuffer + : MemoryStream + { + internal RecordLayerBuffer(int size) + : base(size) + { + } + + internal void SendToRecordLayer(DtlsRecordLayer recordLayer) + { + recordLayer.Send(GetBuffer(), 0, (int)Length); + this.Close(); + } + } + + internal class Retransmit + : DtlsHandshakeRetransmit + { + private readonly DtlsReliableHandshake mOuter; + + internal Retransmit(DtlsReliableHandshake outer) + { + this.mOuter = outer; + } + + public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len) + { + mOuter.HandleRetransmittedHandshakeRecord(epoch, buf, off, len); + } + } + } +} |