diff options
Diffstat (limited to 'crypto/src/tls/DtlsReliableHandshake.cs')
-rw-r--r-- | crypto/src/tls/DtlsReliableHandshake.cs | 185 |
1 files changed, 101 insertions, 84 deletions
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs index e27d72762..58b9301fd 100644 --- a/crypto/src/tls/DtlsReliableHandshake.cs +++ b/crypto/src/tls/DtlsReliableHandshake.cs @@ -142,11 +142,9 @@ namespace Org.BouncyCastle.Tls get { return m_handshakeHash; } } - internal TlsHandshakeHash PrepareToFinish() + internal void PrepareToFinish() { - TlsHandshakeHash result = m_handshakeHash; - this.m_handshakeHash = m_handshakeHash.StopTracking(); - return result; + m_handshakeHash.StopTracking(); } /// <exception cref="IOException"/> @@ -173,69 +171,63 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> + internal Message ReceiveMessage() + { + Message message = ImplReceiveMessage(); + UpdateHandshakeMessagesDigest(message); + return message; + } + + /// <exception cref="IOException"/> internal byte[] ReceiveMessageBody(short msg_type) { - Message message = ReceiveMessage(); + Message message = ImplReceiveMessage(); if (message.Type != msg_type) throw new TlsFatalAlert(AlertDescription.unexpected_message); + UpdateHandshakeMessagesDigest(message); return message.Body; } /// <exception cref="IOException"/> - internal Message ReceiveMessage() + internal Message ReceiveMessageDelayedDigest(short msg_type) { - long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); - - if (null == m_resendTimeout) - { - m_resendMillis = INITIAL_RESEND_MILLIS; - m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis); - - PrepareInboundFlight(Platform.CreateHashtable()); - } + Message message = ImplReceiveMessage(); + if (message.Type != msg_type) + throw new TlsFatalAlert(AlertDescription.unexpected_message); - byte[] buf = null; + return message; + } - for (;;) + /// <exception cref="IOException"/> + internal Message UpdateHandshakeMessagesDigest(Message message) + { + short msg_type = message.Type; + switch (msg_type) { - if (m_recordLayer.IsClosed) - throw new TlsFatalAlert(AlertDescription.user_canceled); - - Message pending = GetPendingMessage(); - if (pending != null) - return pending; - - if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis)) - throw new TlsTimeoutException("Handshake timed out"); - - int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis); - waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis); - - // NOTE: Ensure a finite wait, of at least 1ms - if (waitMillis < 1) - { - waitMillis = 1; - } - - int receiveLimit = m_recordLayer.GetReceiveLimit(); - if (buf == null || buf.Length < receiveLimit) - { - buf = new byte[receiveLimit]; - } - - int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis); - if (received < 0) - { - ResendOutboundFlight(); - } - else - { - ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received); - } + case HandshakeType.hello_request: + case HandshakeType.hello_verify_request: + case HandshakeType.key_update: + break; - currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+ + case HandshakeType.new_session_ticket: + default: + { + byte[] body = message.Body; + byte[] buf = new byte[MESSAGE_HEADER_LENGTH]; + TlsUtilities.WriteUint8(msg_type, buf, 0); + TlsUtilities.WriteUint24(body.Length, buf, 1); + TlsUtilities.WriteUint16(message.Seq, buf, 4); + TlsUtilities.WriteUint24(0, buf, 6); + TlsUtilities.WriteUint24(body.Length, buf, 9); + m_handshakeHash.Update(buf, 0, buf.Length); + m_handshakeHash.Update(body, 0, body.Length); + break; + } } + + return message; } internal void Finish() @@ -297,12 +289,68 @@ namespace Org.BouncyCastle.Tls if (body != null) { m_previousInboundFlight = null; - return UpdateHandshakeMessagesDigest(new Message(m_next_receive_seq++, next.MsgType, body)); + return new Message(m_next_receive_seq++, next.MsgType, body); } } return null; } + /// <exception cref="IOException"/> + private Message ImplReceiveMessage() + { + long currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + + if (null == m_resendTimeout) + { + m_resendMillis = INITIAL_RESEND_MILLIS; + m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis); + + PrepareInboundFlight(Platform.CreateHashtable()); + } + + byte[] buf = null; + + for (; ; ) + { + if (m_recordLayer.IsClosed) + throw new TlsFatalAlert(AlertDescription.user_canceled); + + Message pending = GetPendingMessage(); + if (pending != null) + return pending; + + if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis)) + throw new TlsTimeoutException("Handshake timed out"); + + int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis); + waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis); + + // NOTE: Ensure a finite wait, of at least 1ms + if (waitMillis < 1) + { + waitMillis = 1; + } + + int receiveLimit = m_recordLayer.GetReceiveLimit(); + if (buf == null || buf.Length < receiveLimit) + { + buf = new byte[receiveLimit]; + } + + int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis); + if (received < 0) + { + ResendOutboundFlight(); + } + else + { + ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received); + } + + currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + } + } + private void PrepareInboundFlight(IDictionary nextFlight) { ResetAll(m_currentInboundFlight); @@ -400,37 +448,6 @@ namespace Org.BouncyCastle.Tls } /// <exception cref="IOException"/> - private Message UpdateHandshakeMessagesDigest(Message message) - { - short msg_type = message.Type; - switch (msg_type) - { - case HandshakeType.hello_request: - case HandshakeType.hello_verify_request: - case HandshakeType.key_update: - break; - - // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+ - case HandshakeType.new_session_ticket: - default: - { - byte[] body = message.Body; - byte[] buf = new byte[MESSAGE_HEADER_LENGTH]; - TlsUtilities.WriteUint8(msg_type, buf, 0); - TlsUtilities.WriteUint24(body.Length, buf, 1); - TlsUtilities.WriteUint16(message.Seq, buf, 4); - TlsUtilities.WriteUint24(0, buf, 6); - TlsUtilities.WriteUint24(body.Length, buf, 9); - m_handshakeHash.Update(buf, 0, buf.Length); - m_handshakeHash.Update(body, 0, body.Length); - break; - } - } - - return message; - } - - /// <exception cref="IOException"/> private void WriteMessage(Message message) { int sendLimit = m_recordLayer.GetSendLimit(); |