diff options
Diffstat (limited to 'crypto/src/tls/DtlsReliableHandshake.cs')
-rw-r--r-- | crypto/src/tls/DtlsReliableHandshake.cs | 50 |
1 files changed, 22 insertions, 28 deletions
diff --git a/crypto/src/tls/DtlsReliableHandshake.cs b/crypto/src/tls/DtlsReliableHandshake.cs index 42a98a991..b1107f7a1 100644 --- a/crypto/src/tls/DtlsReliableHandshake.cs +++ b/crypto/src/tls/DtlsReliableHandshake.cs @@ -8,47 +8,41 @@ namespace Org.BouncyCastle.Tls { internal class DtlsReliableHandshake { - private const int MAX_RECEIVE_AHEAD = 16; - private const int MESSAGE_HEADER_LENGTH = 12; + internal const int MessageHeaderLength = 12; + private const int MAX_RECEIVE_AHEAD = 16; private const int MAX_RESEND_MILLIS = 60000; /// <exception cref="IOException"/> - internal static DtlsRequest ReadClientRequest(byte[] data, int dataOff, int dataLen, Stream dtlsOutput) + internal static MemoryStream ReceiveClientHelloMessage(byte[] msg, int msgOff, int msgLen) { // TODO Support the possibility of a fragmented ClientHello datagram - byte[] message = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen); - if (null == message || message.Length < MESSAGE_HEADER_LENGTH) + if (msgLen < MessageHeaderLength) return null; - long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5); - - short msgType = TlsUtilities.ReadUint8(message, 0); + short msgType = TlsUtilities.ReadUint8(msg, msgOff); if (HandshakeType.client_hello != msgType) return null; - int length = TlsUtilities.ReadUint24(message, 1); - if (message.Length != MESSAGE_HEADER_LENGTH + length) + int length = TlsUtilities.ReadUint24(msg, msgOff + 1); + if (msgLen != MessageHeaderLength + length) return null; // TODO Consider stricter HelloVerifyRequest-related checks - //int messageSeq = TlsUtilities.ReadUint16(message, 4); + //int messageSeq = TlsUtilities.ReadUint16(msg, msgOff + 4); //if (messageSeq > 1) // return null; - int fragmentOffset = TlsUtilities.ReadUint24(message, 6); + int fragmentOffset = TlsUtilities.ReadUint24(msg, msgOff + 6); if (0 != fragmentOffset) return null; - int fragmentLength = TlsUtilities.ReadUint24(message, 9); + int fragmentLength = TlsUtilities.ReadUint24(msg, msgOff + 9); if (length != fragmentLength) return null; - ClientHello clientHello = ClientHello.Parse( - new MemoryStream(message, MESSAGE_HEADER_LENGTH, length, false), dtlsOutput); - - return new DtlsRequest(recordSeq, message, clientHello); + return new MemoryStream(msg, msgOff + MessageHeaderLength, length, false); } /// <exception cref="IOException"/> @@ -58,7 +52,7 @@ namespace Org.BouncyCastle.Tls int length = 3 + cookie.Length; - byte[] message = new byte[MESSAGE_HEADER_LENGTH + length]; + byte[] message = new byte[MessageHeaderLength + length]; TlsUtilities.WriteUint8(HandshakeType.hello_verify_request, message, 0); TlsUtilities.WriteUint24(length, message, 1); //TlsUtilities.WriteUint16(0, message, 4); @@ -66,8 +60,8 @@ namespace Org.BouncyCastle.Tls TlsUtilities.WriteUint24(length, message, 9); // HelloVerifyRequest fields - TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MESSAGE_HEADER_LENGTH + 0); - TlsUtilities.WriteOpaque8(cookie, message, MESSAGE_HEADER_LENGTH + 2); + TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MessageHeaderLength + 0); + TlsUtilities.WriteOpaque8(cookie, message, MessageHeaderLength + 2); DtlsRecordLayer.SendHelloVerifyRequestRecord(sender, recordSeq, message); } @@ -111,7 +105,7 @@ namespace Org.BouncyCastle.Tls // Simulate a previous flight consisting of the request ClientHello DtlsReassembler reassembler = new DtlsReassembler(HandshakeType.client_hello, - message.Length - MESSAGE_HEADER_LENGTH); + message.Length - MessageHeaderLength); m_currentInboundFlight[messageSeq] = reassembler; // We sent HelloVerifyRequest with (message) sequence number 0 @@ -215,7 +209,7 @@ namespace Org.BouncyCastle.Tls default: { byte[] body = message.Body; - byte[] buf = new byte[MESSAGE_HEADER_LENGTH]; + byte[] buf = new byte[MessageHeaderLength]; TlsUtilities.WriteUint8(msg_type, buf, 0); TlsUtilities.WriteUint24(body.Length, buf, 1); TlsUtilities.WriteUint16(message.Seq, buf, 4); @@ -360,10 +354,10 @@ namespace Org.BouncyCastle.Tls { bool checkPreviousFlight = false; - while (len >= MESSAGE_HEADER_LENGTH) + while (len >= MessageHeaderLength) { int fragment_length = TlsUtilities.ReadUint24(buf, off + 9); - int message_length = fragment_length + MESSAGE_HEADER_LENGTH; + int message_length = fragment_length + MessageHeaderLength; if (len < message_length) { // NOTE: Truncated message - ignore it @@ -400,7 +394,7 @@ namespace Org.BouncyCastle.Tls m_currentInboundFlight[message_seq] = reassembler; } - reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, fragment_offset, + reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset, fragment_length); } else if (m_previousInboundFlight != null) @@ -412,7 +406,7 @@ namespace Org.BouncyCastle.Tls if (m_previousInboundFlight.TryGetValue(message_seq, out var reassembler)) { - reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, + reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset, fragment_length); checkPreviousFlight = true; } @@ -446,7 +440,7 @@ namespace Org.BouncyCastle.Tls private void WriteMessage(Message message) { int sendLimit = m_recordLayer.GetSendLimit(); - int fragmentLimit = sendLimit - MESSAGE_HEADER_LENGTH; + int fragmentLimit = sendLimit - MessageHeaderLength; // TODO Support a higher minimum fragment size? if (fragmentLimit < 1) @@ -471,7 +465,7 @@ namespace Org.BouncyCastle.Tls /// <exception cref="IOException"/> private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length) { - RecordLayerBuffer fragment = new RecordLayerBuffer(MESSAGE_HEADER_LENGTH + fragment_length); + RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length); TlsUtilities.WriteUint8(message.Type, fragment); TlsUtilities.WriteUint24(message.Body.Length, fragment); TlsUtilities.WriteUint16(message.Seq, fragment); |